langfun 0.0.2.dev20240423__py3-none-any.whl → 0.0.2.dev20240428__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langfun/__init__.py CHANGED
@@ -55,6 +55,7 @@ Video = modalities.Video
55
55
  PDF = modalities.PDF
56
56
 
57
57
  # Error types.
58
+ MappingError = structured.MappingError
58
59
  SchemaError = structured.SchemaError
59
60
  JsonError = structured.JsonError
60
61
  CodeError = coding.CodeError
@@ -16,6 +16,8 @@
16
16
  # pylint: disable=g-importing-member
17
17
  # pylint: disable=g-bad-import-order
18
18
 
19
+ from langfun.core.eval.base import app_run
20
+
19
21
  from langfun.core.eval.base import Evaluable
20
22
  from langfun.core.eval.base import Evaluation
21
23
  from langfun.core.eval.base import Suite
langfun/core/eval/base.py CHANGED
@@ -26,6 +26,8 @@ import threading
26
26
  import time
27
27
  from typing import Annotated, Any, Callable, Iterator, Literal, Optional, Sequence, Type, Union
28
28
 
29
+ from absl import app
30
+ from absl import flags
29
31
  import langfun.core as lf
30
32
  import langfun.core.coding as lf_coding
31
33
  from langfun.core.llms.cache import in_memory
@@ -538,7 +540,7 @@ class Evaluable(lf.Component):
538
540
  f'<div style="color: {text_color}; white-space: pre-wrap;'
539
541
  'padding: 10px; border: 1px solid; margin-top: 10px">'
540
542
  )
541
- s.write(m.text)
543
+ s.write(m.get('formatted_text', m.text))
542
544
  if m.result is not None:
543
545
  s.write(
544
546
  '<div style="color: magenta; white-space: pre-wrap;'
@@ -546,6 +548,16 @@ class Evaluable(lf.Component):
546
548
  )
547
549
  s.write(pg.format(m.result))
548
550
  s.write('</div>')
551
+ if 'usage' in m.metadata:
552
+ s.write(
553
+ '<div style="background-color: #EEEEEE; color: black; '
554
+ 'white-space: pre-wrap; padding: 10px; border: 0px solid; '
555
+ 'margin: 10px">'
556
+ f'prompt: {m.usage.prompt_tokens} tokens, '
557
+ f'response: {m.usage.completion_tokens} tokens, '
558
+ f'total: {m.usage.total_tokens} tokens'
559
+ '</div>'
560
+ )
549
561
  s.write('</div>')
550
562
 
551
563
  @classmethod
@@ -810,6 +822,30 @@ class Evaluation(Evaluable):
810
822
  return 0.0
811
823
  return self.num_failures / self.num_completed
812
824
 
825
+ @property
826
+ def has_usage(self) -> bool:
827
+ """Returns True if token usage is enabled."""
828
+ return self._num_usages > 0
829
+
830
+ @property
831
+ def average_prompt_tokens(self) -> int:
832
+ """Returns the average prompt tokens."""
833
+ if not self.has_usage:
834
+ return 0
835
+ return self._total_prompt_tokens // self._num_usages
836
+
837
+ @property
838
+ def average_completion_tokens(self) -> int:
839
+ """Returns the average completion tokens."""
840
+ if not self.has_usage:
841
+ return 0
842
+ return self._total_completion_tokens // self._num_usages
843
+
844
+ @property
845
+ def average_total_tokens(self) -> int:
846
+ """Returns the average total tokens."""
847
+ return self.average_prompt_tokens + self.average_completion_tokens
848
+
813
849
  @functools.cached_property
814
850
  def schema(self) -> lf_structured.Schema | None:
815
851
  """Schema."""
@@ -856,7 +892,11 @@ class Evaluation(Evaluable):
856
892
  'Encountered: {annotation!r}.'
857
893
  )
858
894
  self._maybe_adjust_schema_for_completion(annotation)
