langfun 0.0.2.dev20240605__py3-none-any.whl → 0.0.2.dev20240613__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

langfun/core/__init__.py CHANGED
@@ -101,6 +101,7 @@ from langfun.core.language_model import LanguageModel
101
101
  from langfun.core.language_model import LMSample
102
102
  from langfun.core.language_model import LMSamplingOptions
103
103
  from langfun.core.language_model import LMSamplingUsage
104
+ from langfun.core.language_model import UsageNotAvailable
104
105
  from langfun.core.language_model import LMSamplingResult
105
106
  from langfun.core.language_model import LMScoringResult
106
107
  from langfun.core.language_model import LMCache
@@ -111,12 +112,21 @@ from langfun.core.language_model import RetryableLMError
111
112
  from langfun.core.language_model import RateLimitError
112
113
  from langfun.core.language_model import TemporaryLMError
113
114
 
115
+ # Context manager for tracking usages.
116
+ from langfun.core.language_model import track_usages
117
+
114
118
  # Components for building agents.
115
119
  from langfun.core.memory import Memory
116
120
 
117
121
  # Utility for console output.
118
122
  from langfun.core import console
119
123
 
124
+ # Helpers for implementing _repr_xxx_ methods.
125
+ from langfun.core import repr_utils
126
+
127
+ # Utility for event logging.
128
+ from langfun.core import logging
129
+
120
130
  # Import internal modules.
121
131
 
122
132
  # pylint: enable=g-import-not-at-top
langfun/core/component.py CHANGED
@@ -210,6 +210,16 @@ def get_contextual_override(var_name: str) -> ContextualOverride | None:
210
210
  return _get_scoped_value(_global_tls, _CONTEXT_OVERRIDES, var_name)
211
211
 
212
212
 
213
+ def context_value(var_name: str, default: Any = RAISE_IF_HAS_ERROR) -> Any:
214
+ """Returns the value of a variable defined in `lf.context`."""
215
+ override = get_contextual_override(var_name)
216
+ if override is None:
217
+ if default == RAISE_IF_HAS_ERROR:
218
+ raise KeyError(f'{var_name!r} does not exist in current context.')
219
+ return default
220
+ return override.value
221
+
222
+
213
223
  def all_contextual_values() -> dict[str, Any]:
214
224
  """Returns all contextual values provided from `lf.context` in scope."""
215
225
  overrides = getattr(_global_tls, _CONTEXT_OVERRIDES, {})
@@ -84,6 +84,11 @@ class ComponentContextTest(unittest.TestCase):
84
84
  lf.get_contextual_override('y'),
85
85
  lf.ContextualOverride(3, cascade=False, override_attrs=False),
86
86
  )
87
+ self.assertEqual(lf.context_value('x'), 3)
88
+ self.assertIsNone(lf.context_value('f', None))
89
+ with self.assertRaisesRegex(KeyError, '.* does not exist'):
90
+ lf.context_value('f')
91
+
87
92
  self.assertEqual(lf.all_contextual_values(), dict(x=3, y=3, z=3))
88
93
 
89
94
  # Member attributes take precedence over `lf.context`.
langfun/core/eval/base.py CHANGED
@@ -549,7 +549,7 @@ class Evaluable(lf.Component):
549
549
  )
550
550
  s.write(html.escape(pg.format(m.result)))
551
551
  s.write('</div>')
