langfun 0.0.2.dev20240319__py3-none-any.whl → 0.0.2.dev20240429__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +2 -0
- langfun/core/__init__.py +1 -0
- langfun/core/coding/python/correction.py +0 -7
- langfun/core/component.py +6 -0
- langfun/core/component_test.py +1 -0
- langfun/core/eval/__init__.py +2 -0
- langfun/core/eval/base.py +240 -37
- langfun/core/eval/base_test.py +52 -18
- langfun/core/eval/matching.py +26 -9
- langfun/core/eval/matching_test.py +3 -4
- langfun/core/eval/scoring.py +15 -6
- langfun/core/eval/scoring_test.py +2 -2
- langfun/core/langfunc.py +0 -5
- langfun/core/langfunc_test.py +6 -4
- langfun/core/language_model.py +124 -24
- langfun/core/language_model_test.py +249 -26
- langfun/core/llms/__init__.py +24 -5
- langfun/core/llms/anthropic.py +263 -0
- langfun/core/llms/anthropic_test.py +167 -0
- langfun/core/llms/cache/in_memory_test.py +37 -28
- langfun/core/llms/fake.py +31 -22
- langfun/core/llms/fake_test.py +122 -11
- langfun/core/llms/{gemini.py → google_genai.py} +117 -15
- langfun/core/llms/{gemini_test.py → google_genai_test.py} +83 -15
- langfun/core/llms/groq.py +260 -0
- langfun/core/llms/groq_test.py +170 -0
- langfun/core/llms/llama_cpp.py +3 -1
- langfun/core/llms/openai.py +97 -79
- langfun/core/llms/openai_test.py +285 -59
- langfun/core/modalities/video.py +5 -2
- langfun/core/structured/__init__.py +3 -0
- langfun/core/structured/completion_test.py +2 -2
- langfun/core/structured/function_generation.py +245 -0
- langfun/core/structured/function_generation_test.py +329 -0
- langfun/core/structured/mapping.py +59 -3
- langfun/core/structured/mapping_test.py +17 -0
- langfun/core/structured/parsing.py +2 -1
- langfun/core/structured/parsing_test.py +18 -13
- langfun/core/structured/prompting.py +27 -6
- langfun/core/structured/prompting_test.py +79 -12
- langfun/core/structured/schema.py +25 -22
- langfun/core/structured/schema_generation.py +2 -3
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +42 -27
- langfun/core/template.py +125 -10
- langfun/core/template_test.py +75 -0
- langfun/core/templates/selfplay_test.py +6 -2
- {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/METADATA +3 -2
- {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/RECORD +52 -46
- {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/top_level.txt +0 -0
langfun/core/eval/base_test.py
CHANGED
@@ -70,8 +70,7 @@ def eval_set(
|
|
70
70
|
"""Creates an evaluation object for testing."""
|
71
71
|
tmp_dir = tempfile.gettempdir()
|
72
72
|
return cls(
|
73
|
-
|
74
|
-
root_dir=tmp_dir,
|
73
|
+
root_dir=os.path.join(tmp_dir, eval_id),
|
75
74
|
inputs=base.as_inputs([
|
76
75
|
pg.Dict(question='Compute 1 + 1'),
|
77
76
|
pg.Dict(question='Compute 1 + 2'),
|
@@ -102,7 +101,7 @@ class EvaluationTest(unittest.TestCase):
|
|
102
101
|
self.assertEqual(s.dir, os.path.join(s.root_dir, s.id))
|
103
102
|
self.assertEqual(s.hash, s.clone().hash)
|
104
103
|
# Test persistent hash.
|
105
|
-
self.assertEqual(s.hash, '
|
104
|
+
self.assertEqual(s.hash, 'ae86c703')
|
106
105
|
self.assertEqual(
|
107
106
|
s.hash, s.clone(override={'max_workers': 2, 'lm.timeout': 20}).hash
|
108
107
|
)
|
@@ -195,6 +194,7 @@ class EvaluationTest(unittest.TestCase):
|
|
195
194
|
cache_seed=0,
|
196
195
|
score=1.0,
|
197
196
|
logprobs=None,
|
197
|
+
usage=lf.LMSamplingUsage(387, 24, 411),
|
198
198
|
tags=['lm-response', 'lm-output', 'transformed'],
|
199
199
|
),
|
200
200
|
)
|
@@ -210,7 +210,7 @@ class EvaluationTest(unittest.TestCase):
|
|
210
210
|
s.result,
|
211
211
|
dict(
|
212
212
|
experiment_setup=dict(
|
213
|
-
id='
|
213
|
+
id='Evaluation@0fade07d',
|
214
214
|
dir=s.dir,
|
215
215
|
model='StaticSequence',
|
216
216
|
prompt_template='{{example.question}}',
|
@@ -221,6 +221,14 @@ class EvaluationTest(unittest.TestCase):
|
|
221
221
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
222
222
|
),
|
223
223
|
metrics=dict(total=2, failures=1, failure_rate=0.5),
|
224
|
+
usage=dict(
|
225
|
+
total_prompt_tokens=774,
|
226
|
+
total_completion_tokens=25,
|
227
|
+
num_usages=2,
|
228
|
+
average_prompt_tokens=387,
|
229
|
+
average_completion_tokens=12,
|
230
|
+
average_total_tokens=399,
|
231
|
+
),
|
224
232
|
),
|
225
233
|
)
|
226
234
|
self.assertTrue(
|
@@ -229,13 +237,23 @@ class EvaluationTest(unittest.TestCase):
|
|
229
237
|
os.path.exists(os.path.join(s.dir, base.Evaluation.RESULT_JSON)))
|
230
238
|
self.assertTrue(
|
231
239
|
os.path.exists(os.path.join(s.dir, base.Evaluation.CACHE_JSON)))
|
232
|
-
self.assertTrue(
|
233
|
-
os.path.exists(os.path.join(s.root_dir, base.Evaluation.SUMMARY_HTML))
|
234
|
-
)
|
235
240
|
self.assertTrue(
|
236
241
|
os.path.exists(os.path.join(s.dir, base.Evaluation.INDEX_HTML)))
|
237
242
|
self.assertTrue(
|
238
243
|
os.path.exists(os.path.join(s.dir, base.Evaluation.FAILURES_HTML)))
|
244
|
+
self.assertTrue(
|
245
|
+
os.path.exists(os.path.join(s.root_dir, base.Evaluation.SUMMARY_HTML))
|
246
|
+
)
|
247
|
+
# Check summary JSON.
|
248
|
+
summary_json = os.path.join(
|
249
|
+
s.root_dir, base.Evaluation.SUMMARY_HTML.replace('.html', '.json')
|
250
|
+
)
|
251
|
+
self.assertTrue(os.path.exists(summary_json))
|
252
|
+
summary = pg.load(summary_json, force_dict=True)
|
253
|
+
self.assertIn('Evaluation', summary)
|
254
|
+
self.assertEqual(len(summary['Evaluation']), 1)
|
255
|
+
self.assertIsNotNone(summary['Evaluation'][0].experiment)
|
256
|
+
self.assertIsNotNone(summary['Evaluation'][0].metrics)
|
239
257
|
|
240
258
|
def test_run_wihtout_save(self):
|
241
259
|
lm = fake.StaticSequence([
|
@@ -275,8 +293,11 @@ class EvaluationTest(unittest.TestCase):
|
|
275
293
|
s = eval_set(
|
276
294
|
'run_filter_test', pg.oneof(['call', 'query']),
|
277
295
|
schema_fn=answer_schema(), lm=lm)
|
296
|
+
result = s.run(
|
297
|
+
filter=lambda x: x.method == 'query', dryrun=True, summary=False
|
298
|
+
)
|
278
299
|
self.assertEqual(
|
279
|
-
|
300
|
+
result,
|
280
301
|
{
|
281
302
|
s.children[0].id: None,
|
282
303
|
s.children[1].id: dict(
|
@@ -292,7 +313,8 @@ class EvaluationTest(unittest.TestCase):
|
|
292
313
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
293
314
|
),
|
294
315
|
metrics=dict(total=2, failures=0, failure_rate=0.0),
|
295
|
-
|
316
|
+
usage=s.children[1].result.usage,
|
317
|
+
),
|
296
318
|
},
|
297
319
|
)
|
298
320
|
|
@@ -302,7 +324,6 @@ class EvaluationTest(unittest.TestCase):
|
|
302
324
|
'3',
|
303
325
|
])
|
304
326
|
s = base.Evaluation(
|
305
|
-
id='search_space_test',
|
306
327
|
root_dir=tempfile.gettempdir(),
|
307
328
|
inputs=base.as_inputs([
|
308
329
|
pg.Dict(question='Compute 1 + 1'),
|
@@ -323,11 +344,10 @@ class EvaluationTest(unittest.TestCase):
|
|
323
344
|
s.children[0].dir, os.path.join(s.root_dir, s.children[0].id)
|
324
345
|
)
|
325
346
|
# Test persistent hash.
|
326
|
-
self.assertEqual(s.hash, '
|
347
|
+
self.assertEqual(s.hash, 'b66a4e88')
|
327
348
|
|
328
349
|
summary = s.run(verbose=True)
|
329
350
|
self.assertEqual(len(summary.evaluations), 2)
|
330
|
-
|
331
351
|
self.assertEqual(
|
332
352
|
s.result,
|
333
353
|
{
|
@@ -344,6 +364,7 @@ class EvaluationTest(unittest.TestCase):
|
|
344
364
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
345
365
|
),
|
346
366
|
metrics=dict(total=2, failures=1, failure_rate=0.5),
|
367
|
+
usage=s.children[0].result.usage,
|
347
368
|
),
|
348
369
|
s.children[1].id: dict(
|
349
370
|
experiment_setup=dict(
|
@@ -358,6 +379,7 @@ class EvaluationTest(unittest.TestCase):
|
|
358
379
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
359
380
|
),
|
360
381
|
metrics=dict(total=2, failures=1, failure_rate=0.5),
|
382
|
+
usage=s.children[1].result.usage,
|
361
383
|
),
|
362
384
|
},
|
363
385
|
)
|
@@ -439,7 +461,6 @@ class SuiteTest(unittest.TestCase):
|
|
439
461
|
'3',
|
440
462
|
] * 5)
|
441
463
|
s = base.Suite(
|
442
|
-
'suite_run_test',
|
443
464
|
[
|
444
465
|
eval_set('run_test_1', 'query', schema_fn=answer_schema()),
|
445
466
|
# A suite of search space. Two of the sub-experiments are identical,
|
@@ -451,7 +472,7 @@ class SuiteTest(unittest.TestCase):
|
|
451
472
|
lm=lm
|
452
473
|
)
|
453
474
|
# Test for persistent hash.
|
454
|
-
self.assertEqual(s.hash, '
|
475
|
+
self.assertEqual(s.hash, '26e6cc25')
|
455
476
|
s.run()
|
456
477
|
expected = {
|
457
478
|
s.children[0].id: dict(
|
@@ -467,6 +488,7 @@ class SuiteTest(unittest.TestCase):
|
|
467
488
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
468
489
|
),
|
469
490
|
metrics=dict(total=2, failures=1, failure_rate=0.5),
|
491
|
+
usage=s.children[0].result.usage,
|
470
492
|
),
|
471
493
|
s.children[1].id: {
|
472
494
|
s.children[1]
|
@@ -484,6 +506,7 @@ class SuiteTest(unittest.TestCase):
|
|
484
506
|
use_cache=True, num_queries=4, num_hits=1, num_updates=3
|
485
507
|
),
|
486
508
|
metrics=dict(total=2, failures=2, failure_rate=1.0),
|
509
|
+
usage=s.children[1].children[0].result.usage,
|
487
510
|
),
|
488
511
|
s.children[1]
|
489
512
|
.children[2]
|
@@ -503,6 +526,7 @@ class SuiteTest(unittest.TestCase):
|
|
503
526
|
num_updates=2,
|
504
527
|
),
|
505
528
|
metrics=dict(total=2, failures=1, failure_rate=0.5),
|
529
|
+
usage=s.children[1].children[2].result.usage,
|
506
530
|
),
|
507
531
|
},
|
508
532
|
}
|
@@ -548,7 +572,6 @@ class SummaryTest(unittest.TestCase):
|
|
548
572
|
def _eval_set(self, root_dir):
|
549
573
|
return base.Suite(id='select_test', children=[
|
550
574
|
TaskA(
|
551
|
-
id='task_a',
|
552
575
|
inputs=base.as_inputs([
|
553
576
|
pg.Dict(question='Compute 1 + 1'),
|
554
577
|
]),
|
@@ -569,7 +592,6 @@ class SummaryTest(unittest.TestCase):
|
|
569
592
|
max_workers=1,
|
570
593
|
),
|
571
594
|
TaskB(
|
572
|
-
id='task_b',
|
573
595
|
inputs=base.as_inputs([
|
574
596
|
pg.Dict(question='Compute 1 + 1'),
|
575
597
|
]),
|
@@ -650,10 +672,10 @@ class SummaryTest(unittest.TestCase):
|
|
650
672
|
len(base.Summary.from_dirs(root_dir)), 2 * 2 * 2 * 2 + 2 * 1 * 1 * 2
|
651
673
|
)
|
652
674
|
self.assertEqual(
|
653
|
-
len(base.Summary.from_dirs(root_dir, '
|
675
|
+
len(base.Summary.from_dirs(root_dir, 'TaskB')), 2 * 1 * 1 * 2
|
654
676
|
)
|
655
677
|
self.assertEqual(
|
656
|
-
len(base.Summary.from_dirs(root_dir, ('
|
678
|
+
len(base.Summary.from_dirs(root_dir, ('TaskA'))), 2 * 2 * 2 * 2
|
657
679
|
)
|
658
680
|
|
659
681
|
def test_monitor(self):
|
@@ -676,5 +698,17 @@ class SummaryTest(unittest.TestCase):
|
|
676
698
|
self.assertTrue(pg.io.path_exists(summary_file))
|
677
699
|
|
678
700
|
|
701
|
+
class AppRunTest(unittest.TestCase):
|
702
|
+
|
703
|
+
def test_app_run(self):
|
704
|
+
lm = fake.StaticSequence(['two', 'Solution(final_answer=2)'])
|
705
|
+
try:
|
706
|
+
base.app_run(
|
707
|
+
eval_set('app_run_test', 'query', schema_fn=answer_schema(), lm=lm)
|
708
|
+
)
|
709
|
+
except SystemExit:
|
710
|
+
pass
|
711
|
+
|
712
|
+
|
679
713
|
if __name__ == '__main__':
|
680
714
|
unittest.main()
|
langfun/core/eval/matching.py
CHANGED
@@ -86,9 +86,26 @@ class Matching(base.Evaluation):
|
|
86
86
|
self._matches = []
|
87
87
|
self._mismatches = []
|
88
88
|
|
89
|
-
def
|
89
|
+
def audit_processed(
|
90
|
+
self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
|
91
|
+
) -> None:
|
90
92
|
groundtruth = self.groundtruth(example)
|
91
93
|
answer = self.answer(output, example)
|
94
|
+
|
95
|
+
if dryrun:
|
96
|
+
lf.console.write('')
|
97
|
+
lf.console.write(
|
98
|
+
str(groundtruth),
|
99
|
+
title='GROUDTRUTH',
|
100
|
+
color='green',
|
101
|
+
)
|
102
|
+
lf.console.write('')
|
103
|
+
lf.console.write(
|
104
|
+
str(answer),
|
105
|
+
title='ANSWER',
|
106
|
+
color='blue',
|
107
|
+
)
|
108
|
+
|
92
109
|
if self.match(answer, groundtruth):
|
93
110
|
self._matches.append((example, output, message))
|
94
111
|
else:
|
@@ -155,19 +172,16 @@ class Matching(base.Evaluation):
|
|
155
172
|
super().save(definition, result, report)
|
156
173
|
|
157
174
|
if result:
|
158
|
-
|
159
|
-
def force_dict(v):
|
160
|
-
return pg.object_utils.json_conversion.strip_types(pg.to_json(v))
|
161
|
-
|
162
175
|
# Save matches.
|
163
176
|
pg.save(
|
164
177
|
[
|
165
|
-
|
166
|
-
# within functors which could be deserialized.
|
167
|
-
pg.Dict(input=input, output=force_dict(output))
|
178
|
+
pg.Dict(input=input, output=output)
|
168
179
|
for input, output, _ in self.matches
|
169
180
|
],
|
170
181
|
os.path.join(self.dir, Matching.MATCHES_JSON),
|
182
|
+
# We force the input and output to be dict so it does not depend on
|
183
|
+
# the downstream to serialize.
|
184
|
+
force_dict=True,
|
171
185
|
)
|
172
186
|
|
173
187
|
# Save mismatches.
|
@@ -175,10 +189,13 @@ class Matching(base.Evaluation):
|
|
175
189
|
[
|
176
190
|
# We force the output to be dict as its type may be defined
|
177
191
|
# within functors which could be deserialized.
|
178
|
-
pg.Dict(input=input, output=
|
192
|
+
pg.Dict(input=input, output=output)
|
179
193
|
for input, output, _ in self.mismatches
|
180
194
|
],
|
181
195
|
os.path.join(self.dir, Matching.MISMATCHES_JSON),
|
196
|
+
# We force the input and output to be dict so it does not depend on
|
197
|
+
# the downstream to serialize.
|
198
|
+
force_dict=True,
|
182
199
|
)
|
183
200
|
|
184
201
|
if report:
|
@@ -65,10 +65,8 @@ def eval_set(
|
|
65
65
|
use_cache: bool = True,
|
66
66
|
):
|
67
67
|
"""Creates an evaluation object for testing."""
|
68
|
-
tmp_dir = tempfile.gettempdir()
|
69
68
|
return MyTask(
|
70
|
-
|
71
|
-
root_dir=tmp_dir,
|
69
|
+
root_dir=os.path.join(tempfile.gettempdir(), eval_id),
|
72
70
|
inputs=base.as_inputs([
|
73
71
|
pg.Dict(question='Compute 1 + 1', groundtruth=2),
|
74
72
|
pg.Dict(question='Compute 1 + 2', groundtruth=3),
|
@@ -105,7 +103,7 @@ class MatchingTest(unittest.TestCase):
|
|
105
103
|
s.result,
|
106
104
|
dict(
|
107
105
|
experiment_setup=dict(
|
108
|
-
id='
|
106
|
+
id='MyTask@739a174b',
|
109
107
|
dir=s.dir,
|
110
108
|
model='StaticSequence',
|
111
109
|
prompt_template='{{example.question}}',
|
@@ -127,6 +125,7 @@ class MatchingTest(unittest.TestCase):
|
|
127
125
|
num_mismatches=1,
|
128
126
|
mismatch_rate=0.25,
|
129
127
|
),
|
128
|
+
usage=s.result.usage,
|
130
129
|
),
|
131
130
|
)
|
132
131
|
self.assertTrue(
|
langfun/core/eval/scoring.py
CHANGED
@@ -61,8 +61,18 @@ class Scoring(base.Evaluation):
|
|
61
61
|
super()._reset()
|
62
62
|
self._scored = []
|
63
63
|
|
64
|
-
def
|
64
|
+
def audit_processed(
|
65
|
+
self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
|
66
|
+
) -> None:
|
65
67
|
score = self.score(example, output)
|
68
|
+
|
69
|
+
if dryrun:
|
70
|
+
lf.console.write('')
|
71
|
+
lf.console.write(
|
72
|
+
str(score),
|
73
|
+
title='SCORE',
|
74
|
+
color='blue',
|
75
|
+
)
|
66
76
|
self._scored.append((example, output, score, message))
|
67
77
|
|
68
78
|
@abc.abstractmethod
|
@@ -118,19 +128,18 @@ class Scoring(base.Evaluation):
|
|
118
128
|
super().save(definition, result, report)
|
119
129
|
|
120
130
|
if result:
|
121
|
-
|
122
|
-
def force_dict(v):
|
123
|
-
return pg.object_utils.json_conversion.strip_types(pg.to_json(v))
|
124
|
-
|
125
131
|
# Save scored.
|
126
132
|
pg.save(
|
127
133
|
[
|
128
134
|
# We force the output to be dict as its type may be defined
|
129
135
|
# within functors which could be deserialized.
|
130
|
-
pg.Dict(input=input, output=
|
136
|
+
pg.Dict(input=input, output=output, score=score)
|
131
137
|
for input, output, score, _ in self.scored
|
132
138
|
],
|
133
139
|
os.path.join(self.dir, Scoring.SCORED_JSON),
|
140
|
+
# We force the input and output to be dict so it does not depend on
|
141
|
+
# the downstream to serialize.
|
142
|
+
force_dict=True,
|
134
143
|
)
|
135
144
|
|
136
145
|
if report:
|
@@ -43,7 +43,6 @@ def constrained_by_upperbound(upper_bound: int):
|
|
43
43
|
|
44
44
|
|
45
45
|
class ConstraintFollowing(scoring.Scoring):
|
46
|
-
id = 'constraint_following'
|
47
46
|
inputs = constrained_by_upperbound(1)
|
48
47
|
prompt = '{{example}}'
|
49
48
|
method = 'query'
|
@@ -82,7 +81,7 @@ class ScoringTest(unittest.TestCase):
|
|
82
81
|
s.result,
|
83
82
|
dict(
|
84
83
|
experiment_setup=dict(
|
85
|
-
id='
|
84
|
+
id='ConstraintFollowing@5c88a5eb',
|
86
85
|
dir=s.dir,
|
87
86
|
model='StaticSequence',
|
88
87
|
prompt_template='{{example}}',
|
@@ -103,6 +102,7 @@ class ScoringTest(unittest.TestCase):
|
|
103
102
|
score_rate=1.0,
|
104
103
|
avg_score=0.5,
|
105
104
|
),
|
105
|
+
usage=s.result.usage,
|
106
106
|
),
|
107
107
|
)
|
108
108
|
self.assertTrue(
|
langfun/core/langfunc.py
CHANGED
@@ -261,7 +261,6 @@ class LangFunc(
|
|
261
261
|
if lm_input is None:
|
262
262
|
lm_input = self.render(**kwargs)
|
263
263
|
|
264
|
-
lm_input.tag(message_lib.Message.TAG_LM_INPUT)
|
265
264
|
if skip_lm:
|
266
265
|
return lm_input
|
267
266
|
|
@@ -270,10 +269,6 @@ class LangFunc(
|
|
270
269
|
# Send rendered text to LM.
|
271
270
|
lm_output = self.lm(lm_input, cache_seed=cache_seed)
|
272
271
|
|
273
|
-
# Track the input as the source of the output.
|
274
|
-
lm_output.source = lm_input
|
275
|
-
lm_output.tag(message_lib.Message.TAG_LM_RESPONSE)
|
276
|
-
|
277
272
|
# Transform the output message.
|
278
273
|
lm_output = self.transform_output(lm_output)
|
279
274
|
lm_output.tag(message_lib.Message.TAG_LM_OUTPUT)
|
langfun/core/langfunc_test.py
CHANGED
@@ -82,7 +82,9 @@ class LangFuncCallTest(unittest.TestCase):
|
|
82
82
|
self.assertEqual(i.tags, ['rendered'])
|
83
83
|
|
84
84
|
r = l()
|
85
|
-
self.assertEqual(
|
85
|
+
self.assertEqual(
|
86
|
+
r, message.AIMessage('Hello!!!', score=0.0, logprobs=None, usage=None)
|
87
|
+
)
|
86
88
|
self.assertEqual(r.tags, ['lm-response', 'lm-output'])
|
87
89
|
self.assertEqual(r.source, message.UserMessage('Hello'))
|
88
90
|
self.assertEqual(r.source.tags, ['rendered', 'lm-input'])
|
@@ -92,8 +94,8 @@ class LangFuncCallTest(unittest.TestCase):
|
|
92
94
|
self.assertEqual(
|
93
95
|
repr(l),
|
94
96
|
"LangFunc(template_str='Hello', clean=True,"
|
95
|
-
' lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=
|
96
|
-
' max_tokens=
|
97
|
+
' lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=None,'
|
98
|
+
' max_tokens=None, n=1, top_k=40, top_p=None, stop=None,'
|
97
99
|
' random_seed=None, logprobs=False, top_logprobs=None), cache=None,'
|
98
100
|
' max_concurrency=None, timeout=120.0, max_attempts=5,'
|
99
101
|
' retry_interval=(5, 60), exponential_backoff=True, debug=False))',
|
@@ -106,7 +108,7 @@ class LangFuncCallTest(unittest.TestCase):
|
|
106
108
|
self.assertEqual(l.render(), 'Hello')
|
107
109
|
r = l()
|
108
110
|
self.assertEqual(
|
109
|
-
r, message.AIMessage('Hello!!!', score=0.0, logprobs=None)
|
111
|
+
r, message.AIMessage('Hello!!!', score=0.0, logprobs=None, usage=None)
|
110
112
|
)
|
111
113
|
self.assertEqual(r.tags, ['lm-response', 'lm-output'])
|
112
114
|
|