langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. langfun/__init__.py +22 -2
  2. langfun/core/__init__.py +17 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -28
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +69 -2
  18. langfun/core/component_test.py +54 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +17 -0
  24. langfun/core/eval/base.py +767 -140
  25. langfun/core/eval/base_test.py +238 -53
  26. langfun/core/eval/matching.py +80 -76
  27. langfun/core/eval/matching_test.py +19 -9
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +37 -28
  31. langfun/core/eval/scoring_test.py +21 -3
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +3 -21
  55. langfun/core/langfunc_test.py +26 -8
  56. langfun/core/language_model.py +686 -48
  57. langfun/core/language_model_test.py +681 -44
  58. langfun/core/llms/__init__.py +100 -12
  59. langfun/core/llms/anthropic.py +488 -0
  60. langfun/core/llms/anthropic_test.py +235 -0
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +88 -28
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +39 -26
  69. langfun/core/llms/fake_test.py +136 -11
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -197
  74. langfun/core/llms/groq.py +276 -0
  75. langfun/core/llms/groq_test.py +64 -0
  76. langfun/core/llms/llama_cpp.py +15 -40
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +436 -226
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +35 -174
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -23
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +15 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +9 -8
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +278 -0
  112. langfun/core/structured/function_generation_test.py +399 -0
  113. langfun/core/structured/mapping.py +150 -46
  114. langfun/core/structured/mapping_test.py +105 -0
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +71 -22
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
  119. langfun/core/structured/schema.py +208 -99
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_generation_test.py +2 -2
  122. langfun/core/structured/schema_test.py +133 -34
  123. langfun/core/structured/scoring.py +125 -19
  124. langfun/core/structured/scoring_test.py +30 -0
  125. langfun/core/structured/tokenization.py +64 -0
  126. langfun/core/structured/tokenization_test.py +48 -0
  127. langfun/core/template.py +240 -11
  128. langfun/core/template_test.py +146 -1
  129. langfun/core/templates/conversation.py +9 -0
  130. langfun/core/templates/conversation_test.py +4 -3
  131. langfun/core/templates/selfplay_test.py +14 -2
  132. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  133. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  134. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  135. langfun/core/coding/python/errors.py +0 -108
  136. langfun/core/coding/python/errors_test.py +0 -99
  137. langfun/core/coding/python/permissions.py +0 -90
  138. langfun/core/coding/python/permissions_test.py +0 -86
  139. langfun/core/structured/prompting.py +0 -217
  140. langfun/core/text_formatting.py +0 -162
  141. langfun/core/text_formatting_test.py +0 -47
  142. langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
  143. langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
  144. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  145. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,73 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for compositional models."""
15
+ import unittest
16
+
17
+ import langfun.core as lf
18
+ from langfun.core.llms import compositional
19
+ from langfun.core.llms import fake
20
+
21
+
22
+ class RandomChoiceTest(unittest.TestCase):
23
+
24
+ def test_basic(self):
25
+ lm = compositional.RandomChoice([
26
+ fake.StaticResponse('hi'),
27
+ fake.StaticSequence(['hello', 'world'])
28
+ ])
29
+ self.assertEqual(
30
+ lm.model_id,
31
+ 'RandomChoice(StaticResponse, StaticSequence)'
32
+ )
33
+ self.assertEqual(
34
+ lm.resource_id,
35
+ 'RandomChoice(StaticResponse, StaticSequence)'
36
+ )
37
+ self.assertEqual(
38
+ [lm('a'), lm('b'), lm('c')],
39
+ ['hello', 'world', 'hi']
40
+ )
41
+ lm = lm.clone()
42
+ self.assertEqual(
43
+ [
44
+ x.samples[0].response for x in [
45
+ lm.sample(['a'])[0],
46
+ lm.sample(['b'])[0],
47
+ lm.sample(['c'])[0],
48
+ ]
49
+ ],
50
+ ['hello', 'world', 'hi']
51
+ )
52
+ self.assertEqual(
53
+ lm.score('hello', ['world']),
54
+ [lf.LMScoringResult(0.0)]
55
+ )
56
+ self.assertEqual(
57
+ lm.tokenize('hello'),
58
+ [('hello', 0)]
59
+ )
60
+
61
+ def test_sampling_options(self):
62
+ lm = compositional.RandomChoice([
63
+ fake.StaticResponse('hi'),
64
+ fake.StaticSequence(['hello', 'world'])
65
+ ], temperature=0.5)
66
+ self.assertEqual(
67
+ lm.candidates[0].sampling_options.temperature,
68
+ 0.5
69
+ )
70
+
71
+
72
+ if __name__ == '__main__':
73
+ unittest.main()
@@ -0,0 +1,117 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Language models from DeepSeek."""
15
+
16
+ import os
17
+ from typing import Annotated, Any
18
+
19
+ import langfun.core as lf
20
+ from langfun.core.llms import openai_compatible
21
+ import pyglove as pg
22
+
23
+ SUPPORTED_MODELS_AND_SETTINGS = {
24
+ # pylint: disable=g-line-too-long
25
+ # TODO(yifenglu): The RPM and TPM are arbitrary numbers. Update them once DeepSeek provides concrete guidelines.
26
+ # DeepSeek doesn't control the rate limit at the moment: https://api-docs.deepseek.com/quick_start/rate_limit
27
+ # The cost is based on: https://api-docs.deepseek.com/quick_start/pricing
28
+ 'deepseek-chat': pg.Dict(
29
+ in_service=True,
30
+ rpm=100,
31
+ tpm=1000000,
32
+ cost_per_1k_input_tokens=0.00014,
33
+ cost_per_1k_output_tokens=0.00028,
34
+ ),
35
+ }
36
+
37
+
38
+ # DeepSeek API uses an API format compatible with OpenAI.
39
+ # Reference: https://api-docs.deepseek.com/
40
+ @lf.use_init_args(['model'])
41
+ class DeepSeek(openai_compatible.OpenAICompatible):
42
+ """DeepSeek model."""
43
+
44
+ model: pg.typing.Annotated[
45
+ pg.typing.Enum(
46
+ pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
47
+ ),
48
+ 'The name of the model to use.',
49
+ ]
50
+
51
+ api_endpoint: str = 'https://api.deepseek.com/chat/completions'
52
+
53
+ api_key: Annotated[
54
+ str | None,
55
+ (
56
+ 'API key. If None, the key will be read from environment variable '
57
+ "'DEEPSEEK_API_KEY'."
58
+ ),
59
+ ] = None
60
+
61
+ @property
62
+ def headers(self) -> dict[str, Any]:
63
+ api_key = self.api_key or os.environ.get('DEEPSEEK_API_KEY', None)
64
+ if not api_key:
65
+ raise ValueError(
66
+ 'Please specify `api_key` during `__init__` or set environment '
67
+ 'variable `DEEPSEEK_API_KEY` with your DeepSeek API key.'
68
+ )
69
+ headers = super().headers
70
+ headers.update({
71
+ 'Authorization': f'Bearer {api_key}',
72
+ })
73
+ return headers
74
+
75
+ @property
76
+ def model_id(self) -> str:
77
+ """Returns a string to identify the model."""
78
+ return f'DeepSeek({self.model})'
79
+
80
+ @property
81
+ def max_concurrency(self) -> int:
82
+ rpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('rpm', 0)
83
+ tpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('tpm', 0)
84
+ return self.rate_to_max_concurrency(
85
+ requests_per_min=rpm, tokens_per_min=tpm
86
+ )
87
+
88
+ def estimate_cost(
89
+ self, num_input_tokens: int, num_output_tokens: int
90
+ ) -> float | None:
91
+ """Estimate the cost based on usage."""
92
+ cost_per_1k_input_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
93
+ 'cost_per_1k_input_tokens', None
94
+ )
95
+ cost_per_1k_output_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
96
+ 'cost_per_1k_output_tokens', None
97
+ )
98
+ if cost_per_1k_output_tokens is None or cost_per_1k_input_tokens is None:
99
+ return None
100
+ return (
101
+ cost_per_1k_input_tokens * num_input_tokens
102
+ + cost_per_1k_output_tokens * num_output_tokens
103
+ ) / 1000
104
+
105
+ @classmethod
106
+ def dir(cls):
107
+ return [k for k, v in SUPPORTED_MODELS_AND_SETTINGS.items() if v.in_service]
108
+
109
+
110
+ class DeepSeekChat(DeepSeek):
111
+ """DeepSeek Chat model.
112
+
113
+ Currently, it is powered by DeepSeek-V3 model, 64K input contenxt window and
114
+ 8k max output tokens.
115
+ """
116
+
117
+ model = 'deepseek-chat'
@@ -0,0 +1,61 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import unittest
15
+ from langfun.core.llms import deepseek
16
+
17
+
18
+ class DeepSeekTest(unittest.TestCase):
19
+ """Tests for DeepSeek language model."""
20
+
21
+ def test_dir(self):
22
+ self.assertIn('deepseek-chat', deepseek.DeepSeek.dir())
23
+
24
+ def test_key(self):
25
+ with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
26
+ _ = deepseek.DeepSeekChat().headers
27
+ self.assertEqual(
28
+ deepseek.DeepSeekChat(api_key='test_key').headers,
29
+ {
30
+ 'Content-Type': 'application/json',
31
+ 'Authorization': 'Bearer test_key',
32
+ }
33
+ )
34
+
35
+ def test_model_id(self):
36
+ self.assertEqual(
37
+ deepseek.DeepSeekChat(api_key='test_key').model_id,
38
+ 'DeepSeek(deepseek-chat)',
39
+ )
40
+
41
+ def test_resource_id(self):
42
+ self.assertEqual(
43
+ deepseek.DeepSeekChat(api_key='test_key').resource_id,
44
+ 'DeepSeek(deepseek-chat)',
45
+ )
46
+
47
+ def test_max_concurrency(self):
48
+ self.assertGreater(
49
+ deepseek.DeepSeekChat(api_key='test_key').max_concurrency, 0
50
+ )
51
+
52
+ def test_estimate_cost(self):
53
+ self.assertEqual(
54
+ deepseek.DeepSeekChat(api_key='test_key').estimate_cost(
55
+ num_input_tokens=100, num_output_tokens=100
56
+ ),
57
+ 4.2e-5
58
+ )
59
+
60
+ if __name__ == '__main__':
61
+ unittest.main()
langfun/core/llms/fake.py CHANGED
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  """Fake LMs for testing."""
15
15
 