552
- if 'usage' in m.metadata and m.usage is not None:
552
+ if m.metadata.get('usage', None):
553
553
  s.write(
554
554
  '<div style="background-color: #EEEEEE; color: black; '
555
555
  'white-space: pre-wrap; padding: 10px; border: 0px solid; '
@@ -1321,7 +1321,7 @@ class Evaluation(Evaluable):
1321
1321
  self._render_summary_metrics(s)
1322
1322
 
1323
1323
  # Summarize average usage.
1324
- if self.result.usage is not None:
1324
+ if self.result.usage:
1325
1325
  self._render_summary_usage(s)
1326
1326
 
1327
1327
  s.write('</td></tr></table></div>')
@@ -1441,9 +1441,10 @@ class Evaluation(Evaluable):
1441
1441
  def audit_usage(self, message: lf.Message, dryrun: bool = False) -> None:
1442
1442
  del dryrun
1443
1443
  for m in message.trace():
1444
- if m.metadata.get('usage', None) is not None:
1445
- self._total_prompt_tokens += m.usage.prompt_tokens
1446
- self._total_completion_tokens += m.usage.completion_tokens
1444
+ usage = m.metadata.get('usage', None)
1445
+ if usage:
1446
+ self._total_prompt_tokens += usage.prompt_tokens
1447
+ self._total_completion_tokens += usage.completion_tokens
1447
1448
  self._num_usages += 1
1448
1449
 
1449
1450
  def audit_processed(
@@ -1504,7 +1505,7 @@ class Evaluation(Evaluable):
1504
1505
  '<td>Schema</td>'
1505
1506
  '<td>Additional Args</td>'
1506
1507
  )
1507
- if self.result.usage is not None:
1508
+ if self.result.usage:
1508
1509
  s.write('<td>Usage</td>')
1509
1510
  s.write('<td>OOP Failures</td>')
1510
1511
  s.write('<td>Non-OOP Failures</td>')
@@ -1533,7 +1534,7 @@ class Evaluation(Evaluable):
1533
1534
  f'{_html_repr(self.additional_args, compact=False)}</td>'
1534
1535
  )
1535
1536
  # Usage.
1536
- if self.result.usage is not None:
1537
+ if self.result.usage:
1537
1538
  s.write('<td>')
1538
1539
  self._render_summary_usage(s)
1539
1540
  s.write('</td>')
@@ -87,7 +87,11 @@ class LangFuncCallTest(unittest.TestCase):
87
87
 
88
88
  r = l()
89
89
  self.assertEqual(
90
- r, message.AIMessage('Hello!!!', score=0.0, logprobs=None, usage=None)
90
+ r,
91
+ message.AIMessage(
92
+ 'Hello!!!', score=0.0, logprobs=None,
93
+ usage=language_model.UsageNotAvailable()
94
+ )
91
95
  )
92
96
  self.assertEqual(r.tags, ['lm-response', 'lm-output'])
93
97
  self.assertEqual(r.source, message.UserMessage('Hello'))
@@ -112,7 +116,11 @@ class LangFuncCallTest(unittest.TestCase):
112
116
  self.assertEqual(l.render(), 'Hello')
113
117
  r = l()
114
118
  self.assertEqual(
115
- r, message.AIMessage('Hello!!!', score=0.0, logprobs=None, usage=None)
119
+ r,
120
+ message.AIMessage(
121
+ 'Hello!!!', score=0.0, logprobs=None,
122
+ usage=language_model.UsageNotAvailable()
123
+ )
116
124
  )
117
125
  self.assertEqual(r.tags, ['lm-response', 'lm-output'])
118
126
 
@@ -14,10 +14,11 @@
14
14
  """Interface for language model."""
15
15
 
16
16
  import abc
17
+ import contextlib
17
18
  import dataclasses
18
19
  import enum
19
20
  import time
20
- from typing import Annotated, Any, Callable, Sequence, Tuple, Type, Union
21
+ from typing import Annotated, Any, Callable, Iterator, Sequence, Tuple, Type, Union
21
22
  from langfun.core import component
22
23
  from langfun.core import concurrent
23
24
  from langfun.core import console
@@ -84,6 +85,23 @@ class LMSamplingUsage(pg.Object):
84
85
  completion_tokens: int
85
86
  total_tokens: int
86
87
 
88
+ def __add__(self, other: 'LMSamplingUsage') -> 'LMSamplingUsage':
89
+ return LMSamplingUsage(
90
+ prompt_tokens=self.prompt_tokens + other.prompt_tokens,
91
+ completion_tokens=self.completion_tokens + other.completion_tokens,
92
+ total_tokens=self.total_tokens + other.total_tokens,
93
+ )
94
+
95
+
96
+ class UsageNotAvailable(LMSamplingUsage):
97
+ """Usage information not available."""
98
+ prompt_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
99
+ completion_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
100
+ total_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
101
+
102
+ def __bool__(self) -> bool:
103
+ return False
104
+
87
105
 
88
106
  class LMSamplingResult(pg.Object):
89
107
  """Language model response."""
