langfun 0.0.2.dev20240423__py3-none-any.whl → 0.0.2.dev20240428__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +1 -0
- langfun/core/eval/__init__.py +2 -0
- langfun/core/eval/base.py +176 -18
- langfun/core/eval/base_test.py +34 -6
- langfun/core/eval/matching.py +18 -1
- langfun/core/eval/matching_test.py +2 -1
- langfun/core/eval/scoring.py +11 -1
- langfun/core/eval/scoring_test.py +2 -1
- langfun/core/langfunc.py +0 -5
- langfun/core/language_model.py +39 -9
- langfun/core/language_model_test.py +156 -18
- langfun/core/llms/fake_test.py +91 -7
- langfun/core/llms/openai_test.py +202 -17
- langfun/core/structured/__init__.py +1 -0
- langfun/core/structured/completion_test.py +1 -2
- langfun/core/structured/mapping.py +38 -1
- langfun/core/structured/mapping_test.py +17 -0
- langfun/core/structured/parsing_test.py +2 -4
- langfun/core/structured/prompting.py +14 -4
- langfun/core/structured/prompting_test.py +35 -4
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/template.py +99 -2
- langfun/core/template_test.py +66 -0
- {langfun-0.0.2.dev20240423.dist-info → langfun-0.0.2.dev20240428.dist-info}/METADATA +3 -2
- {langfun-0.0.2.dev20240423.dist-info → langfun-0.0.2.dev20240428.dist-info}/RECORD +28 -28
- {langfun-0.0.2.dev20240423.dist-info → langfun-0.0.2.dev20240428.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240423.dist-info → langfun-0.0.2.dev20240428.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240423.dist-info → langfun-0.0.2.dev20240428.dist-info}/top_level.txt +0 -0
langfun/__init__.py
CHANGED
langfun/core/eval/__init__.py
CHANGED
@@ -16,6 +16,8 @@
|
|
16
16
|
# pylint: disable=g-importing-member
|
17
17
|
# pylint: disable=g-bad-import-order
|
18
18
|
|
19
|
+
from langfun.core.eval.base import app_run
|
20
|
+
|
19
21
|
from langfun.core.eval.base import Evaluable
|
20
22
|
from langfun.core.eval.base import Evaluation
|
21
23
|
from langfun.core.eval.base import Suite
|
langfun/core/eval/base.py
CHANGED
@@ -26,6 +26,8 @@ import threading
|
|
26
26
|
import time
|
27
27
|
from typing import Annotated, Any, Callable, Iterator, Literal, Optional, Sequence, Type, Union
|
28
28
|
|
29
|
+
from absl import app
|
30
|
+
from absl import flags
|
29
31
|
import langfun.core as lf
|
30
32
|
import langfun.core.coding as lf_coding
|
31
33
|
from langfun.core.llms.cache import in_memory
|
@@ -538,7 +540,7 @@ class Evaluable(lf.Component):
|
|
538
540
|
f'<div style="color: {text_color}; white-space: pre-wrap;'
|
539
541
|
'padding: 10px; border: 1px solid; margin-top: 10px">'
|
540
542
|
)
|
541
|
-
s.write(m.text)
|
543
|
+
s.write(m.get('formatted_text', m.text))
|
542
544
|
if m.result is not None:
|
543
545
|
s.write(
|
544
546
|
'<div style="color: magenta; white-space: pre-wrap;'
|
@@ -546,6 +548,16 @@ class Evaluable(lf.Component):
|
|
546
548
|
)
|
547
549
|
s.write(pg.format(m.result))
|
548
550
|
s.write('</div>')
|
551
|
+
if 'usage' in m.metadata:
|
552
|
+
s.write(
|
553
|
+
'<div style="background-color: #EEEEEE; color: black; '
|
554
|
+
'white-space: pre-wrap; padding: 10px; border: 0px solid; '
|
555
|
+
'margin: 10px">'
|
556
|
+
f'prompt: {m.usage.prompt_tokens} tokens, '
|
557
|
+
f'response: {m.usage.completion_tokens} tokens, '
|
558
|
+
f'total: {m.usage.total_tokens} tokens'
|
559
|
+
'</div>'
|
560
|
+
)
|
549
561
|
s.write('</div>')
|
550
562
|
|
551
563
|
@classmethod
|
@@ -810,6 +822,30 @@ class Evaluation(Evaluable):
|
|
810
822
|
return 0.0
|
811
823
|
return self.num_failures / self.num_completed
|
812
824
|
|
825
|
+
@property
|
826
|
+
def has_usage(self) -> bool:
|
827
|
+
"""Returns True if token usage is enabled."""
|
828
|
+
return self._num_usages > 0
|
829
|
+
|
830
|
+
@property
|
831
|
+
def average_prompt_tokens(self) -> int:
|
832
|
+
"""Returns the average prompt tokens."""
|
833
|
+
if not self.has_usage:
|
834
|
+
return 0
|
835
|
+
return self._total_prompt_tokens // self._num_usages
|
836
|
+
|
837
|
+
@property
|
838
|
+
def average_completion_tokens(self) -> int:
|
839
|
+
"""Returns the average completion tokens."""
|
840
|
+
if not self.has_usage:
|
841
|
+
return 0
|
842
|
+
return self._total_completion_tokens // self._num_usages
|
843
|
+
|
844
|
+
@property
|
845
|
+
def average_total_tokens(self) -> int:
|
846
|
+
"""Returns the average total tokens."""
|
847
|
+
return self.average_prompt_tokens + self.average_completion_tokens
|
848
|
+
|
813
849
|
@functools.cached_property
|
814
850
|
def schema(self) -> lf_structured.Schema | None:
|
815
851
|
"""Schema."""
|
@@ -856,7 +892,11 @@ class Evaluation(Evaluable):
|
|
856
892
|
'Encountered: {annotation!r}.'
|
857
893
|
)
|
858
894
|
self._maybe_adjust_schema_for_completion(annotation)
|
859
|
-
|
895
|
+
schema = lf_structured.Schema.from_value(annotation)
|
896
|
+
# NOTE(daiyip): add references to the dependent classes of the returned type
|
897
|
+
# to prevent unused subclasses get garbage collected by Python.
|
898
|
+
setattr(schema, '__dependencies__', schema.class_dependencies())
|
899
|
+
return schema
|
860
900
|
|
861
901
|
def _maybe_adjust_schema_for_completion(self, cls):
|
862
902
|
if (self.completion_prompt_field is None
|
@@ -933,6 +973,10 @@ class Evaluation(Evaluable):
|
|
933
973
|
self._failures = []
|
934
974
|
self._num_completed = 0
|
935
975
|
|
976
|
+
self._total_prompt_tokens = 0
|
977
|
+
self._total_completion_tokens = 0
|
978
|
+
self._num_usages = 0
|
979
|
+
|
936
980
|
@property
|
937
981
|
def failures_link(self) -> str | None:
|
938
982
|
"""Returns the link to the failures page."""
|
@@ -952,7 +996,7 @@ class Evaluation(Evaluable):
|
|
952
996
|
example = example or self.examples[0]
|
953
997
|
|
954
998
|
# We make a copy to avoid pollute the state of current object.
|
955
|
-
copy = self.clone()
|
999
|
+
copy: Evaluation = self.clone()
|
956
1000
|
copy.__dict__['examples'] = [example]
|
957
1001
|
|
958
1002
|
# We set the symbolic parent of the cloned to access contextual information
|
@@ -982,9 +1026,9 @@ class Evaluation(Evaluable):
|
|
982
1026
|
color='blue',
|
983
1027
|
)
|
984
1028
|
|
985
|
-
|
986
|
-
copy.audit(example, output, output_message)
|
1029
|
+
copy.audit(example, output_message, None, dryrun=True)
|
987
1030
|
result = copy.summarize()
|
1031
|
+
|
988
1032
|
if verbose:
|
989
1033
|
lf.console.write('')
|
990
1034
|
lf.console.write(
|
@@ -1031,11 +1075,12 @@ class Evaluation(Evaluable):
|
|
1031
1075
|
status_fn=self._status,
|
1032
1076
|
):
|
1033
1077
|
if error is not None:
|
1034
|
-
|
1035
|
-
|
1036
|
-
|
1037
|
-
|
1038
|
-
|
1078
|
+
message = (
|
1079
|
+
error.lm_response
|
1080
|
+
if isinstance(error, lf_structured.MappingError)
|
1081
|
+
else None
|
1082
|
+
)
|
1083
|
+
self.audit(example, message, error)
|
1039
1084
|
finally:
|
1040
1085
|
# Save cache upon completion or interruption.
|
1041
1086
|
if self.dir and self.cache:
|
@@ -1138,6 +1183,19 @@ class Evaluation(Evaluable):
|
|
1138
1183
|
)
|
1139
1184
|
else:
|
1140
1185
|
cache_stats = dict(use_cache=False)
|
1186
|
+
|
1187
|
+
if self.has_usage:
|
1188
|
+
usage = pg.Dict(
|
1189
|
+
total_prompt_tokens=self._total_prompt_tokens,
|
1190
|
+
total_completion_tokens=self._total_completion_tokens,
|
1191
|
+
num_usages=self._num_usages,
|
1192
|
+
average_prompt_tokens=self.average_prompt_tokens,
|
1193
|
+
average_completion_tokens=self.average_completion_tokens,
|
1194
|
+
average_total_tokens=self.average_total_tokens,
|
1195
|
+
)
|
1196
|
+
else:
|
1197
|
+
usage = None
|
1198
|
+
|
1141
1199
|
result = pg.Dict(
|
1142
1200
|
experiment_setup=pg.Dict(
|
1143
1201
|
id=self.id,
|
@@ -1153,6 +1211,7 @@ class Evaluation(Evaluable):
|
|
1153
1211
|
failures=self.num_failures,
|
1154
1212
|
failure_rate=self.failure_rate,
|
1155
1213
|
),
|
1214
|
+
usage=usage,
|
1156
1215
|
)
|
1157
1216
|
return result
|
1158
1217
|
|
@@ -1174,9 +1233,28 @@ class Evaluation(Evaluable):
|
|
1174
1233
|
'</td></tr><tr><td>'
|
1175
1234
|
)
|
1176
1235
|
self._render_metric(s)
|
1236
|
+
|
1237
|
+
# Summarize average usage.
|
1238
|
+
if self.result.usage is not None:
|
1239
|
+
self._render_usage(s)
|
1240
|
+
|
1177
1241
|
s.write('</td></tr></table></div>')
|
1178
1242
|
return s.getvalue()
|
1179
1243
|
|
1244
|
+
def _render_usage(self, s: io.StringIO) -> None:
|
1245
|
+
"""Renders usage in HTML."""
|
1246
|
+
usage = self.result.usage
|
1247
|
+
total = usage.total_prompt_tokens + usage.total_completion_tokens
|
1248
|
+
s.write(
|
1249
|
+
' <a title="'
|
1250
|
+
f'# of usages: {usage.num_usages}
'
|
1251
|
+
f'total prompt: {usage.total_prompt_tokens}
'
|
1252
|
+
f'total response: {usage.total_completion_tokens}
'
|
1253
|
+
f'avg prompt: {usage.average_prompt_tokens}
'
|
1254
|
+
f'avg response: {usage.average_completion_tokens}'
|
1255
|
+
f'" style="color:gray">({total} tokens)</a>'
|
1256
|
+
)
|
1257
|
+
|
1180
1258
|
def _render_metric(self, s: io.StringIO) -> None:
|
1181
1259
|
"""Renders metrics in HTML."""
|
1182
1260
|
assert self.result is not None
|
@@ -1191,17 +1269,48 @@ class Evaluation(Evaluable):
|
|
1191
1269
|
)
|
1192
1270
|
)
|
1193
1271
|
|
1194
|
-
def audit(
|
1272
|
+
def audit(
|
1273
|
+
self,
|
1274
|
+
example: Any,
|
1275
|
+
message: lf.Message | None,
|
1276
|
+
error: Exception | None = None,
|
1277
|
+
dryrun: bool = False,
|
1278
|
+
) -> None:
|
1195
1279
|
"""Audits the example against the output. Subclasses should override.
|
1196
1280
|
|
1197
1281
|
Args:
|
1198
1282
|
example: The input object.
|
1199
|
-
output: The output from LM. For `lf.call`, if `schema_fn` is not provided,
|
1200
|
-
it will be the raw LM response string. Otherwise it will be the
|
1201
|
-
structured output from the LM.
|
1202
1283
|
message: The entire message returned by the LM, which could be used to
|
1203
|
-
trace the LM input, response and parsed structure.
|
1284
|
+
trace the LM input, response and parsed structure. If error is raised
|
1285
|
+
before LLM could return a response, None will be its value.
|
1286
|
+
error: The exception during processing the example.
|
1287
|
+
dryrun: Whether or not audition takes place during dryrun.
|
1204
1288
|
"""
|
1289
|
+
if error is not None:
|
1290
|
+
self._failures.append((example, str(error)))
|
1291
|
+
if isinstance(error, lf_structured.MappingError):
|
1292
|
+
message = error.lm_response
|
1293
|
+
else:
|
1294
|
+
assert message is not None
|
1295
|
+
output = message.text if self.schema is None else message.result
|
1296
|
+
self.audit_processed(example, output, message, dryrun=dryrun)
|
1297
|
+
|
1298
|
+
# Audit usage.
|
1299
|
+
if message is not None:
|
1300
|
+
self.audit_usage(message, dryrun=dryrun)
|
1301
|
+
self._num_completed += 1
|
1302
|
+
|
1303
|
+
def audit_usage(self, message: lf.Message, dryrun: bool = False) -> None:
|
1304
|
+
for m in message.trace():
|
1305
|
+
if 'usage' in m.metadata:
|
1306
|
+
self._total_prompt_tokens += m.usage.prompt_tokens
|
1307
|
+
self._total_completion_tokens += m.usage.completion_tokens
|
1308
|
+
self._num_usages += 1
|
1309
|
+
|
1310
|
+
def audit_processed(
|
1311
|
+
self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
|
1312
|
+
) -> None:
|
1313
|
+
"""Audits a successfully processed example. Subclass should override."""
|
1205
1314
|
|
1206
1315
|
def save(
|
1207
1316
|
self, definition: bool = True, result: bool = True, report: bool = True
|
@@ -1245,8 +1354,10 @@ class Evaluation(Evaluable):
|
|
1245
1354
|
'<td>Prompt</td>'
|
1246
1355
|
'<td>Schema</td>'
|
1247
1356
|
'<td>Additional Args</td>'
|
1248
|
-
'<td>Failures</td>'
|
1249
1357
|
)
|
1358
|
+
if self.result.usage is not None:
|
1359
|
+
s.write('<td>Usage</td>')
|
1360
|
+
s.write('<td>Failures</td>')
|
1250
1361
|
|
1251
1362
|
def _render_result_row(self, s: io.StringIO) -> None:
|
1252
1363
|
s.write(
|
@@ -1271,6 +1382,12 @@ class Evaluation(Evaluable):
|
|
1271
1382
|
'<td style="color:purple" '
|
1272
1383
|
f'{_html_repr(self.additional_args, compact=False)}</td>'
|
1273
1384
|
)
|
1385
|
+
# Usage.
|
1386
|
+
if self.result.usage is not None:
|
1387
|
+
s.write('<td>')
|
1388
|
+
self._render_usage(s)
|
1389
|
+
s.write('</td>')
|
1390
|
+
|
1274
1391
|
# Failures.
|
1275
1392
|
s.write(
|
1276
1393
|
'<td><span style="color:orange">%s</span>%s</td>'
|
@@ -1369,8 +1486,8 @@ class Summary(pg.Object):
|
|
1369
1486
|
Type[lf.LanguageModel],
|
1370
1487
|
tuple[lf.LanguageModel | Type[lf.LanguageModel], ...],
|
1371
1488
|
] = lf.LanguageModel,
|
1372
|
-
method: Union[str, tuple[str], None] = None,
|
1373
|
-
schema_fn: Union[pg.Functor, tuple[pg.Functor], None] = None,
|
1489
|
+
method: Union[str, tuple[str, ...], None] = None,
|
1490
|
+
schema_fn: Union[pg.Functor, tuple[pg.Functor, ...], None] = None,
|
1374
1491
|
completed: bool | None = None,
|
1375
1492
|
pivot_field: str | None = None,
|
1376
1493
|
) -> 'Summary':
|
@@ -1564,6 +1681,7 @@ class Summary(pg.Object):
|
|
1564
1681
|
for entry in self.select(task=task).evaluations:
|
1565
1682
|
results.append(
|
1566
1683
|
pg.Dict(
|
1684
|
+
id=entry.id,
|
1567
1685
|
experiment=entry,
|
1568
1686
|
dir=entry.dir,
|
1569
1687
|
metrics=entry.result.metrics if entry.result else None,
|
@@ -1789,3 +1907,43 @@ def monitor_async(
|
|
1789
1907
|
scan_interval=scan_interval,
|
1790
1908
|
refresh_when_stop=refresh_when_stop,
|
1791
1909
|
)
|
1910
|
+
|
1911
|
+
|
1912
|
+
def app_run(target: Evaluable):
|
1913
|
+
"""Runs the target evaluation as an absl app.
|
1914
|
+
|
1915
|
+
Args:
|
1916
|
+
target: An Langfun evaluable object.
|
1917
|
+
"""
|
1918
|
+
flags.DEFINE_string(
|
1919
|
+
'root_dir', None, 'Root directory for running the evaluation.'
|
1920
|
+
)
|
1921
|
+
|
1922
|
+
flags.DEFINE_bool(
|
1923
|
+
'dryrun', False, 'If True, dryrun the experiment instead of running it.'
|
1924
|
+
)
|
1925
|
+
|
1926
|
+
flags.DEFINE_bool(
|
1927
|
+
'debug', False, 'If True, output prompt and response to the console.'
|
1928
|
+
)
|
1929
|
+
|
1930
|
+
flags.DEFINE_bool(
|
1931
|
+
'rerun',
|
1932
|
+
False,
|
1933
|
+
'If True, rerun the experiment even a cached result is found.',
|
1934
|
+
)
|
1935
|
+
|
1936
|
+
FLAGS = flags.FLAGS # pylint: disable=invalid-name
|
1937
|
+
|
1938
|
+
def _main(argv):
|
1939
|
+
if len(argv) > 1:
|
1940
|
+
raise app.UsageError('Too many command-line arguments.')
|
1941
|
+
|
1942
|
+
if FLAGS.root_dir:
|
1943
|
+
target.rebind(root_dir=FLAGS.root_dir, raise_on_no_change=False)
|
1944
|
+
if FLAGS.dryrun:
|
1945
|
+
target.dryrun(debug=FLAGS.debug)
|
1946
|
+
else:
|
1947
|
+
target.run(debug=FLAGS.debug, rerun=FLAGS.rerun)
|
1948
|
+
|
1949
|
+
app.run(_main)
|
langfun/core/eval/base_test.py
CHANGED
@@ -101,7 +101,7 @@ class EvaluationTest(unittest.TestCase):
|
|
101
101
|
self.assertEqual(s.dir, os.path.join(s.root_dir, s.id))
|
102
102
|
self.assertEqual(s.hash, s.clone().hash)
|
103
103
|
# Test persistent hash.
|
104
|
-
self.assertEqual(s.hash, '
|
104
|
+
self.assertEqual(s.hash, 'ae86c703')
|
105
105
|
self.assertEqual(
|
106
106
|
s.hash, s.clone(override={'max_workers': 2, 'lm.timeout': 20}).hash
|
107
107
|
)
|
@@ -210,7 +210,7 @@ class EvaluationTest(unittest.TestCase):
|
|
210
210
|
s.result,
|
211
211
|
dict(
|
212
212
|
experiment_setup=dict(
|
213
|
-
id='Evaluation@
|
213
|
+
id='Evaluation@0fade07d',
|
214
214
|
dir=s.dir,
|
215
215
|
model='StaticSequence',
|
216
216
|
prompt_template='{{example.question}}',
|
@@ -221,6 +221,14 @@ class EvaluationTest(unittest.TestCase):
|
|
221
221
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
222
222
|
),
|
223
223
|
metrics=dict(total=2, failures=1, failure_rate=0.5),
|
224
|
+
usage=dict(
|
225
|
+
total_prompt_tokens=774,
|
226
|
+
total_completion_tokens=25,
|
227
|
+
num_usages=2,
|
228
|
+
average_prompt_tokens=387,
|
229
|
+
average_completion_tokens=12,
|
230
|
+
average_total_tokens=399,
|
231
|
+
),
|
224
232
|
),
|
225
233
|
)
|
226
234
|
self.assertTrue(
|
@@ -285,8 +293,11 @@ class EvaluationTest(unittest.TestCase):
|
|
285
293
|
s = eval_set(
|
286
294
|
'run_filter_test', pg.oneof(['call', 'query']),
|
287
295
|
schema_fn=answer_schema(), lm=lm)
|
296
|
+
result = s.run(
|
297
|
+
filter=lambda x: x.method == 'query', dryrun=True, summary=False
|
298
|
+
)
|
288
299
|
self.assertEqual(
|
289
|
-
|
300
|
+
result,
|
290
301
|
{
|
291
302
|
s.children[0].id: None,
|
292
303
|
s.children[1].id: dict(
|
@@ -302,7 +313,8 @@ class EvaluationTest(unittest.TestCase):
|
|
302
313
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
303
314
|
),
|
304
315
|
metrics=dict(total=2, failures=0, failure_rate=0.0),
|
305
|
-
|
316
|
+
usage=s.children[1].result.usage,
|
317
|
+
),
|
306
318
|
},
|
307
319
|
)
|
308
320
|
|
@@ -336,7 +348,6 @@ class EvaluationTest(unittest.TestCase):
|
|
336
348
|
|
337
349
|
summary = s.run(verbose=True)
|
338
350
|
self.assertEqual(len(summary.evaluations), 2)
|
339
|
-
|
340
351
|
self.assertEqual(
|
341
352
|
s.result,
|
342
353
|
{
|
@@ -353,6 +364,7 @@ class EvaluationTest(unittest.TestCase):
|
|
353
364
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
354
365
|
),
|
355
366
|
metrics=dict(total=2, failures=1, failure_rate=0.5),
|
367
|
+
usage=s.children[0].result.usage,
|
356
368
|
),
|
357
369
|
s.children[1].id: dict(
|
358
370
|
experiment_setup=dict(
|
@@ -367,6 +379,7 @@ class EvaluationTest(unittest.TestCase):
|
|
367
379
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
368
380
|
),
|
369
381
|
metrics=dict(total=2, failures=1, failure_rate=0.5),
|
382
|
+
usage=s.children[1].result.usage,
|
370
383
|
),
|
371
384
|
},
|
372
385
|
)
|
@@ -459,7 +472,7 @@ class SuiteTest(unittest.TestCase):
|
|
459
472
|
lm=lm
|
460
473
|
)
|
461
474
|
# Test for persistent hash.
|
462
|
-
self.assertEqual(s.hash, '
|
475
|
+
self.assertEqual(s.hash, '26e6cc25')
|
463
476
|
s.run()
|
464
477
|
expected = {
|
465
478
|
s.children[0].id: dict(
|
@@ -475,6 +488,7 @@ class SuiteTest(unittest.TestCase):
|
|
475
488
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
476
489
|
),
|
477
490
|
metrics=dict(total=2, failures=1, failure_rate=0.5),
|
491
|
+
usage=s.children[0].result.usage,
|
478
492
|
),
|
479
493
|
s.children[1].id: {
|
480
494
|
s.children[1]
|
@@ -492,6 +506,7 @@ class SuiteTest(unittest.TestCase):
|
|
492
506
|
use_cache=True, num_queries=4, num_hits=1, num_updates=3
|
493
507
|
),
|
494
508
|
metrics=dict(total=2, failures=2, failure_rate=1.0),
|
509
|
+
usage=s.children[1].children[0].result.usage,
|
495
510
|
),
|
496
511
|
s.children[1]
|
497
512
|
.children[2]
|
@@ -511,6 +526,7 @@ class SuiteTest(unittest.TestCase):
|
|
511
526
|
num_updates=2,
|
512
527
|
),
|
513
528
|
metrics=dict(total=2, failures=1, failure_rate=0.5),
|
529
|
+
usage=s.children[1].children[2].result.usage,
|
514
530
|
),
|
515
531
|
},
|
516
532
|
}
|
@@ -682,5 +698,17 @@ class SummaryTest(unittest.TestCase):
|
|
682
698
|
self.assertTrue(pg.io.path_exists(summary_file))
|
683
699
|
|
684
700
|
|
701
|
+
class AppRunTest(unittest.TestCase):
|
702
|
+
|
703
|
+
def test_app_run(self):
|
704
|
+
lm = fake.StaticSequence(['two', 'Solution(final_answer=2)'])
|
705
|
+
try:
|
706
|
+
base.app_run(
|
707
|
+
eval_set('app_run_test', 'query', schema_fn=answer_schema(), lm=lm)
|
708
|
+
)
|
709
|
+
except SystemExit:
|
710
|
+
pass
|
711
|
+
|
712
|
+
|
685
713
|
if __name__ == '__main__':
|
686
714
|
unittest.main()
|
langfun/core/eval/matching.py
CHANGED
@@ -86,9 +86,26 @@ class Matching(base.Evaluation):
|
|
86
86
|
self._matches = []
|
87
87
|
self._mismatches = []
|
88
88
|
|
89
|
-
def
|
89
|
+
def audit_processed(
|
90
|
+
self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
|
91
|
+
) -> None:
|
90
92
|
groundtruth = self.groundtruth(example)
|
91
93
|
answer = self.answer(output, example)
|
94
|
+
|
95
|
+
if dryrun:
|
96
|
+
lf.console.write('')
|
97
|
+
lf.console.write(
|
98
|
+
str(groundtruth),
|
99
|
+
title='GROUDTRUTH',
|
100
|
+
color='green',
|
101
|
+
)
|
102
|
+
lf.console.write('')
|
103
|
+
lf.console.write(
|
104
|
+
str(answer),
|
105
|
+
title='ANSWER',
|
106
|
+
color='blue',
|
107
|
+
)
|
108
|
+
|
92
109
|
if self.match(answer, groundtruth):
|
93
110
|
self._matches.append((example, output, message))
|
94
111
|
else:
|
@@ -103,7 +103,7 @@ class MatchingTest(unittest.TestCase):
|
|
103
103
|
s.result,
|
104
104
|
dict(
|
105
105
|
experiment_setup=dict(
|
106
|
-
id='MyTask@
|
106
|
+
id='MyTask@739a174b',
|
107
107
|
dir=s.dir,
|
108
108
|
model='StaticSequence',
|
109
109
|
prompt_template='{{example.question}}',
|
@@ -125,6 +125,7 @@ class MatchingTest(unittest.TestCase):
|
|
125
125
|
num_mismatches=1,
|
126
126
|
mismatch_rate=0.25,
|
127
127
|
),
|
128
|
+
usage=s.result.usage,
|
128
129
|
),
|
129
130
|
)
|
130
131
|
self.assertTrue(
|
langfun/core/eval/scoring.py
CHANGED
@@ -61,8 +61,18 @@ class Scoring(base.Evaluation):
|
|
61
61
|
super()._reset()
|
62
62
|
self._scored = []
|
63
63
|
|
64
|
-
def
|
64
|
+
def audit_processed(
|
65
|
+
self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
|
66
|
+
) -> None:
|
65
67
|
score = self.score(example, output)
|
68
|
+
|
69
|
+
if dryrun:
|
70
|
+
lf.console.write('')
|
71
|
+
lf.console.write(
|
72
|
+
str(score),
|
73
|
+
title='SCORE',
|
74
|
+
color='blue',
|
75
|
+
)
|
66
76
|
self._scored.append((example, output, score, message))
|
67
77
|
|
68
78
|
@abc.abstractmethod
|
@@ -81,7 +81,7 @@ class ScoringTest(unittest.TestCase):
|
|
81
81
|
s.result,
|
82
82
|
dict(
|
83
83
|
experiment_setup=dict(
|
84
|
-
id='ConstraintFollowing@
|
84
|
+
id='ConstraintFollowing@5c88a5eb',
|
85
85
|
dir=s.dir,
|
86
86
|
model='StaticSequence',
|
87
87
|
prompt_template='{{example}}',
|
@@ -102,6 +102,7 @@ class ScoringTest(unittest.TestCase):
|
|
102
102
|
score_rate=1.0,
|
103
103
|
avg_score=0.5,
|
104
104
|
),
|
105
|
+
usage=s.result.usage,
|
105
106
|
),
|
106
107
|
)
|
107
108
|
self.assertTrue(
|
langfun/core/langfunc.py
CHANGED
@@ -261,7 +261,6 @@ class LangFunc(
|
|
261
261
|
if lm_input is None:
|
262
262
|
lm_input = self.render(**kwargs)
|
263
263
|
|
264
|
-
lm_input.tag(message_lib.Message.TAG_LM_INPUT)
|
265
264
|
if skip_lm:
|
266
265
|
return lm_input
|
267
266
|
|
@@ -270,10 +269,6 @@ class LangFunc(
|
|
270
269
|
# Send rendered text to LM.
|
271
270
|
lm_output = self.lm(lm_input, cache_seed=cache_seed)
|
272
271
|
|
273
|
-
# Track the input as the source of the output.
|
274
|
-
lm_output.source = lm_input
|
275
|
-
lm_output.tag(message_lib.Message.TAG_LM_RESPONSE)
|
276
|
-
|
277
272
|
# Transform the output message.
|
278
273
|
lm_output = self.transform_output(lm_output)
|
279
274
|
lm_output.tag(message_lib.Message.TAG_LM_OUTPUT)
|
langfun/core/language_model.py
CHANGED
@@ -346,9 +346,42 @@ class LanguageModel(component.Component):
|
|
346
346
|
|
347
347
|
with component.context(override_attrs=True, **kwargs):
|
348
348
|
if self.cache is None:
|
349
|
-
|
349
|
+
results = self._sample(prompts)
|
350
350
|
else:
|
351
|
-
|
351
|
+
results = self._sample_with_cache_lookup(prompts, cache_seed)
|
352
|
+
|
353
|
+
for prompt, result in zip(prompts, results):
|
354
|
+
|
355
|
+
# Tag LM input.
|
356
|
+
prompt.tag(message_lib.Message.TAG_LM_INPUT)
|
357
|
+
|
358
|
+
for sample in result.samples:
|
359
|
+
# Update metadata for response message.
|
360
|
+
|
361
|
+
response = sample.response
|
362
|
+
response.metadata.score = sample.score
|
363
|
+
response.metadata.logprobs = sample.logprobs
|
364
|
+
|
365
|
+
# NOTE(daiyip): Current usage is computed at per-result level,
|
366
|
+
# which is accurate when n=1. For n > 1, we average the usage across
|
367
|
+
# multiple samples.
|
368
|
+
usage = result.usage
|
369
|
+
if len(result.samples) == 1 or usage is None:
|
370
|
+
response.metadata.usage = usage
|
371
|
+
else:
|
372
|
+
n = len(result.samples)
|
373
|
+
response.metadata.usage = LMSamplingUsage(
|
374
|
+
prompt_tokens=usage.prompt_tokens // n,
|
375
|
+
completion_tokens=usage.completion_tokens // n,
|
376
|
+
total_tokens=usage.total_tokens // n,
|
377
|
+
)
|
378
|
+
|
379
|
+
# Track the prompt for corresponding response.
|
380
|
+
response.source = prompt
|
381
|
+
|
382
|
+
# Tag LM response.
|
383
|
+
response.tag(message_lib.Message.TAG_LM_RESPONSE)
|
384
|
+
return results
|
352
385
|
|
353
386
|
def _sample_with_cache_lookup(
|
354
387
|
self, prompts: list[str | message_lib.Message], cache_seed: int
|
@@ -436,13 +469,8 @@ class LanguageModel(component.Component):
|
|
436
469
|
result = self.sample(
|
437
470
|
[prompt], sampling_options=sampling_options, cache_seed=cache_seed
|
438
471
|
)[0]
|
439
|
-
response = result.samples[0].response
|
440
|
-
logprobs = result.samples[0].logprobs
|
441
|
-
response.set('score', result.samples[0].score)
|
442
|
-
response.metadata.logprobs = logprobs
|
443
|
-
response.metadata.usage = result.usage
|
444
|
-
|
445
472
|
elapse = time.time() - request_start
|
473
|
+
response = result.samples[0].response
|
446
474
|
self._debug(prompt, response, call_counter, result.usage, elapse)
|
447
475
|
return response
|
448
476
|
|
@@ -494,7 +522,9 @@ class LanguageModel(component.Component):
|
|
494
522
|
title_suffix = console.colored(f' ({usage.prompt_tokens} tokens)', 'red')
|
495
523
|
|
496
524
|
console.write(
|
497
|
-
prompt
|
525
|
+
# We use metadata 'formatted_text' for scenarios where the prompt text
|
526
|
+
# is formatted by the LM.
|
527
|
+
prompt.get('formatted_text', prompt.text),
|
498
528
|
title=f'\n[{call_counter}] PROMPT SENT TO LM{title_suffix}:',
|
499
529
|
color='green',
|
500
530
|
)
|