859
- return lf_structured.Schema.from_value(annotation)
895
+ schema = lf_structured.Schema.from_value(annotation)
896
+ # NOTE(daiyip): add references to the dependent classes of the returned type
897
+ # to prevent unused subclasses get garbage collected by Python.
898
+ setattr(schema, '__dependencies__', schema.class_dependencies())
899
+ return schema
860
900
 
861
901
  def _maybe_adjust_schema_for_completion(self, cls):
862
902
  if (self.completion_prompt_field is None
@@ -933,6 +973,10 @@ class Evaluation(Evaluable):
933
973
  self._failures = []
934
974
  self._num_completed = 0
935
975
 
976
+ self._total_prompt_tokens = 0
977
+ self._total_completion_tokens = 0
978
+ self._num_usages = 0
979
+
936
980
  @property
937
981
  def failures_link(self) -> str | None:
938
982
  """Returns the link to the failures page."""
@@ -952,7 +996,7 @@ class Evaluation(Evaluable):
952
996
  example = example or self.examples[0]
953
997
 
954
998
  # We make a copy to avoid pollute the state of current object.
955
- copy = self.clone()
999
+ copy: Evaluation = self.clone()
956
1000
  copy.__dict__['examples'] = [example]
957
1001
 
958
1002
  # We set the symbolic parent of the cloned to access contextual information
@@ -982,9 +1026,9 @@ class Evaluation(Evaluable):
982
1026
  color='blue',
983
1027
  )
984
1028
 
985
- # Audit the result.
986
- copy.audit(example, output, output_message)
1029
+ copy.audit(example, output_message, None, dryrun=True)
987
1030
  result = copy.summarize()
1031
+
988
1032
  if verbose:
989
1033
  lf.console.write('')
990
1034
  lf.console.write(
@@ -1031,11 +1075,12 @@ class Evaluation(Evaluable):
1031
1075
  status_fn=self._status,
1032
1076
  ):
1033
1077
  if error is not None:
1034
- self._failures.append((example, str(error)))
1035
- else:
1036
- output = message.text if self.schema is None else message.result
1037
- self.audit(example, output, message)
1038
- self._num_completed += 1
1078
+ message = (
1079
+ error.lm_response
1080
+ if isinstance(error, lf_structured.MappingError)
1081
+ else None
1082
+ )
1083
+ self.audit(example, message, error)
1039
1084
  finally:
1040
1085
  # Save cache upon completion or interruption.
1041
1086
  if self.dir and self.cache:
@@ -1138,6 +1183,19 @@ class Evaluation(Evaluable):
1138
1183
  )
1139
1184
  else:
1140
1185
  cache_stats = dict(use_cache=False)
1186
+
1187
+ if self.has_usage:
1188
+ usage = pg.Dict(
1189
+ total_prompt_tokens=self._total_prompt_tokens,
1190
+ total_completion_tokens=self._total_completion_tokens,
1191
+ num_usages=self._num_usages,
1192
+ average_prompt_tokens=self.average_prompt_tokens,
1193
+ average_completion_tokens=self.average_completion_tokens,
1194
+ average_total_tokens=self.average_total_tokens,
1195
+ )
1196
+ else:
1197
+ usage = None
1198
+
1141
1199
  result = pg.Dict(
1142
1200
  experiment_setup=pg.Dict(
1143
1201
  id=self.id,
@@ -1153,6 +1211,7 @@ class Evaluation(Evaluable):
1153
1211
  failures=self.num_failures,
1154
1212
  failure_rate=self.failure_rate,
1155
1213
  ),
1214
+ usage=usage,
1156
1215
  )
1157
1216
  return result
1158
1217
 
@@ -1174,9 +1233,28 @@ class Evaluation(Evaluable):
1174
1233
  '</td></tr><tr><td>'
1175
1234
  )
1176
1235
  self._render_metric(s)
1236
+
1237
+ # Summarize average usage.
1238
+ if self.result.usage is not None:
1239
+ self._render_usage(s)
1240
+
1177
1241
  s.write('</td></tr></table></div>')