@@ -97,9 +115,9 @@ class LMSamplingResult(pg.Object):
97
115
  ] = []
98
116
 
99
117
  usage: Annotated[
100
- LMSamplingUsage | None,
118
+ LMSamplingUsage,
101
119
  'Usage information. Currently only OpenAI models are supported.',
102
- ] = None
120
+ ] = UsageNotAvailable()
103
121
 
104
122
 
105
123
  class LMSamplingOptions(component.Component):
@@ -398,7 +416,7 @@ class LanguageModel(component.Component):
398
416
  # which is accurate when n=1. For n > 1, we average the usage across
399
417
  # multiple samples.
400
418
  usage = result.usage
401
- if len(result.samples) == 1 or usage is None:
419
+ if len(result.samples) == 1 or not usage:
402
420
  response.metadata.usage = usage
403
421
  else:
404
422
  n = len(result.samples)
@@ -408,6 +426,13 @@ class LanguageModel(component.Component):
408
426
  total_tokens=usage.total_tokens // n,
409
427
  )
410
428
 
429
+ # Track usage.
430
+ trackers = component.context_value('__usage_trackers__', [])
431
+ if trackers:
432
+ model_id = self.model_id
433
+ for tracker in trackers:
434
+ tracker.track(model_id, usage)
435
+
411
436
  # Track the prompt for corresponding response.
412
437
  response.source = prompt
413
438
 
@@ -511,7 +536,7 @@ class LanguageModel(component.Component):
511
536
  prompt: message_lib.Message,
512
537
  response: message_lib.Message,
513
538
  call_counter: int,
514
- usage: LMSamplingUsage | None,
539
+ usage: LMSamplingUsage,
515
540
  elapse: float,
516
541
  ) -> None:
517
542
  """Outputs debugging information."""
@@ -529,12 +554,13 @@ class LanguageModel(component.Component):
529
554
  self._debug_response(response, call_counter, usage, elapse)
530
555
 
531
556
  def _debug_model_info(
532
- self, call_counter: int, usage: LMSamplingUsage | None) -> None:
557
+ self, call_counter: int, usage: LMSamplingUsage) -> None:
533
558
  """Outputs debugging information about the model."""
534
559
  title_suffix = ''
535
- if usage and usage.total_tokens != 0:
560
+ if usage.total_tokens != 0:
536
561
  title_suffix = console.colored(
537
- f' (total {usage.total_tokens} tokens)', 'red')
562
+ f' (total {usage.total_tokens} tokens)', 'red'
563
+ )
538
564
 
539
565
  console.write(
540
566
  self.format(compact=True, use_inferred=True),
@@ -546,11 +572,11 @@ class LanguageModel(component.Component):
546
572
  self,
547
573
  prompt: message_lib.Message,
548
574
  call_counter: int,
549
- usage: LMSamplingUsage | None,
575
+ usage: LMSamplingUsage,
550
576
  ) -> None:
551
577
  """Outputs debugging information about the prompt."""
552
578
  title_suffix = ''
553
- if usage and usage.prompt_tokens != 0:
579
+ if usage.prompt_tokens != 0:
554
580
  title_suffix = console.colored(f' ({usage.prompt_tokens} tokens)', 'red')
555
581
 
556
582
  console.write(
@@ -574,12 +600,12 @@ class LanguageModel(component.Component):
574
600
  self,
575
601
  response: message_lib.Message,
576
602
  call_counter: int,
577
- usage: LMSamplingUsage | None,
603
+ usage: LMSamplingUsage,
578
604
  elapse: float
579
605
  ) -> None:
580
606
  """Outputs debugging information about the response."""
581
607
  title_suffix = ' ('
582
- if usage and usage.completion_tokens != 0:
608
+ if usage.completion_tokens != 0:
583
609
  title_suffix += f'{usage.completion_tokens} tokens '
584
610
  title_suffix += f'in {elapse:.2f} seconds)'
585
611
  title_suffix = console.colored(title_suffix, 'red')
@@ -641,7 +667,7 @@ class LanguageModel(component.Component):
641
667
  debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
642
668
 
643
669
  if debug & LMDebugMode.INFO:
644
- self._debug_model_info(call_counter, None)
670
+ self._debug_model_info(call_counter, UsageNotAvailable())
645
671
 
