langfun 0.0.2.dev20240516__py3-none-any.whl → 0.0.2.dev20240520__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/core/language_model.py +25 -13
- langfun/core/language_model_test.py +16 -1
- langfun/core/llms/fake.py +2 -1
- langfun/core/structured/scoring.py +22 -10
- {langfun-0.0.2.dev20240516.dist-info → langfun-0.0.2.dev20240520.dist-info}/METADATA +1 -1
- {langfun-0.0.2.dev20240516.dist-info → langfun-0.0.2.dev20240520.dist-info}/RECORD +9 -9
- {langfun-0.0.2.dev20240516.dist-info → langfun-0.0.2.dev20240520.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240516.dist-info → langfun-0.0.2.dev20240520.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240516.dist-info → langfun-0.0.2.dev20240520.dist-info}/top_level.txt +0 -0
langfun/core/language_model.py
CHANGED
@@ -566,12 +566,19 @@ class LanguageModel(component.Component):
|
|
566
566
|
|
567
567
|
def score(
|
568
568
|
self,
|
569
|
-
prompt: str | message_lib.Message,
|
569
|
+
prompt: str | message_lib.Message | list[message_lib.Message],
|
570
570
|
completions: list[str | message_lib.Message],
|
571
571
|
**kwargs,
|
572
572
|
) -> list[LMScoringResult]:
|
573
573
|
"""Scores the given prompt."""
|
574
|
-
|
574
|
+
if isinstance(prompt, list):
|
575
|
+
if len(prompt) != len(completions):
|
576
|
+
raise ValueError(
|
577
|
+
'prompt and completions must have the same length.'
|
578
|
+
)
|
579
|
+
prompt = [message_lib.UserMessage.from_value(p) for p in prompt]
|
580
|
+
else:
|
581
|
+
prompt = message_lib.UserMessage.from_value(prompt)
|
575
582
|
completions = [message_lib.UserMessage.from_value(c) for c in completions]
|
576
583
|
|
577
584
|
call_counter = self._call_counter
|
@@ -587,7 +594,8 @@ class LanguageModel(component.Component):
|
|
587
594
|
return scoring_results
|
588
595
|
|
589
596
|
def _score(
|
590
|
-
self, prompt: message_lib.Message
|
597
|
+
self, prompt: message_lib.Message | list[message_lib.Message],
|
598
|
+
completions: list[message_lib.Message]
|
591
599
|
) -> list[LMScoringResult]:
|
592
600
|
"""Subclass to implement."""
|
593
601
|
raise NotImplementedError(
|
@@ -596,7 +604,7 @@ class LanguageModel(component.Component):
|
|
596
604
|
|
597
605
|
def _debug_score(
|
598
606
|
self,
|
599
|
-
prompt: message_lib.Message,
|
607
|
+
prompt: message_lib.Message | list[message_lib.Message],
|
600
608
|
completions: list[message_lib.Message],
|
601
609
|
scoring_results: list[LMScoringResult],
|
602
610
|
call_counter: int,
|
@@ -615,15 +623,19 @@ class LanguageModel(component.Component):
|
|
615
623
|
title=f'\n[{call_counter}] SCORING LM WITH PROMPT:',
|
616
624
|
color='green',
|
617
625
|
)
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
626
|
+
if isinstance(prompt, list):
|
627
|
+
referred_modalities_lst = [p.referred_modalities() for p in prompt]
|
628
|
+
else:
|
629
|
+
referred_modalities_lst = [prompt.referred_modalities(),]
|
630
|
+
if referred_modalities_lst:
|
631
|
+
for referred_modalities in referred_modalities_lst:
|
632
|
+
console.write(
|
633
|
+
pg.object_utils.kvlist_str(
|
634
|
+
[(k, repr(v), None) for k, v in referred_modalities.items()]
|
635
|
+
),
|
636
|
+
title=f'\n[{call_counter}] MODALITY OBJECTS SENT TO LM:',
|
637
|
+
color='green',
|
638
|
+
)
|
627
639
|
|
628
640
|
if debug & LMDebugMode.RESPONSE:
|
629
641
|
console.write(
|
@@ -68,7 +68,7 @@ class MockScoringModel(MockModel):
|
|
68
68
|
|
69
69
|
def _score(
|
70
70
|
self,
|
71
|
-
prompt: message_lib.Message,
|
71
|
+
prompt: message_lib.Message | list[message_lib.Message],
|
72
72
|
completions: list[message_lib.Message],
|
73
73
|
**kwargs
|
74
74
|
) -> list[lm_lib.LMScoringResult]:
|
@@ -508,6 +508,17 @@ class LanguageModelTest(unittest.TestCase):
|
|
508
508
|
],
|
509
509
|
)
|
510
510
|
|
511
|
+
self.assertEqual(
|
512
|
+
lm.score(
|
513
|
+
[message_lib.UserMessage('hi {{image}}', image=Image()),
|
514
|
+
message_lib.UserMessage('hi {{image}}', image=Image())],
|
515
|
+
['1', '2'], debug=debug_mode),
|
516
|
+
[
|
517
|
+
lm_lib.LMScoringResult(score=-0.0),
|
518
|
+
lm_lib.LMScoringResult(score=-1.0),
|
519
|
+
],
|
520
|
+
)
|
521
|
+
|
511
522
|
debug_info = string_io.getvalue()
|
512
523
|
expected_included = [
|
513
524
|
debug_prints[f]
|
@@ -528,6 +539,10 @@ class LanguageModelTest(unittest.TestCase):
|
|
528
539
|
if debug_mode & lm_lib.LMDebugMode.PROMPT:
|
529
540
|
self.assertIn('[0] MODALITY OBJECTS SENT TO LM', debug_info)
|
530
541
|
|
542
|
+
def test_score_with_unmatched_prompt_and_completions(self):
|
543
|
+
with self.assertRaises(ValueError):
|
544
|
+
MockScoringModel().score(['hi',], ['1', '2', '3'])
|
545
|
+
|
531
546
|
def test_score_with_unsupported_model(self):
|
532
547
|
with self.assertRaises(NotImplementedError):
|
533
548
|
MockModel().score('hi', ['1', '2'])
|
langfun/core/llms/fake.py
CHANGED
@@ -21,7 +21,8 @@ import langfun.core as lf
|
|
21
21
|
class Fake(lf.LanguageModel):
|
22
22
|
"""The base class for all fake language models."""
|
23
23
|
|
24
|
-
def _score(self, prompt: lf.Message
|
24
|
+
def _score(self, prompt: lf.Message| list[lf.Message],
|
25
|
+
completions: list[lf.Message]):
|
25
26
|
return [lf.LMScoringResult(score=-i * 1.0) for i in range(len(completions))]
|
26
27
|
|
27
28
|
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
@@ -23,7 +23,7 @@ import pyglove as pg
|
|
23
23
|
|
24
24
|
|
25
25
|
def score(
|
26
|
-
prompt: Union[str, pg.Symbolic],
|
26
|
+
prompt: Union[str, pg.Symbolic] | list[str | pg.Symbolic],
|
27
27
|
completions: list[str | pg.Symbolic],
|
28
28
|
schema: Union[
|
29
29
|
schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
|
@@ -49,15 +49,27 @@ def score(
|
|
49
49
|
f'{[type(c) for c in completions]}.'
|
50
50
|
)
|
51
51
|
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
52
|
+
if isinstance(prompt, list):
|
53
|
+
prompts = []
|
54
|
+
for p in prompt:
|
55
|
+
prompts.append(
|
56
|
+
prompting.query_prompt(
|
57
|
+
p,
|
58
|
+
schema,
|
59
|
+
examples=examples,
|
60
|
+
protocol=protocol,
|
61
|
+
**kwargs,
|
62
|
+
)
|
63
|
+
)
|
64
|
+
input_message = prompts
|
65
|
+
else:
|
66
|
+
input_message = prompting.query_prompt(
|
67
|
+
prompt,
|
68
|
+
schema,
|
69
|
+
examples=examples,
|
70
|
+
protocol=protocol,
|
71
|
+
**kwargs,
|
72
|
+
)
|
61
73
|
if lm is None:
|
62
74
|
lm_override = lf.get_contextual_override('lm')
|
63
75
|
if lm_override is None:
|
@@ -8,8 +8,8 @@ langfun/core/console.py,sha256=bk5rNPNm9rMGW5YT2HixxU04p2umnoabn5SDz6Dqe88,2317
|
|
8
8
|
langfun/core/console_test.py,sha256=5SYJdxpJGLgdSSQqqMPoA1X6jpsLD8rgcyk-EgI65oE,1077
|
9
9
|
langfun/core/langfunc.py,sha256=RvIcRjIq0jWYRu1xim-FYe4HSrt97r3GMBO_PuagUmw,11060
|
10
10
|
langfun/core/langfunc_test.py,sha256=_mfARnakX3oji5HDigFSLMd6yQ2wma-2Mgbztwqn73g,8501
|
11
|
-
langfun/core/language_model.py,sha256=
|
12
|
-
langfun/core/language_model_test.py,sha256=
|
11
|
+
langfun/core/language_model.py,sha256=owNCgefGoPeRCHrxBhMtNdOj3orbeVml4eqLf1n211o,20760
|
12
|
+
langfun/core/language_model_test.py,sha256=36evArVJgSQ9lRgHfMmlLW3lwjjDoiAgfTEbk2FIKa4,20122
|
13
13
|
langfun/core/memory.py,sha256=f-asN1F7Vehgdn_fK84v73GrEUOxRtaW934keutTKjk,2416
|
14
14
|
langfun/core/message.py,sha256=QhvV9t5qaryPcruyxxcXi3gm9QDInkSldwTtK6sVJ3c,15734
|
15
15
|
langfun/core/message_test.py,sha256=Z23pUM5vPnDrYkIIibe2KL73D5HKur_awI0ut_EQFQA,9501
|
@@ -51,7 +51,7 @@ langfun/core/eval/scoring_test.py,sha256=O8olHbrUEg60gMxwOkWzKBJZpZoUlmVnBANX5Se
|
|
51
51
|
langfun/core/llms/__init__.py,sha256=h_kam-0fjWISAQ90KZ_ydBhwADVCzrhLPXmAki3GfU0,4175
|
52
52
|
langfun/core/llms/anthropic.py,sha256=7W9YdPN3SlAFhAIQlihMkrpo7tTY_4NvD0KIlCrqcsk,8505
|
53
53
|
langfun/core/llms/anthropic_test.py,sha256=TMM30myyEhwF99Le4RvJEXOn8RYl0q1FRkt9Q9nl1jk,5540
|
54
|
-
langfun/core/llms/fake.py,sha256=
|
54
|
+
langfun/core/llms/fake.py,sha256=Dd7-6ka9pFf3fcWZyczamjOqQ91MOI-m7We3Oc9Ffmo,2927
|
55
55
|
langfun/core/llms/fake_test.py,sha256=ipKfdOcuqVcJ8lDXVpnBVb9HHG0hAVkFkMoHpWjC2cI,7212
|
56
56
|
langfun/core/llms/google_genai.py,sha256=nDI_Adur_K458l6EWoiiAhzjfnjRSqfTiikdu7iLPyU,8808
|
57
57
|
langfun/core/llms/google_genai_test.py,sha256=_UcGTfl16-aDUlEWFC2W2F8y9jPUs53RBYA6MOCpGXw,7525
|
@@ -94,7 +94,7 @@ langfun/core/structured/schema.py,sha256=Zy9y6Vq9DrFwcuP5o5VL_PvMCmzavF-nuDqyviB
|
|
94
94
|
langfun/core/structured/schema_generation.py,sha256=U3nRQsqmMZg_qIVDh2fiY3K4JLfsAL1LcKzIFP1iXFg,5316
|
95
95
|
langfun/core/structured/schema_generation_test.py,sha256=RM9s71kMNg2jTePwInkiW9fK1ACN37eyPeF8OII-0zw,2950
|
96
96
|
langfun/core/structured/schema_test.py,sha256=NgQK1zGSliZVx_Af6gDBTqQxXRHvmAvGARv4dUs8IbI,23078
|
97
|
-
langfun/core/structured/scoring.py,sha256=
|
97
|
+
langfun/core/structured/scoring.py,sha256=QyT1S8FkLtKICfUbh4AXoKK3YJ_rgejyk6TI2OtOa68,2751
|
98
98
|
langfun/core/structured/scoring_test.py,sha256=39_dw6p_FkoqeUccO67yIqos-MccAWezoozS21i8mi0,1732
|
99
99
|
langfun/core/templates/__init__.py,sha256=bO0eMsVJbi7sxEB2YlInKRQ2EVP-RyyKUwcD-8msuN4,927
|
100
100
|
langfun/core/templates/completion.py,sha256=mUqZHOEV3ag6-A08XghpeEltcrBvCDxXP004eDDfeag,1931
|
@@ -105,8 +105,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
|
|
105
105
|
langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
|
106
106
|
langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
|
107
107
|
langfun/core/templates/selfplay_test.py,sha256=DYVrkk7uNKCqJGEHH31HssU2BPuMItU1vJLzfcXIlYg,2156
|
108
|
-
langfun-0.0.2.
|
109
|
-
langfun-0.0.2.
|
110
|
-
langfun-0.0.2.
|
111
|
-
langfun-0.0.2.
|
112
|
-
langfun-0.0.2.
|
108
|
+
langfun-0.0.2.dev20240520.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
109
|
+
langfun-0.0.2.dev20240520.dist-info/METADATA,sha256=ndSicgG9hzjEnpZrs3WR8qP-CxdoXFlTIhGXCbvCaho,3452
|
110
|
+
langfun-0.0.2.dev20240520.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
111
|
+
langfun-0.0.2.dev20240520.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
|
112
|
+
langfun-0.0.2.dev20240520.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|