1178
1242
  return s.getvalue()
1179
1243
 
1244
+ def _render_usage(self, s: io.StringIO) -> None:
1245
+ """Renders usage in HTML."""
1246
+ usage = self.result.usage
1247
+ total = usage.total_prompt_tokens + usage.total_completion_tokens
1248
+ s.write(
1249
+ '&nbsp;<a title="'
1250
+ f'# of usages: {usage.num_usages}&#013;'
1251
+ f'total prompt: {usage.total_prompt_tokens}&#013;'
1252
+ f'total response: {usage.total_completion_tokens}&#013;'
1253
+ f'avg prompt: {usage.average_prompt_tokens}&#013;'
1254
+ f'avg response: {usage.average_completion_tokens}'
1255
+ f'" style="color:gray">({total} tokens)</a>'
1256
+ )
1257
+
1180
1258
  def _render_metric(self, s: io.StringIO) -> None:
1181
1259
  """Renders metrics in HTML."""
1182
1260
  assert self.result is not None
@@ -1191,17 +1269,48 @@ class Evaluation(Evaluable):
1191
1269
  )
1192
1270
  )
1193
1271
 
1194
- def audit(self, example: Any, output: Any, message: lf.Message) -> None:
1272
+ def audit(
1273
+ self,
1274
+ example: Any,
1275
+ message: lf.Message | None,
1276
+ error: Exception | None = None,
1277
+ dryrun: bool = False,
1278
+ ) -> None:
1195
1279
  """Audits the example against the output. Subclasses should override.
1196
1280
 
1197
1281
  Args:
1198
1282
  example: The input object.
1199
- output: The output from LM. For `lf.call`, if `schema_fn` is not provided,
1200
- it will be the raw LM response string. Otherwise it will be the
1201
- structured output from the LM.
1202
1283
  message: The entire message returned by the LM, which could be used to
1203
- trace the LM input, response and parsed structure.
1284
+ trace the LM input, response and parsed structure. If error is raised
1285
+ before LLM could return a response, None will be its value.
1286
+ error: The exception during processing the example.
1287
+ dryrun: Whether or not audition takes place during dryrun.
1204
1288
  """
1289
+ if error is not None:
1290
+ self._failures.append((example, str(error)))
1291
+ if isinstance(error, lf_structured.MappingError):
1292
+ message = error.lm_response
1293
+ else:
1294
+ assert message is not None
1295
+ output = message.text if self.schema is None else message.result
1296
+ self.audit_processed(example, output, message, dryrun=dryrun)
1297
+
1298
+ # Audit usage.
1299
+ if message is not None:
1300
+ self.audit_usage(message, dryrun=dryrun)
1301
+ self._num_completed += 1
1302
+
1303
+ def audit_usage(self, message: lf.Message, dryrun: bool = False) -> None:
1304
+ for m in message.trace():
1305
+ if 'usage' in m.metadata:
1306
+ self._total_prompt_tokens += m.usage.prompt_tokens
1307
+ self._total_completion_tokens += m.usage.completion_tokens
1308
+ self._num_usages += 1
1309
+
1310
+ def audit_processed(
1311
+ self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
1312
+ ) -> None:
1313
+ """Audits a successfully processed example. Subclass should override."""
1205
1314
 
1206
1315
  def save(
1207
1316
  self, definition: bool = True, result: bool = True, report: bool = True
@@ -1245,8 +1354,10 @@ class Evaluation(Evaluable):
1245
1354
  '<td>Prompt</td>'
1246
1355
  '<td>Schema</td>'
1247
1356
  '<td>Additional Args</td>'
1248
- '<td>Failures</td>'
1249
1357
  )
1358
+ if self.result.usage is not None:
1359
+ s.write('<td>Usage</td>')
1360
+ s.write('<td>Failures</td>')
1250
1361
 
1251
1362
  def _render_result_row(self, s: io.StringIO) -> None:
1252
1363
  s.write(
@@ -1271,6 +1382,12 @@ class Evaluation(Evaluable):
1271
1382
  '<td style="color:purple" '
1272
1383
  f'{_html_repr(self.additional_args, compact=False)}</td>'
1273
1384
  )
1385
+ # Usage.
1386
+ if self.result.usage is not None:
1387
+ s.write('<td>')
1388
+ self._render_usage(s)
1389
+ s.write('</td>')
1390
+
1274
1391
  # Failures.
1275
1392
  s.write(
1276
1393
  '<td><span style="color:orange">%s</span>%s</td>'
@@ -1369,8 +1486,8 @@ class Summary(pg.Object):
1369
1486
  Type[lf.LanguageModel],
1370
1487
  tuple[lf.LanguageModel | Type[lf.LanguageModel], ...],
1371
1488
  ] = lf.LanguageModel,
1372
- method: Union[str, tuple[str], None] = None,
1373
- schema_fn: Union[pg.Functor, tuple[pg.Functor], None] = None,
1489
+ method: Union[str, tuple[str, ...], None] = None,
1490
+ schema_fn: Union[pg.Functor, tuple[pg.Functor, ...], None] = None,
1374
1491
  completed: bool | None = None,
1375
1492
  pivot_field: str | None = None,
1376
1493
  ) -> 'Summary':
@@ -1564,6 +1681,7 @@ class Summary(pg.Object):
1564
1681
  for entry in self.select(task=task).evaluations:
1565
1682
  results.append(
1566
1683
  pg.Dict(
1684
+ id=entry.id,
1567
1685
  experiment=entry,
1568
1686
  dir=entry.dir,
1569
1687
  metrics=entry.result.metrics if entry.result else None,
@@ -1789,3 +1907,43 @@ def monitor_async(
1789
1907
  scan_interval=scan_interval,
1790
1908
  refresh_when_stop=refresh_when_stop,
1791
1909
  )
1910
+
1911
+
1912
+ def app_run(target: Evaluable):
1913
+ """Runs the target evaluation as an absl app.
1914
+
1915
+ Args:
1916
+ target: An Langfun evaluable object.
1917
+ """
1918
+ flags.DEFINE_string(
1919
+ 'root_dir', None, 'Root directory for running the evaluation.'
1920
+ )
1921
+
1922
+ flags.DEFINE_bool(
1923
+ 'dryrun', False, 'If True, dryrun the experiment instead of running it.'
1924
+ )
1925
+
1926
+ flags.DEFINE_bool(
1927
+ 'debug', False, 'If True, output prompt and response to the console.'
1928
+ )
1929
+
1930
+ flags.DEFINE_bool(
1931
+ 'rerun',
1932
+ False,
1933
+ 'If True, rerun the experiment even a cached result is found.',
1934
+ )
1935
+
1936
+ FLAGS = flags.FLAGS # pylint: disable=invalid-name
1937
+
1938
+ def _main(argv):
1939
+ if len(argv) > 1:
1940
+ raise app.UsageError('Too many command-line arguments.')
1941
+
1942
+ if FLAGS.root_dir:
1943
+ target.rebind(root_dir=FLAGS.root_dir, raise_on_no_change=False)
1944
+ if FLAGS.dryrun:
1945
+ target.dryrun(debug=FLAGS.debug)
1946
+ else:
1947
+ target.run(debug=FLAGS.debug, rerun=FLAGS.rerun)
1948
+
1949
+ app.run(_main)
@@ -101,7 +101,7 @@ class EvaluationTest(unittest.TestCase):
101
101
  self.assertEqual(s.dir, os.path.join(s.root_dir, s.id))
102
102
  self.assertEqual(s.hash, s.clone().hash)
103
103
  # Test persistent hash.
104
- self.assertEqual(s.hash, '436dc80c')
104
+ self.assertEqual(s.hash, 'ae86c703')
105
105
  self.assertEqual(
106
106
  s.hash, s.clone(override={'max_workers': 2, 'lm.timeout': 20}).hash
107
107
  )
@@ -210,7 +210,7 @@ class EvaluationTest(unittest.TestCase):
210
210
  s.result,
211
211
  dict(
212
212
  experiment_setup=dict(
213
- id='Evaluation@f1aa5126',
213
+ id='Evaluation@0fade07d',
214
214
  dir=s.dir,
215
215
  model='StaticSequence',
216
216
  prompt_template='{{example.question}}',
@@ -221,6 +221,14 @@ class EvaluationTest(unittest.TestCase):
221
221
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
222
222
  ),
223
223
  metrics=dict(total=2, failures=1, failure_rate=0.5),
224
+ usage=dict(
225
+ total_prompt_tokens=774,
226
+ total_completion_tokens=25,
227
+ num_usages=2,
228
+ average_prompt_tokens=387,
229
+ average_completion_tokens=12,
230
+ average_total_tokens=399,
231
+ ),
224
232
  ),
225
233
  )
226
234
  self.assertTrue(
@@ -285,8 +293,11 @@ class EvaluationTest(unittest.TestCase):
285
293
  s = eval_set(
286
294
  'run_filter_test', pg.oneof(['call', 'query']),
287
295
  schema_fn=answer_schema(), lm=lm)
296
+ result = s.run(
297
+ filter=lambda x: x.method == 'query', dryrun=True, summary=False
298
+ )
288
299
  self.assertEqual(
289
- s.run(filter=lambda x: x.method == 'query', dryrun=True, summary=False),
300
+ result,
290
301
  {
291
302
  s.children[0].id: None,
292
303
  s.children[1].id: dict(
@@ -302,7 +313,8 @@ class EvaluationTest(unittest.TestCase):
302
313
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
303
314
  ),
304
315
  metrics=dict(total=2, failures=0, failure_rate=0.0),
305
- )
316
+ usage=s.children[1].result.usage,
317
+ ),
306
318
  },
307
319
  )