646
672
  if debug & LMDebugMode.PROMPT:
647
673
  console.write(
@@ -692,3 +718,60 @@ class LanguageModel(component.Component):
692
718
  return max(int(requests_per_min / 60), 1) # Max concurrency can't be zero
693
719
  else:
694
720
  return DEFAULT_MAX_CONCURRENCY # Default of 1
721
+
722
+
723
+ class _UsageTracker:
724
+ """Usage tracker."""
725
+
726
+ def __init__(self, model_ids: set[str] | None):
727
+ self.model_ids = model_ids
728
+ self.usages = {
729
+ m: LMSamplingUsage(0, 0, 0) for m in model_ids
730
+ } if model_ids else {}
731
+
732
+ def track(self, model_id: str, usage: LMSamplingUsage):
733
+ if self.model_ids is not None and model_id not in self.model_ids:
734
+ return
735
+ if not isinstance(usage, UsageNotAvailable) and model_id in self.usages:
736
+ self.usages[model_id] += usage
737
+ else:
738
+ self.usages[model_id] = usage
739
+
740
+
741
+ @contextlib.contextmanager
742
+ def track_usages(
743
+ *lm: Union[str, LanguageModel]
744
+ ) -> Iterator[dict[str, LMSamplingUsage]]:
745
+ """Context manager to track the usages of all language models in scope.
746
+
747
+ `lf.track_usages` works with threads spawned by `lf.concurrent_map` and
748
+ `lf.concurrent_execute`.
749
+
750
+ Example:
751
+ ```
752
+ lm = lf.llms.GeminiPro1()
753
+ with lf.track_usages() as usages:
754
+ # invoke any code that will call LLMs.
755
+
756
+ print(usages[lm.model_id])
757
+ ```
758
+
759
+ Args:
760
+ *lm: The language model(s) to track. If None, track all models in scope.
761
+
762
+ Yields:
763
+ A dictionary of model ID to usage. If a model does not supports usage
764
+ counting, the dict entry will be None.
765
+ """
766
+ if not lm:
767
+ model_ids = None
768
+ else:
769
+ model_ids = [m.model_id if isinstance(m, LanguageModel) else m for m in lm]
770
+
771
+ trackers = component.context_value('__usage_trackers__', [])
772
+ tracker = _UsageTracker(set(model_ids) if model_ids else None)
773
+ with component.context(__usage_trackers__=trackers + [tracker]):
774
+ try:
775
+ yield tracker.usages
776
+ finally:
777
+ pass
@@ -27,8 +27,8 @@ import pyglove as pg
27
27
  @pg.use_init_args(['failures_before_attempt'])
28
28
  class MockModel(lm_lib.LanguageModel):
29
29
  """A mock model that echo back user prompts."""
30
-
31
30
  failures_before_attempt: int = 0
31
+ name: str = 'MockModel'
32
32
 
33
33
  def _sample(self,
34
34
  prompts: list[message_lib.Message]
@@ -63,6 +63,10 @@ class MockModel(lm_lib.LanguageModel):
63
63
  retry_interval=1,
64
64
  )(prompts)
65
65
 
66
+ @property
67
+ def model_id(self) -> str:
68
+ return self.name
69
+
66
70
 
67
71
  class MockScoringModel(MockModel):
68
72
 
@@ -580,6 +584,34 @@ class LanguageModelTest(unittest.TestCase):
580
584
  self.assertEqual(lm.rate_to_max_concurrency(requests_per_min=1), 1)
581
585
  self.assertEqual(lm.rate_to_max_concurrency(tokens_per_min=1), 1)
582
586
 
587
+ def test_track_usages(self):
588
+ lm = MockModel(name='model1')
589
+ lm2 = MockModel(name='model2')
590
+ with lm_lib.track_usages() as usages1:
591
+ _ = lm('hi')
592
+ with lm_lib.track_usages(lm2) as usages2:
593
+ with lm_lib.track_usages('model1') as usages3:
594
+ with lm_lib.track_usages('model1', lm2) as usages4:
595
+ def call_lm(prompt):
596
+ _ = lm.sample([prompt] * 2)
597
+ lm2('hi')
598
+ list(concurrent.concurrent_map(call_lm, ['hi', 'hello']))
599
+
600
+ self.assertEqual(usages2, {
601
+ 'model2': lm_lib.LMSamplingUsage(100, 100, 200),
602
+ })
603
+ self.assertEqual(usages3, {
604
+ 'model1': lm_lib.LMSamplingUsage(100 * 4, 100 * 4, 200 * 4),
605
+ })
606
+ self.assertEqual(usages4, {
607
+ 'model1': lm_lib.LMSamplingUsage(100 * 4, 100 * 4, 200 * 4),
608
+ 'model2': lm_lib.LMSamplingUsage(100, 100, 200),
609
+ })
610
+ self.assertEqual(usages1, {
611
+ 'model1': lm_lib.LMSamplingUsage(100 * 5, 100 * 5, 200 * 5),
612
+ 'model2': lm_lib.LMSamplingUsage(100, 100, 200),
613
+ })
614
+
583
615
 
584
616
  if __name__ == '__main__':
585
617
  unittest.main()
@@ -202,15 +202,14 @@ class OpenAI(lf.LanguageModel):
202
202
  lf.LMSample(choice.text.strip(), score=choice.logprobs or 0.0)
203
203
  )