16
+ import abc
16
17
  from typing import Annotated
17
18
  import langfun.core as lf
18
19
 
@@ -20,18 +21,39 @@ import langfun.core as lf
20
21
  class Fake(lf.LanguageModel):
21
22
  """The base class for all fake language models."""
22
23
 
23
- def _score(self, prompt: lf.Message, completions: list[lf.Message]):
24
+ def _score(self, prompt: lf.Message| list[lf.Message],
25
+ completions: list[lf.Message]):
24
26
  return [lf.LMScoringResult(score=-i * 1.0) for i in range(len(completions))]
25
27
 
28
+ def _tokenize(self, prompt: lf.Message) -> list[tuple[str | bytes, int]]:
29
+ return [(w, i) for i, w in enumerate(prompt.text.split(' '))]
30
+
31
+ def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
32
+ results = []
33
+ for prompt in prompts:
34
+ response = self._response_from(prompt)
35
+ results.append(
36
+ lf.LMSamplingResult(
37
+ [lf.LMSample(response, 1.0)],
38
+ usage=lf.LMSamplingUsage(
39
+ prompt_tokens=len(prompt.text),
40
+ completion_tokens=len(response.text),
41
+ total_tokens=len(prompt.text) + len(response.text),
42
+ )
43
+ )
44
+ )
45
+ return results
46
+
47
+ @abc.abstractmethod
48
+ def _response_from(self, prompt: lf.Message) -> lf.Message:
49
+ """Returns the response for the given prompt."""
50
+
26
51
 
