langfun 0.0.2.dev20240605__py3-none-any.whl → 0.0.2.dev20240608__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/core/__init__.py +7 -0
- langfun/core/component.py +10 -0
- langfun/core/component_test.py +5 -0
- langfun/core/eval/base.py +8 -7
- langfun/core/langfunc_test.py +10 -2
- langfun/core/language_model.py +96 -13
- langfun/core/language_model_test.py +33 -1
- langfun/core/llms/openai.py +5 -6
- langfun/core/llms/openai_test.py +28 -13
- langfun/core/llms/rest_test.py +1 -1
- langfun/core/logging.py +168 -0
- langfun/core/logging_test.py +51 -0
- langfun/core/templates/selfplay_test.py +8 -2
- {langfun-0.0.2.dev20240605.dist-info → langfun-0.0.2.dev20240608.dist-info}/METADATA +1 -1
- {langfun-0.0.2.dev20240605.dist-info → langfun-0.0.2.dev20240608.dist-info}/RECORD +18 -16
- {langfun-0.0.2.dev20240605.dist-info → langfun-0.0.2.dev20240608.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240605.dist-info → langfun-0.0.2.dev20240608.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240605.dist-info → langfun-0.0.2.dev20240608.dist-info}/top_level.txt +0 -0
langfun/core/__init__.py
CHANGED
@@ -101,6 +101,7 @@ from langfun.core.language_model import LanguageModel
|
|
101
101
|
from langfun.core.language_model import LMSample
|
102
102
|
from langfun.core.language_model import LMSamplingOptions
|
103
103
|
from langfun.core.language_model import LMSamplingUsage
|
104
|
+
from langfun.core.language_model import UsageNotAvailable
|
104
105
|
from langfun.core.language_model import LMSamplingResult
|
105
106
|
from langfun.core.language_model import LMScoringResult
|
106
107
|
from langfun.core.language_model import LMCache
|
@@ -111,12 +112,18 @@ from langfun.core.language_model import RetryableLMError
|
|
111
112
|
from langfun.core.language_model import RateLimitError
|
112
113
|
from langfun.core.language_model import TemporaryLMError
|
113
114
|
|
115
|
+
# Context manager for tracking usages.
|
116
|
+
from langfun.core.language_model import track_usages
|
117
|
+
|
114
118
|
# Components for building agents.
|
115
119
|
from langfun.core.memory import Memory
|
116
120
|
|
117
121
|
# Utility for console output.
|
118
122
|
from langfun.core import console
|
119
123
|
|
124
|
+
# Utility for event logging.
|
125
|
+
from langfun.core import logging
|
126
|
+
|
120
127
|
# Import internal modules.
|
121
128
|
|
122
129
|
# pylint: enable=g-import-not-at-top
|
langfun/core/component.py
CHANGED
@@ -210,6 +210,16 @@ def get_contextual_override(var_name: str) -> ContextualOverride | None:
|
|
210
210
|
return _get_scoped_value(_global_tls, _CONTEXT_OVERRIDES, var_name)
|
211
211
|
|
212
212
|
|
213
|
+
def context_value(var_name: str, default: Any = RAISE_IF_HAS_ERROR) -> Any:
|
214
|
+
"""Returns the value of a variable defined in `lf.context`."""
|
215
|
+
override = get_contextual_override(var_name)
|
216
|
+
if override is None:
|
217
|
+
if default == RAISE_IF_HAS_ERROR:
|
218
|
+
raise KeyError(f'{var_name!r} does not exist in current context.')
|
219
|
+
return default
|
220
|
+
return override.value
|
221
|
+
|
222
|
+
|
213
223
|
def all_contextual_values() -> dict[str, Any]:
|
214
224
|
"""Returns all contextual values provided from `lf.context` in scope."""
|
215
225
|
overrides = getattr(_global_tls, _CONTEXT_OVERRIDES, {})
|
langfun/core/component_test.py
CHANGED
@@ -84,6 +84,11 @@ class ComponentContextTest(unittest.TestCase):
|
|
84
84
|
lf.get_contextual_override('y'),
|
85
85
|
lf.ContextualOverride(3, cascade=False, override_attrs=False),
|
86
86
|
)
|
87
|
+
self.assertEqual(lf.context_value('x'), 3)
|
88
|
+
self.assertIsNone(lf.context_value('f', None))
|
89
|
+
with self.assertRaisesRegex(KeyError, '.* does not exist'):
|
90
|
+
lf.context_value('f')
|
91
|
+
|
87
92
|
self.assertEqual(lf.all_contextual_values(), dict(x=3, y=3, z=3))
|
88
93
|
|
89
94
|
# Member attributes take precedence over `lf.context`.
|
langfun/core/eval/base.py
CHANGED
@@ -549,7 +549,7 @@ class Evaluable(lf.Component):
|
|
549
549
|
)
|
550
550
|
s.write(html.escape(pg.format(m.result)))
|
551
551
|
s.write('</div>')
|
552
|
-
if
|
552
|
+
if m.metadata.get('usage', None):
|
553
553
|
s.write(
|
554
554
|
'<div style="background-color: #EEEEEE; color: black; '
|
555
555
|
'white-space: pre-wrap; padding: 10px; border: 0px solid; '
|
@@ -1321,7 +1321,7 @@ class Evaluation(Evaluable):
|
|
1321
1321
|
self._render_summary_metrics(s)
|
1322
1322
|
|
1323
1323
|
# Summarize average usage.
|
1324
|
-
if self.result.usage
|
1324
|
+
if self.result.usage:
|
1325
1325
|
self._render_summary_usage(s)
|
1326
1326
|
|
1327
1327
|
s.write('</td></tr></table></div>')
|
@@ -1441,9 +1441,10 @@ class Evaluation(Evaluable):
|
|
1441
1441
|
def audit_usage(self, message: lf.Message, dryrun: bool = False) -> None:
|
1442
1442
|
del dryrun
|
1443
1443
|
for m in message.trace():
|
1444
|
-
|
1445
|
-
|
1446
|
-
self.
|
1444
|
+
usage = m.metadata.get('usage', None)
|
1445
|
+
if usage:
|
1446
|
+
self._total_prompt_tokens += usage.prompt_tokens
|
1447
|
+
self._total_completion_tokens += usage.completion_tokens
|
1447
1448
|
self._num_usages += 1
|
1448
1449
|
|
1449
1450
|
def audit_processed(
|
@@ -1504,7 +1505,7 @@ class Evaluation(Evaluable):
|
|
1504
1505
|
'<td>Schema</td>'
|
1505
1506
|
'<td>Additional Args</td>'
|
1506
1507
|
)
|
1507
|
-
if self.result.usage
|
1508
|
+
if self.result.usage:
|
1508
1509
|
s.write('<td>Usage</td>')
|
1509
1510
|
s.write('<td>OOP Failures</td>')
|
1510
1511
|
s.write('<td>Non-OOP Failures</td>')
|
@@ -1533,7 +1534,7 @@ class Evaluation(Evaluable):
|
|
1533
1534
|
f'{_html_repr(self.additional_args, compact=False)}</td>'
|
1534
1535
|
)
|
1535
1536
|
# Usage.
|
1536
|
-
if self.result.usage
|
1537
|
+
if self.result.usage:
|
1537
1538
|
s.write('<td>')
|
1538
1539
|
self._render_summary_usage(s)
|
1539
1540
|
s.write('</td>')
|
langfun/core/langfunc_test.py
CHANGED
@@ -87,7 +87,11 @@ class LangFuncCallTest(unittest.TestCase):
|
|
87
87
|
|
88
88
|
r = l()
|
89
89
|
self.assertEqual(
|
90
|
-
r,
|
90
|
+
r,
|
91
|
+
message.AIMessage(
|
92
|
+
'Hello!!!', score=0.0, logprobs=None,
|
93
|
+
usage=language_model.UsageNotAvailable()
|
94
|
+
)
|
91
95
|
)
|
92
96
|
self.assertEqual(r.tags, ['lm-response', 'lm-output'])
|
93
97
|
self.assertEqual(r.source, message.UserMessage('Hello'))
|
@@ -112,7 +116,11 @@ class LangFuncCallTest(unittest.TestCase):
|
|
112
116
|
self.assertEqual(l.render(), 'Hello')
|
113
117
|
r = l()
|
114
118
|
self.assertEqual(
|
115
|
-
r,
|
119
|
+
r,
|
120
|
+
message.AIMessage(
|
121
|
+
'Hello!!!', score=0.0, logprobs=None,
|
122
|
+
usage=language_model.UsageNotAvailable()
|
123
|
+
)
|
116
124
|
)
|
117
125
|
self.assertEqual(r.tags, ['lm-response', 'lm-output'])
|
118
126
|
|
langfun/core/language_model.py
CHANGED
@@ -14,10 +14,11 @@
|
|
14
14
|
"""Interface for language model."""
|
15
15
|
|
16
16
|
import abc
|
17
|
+
import contextlib
|
17
18
|
import dataclasses
|
18
19
|
import enum
|
19
20
|
import time
|
20
|
-
from typing import Annotated, Any, Callable, Sequence, Tuple, Type, Union
|
21
|
+
from typing import Annotated, Any, Callable, Iterator, Sequence, Tuple, Type, Union
|
21
22
|
from langfun.core import component
|
22
23
|
from langfun.core import concurrent
|
23
24
|
from langfun.core import console
|
@@ -84,6 +85,23 @@ class LMSamplingUsage(pg.Object):
|
|
84
85
|
completion_tokens: int
|
85
86
|
total_tokens: int
|
86
87
|
|
88
|
+
def __add__(self, other: 'LMSamplingUsage') -> 'LMSamplingUsage':
|
89
|
+
return LMSamplingUsage(
|
90
|
+
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
|
91
|
+
completion_tokens=self.completion_tokens + other.completion_tokens,
|
92
|
+
total_tokens=self.total_tokens + other.total_tokens,
|
93
|
+
)
|
94
|
+
|
95
|
+
|
96
|
+
class UsageNotAvailable(LMSamplingUsage):
|
97
|
+
"""Usage information not available."""
|
98
|
+
prompt_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
|
99
|
+
completion_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
|
100
|
+
total_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
|
101
|
+
|
102
|
+
def __bool__(self) -> bool:
|
103
|
+
return False
|
104
|
+
|
87
105
|
|
88
106
|
class LMSamplingResult(pg.Object):
|
89
107
|
"""Language model response."""
|
@@ -97,9 +115,9 @@ class LMSamplingResult(pg.Object):
|
|
97
115
|
] = []
|
98
116
|
|
99
117
|
usage: Annotated[
|
100
|
-
LMSamplingUsage
|
118
|
+
LMSamplingUsage,
|
101
119
|
'Usage information. Currently only OpenAI models are supported.',
|
102
|
-
] =
|
120
|
+
] = UsageNotAvailable()
|
103
121
|
|
104
122
|
|
105
123
|
class LMSamplingOptions(component.Component):
|
@@ -398,7 +416,7 @@ class LanguageModel(component.Component):
|
|
398
416
|
# which is accurate when n=1. For n > 1, we average the usage across
|
399
417
|
# multiple samples.
|
400
418
|
usage = result.usage
|
401
|
-
if len(result.samples) == 1 or usage
|
419
|
+
if len(result.samples) == 1 or not usage:
|
402
420
|
response.metadata.usage = usage
|
403
421
|
else:
|
404
422
|
n = len(result.samples)
|
@@ -408,6 +426,13 @@ class LanguageModel(component.Component):
|
|
408
426
|
total_tokens=usage.total_tokens // n,
|
409
427
|
)
|
410
428
|
|
429
|
+
# Track usage.
|
430
|
+
trackers = component.context_value('__usage_trackers__', [])
|
431
|
+
if trackers:
|
432
|
+
model_id = self.model_id
|
433
|
+
for tracker in trackers:
|
434
|
+
tracker.track(model_id, usage)
|
435
|
+
|
411
436
|
# Track the prompt for corresponding response.
|
412
437
|
response.source = prompt
|
413
438
|
|
@@ -511,7 +536,7 @@ class LanguageModel(component.Component):
|
|
511
536
|
prompt: message_lib.Message,
|
512
537
|
response: message_lib.Message,
|
513
538
|
call_counter: int,
|
514
|
-
usage: LMSamplingUsage
|
539
|
+
usage: LMSamplingUsage,
|
515
540
|
elapse: float,
|
516
541
|
) -> None:
|
517
542
|
"""Outputs debugging information."""
|
@@ -529,12 +554,13 @@ class LanguageModel(component.Component):
|
|
529
554
|
self._debug_response(response, call_counter, usage, elapse)
|
530
555
|
|
531
556
|
def _debug_model_info(
|
532
|
-
self, call_counter: int, usage: LMSamplingUsage
|
557
|
+
self, call_counter: int, usage: LMSamplingUsage) -> None:
|
533
558
|
"""Outputs debugging information about the model."""
|
534
559
|
title_suffix = ''
|
535
|
-
if usage
|
560
|
+
if usage.total_tokens != 0:
|
536
561
|
title_suffix = console.colored(
|
537
|
-
f' (total {usage.total_tokens} tokens)', 'red'
|
562
|
+
f' (total {usage.total_tokens} tokens)', 'red'
|
563
|
+
)
|
538
564
|
|
539
565
|
console.write(
|
540
566
|
self.format(compact=True, use_inferred=True),
|
@@ -546,11 +572,11 @@ class LanguageModel(component.Component):
|
|
546
572
|
self,
|
547
573
|
prompt: message_lib.Message,
|
548
574
|
call_counter: int,
|
549
|
-
usage: LMSamplingUsage
|
575
|
+
usage: LMSamplingUsage,
|
550
576
|
) -> None:
|
551
577
|
"""Outputs debugging information about the prompt."""
|
552
578
|
title_suffix = ''
|
553
|
-
if usage
|
579
|
+
if usage.prompt_tokens != 0:
|
554
580
|
title_suffix = console.colored(f' ({usage.prompt_tokens} tokens)', 'red')
|
555
581
|
|
556
582
|
console.write(
|
@@ -574,12 +600,12 @@ class LanguageModel(component.Component):
|
|
574
600
|
self,
|
575
601
|
response: message_lib.Message,
|
576
602
|
call_counter: int,
|
577
|
-
usage: LMSamplingUsage
|
603
|
+
usage: LMSamplingUsage,
|
578
604
|
elapse: float
|
579
605
|
) -> None:
|
580
606
|
"""Outputs debugging information about the response."""
|
581
607
|
title_suffix = ' ('
|
582
|
-
if usage
|
608
|
+
if usage.completion_tokens != 0:
|
583
609
|
title_suffix += f'{usage.completion_tokens} tokens '
|
584
610
|
title_suffix += f'in {elapse:.2f} seconds)'
|
585
611
|
title_suffix = console.colored(title_suffix, 'red')
|
@@ -641,7 +667,7 @@ class LanguageModel(component.Component):
|
|
641
667
|
debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
|
642
668
|
|
643
669
|
if debug & LMDebugMode.INFO:
|
644
|
-
self._debug_model_info(call_counter,
|
670
|
+
self._debug_model_info(call_counter, UsageNotAvailable())
|
645
671
|
|
646
672
|
if debug & LMDebugMode.PROMPT:
|
647
673
|
console.write(
|
@@ -692,3 +718,60 @@ class LanguageModel(component.Component):
|
|
692
718
|
return max(int(requests_per_min / 60), 1) # Max concurrency can't be zero
|
693
719
|
else:
|
694
720
|
return DEFAULT_MAX_CONCURRENCY # Default of 1
|
721
|
+
|
722
|
+
|
723
|
+
class _UsageTracker:
|
724
|
+
"""Usage tracker."""
|
725
|
+
|
726
|
+
def __init__(self, model_ids: set[str] | None):
|
727
|
+
self.model_ids = model_ids
|
728
|
+
self.usages = {
|
729
|
+
m: LMSamplingUsage(0, 0, 0) for m in model_ids
|
730
|
+
} if model_ids else {}
|
731
|
+
|
732
|
+
def track(self, model_id: str, usage: LMSamplingUsage):
|
733
|
+
if self.model_ids is not None and model_id not in self.model_ids:
|
734
|
+
return
|
735
|
+
if not isinstance(usage, UsageNotAvailable) and model_id in self.usages:
|
736
|
+
self.usages[model_id] += usage
|
737
|
+
else:
|
738
|
+
self.usages[model_id] = usage
|
739
|
+
|
740
|
+
|
741
|
+
@contextlib.contextmanager
|
742
|
+
def track_usages(
|
743
|
+
*lm: Union[str, LanguageModel]
|
744
|
+
) -> Iterator[dict[str, LMSamplingUsage]]:
|
745
|
+
"""Context manager to track the usages of all language models in scope.
|
746
|
+
|
747
|
+
`lf.track_usages` works with threads spawned by `lf.concurrent_map` and
|
748
|
+
`lf.concurrent_execute`.
|
749
|
+
|
750
|
+
Example:
|
751
|
+
```
|
752
|
+
lm = lf.llms.GeminiPro1()
|
753
|
+
with lf.track_usages() as usages:
|
754
|
+
# invoke any code that will call LLMs.
|
755
|
+
|
756
|
+
print(usages[lm.model_id])
|
757
|
+
```
|
758
|
+
|
759
|
+
Args:
|
760
|
+
*lm: The language model(s) to track. If None, track all models in scope.
|
761
|
+
|
762
|
+
Yields:
|
763
|
+
A dictionary of model ID to usage. If a model does not supports usage
|
764
|
+
counting, the dict entry will be None.
|
765
|
+
"""
|
766
|
+
if not lm:
|
767
|
+
model_ids = None
|
768
|
+
else:
|
769
|
+
model_ids = [m.model_id if isinstance(m, LanguageModel) else m for m in lm]
|
770
|
+
|
771
|
+
trackers = component.context_value('__usage_trackers__', [])
|
772
|
+
tracker = _UsageTracker(set(model_ids) if model_ids else None)
|
773
|
+
with component.context(__usage_trackers__=trackers + [tracker]):
|
774
|
+
try:
|
775
|
+
yield tracker.usages
|
776
|
+
finally:
|
777
|
+
pass
|
@@ -27,8 +27,8 @@ import pyglove as pg
|
|
27
27
|
@pg.use_init_args(['failures_before_attempt'])
|
28
28
|
class MockModel(lm_lib.LanguageModel):
|
29
29
|
"""A mock model that echo back user prompts."""
|
30
|
-
|
31
30
|
failures_before_attempt: int = 0
|
31
|
+
name: str = 'MockModel'
|
32
32
|
|
33
33
|
def _sample(self,
|
34
34
|
prompts: list[message_lib.Message]
|
@@ -63,6 +63,10 @@ class MockModel(lm_lib.LanguageModel):
|
|
63
63
|
retry_interval=1,
|
64
64
|
)(prompts)
|
65
65
|
|
66
|
+
@property
|
67
|
+
def model_id(self) -> str:
|
68
|
+
return self.name
|
69
|
+
|
66
70
|
|
67
71
|
class MockScoringModel(MockModel):
|
68
72
|
|
@@ -580,6 +584,34 @@ class LanguageModelTest(unittest.TestCase):
|
|
580
584
|
self.assertEqual(lm.rate_to_max_concurrency(requests_per_min=1), 1)
|
581
585
|
self.assertEqual(lm.rate_to_max_concurrency(tokens_per_min=1), 1)
|
582
586
|
|
587
|
+
def test_track_usages(self):
|
588
|
+
lm = MockModel(name='model1')
|
589
|
+
lm2 = MockModel(name='model2')
|
590
|
+
with lm_lib.track_usages() as usages1:
|
591
|
+
_ = lm('hi')
|
592
|
+
with lm_lib.track_usages(lm2) as usages2:
|
593
|
+
with lm_lib.track_usages('model1') as usages3:
|
594
|
+
with lm_lib.track_usages('model1', lm2) as usages4:
|
595
|
+
def call_lm(prompt):
|
596
|
+
_ = lm.sample([prompt] * 2)
|
597
|
+
lm2('hi')
|
598
|
+
list(concurrent.concurrent_map(call_lm, ['hi', 'hello']))
|
599
|
+
|
600
|
+
self.assertEqual(usages2, {
|
601
|
+
'model2': lm_lib.LMSamplingUsage(100, 100, 200),
|
602
|
+
})
|
603
|
+
self.assertEqual(usages3, {
|
604
|
+
'model1': lm_lib.LMSamplingUsage(100 * 4, 100 * 4, 200 * 4),
|
605
|
+
})
|
606
|
+
self.assertEqual(usages4, {
|
607
|
+
'model1': lm_lib.LMSamplingUsage(100 * 4, 100 * 4, 200 * 4),
|
608
|
+
'model2': lm_lib.LMSamplingUsage(100, 100, 200),
|
609
|
+
})
|
610
|
+
self.assertEqual(usages1, {
|
611
|
+
'model1': lm_lib.LMSamplingUsage(100 * 5, 100 * 5, 200 * 5),
|
612
|
+
'model2': lm_lib.LMSamplingUsage(100, 100, 200),
|
613
|
+
})
|
614
|
+
|
583
615
|
|
584
616
|
if __name__ == '__main__':
|
585
617
|
unittest.main()
|
langfun/core/llms/openai.py
CHANGED
@@ -202,15 +202,14 @@ class OpenAI(lf.LanguageModel):
|
|
202
202
|
lf.LMSample(choice.text.strip(), score=choice.logprobs or 0.0)
|
203
203
|
)
|
204
204
|
|
205
|
+
n = len(samples_by_index)
|
205
206
|
usage = lf.LMSamplingUsage(
|
206
|
-
prompt_tokens=response.usage.prompt_tokens,
|
207
|
-
completion_tokens=response.usage.completion_tokens,
|
208
|
-
total_tokens=response.usage.total_tokens,
|
207
|
+
prompt_tokens=response.usage.prompt_tokens // n,
|
208
|
+
completion_tokens=response.usage.completion_tokens // n,
|
209
|
+
total_tokens=response.usage.total_tokens // n,
|
209
210
|
)
|
210
211
|
return [
|
211
|
-
lf.LMSamplingResult(
|
212
|
-
samples_by_index[index], usage=usage if index == 0 else None
|
213
|
-
)
|
212
|
+
lf.LMSamplingResult(samples_by_index[index], usage=usage)
|
214
213
|
for index in sorted(samples_by_index.keys())
|
215
214
|
]
|
216
215
|
|
langfun/core/llms/openai_test.py
CHANGED
@@ -191,9 +191,9 @@ class OpenAITest(unittest.TestCase):
|
|
191
191
|
score=0.0,
|
192
192
|
logprobs=None,
|
193
193
|
usage=lf.LMSamplingUsage(
|
194
|
-
prompt_tokens=
|
195
|
-
completion_tokens=
|
196
|
-
total_tokens=
|
194
|
+
prompt_tokens=16,
|
195
|
+
completion_tokens=16,
|
196
|
+
total_tokens=33
|
197
197
|
),
|
198
198
|
tags=[lf.Message.TAG_LM_RESPONSE],
|
199
199
|
),
|
@@ -206,9 +206,9 @@ class OpenAITest(unittest.TestCase):
|
|
206
206
|
score=0.1,
|
207
207
|
logprobs=None,
|
208
208
|
usage=lf.LMSamplingUsage(
|
209
|
-
prompt_tokens=
|
210
|
-
completion_tokens=
|
211
|
-
total_tokens=
|
209
|
+
prompt_tokens=16,
|
210
|
+
completion_tokens=16,
|
211
|
+
total_tokens=33
|
212
212
|
),
|
213
213
|
tags=[lf.Message.TAG_LM_RESPONSE],
|
214
214
|
),
|
@@ -221,9 +221,9 @@ class OpenAITest(unittest.TestCase):
|
|
221
221
|
score=0.2,
|
222
222
|
logprobs=None,
|
223
223
|
usage=lf.LMSamplingUsage(
|
224
|
-
prompt_tokens=
|
225
|
-
completion_tokens=
|
226
|
-
total_tokens=
|
224
|
+
prompt_tokens=16,
|
225
|
+
completion_tokens=16,
|
226
|
+
total_tokens=33
|
227
227
|
),
|
228
228
|
tags=[lf.Message.TAG_LM_RESPONSE],
|
229
229
|
),
|
@@ -232,7 +232,7 @@ class OpenAITest(unittest.TestCase):
|
|
232
232
|
),
|
233
233
|
],
|
234
234
|
usage=lf.LMSamplingUsage(
|
235
|
-
prompt_tokens=
|
235
|
+
prompt_tokens=50, completion_tokens=50, total_tokens=100
|
236
236
|
),
|
237
237
|
),
|
238
238
|
)
|
@@ -245,7 +245,11 @@ class OpenAITest(unittest.TestCase):
|
|
245
245
|
'Sample 0 for prompt 1.',
|
246
246
|
score=0.0,
|
247
247
|
logprobs=None,
|
248
|
-
usage=
|
248
|
+
usage=lf.LMSamplingUsage(
|
249
|
+
prompt_tokens=16,
|
250
|
+
completion_tokens=16,
|
251
|
+
total_tokens=33
|
252
|
+
),
|
249
253
|
tags=[lf.Message.TAG_LM_RESPONSE],
|
250
254
|
),
|
251
255
|
score=0.0,
|
@@ -256,7 +260,11 @@ class OpenAITest(unittest.TestCase):
|
|
256
260
|
'Sample 1 for prompt 1.',
|
257
261
|
score=0.1,
|
258
262
|
logprobs=None,
|
259
|
-
usage=
|
263
|
+
usage=lf.LMSamplingUsage(
|
264
|
+
prompt_tokens=16,
|
265
|
+
completion_tokens=16,
|
266
|
+
total_tokens=33
|
267
|
+
),
|
260
268
|
tags=[lf.Message.TAG_LM_RESPONSE],
|
261
269
|
),
|
262
270
|
score=0.1,
|
@@ -267,13 +275,20 @@ class OpenAITest(unittest.TestCase):
|
|
267
275
|
'Sample 2 for prompt 1.',
|
268
276
|
score=0.2,
|
269
277
|
logprobs=None,
|
270
|
-
usage=
|
278
|
+
usage=lf.LMSamplingUsage(
|
279
|
+
prompt_tokens=16,
|
280
|
+
completion_tokens=16,
|
281
|
+
total_tokens=33
|
282
|
+
),
|
271
283
|
tags=[lf.Message.TAG_LM_RESPONSE],
|
272
284
|
),
|
273
285
|
score=0.2,
|
274
286
|
logprobs=None,
|
275
287
|
),
|
276
288
|
],
|
289
|
+
usage=lf.LMSamplingUsage(
|
290
|
+
prompt_tokens=50, completion_tokens=50, total_tokens=100
|
291
|
+
),
|
277
292
|
),
|
278
293
|
)
|
279
294
|
|
langfun/core/llms/rest_test.py
CHANGED
@@ -89,7 +89,7 @@ class RestTest(unittest.TestCase):
|
|
89
89
|
"max_tokens=4096, stop=['\\n']."
|
90
90
|
),
|
91
91
|
)
|
92
|
-
self.
|
92
|
+
self.assertEqual(response.usage, lf.UsageNotAvailable())
|
93
93
|
|
94
94
|
def test_call_errors(self):
|
95
95
|
for status_code, error_type, error_message in [
|
langfun/core/logging.py
ADDED
@@ -0,0 +1,168 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Langfun event logging."""
|
15
|
+
|
16
|
+
import datetime
|
17
|
+
import io
|
18
|
+
import typing
|
19
|
+
from typing import Any, Literal, ContextManager
|
20
|
+
|
21
|
+
from langfun.core import console
|
22
|
+
import pyglove as pg
|
23
|
+
|
24
|
+
|
25
|
+
LogLevel = Literal['debug', 'info', 'error', 'warning', 'fatal']
|
26
|
+
_LOG_LEVELS = list(typing.get_args(LogLevel))
|
27
|
+
_TLS_KEY_MIN_LOG_LEVEL = '_event_log_level'
|
28
|
+
|
29
|
+
|
30
|
+
def use_log_level(log_level: LogLevel | None = 'info') -> ContextManager[None]:
|
31
|
+
"""Contextmanager to enable logging at a given level."""
|
32
|
+
return pg.object_utils.thread_local_value_scope(
|
33
|
+
_TLS_KEY_MIN_LOG_LEVEL, log_level, 'info')
|
34
|
+
|
35
|
+
|
36
|
+
def get_log_level() -> LogLevel | None:
|
37
|
+
"""Gets the current minimum log level."""
|
38
|
+
return pg.object_utils.thread_local_get(_TLS_KEY_MIN_LOG_LEVEL, 'info')
|
39
|
+
|
40
|
+
|
41
|
+
class LogEntry(pg.Object):
|
42
|
+
"""Event log entry."""
|
43
|
+
time: datetime.datetime
|
44
|
+
level: LogLevel
|
45
|
+
message: str
|
46
|
+
metadata: dict[str, Any] = pg.Dict()
|
47
|
+
indent: int = 0
|
48
|
+
|
49
|
+
def should_output(self, min_log_level: LogLevel) -> bool:
|
50
|
+
return _LOG_LEVELS.index(self.level) >= _LOG_LEVELS.index(min_log_level)
|
51
|
+
|
52
|
+
def _repr_html_(self) -> str:
|
53
|
+
s = io.StringIO()
|
54
|
+
padding_left = 50 * self.indent
|
55
|
+
s.write(f'<div style="padding-left: {padding_left}px;">')
|
56
|
+
s.write(self._message_display)
|
57
|
+
if self.metadata:
|
58
|
+
s.write('<div style="padding-left: 20px; margin-top: 10px">')
|
59
|
+
s.write('<table style="border-top: 1px solid #EEEEEE;">')
|
60
|
+
for k, v in self.metadata.items():
|
61
|
+
if hasattr(v, '_repr_html_'):
|
62
|
+
cs = v._repr_html_() # pylint: disable=protected-access
|
63
|
+
else:
|
64
|
+
cs = f'<span style="white-space: pre-wrap">{str(v)}</span>'
|
65
|
+
key_span = self._round_text(k, color='#F1C40F', margin_bottom='0px')
|
66
|
+
s.write(
|
67
|
+
'<tr>'
|
68
|
+
'<td style="padding: 5px; vertical-align: top; '
|
69
|
+
f'border-bottom: 1px solid #EEEEEE">{key_span}</td>'
|
70
|
+
'<td style="padding: 5px; vertical-align: top; '
|
71
|
+
f'border-bottom: 1px solid #EEEEEE">{cs}</td></tr>'
|
72
|
+
)
|
73
|
+
s.write('</table></div>')
|
74
|
+
return s.getvalue()
|
75
|
+
|
76
|
+
@property
|
77
|
+
def _message_text_color(self) -> str:
|
78
|
+
match self.level:
|
79
|
+
case 'debug':
|
80
|
+
return '#EEEEEE'
|
81
|
+
case 'info':
|
82
|
+
return '#A3E4D7'
|
83
|
+
case 'error':
|
84
|
+
return '#F5C6CB'
|
85
|
+
case 'fatal':
|
86
|
+
return '#F19CBB'
|
87
|
+
case _:
|
88
|
+
raise ValueError(f'Unknown log level: {self.level}')
|
89
|
+
|
90
|
+
@property
|
91
|
+
def _time_display(self) -> str:
|
92
|
+
display_text = self.time.strftime('%H:%M:%S')
|
93
|
+
alt_text = self.time.strftime('%Y-%m-%d %H:%M:%S.%f')
|
94
|
+
return (
|
95
|
+
'<span style="background-color: #BBBBBB; color: white; '
|
96
|
+
'border-radius:5px; padding:0px 5px 0px 5px;" '
|
97
|
+
f'title="{alt_text}">{display_text}</span>'
|
98
|
+
)
|
99
|
+
|
100
|
+
@property
|
101
|
+
def _message_display(self) -> str:
|
102
|
+
return self._round_text(
|
103
|
+
self._time_display + ' ' + self.message,
|
104
|
+
color=self._message_text_color,
|
105
|
+
)
|
106
|
+
|
107
|
+
def _round_text(
|
108
|
+
self,
|
109
|
+
text: str,
|
110
|
+
*,
|
111
|
+
color: str = '#EEEEEE',
|
112
|
+
display: str = 'inline-block',
|
113
|
+
margin_top: str = '5px',
|
114
|
+
margin_bottom: str = '5px',
|
115
|
+
whitespace: str = 'pre-wrap') -> str:
|
116
|
+
return (
|
117
|
+
f'<span style="background:{color}; display:{display};'
|
118
|
+
f'border-radius:10px; padding:5px; '
|
119
|
+
f'margin-top: {margin_top}; margin-bottom: {margin_bottom}; '
|
120
|
+
f'white-space: {whitespace}">{text}</span>'
|
121
|
+
)
|
122
|
+
|
123
|
+
|
124
|
+
def log(level: LogLevel,
|
125
|
+
message: str,
|
126
|
+
*,
|
127
|
+
indent: int = 0,
|
128
|
+
**kwargs) -> LogEntry:
|
129
|
+
"""Logs a message."""
|
130
|
+
entry = LogEntry(
|
131
|
+
indent=indent,
|
132
|
+
level=level,
|
133
|
+
time=datetime.datetime.now(),
|
134
|
+
message=message,
|
135
|
+
metadata=kwargs,
|
136
|
+
)
|
137
|
+
if entry.should_output(get_log_level()):
|
138
|
+
if console.under_notebook():
|
139
|
+
console.display(entry)
|
140
|
+
else:
|
141
|
+
# TODO(daiyip): Improve the console output formatting.
|
142
|
+
console.write(entry)
|
143
|
+
return entry
|
144
|
+
|
145
|
+
|
146
|
+
def debug(message: str, *, indent: int = 0, **kwargs) -> LogEntry:
|
147
|
+
"""Logs a debug message to the session."""
|
148
|
+
return log('debug', message, indent=indent, **kwargs)
|
149
|
+
|
150
|
+
|
151
|
+
def info(message: str, *, indent: int = 0, **kwargs) -> LogEntry:
|
152
|
+
"""Logs an info message to the session."""
|
153
|
+
return log('info', message, indent=indent, **kwargs)
|
154
|
+
|
155
|
+
|
156
|
+
def warning(message: str, *, indent: int = 0, **kwargs) -> LogEntry:
|
157
|
+
"""Logs an info message to the session."""
|
158
|
+
return log('warning', message, indent=indent, **kwargs)
|
159
|
+
|
160
|
+
|
161
|
+
def error(message: str, *, indent: int = 0, **kwargs) -> LogEntry:
|
162
|
+
"""Logs an error message to the session."""
|
163
|
+
return log('error', message, indent=indent, **kwargs)
|
164
|
+
|
165
|
+
|
166
|
+
def fatal(message: str, *, indent: int = 0, **kwargs) -> LogEntry:
|
167
|
+
"""Logs a fatal message to the session."""
|
168
|
+
return log('fatal', message, indent=indent, **kwargs)
|
@@ -0,0 +1,51 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Tests for langfun.core.logging."""
|
15
|
+
|
16
|
+
import unittest
|
17
|
+
|
18
|
+
from langfun.core import logging
|
19
|
+
|
20
|
+
|
21
|
+
class LoggingTest(unittest.TestCase):
|
22
|
+
|
23
|
+
def test_use_log_level(self):
|
24
|
+
self.assertEqual(logging.get_log_level(), 'info')
|
25
|
+
with logging.use_log_level('debug'):
|
26
|
+
self.assertEqual(logging.get_log_level(), 'debug')
|
27
|
+
with logging.use_log_level(None):
|
28
|
+
self.assertIsNone(logging.get_log_level(), None)
|
29
|
+
self.assertEqual(logging.get_log_level(), 'debug')
|
30
|
+
self.assertEqual(logging.get_log_level(), 'info')
|
31
|
+
|
32
|
+
def test_log(self):
|
33
|
+
entry = logging.log('info', 'hi', indent=1, x=1, y=2)
|
34
|
+
self.assertEqual(entry.level, 'info')
|
35
|
+
self.assertEqual(entry.message, 'hi')
|
36
|
+
self.assertEqual(entry.indent, 1)
|
37
|
+
self.assertEqual(entry.metadata, {'x': 1, 'y': 2})
|
38
|
+
|
39
|
+
self.assertEqual(logging.debug('hi').level, 'debug')
|
40
|
+
self.assertEqual(logging.info('hi').level, 'info')
|
41
|
+
self.assertEqual(logging.warning('hi').level, 'warning')
|
42
|
+
self.assertEqual(logging.error('hi').level, 'error')
|
43
|
+
self.assertEqual(logging.fatal('hi').level, 'fatal')
|
44
|
+
|
45
|
+
def test_repr_html(self):
|
46
|
+
entry = logging.log('info', 'hi', indent=1, x=1, y=2)
|
47
|
+
self.assertIn('<div', entry._repr_html_())
|
48
|
+
|
49
|
+
|
50
|
+
if __name__ == '__main__':
|
51
|
+
unittest.main()
|
@@ -57,7 +57,10 @@ class SelfPlayTest(unittest.TestCase):
|
|
57
57
|
|
58
58
|
with lf.context(lm=NumberGuesser(guesses=[50, 20, 5, 10])):
|
59
59
|
self.assertEqual(
|
60
|
-
g(),
|
60
|
+
g(),
|
61
|
+
lf.AIMessage(
|
62
|
+
'10', score=0.0, logprobs=None, usage=lf.UsageNotAvailable()
|
63
|
+
)
|
61
64
|
)
|
62
65
|
|
63
66
|
self.assertEqual(g.num_turns, 4)
|
@@ -67,7 +70,10 @@ class SelfPlayTest(unittest.TestCase):
|
|
67
70
|
|
68
71
|
with lf.context(lm=NumberGuesser(guesses=[50, 20, 5, 2, 5, 4])):
|
69
72
|
self.assertEqual(
|
70
|
-
g(),
|
73
|
+
g(),
|
74
|
+
lf.AIMessage(
|
75
|
+
'2', score=0.0, logprobs=None, usage=lf.UsageNotAvailable()
|
76
|
+
)
|
71
77
|
)
|
72
78
|
|
73
79
|
self.assertEqual(g.num_turns, 10)
|
@@ -1,15 +1,17 @@
|
|
1
1
|
langfun/__init__.py,sha256=P62MnqA6-f0h8iYfQ3MT6Yg7a4qRnQeb4GrIn6dcSnY,2274
|
2
|
-
langfun/core/__init__.py,sha256=
|
3
|
-
langfun/core/component.py,sha256=
|
4
|
-
langfun/core/component_test.py,sha256=
|
2
|
+
langfun/core/__init__.py,sha256=nFQgTyhNLA83b_V_nPNByyO2JlF576AdS61AfA0SaN8,4728
|
3
|
+
langfun/core/component.py,sha256=Icyoj9ICoJoK2r2PHbrFXbxnseOr9QZZOvKWklLWNo8,10276
|
4
|
+
langfun/core/component_test.py,sha256=q15Xn51cVTu2RKxZ9U5VQgT3bm6RQ4638bKhWBtvW5o,8220
|
5
5
|
langfun/core/concurrent.py,sha256=TRc49pJ3HQro2kb5FtcWkHjhBm8UcgE8RJybU5cU3-0,24537
|
6
6
|
langfun/core/concurrent_test.py,sha256=mwFMZhDUdppnDr7vDSTwcbMHwrdsIoKJwRYNtl4ZWL4,15185
|
7
7
|
langfun/core/console.py,sha256=bk5rNPNm9rMGW5YT2HixxU04p2umnoabn5SDz6Dqe88,2317
|
8
8
|
langfun/core/console_test.py,sha256=5SYJdxpJGLgdSSQqqMPoA1X6jpsLD8rgcyk-EgI65oE,1077
|
9
9
|
langfun/core/langfunc.py,sha256=RvIcRjIq0jWYRu1xim-FYe4HSrt97r3GMBO_PuagUmw,11060
|
10
|
-
langfun/core/langfunc_test.py,sha256=
|
11
|
-
langfun/core/language_model.py,sha256=
|
12
|
-
langfun/core/language_model_test.py,sha256=
|
10
|
+
langfun/core/langfunc_test.py,sha256=lyt-UzkD8972cxZwzCkps0_RMLeSsOBrcUFIW-fB6us,8653
|
11
|
+
langfun/core/language_model.py,sha256=q41EGhbgoCe68eYWWTAzgt3r6SbdCpINRdXnNGstukQ,23775
|
12
|
+
langfun/core/language_model_test.py,sha256=HmA0GACK4-6tCH32TFkfYj9w4CxbrynKqXdKBgiqgwo,21255
|
13
|
+
langfun/core/logging.py,sha256=FyZRxUy2TTF6tWLhQCRpCvfH55WGUdNgQjUTK_SQLnY,5320
|
14
|
+
langfun/core/logging_test.py,sha256=qvm3RObYP3knO2PnXR9evBRl4gH621GnjnwywbGbRfg,1833
|
13
15
|
langfun/core/memory.py,sha256=f-asN1F7Vehgdn_fK84v73GrEUOxRtaW934keutTKjk,2416
|
14
16
|
langfun/core/message.py,sha256=Rw3yC9HyGRjMhfDgyNjGlSCALEyDDbJ0_o6qTXeeDiQ,15738
|
15
17
|
langfun/core/message_test.py,sha256=b6DDRoQ5j3uK-dc0QPSLelNTKaXX10MxJrRiI61iGX4,9574
|
@@ -40,7 +42,7 @@ langfun/core/coding/python/parsing_test.py,sha256=9vAWF484kWIm6JZq8NFiMgKUDhXV-d
|
|
40
42
|
langfun/core/coding/python/permissions.py,sha256=1QWGHvzL8MM0Ok_auQ9tURqZHtdOfJaDpBzZ29GUE-c,2544
|
41
43
|
langfun/core/coding/python/permissions_test.py,sha256=w5EDb8QxpxgJyZkojyzVWQvDfg366zn99-g__6TbPQ0,2699
|
42
44
|
langfun/core/eval/__init__.py,sha256=Evt-E4FEhZF2tXL6-byh_AyA7Cc_ZoGmvnN7vkAZedk,1898
|
43
|
-
langfun/core/eval/base.py,sha256=
|
45
|
+
langfun/core/eval/base.py,sha256=GM98Zo4gxZui2ORX6Q7Zr94PfiEViQC5X_qz-uj6b2k,74220
|
44
46
|
langfun/core/eval/base_test.py,sha256=cHOTIWVW4Dp8gKKIKcZrAcJ-w84j2GIozTzJoiAX7p4,26743
|
45
47
|
langfun/core/eval/matching.py,sha256=Y4vFoNTQEOwko6IA8l9OZ52-vt52e3VGmcTtvLA67wM,9782
|
46
48
|
langfun/core/eval/matching_test.py,sha256=f7iVyXH5KGJBWt4Wp14Bt9J3X59A6Ayfog9MbuFvPew,5532
|
@@ -59,10 +61,10 @@ langfun/core/llms/groq.py,sha256=pqtyOZ_1_OJMOg8xATWT_B_SVbuT9nMRf4VkH9GzW8g,630
|
|
59
61
|
langfun/core/llms/groq_test.py,sha256=GYF_Qtq5S1H1TrKH38t6_lkdroqT7v-joYLDKnmS9e0,5274
|
60
62
|
langfun/core/llms/llama_cpp.py,sha256=9tXQntSCDtjTF3bnyJrAPCr4N6wycy5nXYvp9uduygE,2843
|
61
63
|
langfun/core/llms/llama_cpp_test.py,sha256=MWO_qaOeKjRniGjcaWPDScd7HPaIJemqUZoslrt4FPs,1806
|
62
|
-
langfun/core/llms/openai.py,sha256=
|
63
|
-
langfun/core/llms/openai_test.py,sha256=
|
64
|
+
langfun/core/llms/openai.py,sha256=0z9qIH9FlWj9VWUnhOX321T6JHO-vjY2IozT7OVI4GY,13654
|
65
|
+
langfun/core/llms/openai_test.py,sha256=3muDTnW7UBOSHq694Fi2bofqhe8Pkj0Tl8IShoLCTOM,15525
|
64
66
|
langfun/core/llms/rest.py,sha256=laopuq-zD8V-3Y6eFDngftHEbE66VlUkCD2-rvvRaLU,3388
|
65
|
-
langfun/core/llms/rest_test.py,sha256=
|
67
|
+
langfun/core/llms/rest_test.py,sha256=NZ3Nf0XQVpT9kLP5cBVo_yBHLI7vWTYhWQxYEJVMGs4,3472
|
66
68
|
langfun/core/llms/vertexai.py,sha256=wIpckH-rMHUBA1vhauQk4LVrSsPQEsVntz7kLDKwm9g,11359
|
67
69
|
langfun/core/llms/vertexai_test.py,sha256=G18BG36h5KvmX2zutDTLjtYCRjTuP_nWIFm4FMnLnyY,7651
|
68
70
|
langfun/core/llms/cache/__init__.py,sha256=QAo3InUMDM_YpteNnVCSejI4zOsnjSMWKJKzkb3VY64,993
|
@@ -112,9 +114,9 @@ langfun/core/templates/conversation_test.py,sha256=RryYyIhfc34dLWOs6GfPQ8HU8mXpK
|
|
112
114
|
langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fikKhwhzwhpKI,1460
|
113
115
|
langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
|
114
116
|
langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
|
115
|
-
langfun/core/templates/selfplay_test.py,sha256=
|
116
|
-
langfun-0.0.2.
|
117
|
-
langfun-0.0.2.
|
118
|
-
langfun-0.0.2.
|
119
|
-
langfun-0.0.2.
|
120
|
-
langfun-0.0.2.
|
117
|
+
langfun/core/templates/selfplay_test.py,sha256=rBW2Qr8yi-aWYwoTwRR-n1peKyMX9QXPZXURjLgoiRs,2264
|
118
|
+
langfun-0.0.2.dev20240608.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
119
|
+
langfun-0.0.2.dev20240608.dist-info/METADATA,sha256=HhaObQS5WL8Y2uRuwNX389KbnRiRPsV3Vl-XzVGTN0A,3550
|
120
|
+
langfun-0.0.2.dev20240608.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
121
|
+
langfun-0.0.2.dev20240608.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
|
122
|
+
langfun-0.0.2.dev20240608.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|