308
320
 
@@ -336,7 +348,6 @@ class EvaluationTest(unittest.TestCase):
336
348
 
337
349
  summary = s.run(verbose=True)
338
350
  self.assertEqual(len(summary.evaluations), 2)
339
-
340
351
  self.assertEqual(
341
352
  s.result,
342
353
  {
@@ -353,6 +364,7 @@ class EvaluationTest(unittest.TestCase):
353
364
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
354
365
  ),
355
366
  metrics=dict(total=2, failures=1, failure_rate=0.5),
367
+ usage=s.children[0].result.usage,
356
368
  ),
357
369
  s.children[1].id: dict(
358
370
  experiment_setup=dict(
@@ -367,6 +379,7 @@ class EvaluationTest(unittest.TestCase):
367
379
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
368
380
  ),
369
381
  metrics=dict(total=2, failures=1, failure_rate=0.5),
382
+ usage=s.children[1].result.usage,
370
383
  ),
371
384
  },
372
385
  )
@@ -459,7 +472,7 @@ class SuiteTest(unittest.TestCase):
459
472
  lm=lm
460
473
  )
461
474
  # Test for persistent hash.
462
- self.assertEqual(s.hash, 'bbfdc7a8')
475
+ self.assertEqual(s.hash, '26e6cc25')
463
476
  s.run()
464
477
  expected = {
465
478
  s.children[0].id: dict(
@@ -475,6 +488,7 @@ class SuiteTest(unittest.TestCase):
475
488
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
476
489
  ),
477
490
  metrics=dict(total=2, failures=1, failure_rate=0.5),
491
+ usage=s.children[0].result.usage,
478
492
  ),
479
493
  s.children[1].id: {
480
494
  s.children[1]
@@ -492,6 +506,7 @@ class SuiteTest(unittest.TestCase):
492
506
  use_cache=True, num_queries=4, num_hits=1, num_updates=3
493
507
  ),
494
508
  metrics=dict(total=2, failures=2, failure_rate=1.0),
509
+ usage=s.children[1].children[0].result.usage,
495
510
  ),
496
511
  s.children[1]