204
204
 
205
+ n = len(samples_by_index)
205
206
  usage = lf.LMSamplingUsage(
206
- prompt_tokens=response.usage.prompt_tokens,
207
- completion_tokens=response.usage.completion_tokens,
208
- total_tokens=response.usage.total_tokens,
207
+ prompt_tokens=response.usage.prompt_tokens // n,
208
+ completion_tokens=response.usage.completion_tokens // n,
209
+ total_tokens=response.usage.total_tokens // n,
209
210
  )
210
211
  return [
211
- lf.LMSamplingResult(
212
- samples_by_index[index], usage=usage if index == 0 else None
213
- )
212
+ lf.LMSamplingResult(samples_by_index[index], usage=usage)
214
213
  for index in sorted(samples_by_index.keys())
215
214
  ]
216
215
 
@@ -191,9 +191,9 @@ class OpenAITest(unittest.TestCase):
191
191
  score=0.0,
192
192
  logprobs=None,
193
193
  usage=lf.LMSamplingUsage(
194
- prompt_tokens=33,
195
- completion_tokens=33,
196
- total_tokens=66
194
+ prompt_tokens=16,
195
+ completion_tokens=16,
196
+ total_tokens=33
197
197
  ),
198
198
  tags=[lf.Message.TAG_LM_RESPONSE],
199
199
  ),
@@ -206,9 +206,9 @@ class OpenAITest(unittest.TestCase):
206
206
  score=0.1,
207
207
  logprobs=None,
208
208
  usage=lf.LMSamplingUsage(
209
- prompt_tokens=33,
210
- completion_tokens=33,
211
- total_tokens=66
209
+ prompt_tokens=16,
210
+ completion_tokens=16,
211
+ total_tokens=33
212
212
  ),
213
213
  tags=[lf.Message.TAG_LM_RESPONSE],
214
214
  ),
@@ -221,9 +221,9 @@ class OpenAITest(unittest.TestCase):
221
221
  score=0.2,
222
222
  logprobs=None,
223
223
  usage=lf.LMSamplingUsage(
224
- prompt_tokens=33,
225
- completion_tokens=33,
226
- total_tokens=66
224
+ prompt_tokens=16,
225
+ completion_tokens=16,
226
+ total_tokens=33
227
227
  ),
228
228
  tags=[lf.Message.TAG_LM_RESPONSE],
229
229
  ),
@@ -232,7 +232,7 @@ class OpenAITest(unittest.TestCase):
232
232
  ),
233
233
  ],
234
234
  usage=lf.LMSamplingUsage(
235
- prompt_tokens=100, completion_tokens=100, total_tokens=200
235
+ prompt_tokens=50, completion_tokens=50, total_tokens=100
236
236
  ),
237
237
  ),
238
238
  )
@@ -245,7 +245,11 @@ class OpenAITest(unittest.TestCase):
245
245
  'Sample 0 for prompt 1.',
246
246
  score=0.0,
247
247
  logprobs=None,