27
52
  class Echo(Fake):
28
53
  """A simple echo language model for testing."""
29
54
 
30
- def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
31
- return [
32
- lf.LMSamplingResult([lf.LMSample(prompt.text, 1.0)])
33
- for prompt in prompts
34
- ]
55
+ def _response_from(self, prompt: lf.Message) -> lf.Message:
56
+ return lf.AIMessage(prompt.text)
35
57
 
36
58
 
37
59
  @lf.use_init_args(['response'])
@@ -39,15 +61,12 @@ class StaticResponse(Fake):
39
61
  """Language model that always gives the same canned response."""
40
62
 
41
63
  response: Annotated[
42
- str,
64
+ str | lf.Message,
43
65
  'A canned response that will be returned regardless of the prompt.'
44
66
  ]
45
67
 
46
- def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
47
- return [
48
- lf.LMSamplingResult([lf.LMSample(self.response, 1.0)])
49
- for _ in prompts
50
- ]
68
+ def _response_from(self, prompt: lf.Message) -> lf.Message:
69
+ return lf.AIMessage.from_value(self.response)
51
70
 
52
71
 
53
72
  @lf.use_init_args(['mapping'])
@@ -55,15 +74,12 @@ class StaticMapping(Fake):
55
74
  """A static mapping from prompt to response."""
56
75
 
57
76
  mapping: Annotated[
58
- dict[str, str],
77
+ dict[str, str | lf.Message],
59
78
  'A mapping from prompt to response.'
60
79
  ]
61
80
 
62
- def _sample(self, prompts: list[str]) -> list[lf.LMSamplingResult]:
63
- return [
64
- lf.LMSamplingResult([lf.LMSample(self.mapping[prompt], 1.0)])
65
- for prompt in prompts
66
- ]
81
+ def _response_from(self, prompt: lf.Message) -> lf.Message:
82
+ return lf.AIMessage.from_value(self.mapping[prompt])
67
83
 