497
512
  .children[2]
@@ -511,6 +526,7 @@ class SuiteTest(unittest.TestCase):
511
526
  num_updates=2,
512
527
  ),
513
528
  metrics=dict(total=2, failures=1, failure_rate=0.5),
529
+ usage=s.children[1].children[2].result.usage,
514
530
  ),
515
531
  },
516
532
  }
@@ -682,5 +698,17 @@ class SummaryTest(unittest.TestCase):
682
698
  self.assertTrue(pg.io.path_exists(summary_file))
683
699
 
684
700
 
701
+ class AppRunTest(unittest.TestCase):
702
+
703
+ def test_app_run(self):
704
+ lm = fake.StaticSequence(['two', 'Solution(final_answer=2)'])
705
+ try:
706
+ base.app_run(
707
+ eval_set('app_run_test', 'query', schema_fn=answer_schema(), lm=lm)
708
+ )
709
+ except SystemExit:
710
+ pass
711
+
712
+
685
713
  if __name__ == '__main__':
686
714
  unittest.main()
@@ -86,9 +86,26 @@ class Matching(base.Evaluation):
86
86
  self._matches = []
87
87
  self._mismatches = []
88
88
 
89
- def audit(self, example: Any, output: Any, message: lf.Message) -> None:
89
+ def audit_processed(
90
+ self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
91
+ ) -> None:
90
92
  groundtruth = self.groundtruth(example)
91
93
  answer = self.answer(output, example)
94
+
95
+ if dryrun:
96
+ lf.console.write('')
97
+ lf.console.write(
98
+ str(groundtruth),
99
+ title='GROUDTRUTH',
100
+ color='green',
101
+ )
102
+ lf.console.write('')
103
+ lf.console.write(
104
+ str(answer),
105
+ title='ANSWER',
106
+ color='blue',
107
+ )
108
+
92
109
  if self.match(answer, groundtruth):
93
110
  self._matches.append((example, output, message))
94
111
  else:
@@ -103,7 +103,7 @@ class MatchingTest(unittest.TestCase):
103
103
  s.result,
104
104
  dict(
105
105
  experiment_setup=dict(
106
- id='MyTask@acd56a61',
106
+ id='MyTask@739a174b',
107
107
  dir=s.dir,
108
108
  model='StaticSequence',
109
109
  prompt_template='{{example.question}}',
@@ -125,6 +125,7 @@ class MatchingTest(unittest.TestCase):
125
125
  num_mismatches=1,
126
126
  mismatch_rate=0.25,
127
127
  ),
128
+ usage=s.result.usage,
128
129
  ),
129
130
  )
130
131
  self.assertTrue(
@@ -61,8 +61,18 @@ class Scoring(base.Evaluation):
61
61
  super()._reset()
62
62
  self._scored = []
63
63
 
64
- def audit(self, example: Any, output: Any, message: lf.Message) -> None:
64
+ def audit_processed(
65
+ self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
66
+ ) -> None:
65
67
  score = self.score(example, output)
68
+
69
+ if dryrun:
70
+ lf.console.write('')
71
+ lf.console.write(
72
+ str(score),
73
+ title='SCORE',
74
+ color='blue',
75
+ )
66
76
  self._scored.append((example, output, score, message))
67
77
 
68
78
  @abc.abstractmethod
@@ -81,7 +81,7 @@ class ScoringTest(unittest.TestCase):
81
81
  s.result,
82
82
  dict(
83
83
  experiment_setup=dict(
84
- id='ConstraintFollowing@a44d8b89',
84
+ id='ConstraintFollowing@5c88a5eb',
85
85
  dir=s.dir,
86
86
  model='StaticSequence',
87
87
  prompt_template='{{example}}',
@@ -102,6 +102,7 @@ class ScoringTest(unittest.TestCase):
102
102
  score_rate=1.0,
103
103
  avg_score=0.5,
104
104
  ),
105
+ usage=s.result.usage,
105
106
  ),
106
107
  )