248
- usage=None,
248
+ usage=lf.LMSamplingUsage(
249
+ prompt_tokens=16,
250
+ completion_tokens=16,
251
+ total_tokens=33
252
+ ),
249
253
  tags=[lf.Message.TAG_LM_RESPONSE],
250
254
  ),
251
255
  score=0.0,
@@ -256,7 +260,11 @@ class OpenAITest(unittest.TestCase):
256
260
  'Sample 1 for prompt 1.',
257
261
  score=0.1,
258
262
  logprobs=None,
259
- usage=None,
263
+ usage=lf.LMSamplingUsage(
264
+ prompt_tokens=16,
265
+ completion_tokens=16,
266
+ total_tokens=33
267
+ ),
260
268
  tags=[lf.Message.TAG_LM_RESPONSE],
261
269
  ),
262
270
  score=0.1,
@@ -267,13 +275,20 @@ class OpenAITest(unittest.TestCase):
267
275
  'Sample 2 for prompt 1.',
268
276
  score=0.2,
269
277
  logprobs=None,
270
- usage=None,
278
+ usage=lf.LMSamplingUsage(
279
+ prompt_tokens=16,
280
+ completion_tokens=16,
281
+ total_tokens=33
282
+ ),
271
283
  tags=[lf.Message.TAG_LM_RESPONSE],
272
284
  ),
273
285
  score=0.2,
274
286
  logprobs=None,
275
287
  ),
276
288
  ],
289
+ usage=lf.LMSamplingUsage(
290
+ prompt_tokens=50, completion_tokens=50, total_tokens=100
291
+ ),
277
292
  ),
278
293
  )
279
294
 
@@ -89,7 +89,7 @@ class RestTest(unittest.TestCase):
89
89
  "max_tokens=4096, stop=['\\n']."
90
90
  ),
91
91
  )
92
- self.assertIsNone(response.usage)
92
+ self.assertEqual(response.usage, lf.UsageNotAvailable())
93
93
 
94
94
  def test_call_errors(self):
