langfun 0.0.2.dev20240330__py3-none-any.whl → 0.0.2.dev20240511__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/__init__.py +7 -0
- langfun/core/__init__.py +1 -0
- langfun/core/coding/python/correction.py +0 -7
- langfun/core/component.py +6 -0
- langfun/core/component_test.py +1 -0
- langfun/core/eval/__init__.py +15 -0
- langfun/core/eval/base.py +665 -95
- langfun/core/eval/base_test.py +224 -53
- langfun/core/eval/matching.py +48 -30
- langfun/core/eval/matching_test.py +25 -3
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +19 -10
- langfun/core/eval/scoring_test.py +21 -3
- langfun/core/langfunc.py +1 -22
- langfun/core/langfunc_test.py +10 -4
- langfun/core/language_model.py +130 -24
- langfun/core/language_model_test.py +249 -26
- langfun/core/llms/__init__.py +27 -2
- langfun/core/llms/anthropic.py +263 -0
- langfun/core/llms/anthropic_test.py +167 -0
- langfun/core/llms/cache/in_memory_test.py +37 -28
- langfun/core/llms/fake.py +34 -25
- langfun/core/llms/fake_test.py +122 -11
- langfun/core/llms/google_genai.py +8 -0
- langfun/core/llms/google_genai_test.py +8 -3
- langfun/core/llms/groq.py +260 -0
- langfun/core/llms/groq_test.py +170 -0
- langfun/core/llms/llama_cpp.py +3 -1
- langfun/core/llms/openai.py +100 -81
- langfun/core/llms/openai_test.py +287 -60
- langfun/core/llms/vertexai.py +291 -0
- langfun/core/llms/vertexai_test.py +233 -0
- langfun/core/modalities/image.py +1 -3
- langfun/core/modalities/mime.py +6 -0
- langfun/core/modalities/video.py +6 -5
- langfun/core/structured/__init__.py +5 -0
- langfun/core/structured/completion_test.py +2 -2
- langfun/core/structured/function_generation.py +245 -0
- langfun/core/structured/function_generation_test.py +329 -0
- langfun/core/structured/mapping.py +61 -3
- langfun/core/structured/mapping_test.py +17 -0
- langfun/core/structured/parsing_test.py +18 -13
- langfun/core/structured/prompting.py +61 -12
- langfun/core/structured/prompting_test.py +122 -12
- langfun/core/structured/schema.py +38 -6
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +36 -7
- langfun/core/structured/scoring.py +4 -1
- langfun/core/structured/scoring_test.py +6 -0
- langfun/core/template.py +147 -11
- langfun/core/template_test.py +75 -0
- langfun/core/templates/selfplay_test.py +6 -2
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/METADATA +3 -2
- langfun-0.0.2.dev20240511.dist-info/RECORD +112 -0
- langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,130 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Experiment patching for Langfun evaluations."""
|
15
|
+
|
16
|
+
import inspect
|
17
|
+
from typing import Union
|
18
|
+
import langfun.core as lf
|
19
|
+
from langfun.core import llms as lf_llms
|
20
|
+
from langfun.core.eval import base
|
21
|
+
import pyglove as pg
|
22
|
+
|
23
|
+
|
24
|
+
#
|
25
|
+
# Program-based patchers.
|
26
|
+
#
|
27
|
+
|
28
|
+
|
29
|
+
def patch_member(cls, key, value, parent_key: str | None = None):
|
30
|
+
"""Patches a member of a class."""
|
31
|
+
|
32
|
+
def _rebind_fn(k, v, p):
|
33
|
+
if (
|
34
|
+
isinstance(p, cls)
|
35
|
+
and k.key == key
|
36
|
+
and (parent_key is None or (p and p.sym_path.key == parent_key))
|
37
|
+
):
|
38
|
+
if inspect.isfunction(value):
|
39
|
+
return value(k, v, p)
|
40
|
+
return value
|
41
|
+
return v
|
42
|
+
|
43
|
+
return _rebind_fn
|
44
|
+
|
45
|
+
|
46
|
+
def patch_lm(lm: Union[lf.LanguageModel, pg.hyper.OneOf]): # pylint: disable=redefined-outer-name
|
47
|
+
"""Patches the LLM of evaluations."""
|
48
|
+
return patch_member(base.Evaluable, "lm", lm)
|
49
|
+
|
50
|
+
|
51
|
+
def patch_parsing_lm(lm: Union[lf.LanguageModel, pg.hyper.OneOf]): # pylint: disable=redefined-outer-name
|
52
|
+
"""Patches the parsing LLM of evaluations."""
|
53
|
+
return patch_member(base.Evaluable, "parsing_lm", lm)
|
54
|
+
|
55
|
+
|
56
|
+
def patch_schema_fn(schema_fn: Union[pg.Functor, pg.hyper.OneOf]):
|
57
|
+
"""Patches the schema_fn of evaluations."""
|
58
|
+
return patch_member(base.Evaluable, "schema_fn", schema_fn)
|
59
|
+
|
60
|
+
|
61
|
+
def patch_prompt(prompt: Union[str, lf.Template, pg.hyper.OneOf]):
|
62
|
+
"""Patches the prompt of evaluations."""
|
63
|
+
return patch_member(base.Evaluable, "prompt", prompt)
|
64
|
+
|
65
|
+
|
66
|
+
def patch_inputs(inputs: Union[pg.Functor, pg.hyper.OneOf]):
|
67
|
+
"""Patches the inputs used in evaluations."""
|
68
|
+
return patch_member(base.Evaluable, "inputs", inputs)
|
69
|
+
|
70
|
+
|
71
|
+
def patch_additional_args(**kwargs):
|
72
|
+
"""Patches additional_args."""
|
73
|
+
|
74
|
+
def value_fn(k, unused_v, p):
|
75
|
+
# We infer the symbolic value for the old args, as it might be a
|
76
|
+
# contextual attribute referring to its containing object.
|
77
|
+
old_args = p.sym_inferred(k.key)
|
78
|
+
if old_args:
|
79
|
+
old_args = dict(old_args)
|
80
|
+
old_args.update(kwargs)
|
81
|
+
return old_args
|
82
|
+
return kwargs
|
83
|
+
|
84
|
+
return patch_member(base.Evaluable, "additional_args", value_fn)
|
85
|
+
|
86
|
+
|
87
|
+
#
|
88
|
+
# String-based patching.
|
89
|
+
#
|
90
|
+
|
91
|
+
_NAMED_MODELS = {
|
92
|
+
# GPT models.
|
93
|
+
"gpt35turbo": lf_llms.Gpt35Turbo,
|
94
|
+
"gpt35turbo16k": lf_llms.Gpt35Turbo16K,
|
95
|
+
"gpt4": lf_llms.Gpt4,
|
96
|
+
"gpt4turbo": lf_llms.Gpt4Turbo,
|
97
|
+
# Anthropic models.
|
98
|
+
"haiku": lf_llms.Claude3Haiku,
|
99
|
+
"claude3haiku": lf_llms.Claude3Haiku,
|
100
|
+
"opus": lf_llms.Claude3Opus,
|
101
|
+
"claude3opus": lf_llms.Claude3Opus,
|
102
|
+
"sonnet": lf_llms.Claude3Sonnet,
|
103
|
+
"claude3sonnet": lf_llms.Claude3Opus,
|
104
|
+
}
|
105
|
+
|
106
|
+
|
107
|
+
def model_by_name(name: str) -> lf.LanguageModel:
|
108
|
+
"""Gets model by name."""
|
109
|
+
name = name.strip().lower()
|
110
|
+
if name in _NAMED_MODELS:
|
111
|
+
return _NAMED_MODELS[name]()
|
112
|
+
raise ValueError(f"Unknown model name: {name}")
|
113
|
+
|
114
|
+
|
115
|
+
@pg.patcher(auto_typing=True)
|
116
|
+
def lm(unused_eval, models: list[str]):
|
117
|
+
"""Patch the LM used for benchmarking."""
|
118
|
+
return patch_lm(pg.oneof([model_by_name(name) for name in models]))
|
119
|
+
|
120
|
+
|
121
|
+
@pg.patcher(auto_typing=True)
|
122
|
+
def temperature(unused_eval, value: float):
|
123
|
+
"""Patch the temperature used for benchmarking."""
|
124
|
+
return patch_member(lf.LMSamplingOptions, "temperature", value)
|
125
|
+
|
126
|
+
|
127
|
+
@pg.patcher(auto_typing=True)
|
128
|
+
def max_tokens(unused_eval, value: int | None):
|
129
|
+
"""Patch the temperature used for benchmarking."""
|
130
|
+
return patch_member(lf.LMSamplingOptions, "max_tokens", value)
|
@@ -0,0 +1,170 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Tests for evaluation patching."""
|
15
|
+
|
16
|
+
import unittest
|
17
|
+
from langfun.core import llms as lf_llms
|
18
|
+
from langfun.core.eval import base
|
19
|
+
from langfun.core.eval import patching
|
20
|
+
import pyglove as pg
|
21
|
+
|
22
|
+
|
23
|
+
class PatchingCommonTest(unittest.TestCase):
|
24
|
+
|
25
|
+
def test_patch_member(self):
|
26
|
+
class A(pg.Object):
|
27
|
+
x: int = 1
|
28
|
+
|
29
|
+
class B(pg.Object):
|
30
|
+
a: A
|
31
|
+
|
32
|
+
b = B(A())
|
33
|
+
pg.patch(b, [patching.patch_member(A, 'x', 2)])
|
34
|
+
self.assertEqual(b, B(A(2)))
|
35
|
+
|
36
|
+
def test_patch_args(self):
|
37
|
+
s = base.Suite(
|
38
|
+
[base.Evaluation(inputs=base.as_inputs([1]))],
|
39
|
+
additional_args=dict(x=1, y=2),
|
40
|
+
)
|
41
|
+
pg.patch(s, [patching.patch_additional_args(x=3, z=4)])
|
42
|
+
self.assertTrue(
|
43
|
+
pg.eq(
|
44
|
+
s,
|
45
|
+
base.Suite(
|
46
|
+
[
|
47
|
+
base.Evaluation(
|
48
|
+
inputs=base.as_inputs([1]),
|
49
|
+
additional_args=dict(x=3, y=2, z=4),
|
50
|
+
)
|
51
|
+
],
|
52
|
+
additional_args=dict(x=3, y=2, z=4),
|
53
|
+
),
|
54
|
+
)
|
55
|
+
)
|
56
|
+
|
57
|
+
def test_patch_lm(self):
|
58
|
+
s = base.Suite(
|
59
|
+
[base.Evaluation(inputs=base.as_inputs([1]))],
|
60
|
+
lm=lf_llms.Gpt35Turbo(),
|
61
|
+
)
|
62
|
+
pg.patch(
|
63
|
+
s, [patching.patch_lm(pg.oneof([lf_llms.Gpt35Turbo(), lf_llms.Gpt4()]))]
|
64
|
+
)
|
65
|
+
self.assertTrue(
|
66
|
+
pg.eq(
|
67
|
+
s,
|
68
|
+
base.Suite(
|
69
|
+
[
|
70
|
+
base.Evaluation(
|
71
|
+
inputs=base.as_inputs([1]),
|
72
|
+
lm=pg.oneof([lf_llms.Gpt35Turbo(), lf_llms.Gpt4()]),
|
73
|
+
)
|
74
|
+
],
|
75
|
+
lm=pg.oneof([lf_llms.Gpt35Turbo(), lf_llms.Gpt4()]),
|
76
|
+
),
|
77
|
+
)
|
78
|
+
)
|
79
|
+
|
80
|
+
def test_patch_parsing_lm(self):
|
81
|
+
s = base.Suite(
|
82
|
+
[base.Evaluation(inputs=base.as_inputs([1]))],
|
83
|
+
lm=lf_llms.Gpt4(),
|
84
|
+
)
|
85
|
+
pg.patch(s, [patching.patch_parsing_lm(lf_llms.Gpt35Turbo())])
|
86
|
+
self.assertTrue(
|
87
|
+
pg.eq(
|
88
|
+
s,
|
89
|
+
base.Suite(
|
90
|
+
[
|
91
|
+
base.Evaluation(
|
92
|
+
inputs=base.as_inputs([1]),
|
93
|
+
lm=lf_llms.Gpt4(),
|
94
|
+
parsing_lm=lf_llms.Gpt35Turbo(),
|
95
|
+
)
|
96
|
+
],
|
97
|
+
# NOTE(daiyip): Suite does not have `parsing_lm` as one of its
|
98
|
+
# variable keyword fields yet, so patching does not add to it.
|
99
|
+
# This is okay since we only care about the leaf nodes.
|
100
|
+
lm=lf_llms.Gpt4(),
|
101
|
+
),
|
102
|
+
)
|
103
|
+
)
|
104
|
+
|
105
|
+
def test_patch_prompt(self):
|
106
|
+
e = base.Evaluation(inputs=base.as_inputs([1]))
|
107
|
+
pg.patch(e, [patching.patch_prompt('Q: {{example.question}}')])
|
108
|
+
self.assertTrue(
|
109
|
+
pg.eq(
|
110
|
+
e,
|
111
|
+
base.Evaluation(
|
112
|
+
inputs=base.as_inputs([1]),
|
113
|
+
prompt='Q: {{example.question}}',
|
114
|
+
),
|
115
|
+
)
|
116
|
+
)
|
117
|
+
|
118
|
+
def test_patch_inputs(self):
|
119
|
+
e = base.Evaluation(inputs=base.as_inputs([1]))
|
120
|
+
pg.patch(e, [patching.patch_inputs(base.as_inputs([2]))])
|
121
|
+
self.assertTrue(
|
122
|
+
pg.eq(
|
123
|
+
e,
|
124
|
+
base.Evaluation(
|
125
|
+
inputs=base.as_inputs([2]),
|
126
|
+
),
|
127
|
+
)
|
128
|
+
)
|
129
|
+
|
130
|
+
def test_patch_schema_fn(self):
|
131
|
+
@pg.functor()
|
132
|
+
def int_schema():
|
133
|
+
return int
|
134
|
+
|
135
|
+
e = base.Evaluation(inputs=base.as_inputs([1]))
|
136
|
+
pg.patch(e, [patching.patch_schema_fn(int_schema())])
|
137
|
+
self.assertTrue(
|
138
|
+
pg.eq(
|
139
|
+
e,
|
140
|
+
base.Evaluation(
|
141
|
+
inputs=base.as_inputs([1]),
|
142
|
+
schema_fn=int_schema(),
|
143
|
+
),
|
144
|
+
)
|
145
|
+
)
|
146
|
+
|
147
|
+
|
148
|
+
class StringPatcheTest(unittest.TestCase):
|
149
|
+
|
150
|
+
def test_lm(self):
|
151
|
+
target = pg.patch(
|
152
|
+
base.Evaluation(inputs=base.as_inputs([1])),
|
153
|
+
['lm?haiku:gpt4', 'max_tokens?1024', 'temperature?0.7'],
|
154
|
+
)
|
155
|
+
self.assertEqual(
|
156
|
+
target.lm,
|
157
|
+
pg.oneof([
|
158
|
+
lf_llms.Claude3Haiku(temperature=0.7, max_tokens=1024),
|
159
|
+
lf_llms.Gpt4(temperature=0.7, max_tokens=1024),
|
160
|
+
]),
|
161
|
+
)
|
162
|
+
with self.assertRaisesRegex(ValueError, 'Unknown model name'):
|
163
|
+
pg.patch(
|
164
|
+
base.Evaluation(inputs=base.as_inputs([1])),
|
165
|
+
['lm?gpt2'],
|
166
|
+
)
|
167
|
+
|
168
|
+
|
169
|
+
if __name__ == '__main__':
|
170
|
+
unittest.main()
|
langfun/core/eval/scoring.py
CHANGED
@@ -61,8 +61,18 @@ class Scoring(base.Evaluation):
|
|
61
61
|
super()._reset()
|
62
62
|
self._scored = []
|
63
63
|
|
64
|
-
def
|
64
|
+
def audit_processed(
|
65
|
+
self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
|
66
|
+
) -> None:
|
65
67
|
score = self.score(example, output)
|
68
|
+
|
69
|
+
if dryrun:
|
70
|
+
lf.console.write('')
|
71
|
+
lf.console.write(
|
72
|
+
str(score),
|
73
|
+
title='SCORE',
|
74
|
+
color='blue',
|
75
|
+
)
|
66
76
|
self._scored.append((example, output, score, message))
|
67
77
|
|
68
78
|
@abc.abstractmethod
|
@@ -103,8 +113,8 @@ class Scoring(base.Evaluation):
|
|
103
113
|
m.total,
|
104
114
|
)
|
105
115
|
|
106
|
-
def
|
107
|
-
result = super().
|
116
|
+
def finalize(self) -> pg.Dict:
|
117
|
+
result = super().finalize()
|
108
118
|
result.metrics.update(
|
109
119
|
num_scored=self.num_scored,
|
110
120
|
score_rate=self.score_rate,
|
@@ -118,19 +128,18 @@ class Scoring(base.Evaluation):
|
|
118
128
|
super().save(definition, result, report)
|
119
129
|
|
120
130
|
if result:
|
121
|
-
|
122
|
-
def force_dict(v):
|
123
|
-
return pg.object_utils.json_conversion.strip_types(pg.to_json(v))
|
124
|
-
|
125
131
|
# Save scored.
|
126
132
|
pg.save(
|
127
133
|
[
|
128
134
|
# We force the output to be dict as its type may be defined
|
129
135
|
# within functors which could be deserialized.
|
130
|
-
pg.Dict(input=input, output=
|
136
|
+
pg.Dict(input=input, output=output, score=score)
|
131
137
|
for input, output, score, _ in self.scored
|
132
138
|
],
|
133
139
|
os.path.join(self.dir, Scoring.SCORED_JSON),
|
140
|
+
# We force the input and output to be dict so it does not depend on
|
141
|
+
# the downstream to serialize.
|
142
|
+
force_dict=True,
|
134
143
|
)
|
135
144
|
|
136
145
|
if report:
|
@@ -159,7 +168,7 @@ class Scoring(base.Evaluation):
|
|
159
168
|
)
|
160
169
|
)
|
161
170
|
|
162
|
-
def
|
171
|
+
def _render_summary_metrics(self, s: io.StringIO) -> None:
|
163
172
|
"""Renders metrics in HTML."""
|
164
173
|
assert self.result is not None
|
165
174
|
m = self.result.metrics
|
@@ -173,7 +182,7 @@ class Scoring(base.Evaluation):
|
|
173
182
|
)
|
174
183
|
)
|
175
184
|
s.write(' | ')
|
176
|
-
super().
|
185
|
+
super()._render_summary_metrics(s)
|
177
186
|
|
178
187
|
def _render_scored(self, s: io.StringIO) -> None:
|
179
188
|
"""Formats the matched cases into html."""
|
@@ -81,7 +81,7 @@ class ScoringTest(unittest.TestCase):
|
|
81
81
|
s.result,
|
82
82
|
dict(
|
83
83
|
experiment_setup=dict(
|
84
|
-
id='ConstraintFollowing@
|
84
|
+
id='ConstraintFollowing@5c88a5eb',
|
85
85
|
dir=s.dir,
|
86
86
|
model='StaticSequence',
|
87
87
|
prompt_template='{{example}}',
|
@@ -98,10 +98,16 @@ class ScoringTest(unittest.TestCase):
|
|
98
98
|
total=2,
|
99
99
|
failures=0,
|
100
100
|
failure_rate=0.0,
|
101
|
+
oop_failures=0,
|
102
|
+
oop_failure_rate=0.0,
|
103
|
+
non_oop_failures=0,
|
104
|
+
non_oop_failure_rate=0.0,
|
105
|
+
failure_breakdown={},
|
101
106
|
num_scored=2,
|
102
107
|
score_rate=1.0,
|
103
108
|
avg_score=0.5,
|
104
109
|
),
|
110
|
+
usage=s.result.usage,
|
105
111
|
),
|
106
112
|
)
|
107
113
|
self.assertTrue(
|
@@ -123,7 +129,12 @@ class ScoringTest(unittest.TestCase):
|
|
123
129
|
)
|
124
130
|
self.assertTrue(
|
125
131
|
os.path.exists(
|
126
|
-
os.path.join(s.dir, scoring.Scoring.
|
132
|
+
os.path.join(s.dir, scoring.Scoring.OOP_FAILURES_JSON)
|
133
|
+
)
|
134
|
+
)
|
135
|
+
self.assertTrue(
|
136
|
+
os.path.exists(
|
137
|
+
os.path.join(s.dir, scoring.Scoring.NON_OOP_FAILURES_JSON)
|
127
138
|
)
|
128
139
|
)
|
129
140
|
self.assertTrue(
|
@@ -142,7 +153,14 @@ class ScoringTest(unittest.TestCase):
|
|
142
153
|
self.assertTrue(
|
143
154
|
os.path.exists(
|
144
155
|
os.path.join(
|
145
|
-
s.dir, scoring.Scoring.
|
156
|
+
s.dir, scoring.Scoring.OOP_FAILURES_HTML
|
157
|
+
)
|
158
|
+
)
|
159
|
+
)
|
160
|
+
self.assertTrue(
|
161
|
+
os.path.exists(
|
162
|
+
os.path.join(
|
163
|
+
s.dir, scoring.Scoring.NON_OOP_FAILURES_HTML
|
146
164
|
)
|
147
165
|
)
|
148
166
|
)
|
langfun/core/langfunc.py
CHANGED
@@ -14,7 +14,7 @@
|
|
14
14
|
"""LangFunc: Language-based functions."""
|
15
15
|
|
16
16
|
import dataclasses
|
17
|
-
from typing import Annotated, Type
|
17
|
+
from typing import Annotated, Type
|
18
18
|
|
19
19
|
from langfun.core import component
|
20
20
|
from langfun.core import language_model
|
@@ -261,7 +261,6 @@ class LangFunc(
|
|
261
261
|
if lm_input is None:
|
262
262
|
lm_input = self.render(**kwargs)
|
263
263
|
|
264
|
-
lm_input.tag(message_lib.Message.TAG_LM_INPUT)
|
265
264
|
if skip_lm:
|
266
265
|
return lm_input
|
267
266
|
|
@@ -270,10 +269,6 @@ class LangFunc(
|
|
270
269
|
# Send rendered text to LM.
|
271
270
|
lm_output = self.lm(lm_input, cache_seed=cache_seed)
|
272
271
|
|
273
|
-
# Track the input as the source of the output.
|
274
|
-
lm_output.source = lm_input
|
275
|
-
lm_output.tag(message_lib.Message.TAG_LM_RESPONSE)
|
276
|
-
|
277
272
|
# Transform the output message.
|
278
273
|
lm_output = self.transform_output(lm_output)
|
279
274
|
lm_output.tag(message_lib.Message.TAG_LM_OUTPUT)
|
@@ -333,22 +328,6 @@ class LangFunc(
|
|
333
328
|
"""Transforms the output message before returning from __call__."""
|
334
329
|
return lm_output
|
335
330
|
|
336
|
-
@classmethod
|
337
|
-
def from_value(
|
338
|
-
cls, value: Union[str, template_lib.Template], **kwargs
|
339
|
-
) -> 'LangFunc':
|
340
|
-
"""Create a LangFunc object from a string or template."""
|
341
|
-
if isinstance(value, LangFunc):
|
342
|
-
return value
|
343
|
-
if isinstance(value, template_lib.Template):
|
344
|
-
lfun = LangFunc(value.template_str, **kwargs)
|
345
|
-
# So lfun could acccess all attributes from value.
|
346
|
-
lfun.sym_setparent(value)
|
347
|
-
return lfun
|
348
|
-
if isinstance(value, str):
|
349
|
-
return LangFunc(template_str=value, **kwargs)
|
350
|
-
return LangFunc('{{input}}', input=value, **kwargs)
|
351
|
-
|
352
331
|
|
353
332
|
# Register converter from str to LangFunc, therefore we can always
|
354
333
|
# pass strs to attributes that accept LangFunc.
|
langfun/core/langfunc_test.py
CHANGED
@@ -57,6 +57,10 @@ class BasicTest(unittest.TestCase):
|
|
57
57
|
l2 = LangFunc.from_value(l1)
|
58
58
|
self.assertIs(l2, l1)
|
59
59
|
|
60
|
+
l3 = LangFunc.from_value(l1, x=1)
|
61
|
+
self.assertIsNot(l3, l1)
|
62
|
+
self.assertTrue(pg.eq(l3, LangFunc('Hello', x=1)))
|
63
|
+
|
60
64
|
c = template_lib.Template(
|
61
65
|
'{{x}} + {{l}}',
|
62
66
|
x=1,
|
@@ -82,7 +86,9 @@ class LangFuncCallTest(unittest.TestCase):
|
|
82
86
|
self.assertEqual(i.tags, ['rendered'])
|
83
87
|
|
84
88
|
r = l()
|
85
|
-
self.assertEqual(
|
89
|
+
self.assertEqual(
|
90
|
+
r, message.AIMessage('Hello!!!', score=0.0, logprobs=None, usage=None)
|
91
|
+
)
|
86
92
|
self.assertEqual(r.tags, ['lm-response', 'lm-output'])
|
87
93
|
self.assertEqual(r.source, message.UserMessage('Hello'))
|
88
94
|
self.assertEqual(r.source.tags, ['rendered', 'lm-input'])
|
@@ -92,8 +98,8 @@ class LangFuncCallTest(unittest.TestCase):
|
|
92
98
|
self.assertEqual(
|
93
99
|
repr(l),
|
94
100
|
"LangFunc(template_str='Hello', clean=True,"
|
95
|
-
' lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=
|
96
|
-
' max_tokens=
|
101
|
+
' lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=None,'
|
102
|
+
' max_tokens=None, n=1, top_k=40, top_p=None, stop=None,'
|
97
103
|
' random_seed=None, logprobs=False, top_logprobs=None), cache=None,'
|
98
104
|
' max_concurrency=None, timeout=120.0, max_attempts=5,'
|
99
105
|
' retry_interval=(5, 60), exponential_backoff=True, debug=False))',
|
@@ -106,7 +112,7 @@ class LangFuncCallTest(unittest.TestCase):
|
|
106
112
|
self.assertEqual(l.render(), 'Hello')
|
107
113
|
r = l()
|
108
114
|
self.assertEqual(
|
109
|
-
r, message.AIMessage('Hello!!!', score=0.0, logprobs=None)
|
115
|
+
r, message.AIMessage('Hello!!!', score=0.0, logprobs=None, usage=None)
|
110
116
|
)
|
111
117
|
self.assertEqual(r.tags, ['lm-response', 'lm-output'])
|
112
118
|
|