107
108
  self.assertTrue(
langfun/core/langfunc.py CHANGED
@@ -261,7 +261,6 @@ class LangFunc(
261
261
  if lm_input is None:
262
262
  lm_input = self.render(**kwargs)
263
263
 
264
- lm_input.tag(message_lib.Message.TAG_LM_INPUT)
265
264
  if skip_lm:
266
265
  return lm_input
267
266
 
@@ -270,10 +269,6 @@ class LangFunc(
270
269
  # Send rendered text to LM.
271
270
  lm_output = self.lm(lm_input, cache_seed=cache_seed)
272
271
 
273
- # Track the input as the source of the output.
274
- lm_output.source = lm_input
275
- lm_output.tag(message_lib.Message.TAG_LM_RESPONSE)
276
-
277
272
  # Transform the output message.
278
273
  lm_output = self.transform_output(lm_output)
279
274
  lm_output.tag(message_lib.Message.TAG_LM_OUTPUT)
@@ -346,9 +346,42 @@ class LanguageModel(component.Component):
346
346
 
347
347
  with component.context(override_attrs=True, **kwargs):
348
348
  if self.cache is None:
349
- return self._sample(prompts)
349
+ results = self._sample(prompts)
350
350
  else:
351
- return self._sample_with_cache_lookup(prompts, cache_seed)
351
+ results = self._sample_with_cache_lookup(prompts, cache_seed)
352
+
353
+ for prompt, result in zip(prompts, results):
354
+
355
+ # Tag LM input.
356
+ prompt.tag(message_lib.Message.TAG_LM_INPUT)
357
+
358
+ for sample in result.samples:
359
+ # Update metadata for response message.
360
+
361
+ response = sample.response
362
+ response.metadata.score = sample.score
363
+ response.metadata.logprobs = sample.logprobs
364
+
365
+ # NOTE(daiyip): Current usage is computed at per-result level,
366
+ # which is accurate when n=1. For n > 1, we average the usage across
367
+ # multiple samples.
368
+ usage = result.usage
369
+ if len(result.samples) == 1 or usage is None:
370
+ response.metadata.usage = usage
371
+ else:
372
+ n = len(result.samples)
373
+ response.metadata.usage = LMSamplingUsage(
374
+ prompt_tokens=usage.prompt_tokens // n,
375
+ completion_tokens=usage.completion_tokens // n,
376
+ total_tokens=usage.total_tokens // n,
377
+ )
378
+
379
+ # Track the prompt for corresponding response.
380
+ response.source = prompt
381
+
382
+ # Tag LM response.
383
+ response.tag(message_lib.Message.TAG_LM_RESPONSE)
384
+ return results
352
385
 
353
386
  def _sample_with_cache_lookup(
354
387
  self, prompts: list[str | message_lib.Message], cache_seed: int
@@ -436,13 +469,8 @@ class LanguageModel(component.Component):
436
469
  result = self.sample(
437
470
  [prompt], sampling_options=sampling_options, cache_seed=cache_seed
438
471
  )[0]
439
- response = result.samples[0].response
440
- logprobs = result.samples[0].logprobs
441
- response.set('score', result.samples[0].score)
442
- response.metadata.logprobs = logprobs
443
- response.metadata.usage = result.usage
444
-
445
472
  elapse = time.time() - request_start
473
+ response = result.samples[0].response
446
474
  self._debug(prompt, response, call_counter, result.usage, elapse)
447
475
  return response
448
476
 
@@ -494,7 +522,9 @@ class LanguageModel(component.Component):
494
522
  title_suffix = console.colored(f' ({usage.prompt_tokens} tokens)', 'red')
495
523
 
496
524
  console.write(
497
- prompt,
525
+ # We use metadata 'formatted_text' for scenarios where the prompt text
526
+ # is formatted by the LM.
527
+ prompt.get('formatted_text', prompt.text),
498
528
  title=f'\n[{call_counter}] PROMPT SENT TO LM{title_suffix}:',
499
529
  color='green',
500
530
  )