68
84
 
69
85
  @lf.use_init_args(['sequence'])
@@ -71,7 +87,7 @@ class StaticSequence(Fake):
71
87
  """A static sequence of responses to use."""
72
88
 
73
89
  sequence: Annotated[
74
- list[str],
90
+ list[str | lf.Message],
75
91
  'A sequence of strings as the response.'
76
92
  ]
77
93
 
@@ -79,10 +95,7 @@ class StaticSequence(Fake):
79
95
  super()._on_bound()
80
96
  self._pos = 0
81
97
 
82
- def _sample(self, prompts: list[str]) -> list[lf.LMSamplingResult]:
83
- results = []
84
- for _ in prompts:
85
- results.append(lf.LMSamplingResult(
86
- [lf.LMSample(self.sequence[self._pos], 1.0)]))
87
- self._pos += 1
88
- return results
98
+ def _response_from(self, prompt: lf.Message) -> lf.Message:
99
+ r = lf.AIMessage.from_value(self.sequence[self._pos])
100
+ self._pos += 1
101
+ return r
@@ -25,7 +25,25 @@ class EchoTest(unittest.TestCase):
25
25
  def test_sample(self):
26
26
  lm = fakelm.Echo()
27
27
  self.assertEqual(
28
- lm.sample(['hi']), [lf.LMSamplingResult([lf.LMSample('hi', 1.0)])]
28
+ lm.sample(['hi']),
29
+ [
30
+ lf.LMSamplingResult(
31
+ [
32
+ lf.LMSample(
33
+ lf.AIMessage(
34
+ 'hi',
35
+ score=1.0,
36
+ logprobs=None,
37
+ is_cached=False,
38
+ usage=lf.LMSamplingUsage(2, 2, 4),
39
+ tags=[lf.Message.TAG_LM_RESPONSE],
40
+ ),
41
+ score=1.0,
42
+ logprobs=None,
43
+ )
44
+ ],
45
+ lf.LMSamplingUsage(2, 2, 4))
46
+ ]
29
47
  )
30
48
 
31
49
  def test_call(self):
@@ -34,8 +52,8 @@ class EchoTest(unittest.TestCase):
34
52
  with contextlib.redirect_stdout(string_io):
35
53
  self.assertEqual(lm('hi'), 'hi')
36
54
  debug_info = string_io.getvalue()
37
- self.assertIn('[0] LM INFO:', debug_info)
38
- self.assertIn('[0] PROMPT SENT TO LM:', debug_info)
55
+ self.assertIn('[0] LM INFO', debug_info)
56
+ self.assertIn('[0] PROMPT SENT TO LM', debug_info)
39
57
  self.assertIn('[0] LM RESPONSE', debug_info)
40
58
 
41
59
  def test_score(self):
@@ -45,6 +63,13 @@ class EchoTest(unittest.TestCase):
45
63
  [lf.LMScoringResult(0.0), lf.LMScoringResult(-1.0)],
46
64
  )
47
65
 
66
+ def test_tokenize(self):
67
+ lm = fakelm.Echo()
68
+ self.assertEqual(
69
+ lm.tokenize('hi'),
70
+ [('hi', 0)]
71
+ )
72
+
48
73
 
49
74
  class StaticResponseTest(unittest.TestCase):
50
75
 
@@ -53,11 +78,47 @@ class StaticResponseTest(unittest.TestCase):
53
78
  lm = fakelm.StaticResponse(canned_response)
54
79
  self.assertEqual(
55
80
  lm.sample(['hi']),
56
- [lf.LMSamplingResult([lf.LMSample(canned_response, 1.0)])],
81
+ [
82
+ lf.LMSamplingResult(
83
+ [
84
+ lf.LMSample(
85
+ lf.AIMessage(
86
+ canned_response,
87
+ score=1.0,
88
+ logprobs=None,
89
+ is_cached=False,
90
+ usage=lf.LMSamplingUsage(2, 38, 40),
91
+ tags=[lf.Message.TAG_LM_RESPONSE],
92
+ ),
93
+ score=1.0,
94
+ logprobs=None,
95
+ )
96
+ ],
97
+ usage=lf.LMSamplingUsage(2, 38, 40)
98
+ )
99
+ ],
57
100
  )
