langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +22 -2
- langfun/core/__init__.py +17 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -28
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +69 -2
- langfun/core/component_test.py +54 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +18 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +17 -0
- langfun/core/eval/base.py +767 -140
- langfun/core/eval/base_test.py +238 -53
- langfun/core/eval/matching.py +80 -76
- langfun/core/eval/matching_test.py +19 -9
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +37 -28
- langfun/core/eval/scoring_test.py +21 -3
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +3 -21
- langfun/core/langfunc_test.py +26 -8
- langfun/core/language_model.py +686 -48
- langfun/core/language_model_test.py +681 -44
- langfun/core/llms/__init__.py +100 -12
- langfun/core/llms/anthropic.py +488 -0
- langfun/core/llms/anthropic_test.py +235 -0
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +88 -28
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +39 -26
- langfun/core/llms/fake_test.py +136 -11
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -197
- langfun/core/llms/groq.py +276 -0
- langfun/core/llms/groq_test.py +64 -0
- langfun/core/llms/llama_cpp.py +15 -40
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +436 -226
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +35 -174
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -23
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +15 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +9 -8
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +278 -0
- langfun/core/structured/function_generation_test.py +399 -0
- langfun/core/structured/mapping.py +150 -46
- langfun/core/structured/mapping_test.py +105 -0
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +71 -22
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
- langfun/core/structured/schema.py +208 -99
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +133 -34
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +240 -11
- langfun/core/template_test.py +146 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +14 -2
- langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -217
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
- langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,73 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Tests for compositional models."""
|
15
|
+
import unittest
|
16
|
+
|
17
|
+
import langfun.core as lf
|
18
|
+
from langfun.core.llms import compositional
|
19
|
+
from langfun.core.llms import fake
|
20
|
+
|
21
|
+
|
22
|
+
class RandomChoiceTest(unittest.TestCase):
|
23
|
+
|
24
|
+
def test_basic(self):
|
25
|
+
lm = compositional.RandomChoice([
|
26
|
+
fake.StaticResponse('hi'),
|
27
|
+
fake.StaticSequence(['hello', 'world'])
|
28
|
+
])
|
29
|
+
self.assertEqual(
|
30
|
+
lm.model_id,
|
31
|
+
'RandomChoice(StaticResponse, StaticSequence)'
|
32
|
+
)
|
33
|
+
self.assertEqual(
|
34
|
+
lm.resource_id,
|
35
|
+
'RandomChoice(StaticResponse, StaticSequence)'
|
36
|
+
)
|
37
|
+
self.assertEqual(
|
38
|
+
[lm('a'), lm('b'), lm('c')],
|
39
|
+
['hello', 'world', 'hi']
|
40
|
+
)
|
41
|
+
lm = lm.clone()
|
42
|
+
self.assertEqual(
|
43
|
+
[
|
44
|
+
x.samples[0].response for x in [
|
45
|
+
lm.sample(['a'])[0],
|
46
|
+
lm.sample(['b'])[0],
|
47
|
+
lm.sample(['c'])[0],
|
48
|
+
]
|
49
|
+
],
|
50
|
+
['hello', 'world', 'hi']
|
51
|
+
)
|
52
|
+
self.assertEqual(
|
53
|
+
lm.score('hello', ['world']),
|
54
|
+
[lf.LMScoringResult(0.0)]
|
55
|
+
)
|
56
|
+
self.assertEqual(
|
57
|
+
lm.tokenize('hello'),
|
58
|
+
[('hello', 0)]
|
59
|
+
)
|
60
|
+
|
61
|
+
def test_sampling_options(self):
|
62
|
+
lm = compositional.RandomChoice([
|
63
|
+
fake.StaticResponse('hi'),
|
64
|
+
fake.StaticSequence(['hello', 'world'])
|
65
|
+
], temperature=0.5)
|
66
|
+
self.assertEqual(
|
67
|
+
lm.candidates[0].sampling_options.temperature,
|
68
|
+
0.5
|
69
|
+
)
|
70
|
+
|
71
|
+
|
72
|
+
if __name__ == '__main__':
|
73
|
+
unittest.main()
|
@@ -0,0 +1,117 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Language models from DeepSeek."""
|
15
|
+
|
16
|
+
import os
|
17
|
+
from typing import Annotated, Any
|
18
|
+
|
19
|
+
import langfun.core as lf
|
20
|
+
from langfun.core.llms import openai_compatible
|
21
|
+
import pyglove as pg
|
22
|
+
|
23
|
+
SUPPORTED_MODELS_AND_SETTINGS = {
|
24
|
+
# pylint: disable=g-line-too-long
|
25
|
+
# TODO(yifenglu): The RPM and TPM are arbitrary numbers. Update them once DeepSeek provides concrete guidelines.
|
26
|
+
# DeepSeek doesn't control the rate limit at the moment: https://api-docs.deepseek.com/quick_start/rate_limit
|
27
|
+
# The cost is based on: https://api-docs.deepseek.com/quick_start/pricing
|
28
|
+
'deepseek-chat': pg.Dict(
|
29
|
+
in_service=True,
|
30
|
+
rpm=100,
|
31
|
+
tpm=1000000,
|
32
|
+
cost_per_1k_input_tokens=0.00014,
|
33
|
+
cost_per_1k_output_tokens=0.00028,
|
34
|
+
),
|
35
|
+
}
|
36
|
+
|
37
|
+
|
38
|
+
# DeepSeek API uses an API format compatible with OpenAI.
|
39
|
+
# Reference: https://api-docs.deepseek.com/
|
40
|
+
@lf.use_init_args(['model'])
|
41
|
+
class DeepSeek(openai_compatible.OpenAICompatible):
|
42
|
+
"""DeepSeek model."""
|
43
|
+
|
44
|
+
model: pg.typing.Annotated[
|
45
|
+
pg.typing.Enum(
|
46
|
+
pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
|
47
|
+
),
|
48
|
+
'The name of the model to use.',
|
49
|
+
]
|
50
|
+
|
51
|
+
api_endpoint: str = 'https://api.deepseek.com/chat/completions'
|
52
|
+
|
53
|
+
api_key: Annotated[
|
54
|
+
str | None,
|
55
|
+
(
|
56
|
+
'API key. If None, the key will be read from environment variable '
|
57
|
+
"'DEEPSEEK_API_KEY'."
|
58
|
+
),
|
59
|
+
] = None
|
60
|
+
|
61
|
+
@property
|
62
|
+
def headers(self) -> dict[str, Any]:
|
63
|
+
api_key = self.api_key or os.environ.get('DEEPSEEK_API_KEY', None)
|
64
|
+
if not api_key:
|
65
|
+
raise ValueError(
|
66
|
+
'Please specify `api_key` during `__init__` or set environment '
|
67
|
+
'variable `DEEPSEEK_API_KEY` with your DeepSeek API key.'
|
68
|
+
)
|
69
|
+
headers = super().headers
|
70
|
+
headers.update({
|
71
|
+
'Authorization': f'Bearer {api_key}',
|
72
|
+
})
|
73
|
+
return headers
|
74
|
+
|
75
|
+
@property
|
76
|
+
def model_id(self) -> str:
|
77
|
+
"""Returns a string to identify the model."""
|
78
|
+
return f'DeepSeek({self.model})'
|
79
|
+
|
80
|
+
@property
|
81
|
+
def max_concurrency(self) -> int:
|
82
|
+
rpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('rpm', 0)
|
83
|
+
tpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('tpm', 0)
|
84
|
+
return self.rate_to_max_concurrency(
|
85
|
+
requests_per_min=rpm, tokens_per_min=tpm
|
86
|
+
)
|
87
|
+
|
88
|
+
def estimate_cost(
|
89
|
+
self, num_input_tokens: int, num_output_tokens: int
|
90
|
+
) -> float | None:
|
91
|
+
"""Estimate the cost based on usage."""
|
92
|
+
cost_per_1k_input_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
|
93
|
+
'cost_per_1k_input_tokens', None
|
94
|
+
)
|
95
|
+
cost_per_1k_output_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
|
96
|
+
'cost_per_1k_output_tokens', None
|
97
|
+
)
|
98
|
+
if cost_per_1k_output_tokens is None or cost_per_1k_input_tokens is None:
|
99
|
+
return None
|
100
|
+
return (
|
101
|
+
cost_per_1k_input_tokens * num_input_tokens
|
102
|
+
+ cost_per_1k_output_tokens * num_output_tokens
|
103
|
+
) / 1000
|
104
|
+
|
105
|
+
@classmethod
|
106
|
+
def dir(cls):
|
107
|
+
return [k for k, v in SUPPORTED_MODELS_AND_SETTINGS.items() if v.in_service]
|
108
|
+
|
109
|
+
|
110
|
+
class DeepSeekChat(DeepSeek):
|
111
|
+
"""DeepSeek Chat model.
|
112
|
+
|
113
|
+
Currently, it is powered by DeepSeek-V3 model, 64K input contenxt window and
|
114
|
+
8k max output tokens.
|
115
|
+
"""
|
116
|
+
|
117
|
+
model = 'deepseek-chat'
|
@@ -0,0 +1,61 @@
|
|
1
|
+
# Copyright 2023 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import unittest
|
15
|
+
from langfun.core.llms import deepseek
|
16
|
+
|
17
|
+
|
18
|
+
class DeepSeekTest(unittest.TestCase):
|
19
|
+
"""Tests for DeepSeek language model."""
|
20
|
+
|
21
|
+
def test_dir(self):
|
22
|
+
self.assertIn('deepseek-chat', deepseek.DeepSeek.dir())
|
23
|
+
|
24
|
+
def test_key(self):
|
25
|
+
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
26
|
+
_ = deepseek.DeepSeekChat().headers
|
27
|
+
self.assertEqual(
|
28
|
+
deepseek.DeepSeekChat(api_key='test_key').headers,
|
29
|
+
{
|
30
|
+
'Content-Type': 'application/json',
|
31
|
+
'Authorization': 'Bearer test_key',
|
32
|
+
}
|
33
|
+
)
|
34
|
+
|
35
|
+
def test_model_id(self):
|
36
|
+
self.assertEqual(
|
37
|
+
deepseek.DeepSeekChat(api_key='test_key').model_id,
|
38
|
+
'DeepSeek(deepseek-chat)',
|
39
|
+
)
|
40
|
+
|
41
|
+
def test_resource_id(self):
|
42
|
+
self.assertEqual(
|
43
|
+
deepseek.DeepSeekChat(api_key='test_key').resource_id,
|
44
|
+
'DeepSeek(deepseek-chat)',
|
45
|
+
)
|
46
|
+
|
47
|
+
def test_max_concurrency(self):
|
48
|
+
self.assertGreater(
|
49
|
+
deepseek.DeepSeekChat(api_key='test_key').max_concurrency, 0
|
50
|
+
)
|
51
|
+
|
52
|
+
def test_estimate_cost(self):
|
53
|
+
self.assertEqual(
|
54
|
+
deepseek.DeepSeekChat(api_key='test_key').estimate_cost(
|
55
|
+
num_input_tokens=100, num_output_tokens=100
|
56
|
+
),
|
57
|
+
4.2e-5
|
58
|
+
)
|
59
|
+
|
60
|
+
if __name__ == '__main__':
|
61
|
+
unittest.main()
|
langfun/core/llms/fake.py
CHANGED
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Fake LMs for testing."""
|
15
15
|
|
16
|
+
import abc
|
16
17
|
from typing import Annotated
|
17
18
|
import langfun.core as lf
|
18
19
|
|
@@ -20,18 +21,39 @@ import langfun.core as lf
|
|
20
21
|
class Fake(lf.LanguageModel):
|
21
22
|
"""The base class for all fake language models."""
|
22
23
|
|
23
|
-
def _score(self, prompt: lf.Message
|
24
|
+
def _score(self, prompt: lf.Message| list[lf.Message],
|
25
|
+
completions: list[lf.Message]):
|
24
26
|
return [lf.LMScoringResult(score=-i * 1.0) for i in range(len(completions))]
|
25
27
|
|
28
|
+
def _tokenize(self, prompt: lf.Message) -> list[tuple[str | bytes, int]]:
|
29
|
+
return [(w, i) for i, w in enumerate(prompt.text.split(' '))]
|
30
|
+
|
31
|
+
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
32
|
+
results = []
|
33
|
+
for prompt in prompts:
|
34
|
+
response = self._response_from(prompt)
|
35
|
+
results.append(
|
36
|
+
lf.LMSamplingResult(
|
37
|
+
[lf.LMSample(response, 1.0)],
|
38
|
+
usage=lf.LMSamplingUsage(
|
39
|
+
prompt_tokens=len(prompt.text),
|
40
|
+
completion_tokens=len(response.text),
|
41
|
+
total_tokens=len(prompt.text) + len(response.text),
|
42
|
+
)
|
43
|
+
)
|
44
|
+
)
|
45
|
+
return results
|
46
|
+
|
47
|
+
@abc.abstractmethod
|
48
|
+
def _response_from(self, prompt: lf.Message) -> lf.Message:
|
49
|
+
"""Returns the response for the given prompt."""
|
50
|
+
|
26
51
|
|
27
52
|
class Echo(Fake):
|
28
53
|
"""A simple echo language model for testing."""
|
29
54
|
|
30
|
-
def
|
31
|
-
return
|
32
|
-
lf.LMSamplingResult([lf.LMSample(prompt.text, 1.0)])
|
33
|
-
for prompt in prompts
|
34
|
-
]
|
55
|
+
def _response_from(self, prompt: lf.Message) -> lf.Message:
|
56
|
+
return lf.AIMessage(prompt.text)
|
35
57
|
|
36
58
|
|
37
59
|
@lf.use_init_args(['response'])
|
@@ -39,15 +61,12 @@ class StaticResponse(Fake):
|
|
39
61
|
"""Language model that always gives the same canned response."""
|
40
62
|
|
41
63
|
response: Annotated[
|
42
|
-
str,
|
64
|
+
str | lf.Message,
|
43
65
|
'A canned response that will be returned regardless of the prompt.'
|
44
66
|
]
|
45
67
|
|
46
|
-
def
|
47
|
-
return
|
48
|
-
lf.LMSamplingResult([lf.LMSample(self.response, 1.0)])
|
49
|
-
for _ in prompts
|
50
|
-
]
|
68
|
+
def _response_from(self, prompt: lf.Message) -> lf.Message:
|
69
|
+
return lf.AIMessage.from_value(self.response)
|
51
70
|
|
52
71
|
|
53
72
|
@lf.use_init_args(['mapping'])
|
@@ -55,15 +74,12 @@ class StaticMapping(Fake):
|
|
55
74
|
"""A static mapping from prompt to response."""
|
56
75
|
|
57
76
|
mapping: Annotated[
|
58
|
-
dict[str, str],
|
77
|
+
dict[str, str | lf.Message],
|
59
78
|
'A mapping from prompt to response.'
|
60
79
|
]
|
61
80
|
|
62
|
-
def
|
63
|
-
return [
|
64
|
-
lf.LMSamplingResult([lf.LMSample(self.mapping[prompt], 1.0)])
|
65
|
-
for prompt in prompts
|
66
|
-
]
|
81
|
+
def _response_from(self, prompt: lf.Message) -> lf.Message:
|
82
|
+
return lf.AIMessage.from_value(self.mapping[prompt])
|
67
83
|
|
68
84
|
|
69
85
|
@lf.use_init_args(['sequence'])
|
@@ -71,7 +87,7 @@ class StaticSequence(Fake):
|
|
71
87
|
"""A static sequence of responses to use."""
|
72
88
|
|
73
89
|
sequence: Annotated[
|
74
|
-
list[str],
|
90
|
+
list[str | lf.Message],
|
75
91
|
'A sequence of strings as the response.'
|
76
92
|
]
|
77
93
|
|
@@ -79,10 +95,7 @@ class StaticSequence(Fake):
|
|
79
95
|
super()._on_bound()
|
80
96
|
self._pos = 0
|
81
97
|
|
82
|
-
def
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
[lf.LMSample(self.sequence[self._pos], 1.0)]))
|
87
|
-
self._pos += 1
|
88
|
-
return results
|
98
|
+
def _response_from(self, prompt: lf.Message) -> lf.Message:
|
99
|
+
r = lf.AIMessage.from_value(self.sequence[self._pos])
|
100
|
+
self._pos += 1
|
101
|
+
return r
|
langfun/core/llms/fake_test.py
CHANGED
@@ -25,7 +25,25 @@ class EchoTest(unittest.TestCase):
|
|
25
25
|
def test_sample(self):
|
26
26
|
lm = fakelm.Echo()
|
27
27
|
self.assertEqual(
|
28
|
-
lm.sample(['hi']),
|
28
|
+
lm.sample(['hi']),
|
29
|
+
[
|
30
|
+
lf.LMSamplingResult(
|
31
|
+
[
|
32
|
+
lf.LMSample(
|
33
|
+
lf.AIMessage(
|
34
|
+
'hi',
|
35
|
+
score=1.0,
|
36
|
+
logprobs=None,
|
37
|
+
is_cached=False,
|
38
|
+
usage=lf.LMSamplingUsage(2, 2, 4),
|
39
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
40
|
+
),
|
41
|
+
score=1.0,
|
42
|
+
logprobs=None,
|
43
|
+
)
|
44
|
+
],
|
45
|
+
lf.LMSamplingUsage(2, 2, 4))
|
46
|
+
]
|
29
47
|
)
|
30
48
|
|
31
49
|
def test_call(self):
|
@@ -34,8 +52,8 @@ class EchoTest(unittest.TestCase):
|
|
34
52
|
with contextlib.redirect_stdout(string_io):
|
35
53
|
self.assertEqual(lm('hi'), 'hi')
|
36
54
|
debug_info = string_io.getvalue()
|
37
|
-
self.assertIn('[0] LM INFO
|
38
|
-
self.assertIn('[0] PROMPT SENT TO LM
|
55
|
+
self.assertIn('[0] LM INFO', debug_info)
|
56
|
+
self.assertIn('[0] PROMPT SENT TO LM', debug_info)
|
39
57
|
self.assertIn('[0] LM RESPONSE', debug_info)
|
40
58
|
|
41
59
|
def test_score(self):
|
@@ -45,6 +63,13 @@ class EchoTest(unittest.TestCase):
|
|
45
63
|
[lf.LMScoringResult(0.0), lf.LMScoringResult(-1.0)],
|
46
64
|
)
|
47
65
|
|
66
|
+
def test_tokenize(self):
|
67
|
+
lm = fakelm.Echo()
|
68
|
+
self.assertEqual(
|
69
|
+
lm.tokenize('hi'),
|
70
|
+
[('hi', 0)]
|
71
|
+
)
|
72
|
+
|
48
73
|
|
49
74
|
class StaticResponseTest(unittest.TestCase):
|
50
75
|
|
@@ -53,11 +78,47 @@ class StaticResponseTest(unittest.TestCase):
|
|
53
78
|
lm = fakelm.StaticResponse(canned_response)
|
54
79
|
self.assertEqual(
|
55
80
|
lm.sample(['hi']),
|
56
|
-
[
|
81
|
+
[
|
82
|
+
lf.LMSamplingResult(
|
83
|
+
[
|
84
|
+
lf.LMSample(
|
85
|
+
lf.AIMessage(
|
86
|
+
canned_response,
|
87
|
+
score=1.0,
|
88
|
+
logprobs=None,
|
89
|
+
is_cached=False,
|
90
|
+
usage=lf.LMSamplingUsage(2, 38, 40),
|
91
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
92
|
+
),
|
93
|
+
score=1.0,
|
94
|
+
logprobs=None,
|
95
|
+
)
|
96
|
+
],
|
97
|
+
usage=lf.LMSamplingUsage(2, 38, 40)
|
98
|
+
)
|
99
|
+
],
|
57
100
|
)
|
58
101
|
self.assertEqual(
|
59
102
|
lm.sample(['Tell me a joke.']),
|
60
|
-
[
|
103
|
+
[
|
104
|
+
lf.LMSamplingResult(
|
105
|
+
[
|
106
|
+
lf.LMSample(
|
107
|
+
lf.AIMessage(
|
108
|
+
canned_response,
|
109
|
+
score=1.0,
|
110
|
+
logprobs=None,
|
111
|
+
is_cached=False,
|
112
|
+
usage=lf.LMSamplingUsage(15, 38, 53),
|
113
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
114
|
+
),
|
115
|
+
score=1.0,
|
116
|
+
logprobs=None,
|
117
|
+
)
|
118
|
+
],
|
119
|
+
usage=lf.LMSamplingUsage(15, 38, 53)
|
120
|
+
)
|
121
|
+
],
|
61
122
|
)
|
62
123
|
|
63
124
|
def test_call(self):
|
@@ -69,8 +130,8 @@ class StaticResponseTest(unittest.TestCase):
|
|
69
130
|
self.assertEqual(lm('hi'), canned_response)
|
70
131
|
|
71
132
|
debug_info = string_io.getvalue()
|
72
|
-
self.assertIn('[0] LM INFO
|
73
|
-
self.assertIn('[0] PROMPT SENT TO LM
|
133
|
+
self.assertIn('[0] LM INFO', debug_info)
|
134
|
+
self.assertIn('[0] PROMPT SENT TO LM', debug_info)
|
74
135
|
self.assertIn('[0] LM RESPONSE', debug_info)
|
75
136
|
|
76
137
|
|
@@ -85,8 +146,40 @@ class StaticMappingTest(unittest.TestCase):
|
|
85
146
|
self.assertEqual(
|
86
147
|
lm.sample(['Hi', 'How are you?']),
|
87
148
|
[
|
88
|
-
lf.LMSamplingResult(
|
89
|
-
|
149
|
+
lf.LMSamplingResult(
|
150
|
+
[
|
151
|
+
lf.LMSample(
|
152
|
+
lf.AIMessage(
|
153
|
+
'Hello',
|
154
|
+
score=1.0,
|
155
|
+
logprobs=None,
|
156
|
+
is_cached=False,
|
157
|
+
usage=lf.LMSamplingUsage(2, 5, 7),
|
158
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
159
|
+
),
|
160
|
+
score=1.0,
|
161
|
+
logprobs=None,
|
162
|
+
)
|
163
|
+
],
|
164
|
+
usage=lf.LMSamplingUsage(2, 5, 7)
|
165
|
+
),
|
166
|
+
lf.LMSamplingResult(
|
167
|
+
[
|
168
|
+
lf.LMSample(
|
169
|
+
lf.AIMessage(
|
170
|
+
'I am fine, how about you?',
|
171
|
+
score=1.0,
|
172
|
+
logprobs=None,
|
173
|
+
is_cached=False,
|
174
|
+
usage=lf.LMSamplingUsage(12, 25, 37),
|
175
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
176
|
+
),
|
177
|
+
score=1.0,
|
178
|
+
logprobs=None,
|
179
|
+
)
|
180
|
+
],
|
181
|
+
usage=lf.LMSamplingUsage(12, 25, 37)
|
182
|
+
)
|
90
183
|
]
|
91
184
|
)
|
92
185
|
with self.assertRaises(KeyError):
|
@@ -104,8 +197,40 @@ class StaticSequenceTest(unittest.TestCase):
|
|
104
197
|
self.assertEqual(
|
105
198
|
lm.sample(['Hi', 'How are you?']),
|
106
199
|
[
|
107
|
-
lf.LMSamplingResult(
|
108
|
-
|
200
|
+
lf.LMSamplingResult(
|
201
|
+
[
|
202
|
+
lf.LMSample(
|
203
|
+
lf.AIMessage(
|
204
|
+
'Hello',
|
205
|
+
score=1.0,
|
206
|
+
logprobs=None,
|
207
|
+
is_cached=False,
|
208
|
+
usage=lf.LMSamplingUsage(2, 5, 7),
|
209
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
210
|
+
),
|
211
|
+
score=1.0,
|
212
|
+
logprobs=None,
|
213
|
+
)
|
214
|
+
],
|
215
|
+
usage=lf.LMSamplingUsage(2, 5, 7)
|
216
|
+
),
|
217
|
+
lf.LMSamplingResult(
|
218
|
+
[
|
219
|
+
lf.LMSample(
|
220
|
+
lf.AIMessage(
|
221
|
+
'I am fine, how about you?',
|
222
|
+
score=1.0,
|
223
|
+
logprobs=None,
|
224
|
+
is_cached=False,
|
225
|
+
usage=lf.LMSamplingUsage(12, 25, 37),
|
226
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
227
|
+
),
|
228
|
+
score=1.0,
|
229
|
+
logprobs=None,
|
230
|
+
)
|
231
|
+
],
|
232
|
+
usage=lf.LMSamplingUsage(12, 25, 37)
|
233
|
+
)
|
109
234
|
]
|
110
235
|
)
|
111
236
|
with self.assertRaises(IndexError):
|