95
95
  for status_code, error_type, error_message in [
@@ -0,0 +1,168 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Langfun event logging."""
15
+
16
+ import datetime
17
+ import io
18
+ import typing
19
+ from typing import Any, Literal, ContextManager
20
+
21
+ from langfun.core import console
22
+ import pyglove as pg
23
+
24
+
25
+ LogLevel = Literal['debug', 'info', 'error', 'warning', 'fatal']
26
+ _LOG_LEVELS = list(typing.get_args(LogLevel))
27
+ _TLS_KEY_MIN_LOG_LEVEL = '_event_log_level'
28
+
29
+
30
+ def use_log_level(log_level: LogLevel | None = 'info') -> ContextManager[None]:
31
+ """Contextmanager to enable logging at a given level."""
32
+ return pg.object_utils.thread_local_value_scope(
33
+ _TLS_KEY_MIN_LOG_LEVEL, log_level, 'info')
34
+
35
+
36
+ def get_log_level() -> LogLevel | None:
37
+ """Gets the current minimum log level."""
38
+ return pg.object_utils.thread_local_get(_TLS_KEY_MIN_LOG_LEVEL, 'info')
39
+
40
+
41
+ class LogEntry(pg.Object):
42
+ """Event log entry."""
43
+ time: datetime.datetime
44
+ level: LogLevel
45
+ message: str
46
+ metadata: dict[str, Any] = pg.Dict()
47
+ indent: int = 0
48
+
49
+ def should_output(self, min_log_level: LogLevel) -> bool:
50
+ return _LOG_LEVELS.index(self.level) >= _LOG_LEVELS.index(min_log_level)
51
+
52
+ def _repr_html_(self) -> str:
53
+ s = io.StringIO()
54
+ padding_left = 50 * self.indent
55
+ s.write(f'<div style="padding-left: {padding_left}px;">')
56
+ s.write(self._message_display)
57
+ if self.metadata:
58
+ s.write('<div style="padding-left: 20px; margin-top: 10px">')
59
+ s.write('<table style="border-top: 1px solid #EEEEEE;">')
60
+ for k, v in self.metadata.items():
61
+ if hasattr(v, '_repr_html_'):
62
+ cs = v._repr_html_() # pylint: disable=protected-access
63
+ else:
64
+ cs = f'<span style="white-space: pre-wrap">{str(v)}</span>'
65
+ key_span = self._round_text(k, color='#F1C40F', margin_bottom='0px')
66
+ s.write(
67
+ '<tr>'
68
+ '<td style="padding: 5px; vertical-align: top; '
69
+ f'border-bottom: 1px solid #EEEEEE">{key_span}</td>'
70
+ '<td style="padding: 5px; vertical-align: top; '
71
+ f'border-bottom: 1px solid #EEEEEE">{cs}</td></tr>'
72
+ )
73
+ s.write('</table></div>')
74
+ return s.getvalue()
75
+
76
+ @property
77
+ def _message_text_color(self) -> str:
78
+ match self.level:
79
+ case 'debug':
80
+ return '#EEEEEE'
81
+ case 'info':
82
+ return '#A3E4D7'
83
+ case 'error':
84
+ return '#F5C6CB'
85
+ case 'fatal':
86
+ return '#F19CBB'
87
+ case _:
88
+ raise ValueError(f'Unknown log level: {self.level}')
89
+
90
+ @property
91
+ def _time_display(self) -> str:
92
+ display_text = self.time.strftime('%H:%M:%S')
93
+ alt_text = self.time.strftime('%Y-%m-%d %H:%M:%S.%f')
94
+ return (
95
+ '<span style="background-color: #BBBBBB; color: white; '
96
+ 'border-radius:5px; padding:0px 5px 0px 5px;" '
97
+ f'title="{alt_text}">{display_text}</span>'
98
+ )
99
+
100
+ @property
101
+ def _message_display(self) -> str:
102
+ return self._round_text(
103
+ self._time_display + '&nbsp;' + self.message,
104
+ color=self._message_text_color,
105
+ )
106
+
107
+ def _round_text(
108
+ self,
109
+ text: str,
110
+ *,
111
+ color: str = '#EEEEEE',
112
+ display: str = 'inline-block',
113
+ margin_top: str = '5px',
114
+ margin_bottom: str = '5px',
115
+ whitespace: str = 'pre-wrap') -> str:
116
+ return (
117
+ f'<span style="background:{color}; display:{display};'
118
+ f'border-radius:10px; padding:5px; '
119
+ f'margin-top: {margin_top}; margin-bottom: {margin_bottom}; '
120
+ f'white-space: {whitespace}">{text}</span>'
121
+ )
122
+
123
+
124
+ def log(level: LogLevel,
125
+ message: str,
126
+ *,
127
+ indent: int = 0,
128
+ **kwargs) -> LogEntry:
129
+ """Logs a message."""
130
+ entry = LogEntry(
131
+ indent=indent,
132
+ level=level,
133
+ time=datetime.datetime.now(),
134
+ message=message,
135
+ metadata=kwargs,
136
+ )
137
+ if entry.should_output(get_log_level()):
138
+ if console.under_notebook():
139
+ console.display(entry)
140
+ else:
141
+ # TODO(daiyip): Improve the console output formatting.
142
+ console.write(entry)
143
+ return entry
144
+
145
+
146
+ def debug(message: str, *, indent: int = 0, **kwargs) -> LogEntry:
147
+ """Logs a debug message to the session."""
148
+ return log('debug', message, indent=indent, **kwargs)
149
+
150
+
151
+ def info(message: str, *, indent: int = 0, **kwargs) -> LogEntry:
152
+ """Logs an info message to the session."""
153
+ return log('info', message, indent=indent, **kwargs)
154
+
155
+
156
+ def warning(message: str, *, indent: int = 0, **kwargs) -> LogEntry:
157
+ """Logs an info message to the session."""
158
+ return log('warning', message, indent=indent, **kwargs)
159
+
160
+
161
+ def error(message: str, *, indent: int = 0, **kwargs) -> LogEntry:
162
+ """Logs an error message to the session."""
163
+ return log('error', message, indent=indent, **kwargs)
164
+
165
+
166
+ def fatal(message: str, *, indent: int = 0, **kwargs) -> LogEntry:
167
+ """Logs a fatal message to the session."""
168
+ return log('fatal', message, indent=indent, **kwargs)