58
101
  self.assertEqual(
59
102
  lm.sample(['Tell me a joke.']),
60
- [lf.LMSamplingResult([lf.LMSample(canned_response, 1.0)])],
103
+ [
104
+ lf.LMSamplingResult(
105
+ [
106
+ lf.LMSample(
107
+ lf.AIMessage(
108
+ canned_response,
109
+ score=1.0,
110
+ logprobs=None,
111
+ is_cached=False,
112
+ usage=lf.LMSamplingUsage(15, 38, 53),
113
+ tags=[lf.Message.TAG_LM_RESPONSE],
114
+ ),
115
+ score=1.0,
116
+ logprobs=None,
117
+ )
118
+ ],
119
+ usage=lf.LMSamplingUsage(15, 38, 53)
120
+ )
121
+ ],
61
122
  )
62
123
 
63
124
  def test_call(self):
@@ -69,8 +130,8 @@ class StaticResponseTest(unittest.TestCase):
69
130
  self.assertEqual(lm('hi'), canned_response)
70
131
 
71
132
  debug_info = string_io.getvalue()
72
- self.assertIn('[0] LM INFO:', debug_info)
73
- self.assertIn('[0] PROMPT SENT TO LM:', debug_info)
133
+ self.assertIn('[0] LM INFO', debug_info)
134
+ self.assertIn('[0] PROMPT SENT TO LM', debug_info)
74
135
  self.assertIn('[0] LM RESPONSE', debug_info)
75
136
 
76
137
 
@@ -85,8 +146,40 @@ class StaticMappingTest(unittest.TestCase):
85
146
  self.assertEqual(
86
147
  lm.sample(['Hi', 'How are you?']),
87
148
  [
88
- lf.LMSamplingResult([lf.LMSample('Hello', 1.0)]),
89
- lf.LMSamplingResult([lf.LMSample('I am fine, how about you?', 1.0)])
149
+ lf.LMSamplingResult(
150
+ [
151
+ lf.LMSample(
152
+ lf.AIMessage(
153
+ 'Hello',
154
+ score=1.0,
155
+ logprobs=None,
156
+ is_cached=False,
157
+ usage=lf.LMSamplingUsage(2, 5, 7),
158
+ tags=[lf.Message.TAG_LM_RESPONSE],
159
+ ),
160
+ score=1.0,
161
+ logprobs=None,
162
+ )
163
+ ],
164
+ usage=lf.LMSamplingUsage(2, 5, 7)
165
+ ),
166
+ lf.LMSamplingResult(
167
+ [
168
+ lf.LMSample(
169
+ lf.AIMessage(
170
+ 'I am fine, how about you?',
171
+ score=1.0,
172
+ logprobs=None,
173
+ is_cached=False,
174
+ usage=lf.LMSamplingUsage(12, 25, 37),
175
+ tags=[lf.Message.TAG_LM_RESPONSE],
176
+ ),
177
+ score=1.0,
178
+ logprobs=None,
179
+ )
180
+ ],
181
+ usage=lf.LMSamplingUsage(12, 25, 37)
182
+ )
90
183
  ]
91
184
  )
92
185
  with self.assertRaises(KeyError):
@@ -104,8 +197,40 @@ class StaticSequenceTest(unittest.TestCase):
104
197
  self.assertEqual(
105
198
  lm.sample(['Hi', 'How are you?']),
106
199
  [
107
- lf.LMSamplingResult([lf.LMSample('Hello', 1.0)]),
108
- lf.LMSamplingResult([lf.LMSample('I am fine, how about you?', 1.0)])
200
+ lf.LMSamplingResult(
201
+ [
202
+ lf.LMSample(
203
+ lf.AIMessage(
204
+ 'Hello',
205
+ score=1.0,
206
+ logprobs=None,
207
+ is_cached=False,
208
+ usage=lf.LMSamplingUsage(2, 5, 7),
209
+ tags=[lf.Message.TAG_LM_RESPONSE],
210
+ ),
211
+ score=1.0,
212
+ logprobs=None,
213
+ )
214
+ ],
215
+ usage=lf.LMSamplingUsage(2, 5, 7)
216
+ ),
217
+ lf.LMSamplingResult(
218
+ [
219
+ lf.LMSample(
220
+ lf.AIMessage(
221
+ 'I am fine, how about you?',
222
+ score=1.0,
223
+ logprobs=None,
224
+ is_cached=False,
225
+ usage=lf.LMSamplingUsage(12, 25, 37),
226
+ tags=[lf.Message.TAG_LM_RESPONSE],
227
+ ),
228
+ score=1.0,
229
+ logprobs=None,
230
+ )
231
+ ],
232
+ usage=lf.LMSamplingUsage(12, 25, 37)
233
+ )
109
234
  ]
110
235
  )
111
236
  with self.assertRaises(IndexError):