langfun 0.0.2.dev20240607__py3-none-any.whl → 0.0.2.dev20240609__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/core/__init__.py +4 -0
- langfun/core/eval/base.py +8 -7
- langfun/core/langfunc_test.py +10 -2
- langfun/core/language_model.py +78 -25
- langfun/core/language_model_test.py +26 -6
- langfun/core/llms/openai.py +5 -6
- langfun/core/llms/openai_test.py +28 -13
- langfun/core/llms/rest_test.py +1 -1
- langfun/core/repr_utils.py +83 -0
- langfun/core/repr_utils_test.py +58 -0
- langfun/core/templates/selfplay_test.py +8 -2
- {langfun-0.0.2.dev20240607.dist-info → langfun-0.0.2.dev20240609.dist-info}/METADATA +1 -1
- {langfun-0.0.2.dev20240607.dist-info → langfun-0.0.2.dev20240609.dist-info}/RECORD +16 -14
- {langfun-0.0.2.dev20240607.dist-info → langfun-0.0.2.dev20240609.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240607.dist-info → langfun-0.0.2.dev20240609.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240607.dist-info → langfun-0.0.2.dev20240609.dist-info}/top_level.txt +0 -0
langfun/core/__init__.py
CHANGED
@@ -101,6 +101,7 @@ from langfun.core.language_model import LanguageModel
|
|
101
101
|
from langfun.core.language_model import LMSample
|
102
102
|
from langfun.core.language_model import LMSamplingOptions
|
103
103
|
from langfun.core.language_model import LMSamplingUsage
|
104
|
+
from langfun.core.language_model import UsageNotAvailable
|
104
105
|
from langfun.core.language_model import LMSamplingResult
|
105
106
|
from langfun.core.language_model import LMScoringResult
|
106
107
|
from langfun.core.language_model import LMCache
|
@@ -120,6 +121,9 @@ from langfun.core.memory import Memory
|
|
120
121
|
# Utility for console output.
|
121
122
|
from langfun.core import console
|
122
123
|
|
124
|
+
# Helpers for implementing _repr_xxx_ methods.
|
125
|
+
from langfun.core import repr_utils
|
126
|
+
|
123
127
|
# Utility for event logging.
|
124
128
|
from langfun.core import logging
|
125
129
|
|
langfun/core/eval/base.py
CHANGED
@@ -549,7 +549,7 @@ class Evaluable(lf.Component):
|
|
549
549
|
)
|
550
550
|
s.write(html.escape(pg.format(m.result)))
|
551
551
|
s.write('</div>')
|
552
|
-
if
|
552
|
+
if m.metadata.get('usage', None):
|
553
553
|
s.write(
|
554
554
|
'<div style="background-color: #EEEEEE; color: black; '
|
555
555
|
'white-space: pre-wrap; padding: 10px; border: 0px solid; '
|
@@ -1321,7 +1321,7 @@ class Evaluation(Evaluable):
|
|
1321
1321
|
self._render_summary_metrics(s)
|
1322
1322
|
|
1323
1323
|
# Summarize average usage.
|
1324
|
-
if self.result.usage
|
1324
|
+
if self.result.usage:
|
1325
1325
|
self._render_summary_usage(s)
|
1326
1326
|
|
1327
1327
|
s.write('</td></tr></table></div>')
|
@@ -1441,9 +1441,10 @@ class Evaluation(Evaluable):
|
|
1441
1441
|
def audit_usage(self, message: lf.Message, dryrun: bool = False) -> None:
|
1442
1442
|
del dryrun
|
1443
1443
|
for m in message.trace():
|
1444
|
-
|
1445
|
-
|
1446
|
-
self.
|
1444
|
+
usage = m.metadata.get('usage', None)
|
1445
|
+
if usage:
|
1446
|
+
self._total_prompt_tokens += usage.prompt_tokens
|
1447
|
+
self._total_completion_tokens += usage.completion_tokens
|
1447
1448
|
self._num_usages += 1
|
1448
1449
|
|
1449
1450
|
def audit_processed(
|
@@ -1504,7 +1505,7 @@ class Evaluation(Evaluable):
|
|
1504
1505
|
'<td>Schema</td>'
|
1505
1506
|
'<td>Additional Args</td>'
|
1506
1507
|
)
|
1507
|
-
if self.result.usage
|
1508
|
+
if self.result.usage:
|
1508
1509
|
s.write('<td>Usage</td>')
|
1509
1510
|
s.write('<td>OOP Failures</td>')
|
1510
1511
|
s.write('<td>Non-OOP Failures</td>')
|
@@ -1533,7 +1534,7 @@ class Evaluation(Evaluable):
|
|
1533
1534
|
f'{_html_repr(self.additional_args, compact=False)}</td>'
|
1534
1535
|
)
|
1535
1536
|
# Usage.
|
1536
|
-
if self.result.usage
|
1537
|
+
if self.result.usage:
|
1537
1538
|
s.write('<td>')
|
1538
1539
|
self._render_summary_usage(s)
|
1539
1540
|
s.write('</td>')
|
langfun/core/langfunc_test.py
CHANGED
@@ -87,7 +87,11 @@ class LangFuncCallTest(unittest.TestCase):
|
|
87
87
|
|
88
88
|
r = l()
|
89
89
|
self.assertEqual(
|
90
|
-
r,
|
90
|
+
r,
|
91
|
+
message.AIMessage(
|
92
|
+
'Hello!!!', score=0.0, logprobs=None,
|
93
|
+
usage=language_model.UsageNotAvailable()
|
94
|
+
)
|
91
95
|
)
|
92
96
|
self.assertEqual(r.tags, ['lm-response', 'lm-output'])
|
93
97
|
self.assertEqual(r.source, message.UserMessage('Hello'))
|
@@ -112,7 +116,11 @@ class LangFuncCallTest(unittest.TestCase):
|
|
112
116
|
self.assertEqual(l.render(), 'Hello')
|
113
117
|
r = l()
|
114
118
|
self.assertEqual(
|
115
|
-
r,
|
119
|
+
r,
|
120
|
+
message.AIMessage(
|
121
|
+
'Hello!!!', score=0.0, logprobs=None,
|
122
|
+
usage=language_model.UsageNotAvailable()
|
123
|
+
)
|
116
124
|
)
|
117
125
|
self.assertEqual(r.tags, ['lm-response', 'lm-output'])
|
118
126
|
|
langfun/core/language_model.py
CHANGED
@@ -93,6 +93,16 @@ class LMSamplingUsage(pg.Object):
|
|
93
93
|
)
|
94
94
|
|
95
95
|
|
96
|
+
class UsageNotAvailable(LMSamplingUsage):
|
97
|
+
"""Usage information not available."""
|
98
|
+
prompt_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
|
99
|
+
completion_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
|
100
|
+
total_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
|
101
|
+
|
102
|
+
def __bool__(self) -> bool:
|
103
|
+
return False
|
104
|
+
|
105
|
+
|
96
106
|
class LMSamplingResult(pg.Object):
|
97
107
|
"""Language model response."""
|
98
108
|
|
@@ -105,9 +115,9 @@ class LMSamplingResult(pg.Object):
|
|
105
115
|
] = []
|
106
116
|
|
107
117
|
usage: Annotated[
|
108
|
-
LMSamplingUsage
|
118
|
+
LMSamplingUsage,
|
109
119
|
'Usage information. Currently only OpenAI models are supported.',
|
110
|
-
] =
|
120
|
+
] = UsageNotAvailable()
|
111
121
|
|
112
122
|
|
113
123
|
class LMSamplingOptions(component.Component):
|
@@ -406,7 +416,7 @@ class LanguageModel(component.Component):
|
|
406
416
|
# which is accurate when n=1. For n > 1, we average the usage across
|
407
417
|
# multiple samples.
|
408
418
|
usage = result.usage
|
409
|
-
if len(result.samples) == 1 or usage
|
419
|
+
if len(result.samples) == 1 or not usage:
|
410
420
|
response.metadata.usage = usage
|
411
421
|
else:
|
412
422
|
n = len(result.samples)
|
@@ -417,14 +427,11 @@ class LanguageModel(component.Component):
|
|
417
427
|
)
|
418
428
|
|
419
429
|
# Track usage.
|
420
|
-
|
421
|
-
|
430
|
+
trackers = component.context_value('__usage_trackers__', [])
|
431
|
+
if trackers:
|
422
432
|
model_id = self.model_id
|
423
|
-
for
|
424
|
-
|
425
|
-
usage_dict[model_id] += usage
|
426
|
-
else:
|
427
|
-
usage_dict[model_id] = usage
|
433
|
+
for tracker in trackers:
|
434
|
+
tracker.track(model_id, usage)
|
428
435
|
|
429
436
|
# Track the prompt for corresponding response.
|
430
437
|
response.source = prompt
|
@@ -529,7 +536,7 @@ class LanguageModel(component.Component):
|
|
529
536
|
prompt: message_lib.Message,
|
530
537
|
response: message_lib.Message,
|
531
538
|
call_counter: int,
|
532
|
-
usage: LMSamplingUsage
|
539
|
+
usage: LMSamplingUsage,
|
533
540
|
elapse: float,
|
534
541
|
) -> None:
|
535
542
|
"""Outputs debugging information."""
|
@@ -547,12 +554,13 @@ class LanguageModel(component.Component):
|
|
547
554
|
self._debug_response(response, call_counter, usage, elapse)
|
548
555
|
|
549
556
|
def _debug_model_info(
|
550
|
-
self, call_counter: int, usage: LMSamplingUsage
|
557
|
+
self, call_counter: int, usage: LMSamplingUsage) -> None:
|
551
558
|
"""Outputs debugging information about the model."""
|
552
559
|
title_suffix = ''
|
553
|
-
if usage
|
560
|
+
if usage.total_tokens != 0:
|
554
561
|
title_suffix = console.colored(
|
555
|
-
f' (total {usage.total_tokens} tokens)', 'red'
|
562
|
+
f' (total {usage.total_tokens} tokens)', 'red'
|
563
|
+
)
|
556
564
|
|
557
565
|
console.write(
|
558
566
|
self.format(compact=True, use_inferred=True),
|
@@ -564,11 +572,11 @@ class LanguageModel(component.Component):
|
|
564
572
|
self,
|
565
573
|
prompt: message_lib.Message,
|
566
574
|
call_counter: int,
|
567
|
-
usage: LMSamplingUsage
|
575
|
+
usage: LMSamplingUsage,
|
568
576
|
) -> None:
|
569
577
|
"""Outputs debugging information about the prompt."""
|
570
578
|
title_suffix = ''
|
571
|
-
if usage
|
579
|
+
if usage.prompt_tokens != 0:
|
572
580
|
title_suffix = console.colored(f' ({usage.prompt_tokens} tokens)', 'red')
|
573
581
|
|
574
582
|
console.write(
|
@@ -592,12 +600,12 @@ class LanguageModel(component.Component):
|
|
592
600
|
self,
|
593
601
|
response: message_lib.Message,
|
594
602
|
call_counter: int,
|
595
|
-
usage: LMSamplingUsage
|
603
|
+
usage: LMSamplingUsage,
|
596
604
|
elapse: float
|
597
605
|
) -> None:
|
598
606
|
"""Outputs debugging information about the response."""
|
599
607
|
title_suffix = ' ('
|
600
|
-
if usage
|
608
|
+
if usage.completion_tokens != 0:
|
601
609
|
title_suffix += f'{usage.completion_tokens} tokens '
|
602
610
|
title_suffix += f'in {elapse:.2f} seconds)'
|
603
611
|
title_suffix = console.colored(title_suffix, 'red')
|
@@ -659,7 +667,7 @@ class LanguageModel(component.Component):
|
|
659
667
|
debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
|
660
668
|
|
661
669
|
if debug & LMDebugMode.INFO:
|
662
|
-
self._debug_model_info(call_counter,
|
670
|
+
self._debug_model_info(call_counter, UsageNotAvailable())
|
663
671
|
|
664
672
|
if debug & LMDebugMode.PROMPT:
|
665
673
|
console.write(
|
@@ -712,13 +720,58 @@ class LanguageModel(component.Component):
|
|
712
720
|
return DEFAULT_MAX_CONCURRENCY # Default of 1
|
713
721
|
|
714
722
|
|
723
|
+
class _UsageTracker:
|
724
|
+
"""Usage tracker."""
|
725
|
+
|
726
|
+
def __init__(self, model_ids: set[str] | None):
|
727
|
+
self.model_ids = model_ids
|
728
|
+
self.usages = {
|
729
|
+
m: LMSamplingUsage(0, 0, 0) for m in model_ids
|
730
|
+
} if model_ids else {}
|
731
|
+
|
732
|
+
def track(self, model_id: str, usage: LMSamplingUsage):
|
733
|
+
if self.model_ids is not None and model_id not in self.model_ids:
|
734
|
+
return
|
735
|
+
if not isinstance(usage, UsageNotAvailable) and model_id in self.usages:
|
736
|
+
self.usages[model_id] += usage
|
737
|
+
else:
|
738
|
+
self.usages[model_id] = usage
|
739
|
+
|
740
|
+
|
715
741
|
@contextlib.contextmanager
|
716
|
-
def track_usages(
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
742
|
+
def track_usages(
|
743
|
+
*lm: Union[str, LanguageModel]
|
744
|
+
) -> Iterator[dict[str, LMSamplingUsage]]:
|
745
|
+
"""Context manager to track the usages of all language models in scope.
|
746
|
+
|
747
|
+
`lf.track_usages` works with threads spawned by `lf.concurrent_map` and
|
748
|
+
`lf.concurrent_execute`.
|
749
|
+
|
750
|
+
Example:
|
751
|
+
```
|
752
|
+
lm = lf.llms.GeminiPro1()
|
753
|
+
with lf.track_usages() as usages:
|
754
|
+
# invoke any code that will call LLMs.
|
755
|
+
|
756
|
+
print(usages[lm.model_id])
|
757
|
+
```
|
758
|
+
|
759
|
+
Args:
|
760
|
+
*lm: The language model(s) to track. If None, track all models in scope.
|
761
|
+
|
762
|
+
Yields:
|
763
|
+
A dictionary of model ID to usage. If a model does not supports usage
|
764
|
+
counting, the dict entry will be None.
|
765
|
+
"""
|
766
|
+
if not lm:
|
767
|
+
model_ids = None
|
768
|
+
else:
|
769
|
+
model_ids = [m.model_id if isinstance(m, LanguageModel) else m for m in lm]
|
770
|
+
|
771
|
+
trackers = component.context_value('__usage_trackers__', [])
|
772
|
+
tracker = _UsageTracker(set(model_ids) if model_ids else None)
|
773
|
+
with component.context(__usage_trackers__=trackers + [tracker]):
|
721
774
|
try:
|
722
|
-
yield
|
775
|
+
yield tracker.usages
|
723
776
|
finally:
|
724
777
|
pass
|
@@ -27,8 +27,8 @@ import pyglove as pg
|
|
27
27
|
@pg.use_init_args(['failures_before_attempt'])
|
28
28
|
class MockModel(lm_lib.LanguageModel):
|
29
29
|
"""A mock model that echo back user prompts."""
|
30
|
-
|
31
30
|
failures_before_attempt: int = 0
|
31
|
+
name: str = 'MockModel'
|
32
32
|
|
33
33
|
def _sample(self,
|
34
34
|
prompts: list[message_lib.Message]
|
@@ -63,6 +63,10 @@ class MockModel(lm_lib.LanguageModel):
|
|
63
63
|
retry_interval=1,
|
64
64
|
)(prompts)
|
65
65
|
|
66
|
+
@property
|
67
|
+
def model_id(self) -> str:
|
68
|
+
return self.name
|
69
|
+
|
66
70
|
|
67
71
|
class MockScoringModel(MockModel):
|
68
72
|
|
@@ -581,17 +585,33 @@ class LanguageModelTest(unittest.TestCase):
|
|
581
585
|
self.assertEqual(lm.rate_to_max_concurrency(tokens_per_min=1), 1)
|
582
586
|
|
583
587
|
def test_track_usages(self):
|
588
|
+
lm = MockModel(name='model1')
|
589
|
+
lm2 = MockModel(name='model2')
|
584
590
|
with lm_lib.track_usages() as usages1:
|
585
|
-
lm = MockModel()
|
586
591
|
_ = lm('hi')
|
587
|
-
with lm_lib.track_usages() as usages2:
|
588
|
-
|
592
|
+
with lm_lib.track_usages(lm2) as usages2:
|
593
|
+
with lm_lib.track_usages('model1') as usages3:
|
594
|
+
with lm_lib.track_usages('model1', lm2) as usages4:
|
595
|
+
def call_lm(prompt):
|
596
|
+
_ = lm.sample([prompt] * 2)
|
597
|
+
lm2('hi')
|
598
|
+
list(concurrent.concurrent_map(call_lm, ['hi', 'hello']))
|
599
|
+
|
589
600
|
self.assertEqual(usages2, {
|
590
|
-
'
|
601
|
+
'model2': lm_lib.LMSamplingUsage(100, 100, 200),
|
602
|
+
})
|
603
|
+
self.assertEqual(usages3, {
|
604
|
+
'model1': lm_lib.LMSamplingUsage(100 * 4, 100 * 4, 200 * 4),
|
605
|
+
})
|
606
|
+
self.assertEqual(usages4, {
|
607
|
+
'model1': lm_lib.LMSamplingUsage(100 * 4, 100 * 4, 200 * 4),
|
608
|
+
'model2': lm_lib.LMSamplingUsage(100, 100, 200),
|
591
609
|
})
|
592
610
|
self.assertEqual(usages1, {
|
593
|
-
'
|
611
|
+
'model1': lm_lib.LMSamplingUsage(100 * 5, 100 * 5, 200 * 5),
|
612
|
+
'model2': lm_lib.LMSamplingUsage(100, 100, 200),
|
594
613
|
})
|
595
614
|
|
615
|
+
|
596
616
|
if __name__ == '__main__':
|
597
617
|
unittest.main()
|
langfun/core/llms/openai.py
CHANGED
@@ -202,15 +202,14 @@ class OpenAI(lf.LanguageModel):
|
|
202
202
|
lf.LMSample(choice.text.strip(), score=choice.logprobs or 0.0)
|
203
203
|
)
|
204
204
|
|
205
|
+
n = len(samples_by_index)
|
205
206
|
usage = lf.LMSamplingUsage(
|
206
|
-
prompt_tokens=response.usage.prompt_tokens,
|
207
|
-
completion_tokens=response.usage.completion_tokens,
|
208
|
-
total_tokens=response.usage.total_tokens,
|
207
|
+
prompt_tokens=response.usage.prompt_tokens // n,
|
208
|
+
completion_tokens=response.usage.completion_tokens // n,
|
209
|
+
total_tokens=response.usage.total_tokens // n,
|
209
210
|
)
|
210
211
|
return [
|
211
|
-
lf.LMSamplingResult(
|
212
|
-
samples_by_index[index], usage=usage if index == 0 else None
|
213
|
-
)
|
212
|
+
lf.LMSamplingResult(samples_by_index[index], usage=usage)
|
214
213
|
for index in sorted(samples_by_index.keys())
|
215
214
|
]
|
216
215
|
|
langfun/core/llms/openai_test.py
CHANGED
@@ -191,9 +191,9 @@ class OpenAITest(unittest.TestCase):
|
|
191
191
|
score=0.0,
|
192
192
|
logprobs=None,
|
193
193
|
usage=lf.LMSamplingUsage(
|
194
|
-
prompt_tokens=
|
195
|
-
completion_tokens=
|
196
|
-
total_tokens=
|
194
|
+
prompt_tokens=16,
|
195
|
+
completion_tokens=16,
|
196
|
+
total_tokens=33
|
197
197
|
),
|
198
198
|
tags=[lf.Message.TAG_LM_RESPONSE],
|
199
199
|
),
|
@@ -206,9 +206,9 @@ class OpenAITest(unittest.TestCase):
|
|
206
206
|
score=0.1,
|
207
207
|
logprobs=None,
|
208
208
|
usage=lf.LMSamplingUsage(
|
209
|
-
prompt_tokens=
|
210
|
-
completion_tokens=
|
211
|
-
total_tokens=
|
209
|
+
prompt_tokens=16,
|
210
|
+
completion_tokens=16,
|
211
|
+
total_tokens=33
|
212
212
|
),
|
213
213
|
tags=[lf.Message.TAG_LM_RESPONSE],
|
214
214
|
),
|
@@ -221,9 +221,9 @@ class OpenAITest(unittest.TestCase):
|
|
221
221
|
score=0.2,
|
222
222
|
logprobs=None,
|
223
223
|
usage=lf.LMSamplingUsage(
|
224
|
-
prompt_tokens=
|
225
|
-
completion_tokens=
|
226
|
-
total_tokens=
|
224
|
+
prompt_tokens=16,
|
225
|
+
completion_tokens=16,
|
226
|
+
total_tokens=33
|
227
227
|
),
|
228
228
|
tags=[lf.Message.TAG_LM_RESPONSE],
|
229
229
|
),
|
@@ -232,7 +232,7 @@ class OpenAITest(unittest.TestCase):
|
|
232
232
|
),
|
233
233
|
],
|
234
234
|
usage=lf.LMSamplingUsage(
|
235
|
-
prompt_tokens=
|
235
|
+
prompt_tokens=50, completion_tokens=50, total_tokens=100
|
236
236
|
),
|
237
237
|
),
|
238
238
|
)
|
@@ -245,7 +245,11 @@ class OpenAITest(unittest.TestCase):
|
|
245
245
|
'Sample 0 for prompt 1.',
|
246
246
|
score=0.0,
|
247
247
|
logprobs=None,
|
248
|
-
usage=
|
248
|
+
usage=lf.LMSamplingUsage(
|
249
|
+
prompt_tokens=16,
|
250
|
+
completion_tokens=16,
|
251
|
+
total_tokens=33
|
252
|
+
),
|
249
253
|
tags=[lf.Message.TAG_LM_RESPONSE],
|
250
254
|
),
|
251
255
|
score=0.0,
|
@@ -256,7 +260,11 @@ class OpenAITest(unittest.TestCase):
|
|
256
260
|
'Sample 1 for prompt 1.',
|
257
261
|
score=0.1,
|
258
262
|
logprobs=None,
|
259
|
-
usage=
|
263
|
+
usage=lf.LMSamplingUsage(
|
264
|
+
prompt_tokens=16,
|
265
|
+
completion_tokens=16,
|
266
|
+
total_tokens=33
|
267
|
+
),
|
260
268
|
tags=[lf.Message.TAG_LM_RESPONSE],
|
261
269
|
),
|
262
270
|
score=0.1,
|
@@ -267,13 +275,20 @@ class OpenAITest(unittest.TestCase):
|
|
267
275
|
'Sample 2 for prompt 1.',
|
268
276
|
score=0.2,
|
269
277
|
logprobs=None,
|
270
|
-
usage=
|
278
|
+
usage=lf.LMSamplingUsage(
|
279
|
+
prompt_tokens=16,
|
280
|
+
completion_tokens=16,
|
281
|
+
total_tokens=33
|
282
|
+
),
|
271
283
|
tags=[lf.Message.TAG_LM_RESPONSE],
|
272
284
|
),
|
273
285
|
score=0.2,
|
274
286
|
logprobs=None,
|
275
287
|
),
|
276
288
|
],
|
289
|
+
usage=lf.LMSamplingUsage(
|
290
|
+
prompt_tokens=50, completion_tokens=50, total_tokens=100
|
291
|
+
),
|
277
292
|
),
|
278
293
|
)
|
279
294
|
|
langfun/core/llms/rest_test.py
CHANGED
@@ -89,7 +89,7 @@ class RestTest(unittest.TestCase):
|
|
89
89
|
"max_tokens=4096, stop=['\\n']."
|
90
90
|
),
|
91
91
|
)
|
92
|
-
self.
|
92
|
+
self.assertEqual(response.usage, lf.UsageNotAvailable())
|
93
93
|
|
94
94
|
def test_call_errors(self):
|
95
95
|
for status_code, error_type, error_message in [
|
@@ -0,0 +1,83 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Helpers for implementing _repr_xxx_ methods."""
|
15
|
+
|
16
|
+
import collections
|
17
|
+
import contextlib
|
18
|
+
import io
|
19
|
+
from typing import Iterator
|
20
|
+
|
21
|
+
from langfun.core import component
|
22
|
+
|
23
|
+
|
24
|
+
@contextlib.contextmanager
|
25
|
+
def share_parts() -> Iterator[dict[str, int]]:
|
26
|
+
"""Context manager for defining the context (scope) of shared content.
|
27
|
+
|
28
|
+
Under the context manager, call to `lf.write_shared` with the same content
|
29
|
+
will be written only once. This is useful for writing shared content such as
|
30
|
+
shared style and script sections in the HTML.
|
31
|
+
|
32
|
+
Example:
|
33
|
+
```
|
34
|
+
class Foo(pg.Object):
|
35
|
+
def _repr_html_(self) -> str:
|
36
|
+
s = io.StringIO()
|
37
|
+
lf.repr_utils.write_shared_part(s, '<style>..</style>')
|
38
|
+
lf.repr_utils.write_shared_part(s, '<script>..</script>')
|
39
|
+
return s.getvalue()
|
40
|
+
|
41
|
+
with lf.repr_utils.share_parts() as share_parts:
|
42
|
+
# The <style> and <script> section will be written only once.
|
43
|
+
lf.console.display(Foo())
|
44
|
+
lf.console.display(Foo())
|
45
|
+
|
46
|
+
# Assert that the shared content is attempted to be written twice.
|
47
|
+
assert share_parts['<style>..</style>'] == 2
|
48
|
+
```
|
49
|
+
|
50
|
+
Yields:
|
51
|
+
A dictionary mapping the shared content to the number of times it is
|
52
|
+
attempted to be written.
|
53
|
+
"""
|
54
|
+
context = component.context_value(
|
55
|
+
'__shared_parts__', collections.defaultdict(int)
|
56
|
+
)
|
57
|
+
with component.context(__shared_parts__=context):
|
58
|
+
try:
|
59
|
+
yield context
|
60
|
+
finally:
|
61
|
+
pass
|
62
|
+
|
63
|
+
|
64
|
+
def write_maybe_shared(s: io.StringIO, content: str) -> bool:
|
65
|
+
"""Writes a maybe shared part to an string stream.
|
66
|
+
|
67
|
+
Args:
|
68
|
+
s: The string stream to write to.
|
69
|
+
content: A maybe shared content to write.
|
70
|
+
|
71
|
+
Returns:
|
72
|
+
True if the content is written to the string. False if the content is
|
73
|
+
already written under the same share context.
|
74
|
+
"""
|
75
|
+
context = component.context_value('__shared_parts__', None)
|
76
|
+
if context is None:
|
77
|
+
s.write(content)
|
78
|
+
return True
|
79
|
+
written = content in context
|
80
|
+
if not written:
|
81
|
+
s.write(content)
|
82
|
+
context[content] += 1
|
83
|
+
return not written
|
@@ -0,0 +1,58 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Tests for langfun.core.repr_utils."""
|
15
|
+
|
16
|
+
import io
|
17
|
+
import unittest
|
18
|
+
|
19
|
+
from langfun.core import repr_utils
|
20
|
+
|
21
|
+
|
22
|
+
class SharingContentTest(unittest.TestCase):
|
23
|
+
|
24
|
+
def test_sharing(self):
|
25
|
+
s = io.StringIO()
|
26
|
+
|
27
|
+
self.assertTrue(repr_utils.write_maybe_shared(s, '<hr>'))
|
28
|
+
self.assertTrue(repr_utils.write_maybe_shared(s, '<hr>'))
|
29
|
+
|
30
|
+
with repr_utils.share_parts() as ctx1:
|
31
|
+
self.assertTrue(repr_utils.write_maybe_shared(s, '<style></style>'))
|
32
|
+
self.assertFalse(repr_utils.write_maybe_shared(s, '<style></style>'))
|
33
|
+
|
34
|
+
with repr_utils.share_parts() as ctx2:
|
35
|
+
self.assertIs(ctx2, ctx1)
|
36
|
+
self.assertFalse(repr_utils.write_maybe_shared(s, '<style></style>'))
|
37
|
+
self.assertTrue(repr_utils.write_maybe_shared(s, '<style>a</style>'))
|
38
|
+
self.assertFalse(repr_utils.write_maybe_shared(s, '<style>a</style>'))
|
39
|
+
self.assertTrue(repr_utils.write_maybe_shared(s, '<style>b</style>'))
|
40
|
+
|
41
|
+
with repr_utils.share_parts() as ctx3:
|
42
|
+
self.assertIs(ctx3, ctx1)
|
43
|
+
self.assertFalse(repr_utils.write_maybe_shared(s, '<style></style>'))
|
44
|
+
self.assertFalse(repr_utils.write_maybe_shared(s, '<style>a</style>'))
|
45
|
+
self.assertFalse(repr_utils.write_maybe_shared(s, '<style>a</style>'))
|
46
|
+
self.assertFalse(repr_utils.write_maybe_shared(s, '<style>b</style>'))
|
47
|
+
|
48
|
+
self.assertEqual(
|
49
|
+
s.getvalue(),
|
50
|
+
'<hr><hr><style></style><style>a</style><style>b</style>'
|
51
|
+
)
|
52
|
+
self.assertEqual(ctx1['<style></style>'], 4)
|
53
|
+
self.assertEqual(ctx1['<style>b</style>'], 2)
|
54
|
+
self.assertEqual(ctx1['<style>a</style>'], 4)
|
55
|
+
|
56
|
+
|
57
|
+
if __name__ == '__main__':
|
58
|
+
unittest.main()
|
@@ -57,7 +57,10 @@ class SelfPlayTest(unittest.TestCase):
|
|
57
57
|
|
58
58
|
with lf.context(lm=NumberGuesser(guesses=[50, 20, 5, 10])):
|
59
59
|
self.assertEqual(
|
60
|
-
g(),
|
60
|
+
g(),
|
61
|
+
lf.AIMessage(
|
62
|
+
'10', score=0.0, logprobs=None, usage=lf.UsageNotAvailable()
|
63
|
+
)
|
61
64
|
)
|
62
65
|
|
63
66
|
self.assertEqual(g.num_turns, 4)
|
@@ -67,7 +70,10 @@ class SelfPlayTest(unittest.TestCase):
|
|
67
70
|
|
68
71
|
with lf.context(lm=NumberGuesser(guesses=[50, 20, 5, 2, 5, 4])):
|
69
72
|
self.assertEqual(
|
70
|
-
g(),
|
73
|
+
g(),
|
74
|
+
lf.AIMessage(
|
75
|
+
'2', score=0.0, logprobs=None, usage=lf.UsageNotAvailable()
|
76
|
+
)
|
71
77
|
)
|
72
78
|
|
73
79
|
self.assertEqual(g.num_turns, 10)
|
@@ -1,5 +1,5 @@
|
|
1
1
|
langfun/__init__.py,sha256=P62MnqA6-f0h8iYfQ3MT6Yg7a4qRnQeb4GrIn6dcSnY,2274
|
2
|
-
langfun/core/__init__.py,sha256=
|
2
|
+
langfun/core/__init__.py,sha256=Mdp1a2YnXdSmfTfbUwuAnEWYbjA3rXXGtbxl5fljZyg,4812
|
3
3
|
langfun/core/component.py,sha256=Icyoj9ICoJoK2r2PHbrFXbxnseOr9QZZOvKWklLWNo8,10276
|
4
4
|
langfun/core/component_test.py,sha256=q15Xn51cVTu2RKxZ9U5VQgT3bm6RQ4638bKhWBtvW5o,8220
|
5
5
|
langfun/core/concurrent.py,sha256=TRc49pJ3HQro2kb5FtcWkHjhBm8UcgE8RJybU5cU3-0,24537
|
@@ -7,9 +7,9 @@ langfun/core/concurrent_test.py,sha256=mwFMZhDUdppnDr7vDSTwcbMHwrdsIoKJwRYNtl4ZW
|
|
7
7
|
langfun/core/console.py,sha256=bk5rNPNm9rMGW5YT2HixxU04p2umnoabn5SDz6Dqe88,2317
|
8
8
|
langfun/core/console_test.py,sha256=5SYJdxpJGLgdSSQqqMPoA1X6jpsLD8rgcyk-EgI65oE,1077
|
9
9
|
langfun/core/langfunc.py,sha256=RvIcRjIq0jWYRu1xim-FYe4HSrt97r3GMBO_PuagUmw,11060
|
10
|
-
langfun/core/langfunc_test.py,sha256=
|
11
|
-
langfun/core/language_model.py,sha256=
|
12
|
-
langfun/core/language_model_test.py,sha256=
|
10
|
+
langfun/core/langfunc_test.py,sha256=lyt-UzkD8972cxZwzCkps0_RMLeSsOBrcUFIW-fB6us,8653
|
11
|
+
langfun/core/language_model.py,sha256=q41EGhbgoCe68eYWWTAzgt3r6SbdCpINRdXnNGstukQ,23775
|
12
|
+
langfun/core/language_model_test.py,sha256=HmA0GACK4-6tCH32TFkfYj9w4CxbrynKqXdKBgiqgwo,21255
|
13
13
|
langfun/core/logging.py,sha256=FyZRxUy2TTF6tWLhQCRpCvfH55WGUdNgQjUTK_SQLnY,5320
|
14
14
|
langfun/core/logging_test.py,sha256=qvm3RObYP3knO2PnXR9evBRl4gH621GnjnwywbGbRfg,1833
|
15
15
|
langfun/core/memory.py,sha256=f-asN1F7Vehgdn_fK84v73GrEUOxRtaW934keutTKjk,2416
|
@@ -19,6 +19,8 @@ langfun/core/modality.py,sha256=Tla4t86DUYHpbZ2G7dy1r19fTj_Ga5XOvlYp6lbWa-Q,3512
|
|
19
19
|
langfun/core/modality_test.py,sha256=HyZ5xONKQ0Fw18SzoWAq-Ob9njOXIIjBo1hNtw-rudw,2400
|
20
20
|
langfun/core/natural_language.py,sha256=3ynSnaYQnjE60LIPK5fyMgdIjubnPYZwzGq4rWPeloE,1177
|
21
21
|
langfun/core/natural_language_test.py,sha256=LHGU_1ytbkGuSZQFIFP7vP3dBlcY4-A12fT6dbjUA0E,1424
|
22
|
+
langfun/core/repr_utils.py,sha256=HrN7FoGUvpTlv5aL_XISouwZN84z9LmrB6_2jEn1ukc,2590
|
23
|
+
langfun/core/repr_utils_test.py,sha256=-XId1A72Vbzo289dYuxC6TegNXuZhI28WbNrm1ghiwc,2206
|
22
24
|
langfun/core/sampling.py,sha256=vygWvgC8MFw0_AKNSmz-ywMXJYWf8cl0tI8QycvAmyI,5795
|
23
25
|
langfun/core/sampling_test.py,sha256=U7PANpMsl9E_pa4_Y4FzesSjcwg-u-LKHGCWSgv-8FY,3663
|
24
26
|
langfun/core/subscription.py,sha256=euawEuSZP-BHydaT-AQpfYFL0m5pWPGcW0upFhrojqc,10930
|
@@ -42,7 +44,7 @@ langfun/core/coding/python/parsing_test.py,sha256=9vAWF484kWIm6JZq8NFiMgKUDhXV-d
|
|
42
44
|
langfun/core/coding/python/permissions.py,sha256=1QWGHvzL8MM0Ok_auQ9tURqZHtdOfJaDpBzZ29GUE-c,2544
|
43
45
|
langfun/core/coding/python/permissions_test.py,sha256=w5EDb8QxpxgJyZkojyzVWQvDfg366zn99-g__6TbPQ0,2699
|
44
46
|
langfun/core/eval/__init__.py,sha256=Evt-E4FEhZF2tXL6-byh_AyA7Cc_ZoGmvnN7vkAZedk,1898
|
45
|
-
langfun/core/eval/base.py,sha256=
|
47
|
+
langfun/core/eval/base.py,sha256=GM98Zo4gxZui2ORX6Q7Zr94PfiEViQC5X_qz-uj6b2k,74220
|
46
48
|
langfun/core/eval/base_test.py,sha256=cHOTIWVW4Dp8gKKIKcZrAcJ-w84j2GIozTzJoiAX7p4,26743
|
47
49
|
langfun/core/eval/matching.py,sha256=Y4vFoNTQEOwko6IA8l9OZ52-vt52e3VGmcTtvLA67wM,9782
|
48
50
|
langfun/core/eval/matching_test.py,sha256=f7iVyXH5KGJBWt4Wp14Bt9J3X59A6Ayfog9MbuFvPew,5532
|
@@ -61,10 +63,10 @@ langfun/core/llms/groq.py,sha256=pqtyOZ_1_OJMOg8xATWT_B_SVbuT9nMRf4VkH9GzW8g,630
|
|
61
63
|
langfun/core/llms/groq_test.py,sha256=GYF_Qtq5S1H1TrKH38t6_lkdroqT7v-joYLDKnmS9e0,5274
|
62
64
|
langfun/core/llms/llama_cpp.py,sha256=9tXQntSCDtjTF3bnyJrAPCr4N6wycy5nXYvp9uduygE,2843
|
63
65
|
langfun/core/llms/llama_cpp_test.py,sha256=MWO_qaOeKjRniGjcaWPDScd7HPaIJemqUZoslrt4FPs,1806
|
64
|
-
langfun/core/llms/openai.py,sha256=
|
65
|
-
langfun/core/llms/openai_test.py,sha256=
|
66
|
+
langfun/core/llms/openai.py,sha256=0z9qIH9FlWj9VWUnhOX321T6JHO-vjY2IozT7OVI4GY,13654
|
67
|
+
langfun/core/llms/openai_test.py,sha256=3muDTnW7UBOSHq694Fi2bofqhe8Pkj0Tl8IShoLCTOM,15525
|
66
68
|
langfun/core/llms/rest.py,sha256=laopuq-zD8V-3Y6eFDngftHEbE66VlUkCD2-rvvRaLU,3388
|
67
|
-
langfun/core/llms/rest_test.py,sha256=
|
69
|
+
langfun/core/llms/rest_test.py,sha256=NZ3Nf0XQVpT9kLP5cBVo_yBHLI7vWTYhWQxYEJVMGs4,3472
|
68
70
|
langfun/core/llms/vertexai.py,sha256=wIpckH-rMHUBA1vhauQk4LVrSsPQEsVntz7kLDKwm9g,11359
|
69
71
|
langfun/core/llms/vertexai_test.py,sha256=G18BG36h5KvmX2zutDTLjtYCRjTuP_nWIFm4FMnLnyY,7651
|
70
72
|
langfun/core/llms/cache/__init__.py,sha256=QAo3InUMDM_YpteNnVCSejI4zOsnjSMWKJKzkb3VY64,993
|
@@ -114,9 +116,9 @@ langfun/core/templates/conversation_test.py,sha256=RryYyIhfc34dLWOs6GfPQ8HU8mXpK
|
|
114
116
|
langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fikKhwhzwhpKI,1460
|
115
117
|
langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
|
116
118
|
langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
|
117
|
-
langfun/core/templates/selfplay_test.py,sha256=
|
118
|
-
langfun-0.0.2.
|
119
|
-
langfun-0.0.2.
|
120
|
-
langfun-0.0.2.
|
121
|
-
langfun-0.0.2.
|
122
|
-
langfun-0.0.2.
|
119
|
+
langfun/core/templates/selfplay_test.py,sha256=rBW2Qr8yi-aWYwoTwRR-n1peKyMX9QXPZXURjLgoiRs,2264
|
120
|
+
langfun-0.0.2.dev20240609.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
121
|
+
langfun-0.0.2.dev20240609.dist-info/METADATA,sha256=uyk0OoHbseBguRcn-EDVaDAun1goszVD5Dt1WO1BbZc,3550
|
122
|
+
langfun-0.0.2.dev20240609.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
123
|
+
langfun-0.0.2.dev20240609.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
|
124
|
+
langfun-0.0.2.dev20240609.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|