langfun 0.0.2.dev20240422__py3-none-any.whl → 0.0.2.dev20240425__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. langfun/__init__.py +1 -0
  2. langfun/core/component.py +6 -0
  3. langfun/core/component_test.py +1 -0
  4. langfun/core/eval/__init__.py +2 -0
  5. langfun/core/eval/base.py +175 -17
  6. langfun/core/eval/base_test.py +34 -6
  7. langfun/core/eval/matching.py +18 -1
  8. langfun/core/eval/matching_test.py +2 -1
  9. langfun/core/eval/scoring.py +11 -1
  10. langfun/core/eval/scoring_test.py +2 -1
  11. langfun/core/language_model.py +14 -0
  12. langfun/core/language_model_test.py +32 -0
  13. langfun/core/llms/anthropic.py +36 -22
  14. langfun/core/llms/anthropic_test.py +7 -7
  15. langfun/core/llms/groq.py +27 -18
  16. langfun/core/llms/groq_test.py +5 -5
  17. langfun/core/llms/openai.py +55 -50
  18. langfun/core/llms/openai_test.py +3 -3
  19. langfun/core/structured/__init__.py +1 -0
  20. langfun/core/structured/completion_test.py +1 -2
  21. langfun/core/structured/mapping.py +38 -1
  22. langfun/core/structured/mapping_test.py +17 -0
  23. langfun/core/structured/parsing_test.py +2 -4
  24. langfun/core/structured/prompting_test.py +2 -4
  25. langfun/core/structured/schema_generation_test.py +2 -2
  26. langfun/core/template.py +26 -8
  27. langfun/core/template_test.py +9 -0
  28. {langfun-0.0.2.dev20240422.dist-info → langfun-0.0.2.dev20240425.dist-info}/METADATA +3 -2
  29. {langfun-0.0.2.dev20240422.dist-info → langfun-0.0.2.dev20240425.dist-info}/RECORD +32 -32
  30. {langfun-0.0.2.dev20240422.dist-info → langfun-0.0.2.dev20240425.dist-info}/LICENSE +0 -0
  31. {langfun-0.0.2.dev20240422.dist-info → langfun-0.0.2.dev20240425.dist-info}/WHEEL +0 -0
  32. {langfun-0.0.2.dev20240422.dist-info → langfun-0.0.2.dev20240425.dist-info}/top_level.txt +0 -0
langfun/__init__.py CHANGED
@@ -55,6 +55,7 @@ Video = modalities.Video
55
55
  PDF = modalities.PDF
56
56
 
57
57
  # Error types.
58
+ MappingError = structured.MappingError
58
59
  SchemaError = structured.SchemaError
59
60
  JsonError = structured.JsonError
60
61
  CodeError = coding.CodeError
langfun/core/component.py CHANGED
@@ -210,6 +210,12 @@ def get_contextual_override(var_name: str) -> ContextualOverride | None:
210
210
  return _get_scoped_value(_global_tls, _CONTEXT_OVERRIDES, var_name)
211
211
 
212
212
 
213
+ def all_contextual_values() -> dict[str, Any]:
214
+ """Returns all contextual values provided from `lf.context` in scope."""
215
+ overrides = getattr(_global_tls, _CONTEXT_OVERRIDES, {})
216
+ return {k: v.value for k, v in overrides.items()}
217
+
218
+
213
219
  @contextlib.contextmanager
214
220
  def _contextual_scope(
215
221
  tls: threading.local, tls_key, **variables
@@ -84,6 +84,7 @@ class ComponentContextTest(unittest.TestCase):
84
84
  lf.get_contextual_override('y'),
85
85
  lf.ContextualOverride(3, cascade=False, override_attrs=False),
86
86
  )
87
+ self.assertEqual(lf.all_contextual_values(), dict(x=3, y=3, z=3))
87
88
 
88
89
  # Member attributes take precedence over `lf.context`.
89
90
  self.assertEqual(a1.x, 1)
@@ -16,6 +16,8 @@
16
16
  # pylint: disable=g-importing-member
17
17
  # pylint: disable=g-bad-import-order
18
18
 
19
+ from langfun.core.eval.base import app_run
20
+
19
21
  from langfun.core.eval.base import Evaluable
20
22
  from langfun.core.eval.base import Evaluation
21
23
  from langfun.core.eval.base import Suite
langfun/core/eval/base.py CHANGED
@@ -26,6 +26,8 @@ import threading
26
26
  import time
27
27
  from typing import Annotated, Any, Callable, Iterator, Literal, Optional, Sequence, Type, Union
28
28
 
29
+ from absl import app
30
+ from absl import flags
29
31
  import langfun.core as lf
30
32
  import langfun.core.coding as lf_coding
31
33
  from langfun.core.llms.cache import in_memory
@@ -546,6 +548,16 @@ class Evaluable(lf.Component):
546
548
  )
547
549
  s.write(pg.format(m.result))
548
550
  s.write('</div>')
551
+ if 'usage' in m.metadata:
552
+ s.write(
553
+ '<div style="background-color: #EEEEEE; color: black; '
554
+ 'white-space: pre-wrap; padding: 10px; border: 0px solid; '
555
+ 'margin: 10px">'
556
+ f'prompt: {m.usage.prompt_tokens} tokens, '
557
+ f'response: {m.usage.completion_tokens} tokens, '
558
+ f'total: {m.usage.total_tokens} tokens'
559
+ '</div>'
560
+ )
549
561
  s.write('</div>')
550
562
 
551
563
  @classmethod
@@ -810,6 +822,30 @@ class Evaluation(Evaluable):
810
822
  return 0.0
811
823
  return self.num_failures / self.num_completed
812
824
 
825
+ @property
826
+ def has_usage(self) -> bool:
827
+ """Returns True if token usage is enabled."""
828
+ return self._num_usages > 0
829
+
830
+ @property
831
+ def average_prompt_tokens(self) -> int:
832
+ """Returns the average prompt tokens."""
833
+ if not self.has_usage:
834
+ return 0
835
+ return self._total_prompt_tokens // self._num_usages
836
+
837
+ @property
838
+ def average_completion_tokens(self) -> int:
839
+ """Returns the average completion tokens."""
840
+ if not self.has_usage:
841
+ return 0
842
+ return self._total_completion_tokens // self._num_usages
843
+
844
+ @property
845
+ def average_total_tokens(self) -> int:
846
+ """Returns the average total tokens."""
847
+ return self.average_prompt_tokens + self.average_completion_tokens
848
+
813
849
  @functools.cached_property
814
850
  def schema(self) -> lf_structured.Schema | None:
815
851
  """Schema."""
@@ -856,7 +892,11 @@ class Evaluation(Evaluable):
856
892
  'Encountered: {annotation!r}.'
857
893
  )
858
894
  self._maybe_adjust_schema_for_completion(annotation)
859
- return lf_structured.Schema.from_value(annotation)
895
+ schema = lf_structured.Schema.from_value(annotation)
896
+ # NOTE(daiyip): add references to the dependent classes of the returned type
897
+ # to prevent unused subclasses get garbage collected by Python.
898
+ setattr(schema, '__dependencies__', schema.class_dependencies())
899
+ return schema
860
900
 
861
901
  def _maybe_adjust_schema_for_completion(self, cls):
862
902
  if (self.completion_prompt_field is None
@@ -933,6 +973,10 @@ class Evaluation(Evaluable):
933
973
  self._failures = []
934
974
  self._num_completed = 0
935
975
 
976
+ self._total_prompt_tokens = 0
977
+ self._total_completion_tokens = 0
978
+ self._num_usages = 0
979
+
936
980
  @property
937
981
  def failures_link(self) -> str | None:
938
982
  """Returns the link to the failures page."""
@@ -952,7 +996,7 @@ class Evaluation(Evaluable):
952
996
  example = example or self.examples[0]
953
997
 
954
998
  # We make a copy to avoid pollute the state of current object.
955
- copy = self.clone()
999
+ copy: Evaluation = self.clone()
956
1000
  copy.__dict__['examples'] = [example]
957
1001
 
958
1002
  # We set the symbolic parent of the cloned to access contextual information
@@ -982,9 +1026,9 @@ class Evaluation(Evaluable):
982
1026
  color='blue',
983
1027
  )
984
1028
 
985
- # Audit the result.
986
- copy.audit(example, output, output_message)
1029
+ copy.audit(example, output_message, None, dryrun=True)
987
1030
  result = copy.summarize()
1031
+
988
1032
  if verbose:
989
1033
  lf.console.write('')
990
1034
  lf.console.write(
@@ -1031,11 +1075,12 @@ class Evaluation(Evaluable):
1031
1075
  status_fn=self._status,
1032
1076
  ):
1033
1077
  if error is not None:
1034
- self._failures.append((example, str(error)))
1035
- else:
1036
- output = message.text if self.schema is None else message.result
1037
- self.audit(example, output, message)
1038
- self._num_completed += 1
1078
+ message = (
1079
+ error.lm_response
1080
+ if isinstance(error, lf_structured.MappingError)
1081
+ else None
1082
+ )
1083
+ self.audit(example, message, error)
1039
1084
  finally:
1040
1085
  # Save cache upon completion or interruption.
1041
1086
  if self.dir and self.cache:
@@ -1138,6 +1183,19 @@ class Evaluation(Evaluable):
1138
1183
  )
1139
1184
  else:
1140
1185
  cache_stats = dict(use_cache=False)
1186
+
1187
+ if self.has_usage:
1188
+ usage = pg.Dict(
1189
+ total_prompt_tokens=self._total_prompt_tokens,
1190
+ total_completion_tokens=self._total_completion_tokens,
1191
+ num_usages=self._num_usages,
1192
+ average_prompt_tokens=self.average_prompt_tokens,
1193
+ average_completion_tokens=self.average_completion_tokens,
1194
+ average_total_tokens=self.average_total_tokens,
1195
+ )
1196
+ else:
1197
+ usage = None
1198
+
1141
1199
  result = pg.Dict(
1142
1200
  experiment_setup=pg.Dict(
1143
1201
  id=self.id,
@@ -1153,6 +1211,7 @@ class Evaluation(Evaluable):
1153
1211
  failures=self.num_failures,
1154
1212
  failure_rate=self.failure_rate,
1155
1213
  ),
1214
+ usage=usage,
1156
1215
  )
1157
1216
  return result
1158
1217
 
@@ -1174,9 +1233,28 @@ class Evaluation(Evaluable):
1174
1233
  '</td></tr><tr><td>'
1175
1234
  )
1176
1235
  self._render_metric(s)
1236
+
1237
+ # Summarize average usage.
1238
+ if self.result.usage is not None:
1239
+ self._render_usage(s)
1240
+
1177
1241
  s.write('</td></tr></table></div>')
1178
1242
  return s.getvalue()
1179
1243
 
1244
+ def _render_usage(self, s: io.StringIO) -> None:
1245
+ """Renders usage in HTML."""
1246
+ usage = self.result.usage
1247
+ total = usage.total_prompt_tokens + usage.total_completion_tokens
1248
+ s.write(
1249
+ '&nbsp;<a title="'
1250
+ f'# of usages: {usage.num_usages}&#013;'
1251
+ f'total prompt: {usage.total_prompt_tokens}&#013;'
1252
+ f'total response: {usage.total_completion_tokens}&#013;'
1253
+ f'avg prompt: {usage.average_prompt_tokens}&#013;'
1254
+ f'avg response: {usage.average_completion_tokens}'
1255
+ f'" style="color:gray">({total} tokens)</a>'
1256
+ )
1257
+
1180
1258
  def _render_metric(self, s: io.StringIO) -> None:
1181
1259
  """Renders metrics in HTML."""
1182
1260
  assert self.result is not None
@@ -1191,17 +1269,48 @@ class Evaluation(Evaluable):
1191
1269
  )
1192
1270
  )
1193
1271
 
1194
- def audit(self, example: Any, output: Any, message: lf.Message) -> None:
1272
+ def audit(
1273
+ self,
1274
+ example: Any,
1275
+ message: lf.Message | None,
1276
+ error: Exception | None = None,
1277
+ dryrun: bool = False,
1278
+ ) -> None:
1195
1279
  """Audits the example against the output. Subclasses should override.
1196
1280
 
1197
1281
  Args:
1198
1282
  example: The input object.
1199
- output: The output from LM. For `lf.call`, if `schema_fn` is not provided,
1200
- it will be the raw LM response string. Otherwise it will be the
1201
- structured output from the LM.
1202
1283
  message: The entire message returned by the LM, which could be used to
1203
- trace the LM input, response and parsed structure.
1284
+ trace the LM input, response and parsed structure. If error is raised
1285
+ before LLM could return a response, None will be its value.
1286
+ error: The exception during processing the example.
1287
+ dryrun: Whether or not audition takes place during dryrun.
1204
1288
  """
1289
+ if error is not None:
1290
+ self._failures.append((example, str(error)))
1291
+ if isinstance(error, lf_structured.MappingError):
1292
+ message = error.lm_response
1293
+ else:
1294
+ assert message is not None
1295
+ output = message.text if self.schema is None else message.result
1296
+ self.audit_processed(example, output, message, dryrun=dryrun)
1297
+
1298
+ # Audit usage.
1299
+ if message is not None:
1300
+ self.audit_usage(message, dryrun=dryrun)
1301
+ self._num_completed += 1
1302
+
1303
+ def audit_usage(self, message: lf.Message, dryrun: bool = False) -> None:
1304
+ for m in message.trace():
1305
+ if 'usage' in m.metadata:
1306
+ self._total_prompt_tokens += m.usage.prompt_tokens
1307
+ self._total_completion_tokens += m.usage.completion_tokens
1308
+ self._num_usages += 1
1309
+
1310
+ def audit_processed(
1311
+ self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
1312
+ ) -> None:
1313
+ """Audits a successfully processed example. Subclass should override."""
1205
1314
 
1206
1315
  def save(
1207
1316
  self, definition: bool = True, result: bool = True, report: bool = True
@@ -1245,8 +1354,10 @@ class Evaluation(Evaluable):
1245
1354
  '<td>Prompt</td>'
1246
1355
  '<td>Schema</td>'
1247
1356
  '<td>Additional Args</td>'
1248
- '<td>Failures</td>'
1249
1357
  )
1358
+ if self.result.usage is not None:
1359
+ s.write('<td>Usage</td>')
1360
+ s.write('<td>Failures</td>')
1250
1361
 
1251
1362
  def _render_result_row(self, s: io.StringIO) -> None:
1252
1363
  s.write(
@@ -1271,6 +1382,12 @@ class Evaluation(Evaluable):
1271
1382
  '<td style="color:purple" '
1272
1383
  f'{_html_repr(self.additional_args, compact=False)}</td>'
1273
1384
  )
1385
+ # Usage.
1386
+ if self.result.usage is not None:
1387
+ s.write('<td>')
1388
+ self._render_usage(s)
1389
+ s.write('</td>')
1390
+
1274
1391
  # Failures.
1275
1392
  s.write(
1276
1393
  '<td><span style="color:orange">%s</span>%s</td>'
@@ -1369,8 +1486,8 @@ class Summary(pg.Object):
1369
1486
  Type[lf.LanguageModel],
1370
1487
  tuple[lf.LanguageModel | Type[lf.LanguageModel], ...],
1371
1488
  ] = lf.LanguageModel,
1372
- method: Union[str, tuple[str], None] = None,
1373
- schema_fn: Union[pg.Functor, tuple[pg.Functor], None] = None,
1489
+ method: Union[str, tuple[str, ...], None] = None,
1490
+ schema_fn: Union[pg.Functor, tuple[pg.Functor, ...], None] = None,
1374
1491
  completed: bool | None = None,
1375
1492
  pivot_field: str | None = None,
1376
1493
  ) -> 'Summary':
@@ -1564,6 +1681,7 @@ class Summary(pg.Object):
1564
1681
  for entry in self.select(task=task).evaluations:
1565
1682
  results.append(
1566
1683
  pg.Dict(
1684
+ id=entry.id,
1567
1685
  experiment=entry,
1568
1686
  dir=entry.dir,
1569
1687
  metrics=entry.result.metrics if entry.result else None,
@@ -1789,3 +1907,43 @@ def monitor_async(
1789
1907
  scan_interval=scan_interval,
1790
1908
  refresh_when_stop=refresh_when_stop,
1791
1909
  )
1910
+
1911
+
1912
+ def app_run(target: Evaluable):
1913
+ """Runs the target evaluation as an absl app.
1914
+
1915
+ Args:
1916
+ target: An Langfun evaluable object.
1917
+ """
1918
+ flags.DEFINE_string(
1919
+ 'root_dir', None, 'Root directory for running the evaluation.'
1920
+ )
1921
+
1922
+ flags.DEFINE_bool(
1923
+ 'dryrun', False, 'If True, dryrun the experiment instead of running it.'
1924
+ )
1925
+
1926
+ flags.DEFINE_bool(
1927
+ 'debug', False, 'If True, output prompt and response to the console.'
1928
+ )
1929
+
1930
+ flags.DEFINE_bool(
1931
+ 'rerun',
1932
+ False,
1933
+ 'If True, rerun the experiment even a cached result is found.',
1934
+ )
1935
+
1936
+ FLAGS = flags.FLAGS # pylint: disable=invalid-name
1937
+
1938
+ def _main(argv):
1939
+ if len(argv) > 1:
1940
+ raise app.UsageError('Too many command-line arguments.')
1941
+
1942
+ if FLAGS.root_dir:
1943
+ target.rebind(root_dir=FLAGS.root_dir, raise_on_no_change=False)
1944
+ if FLAGS.dryrun:
1945
+ target.dryrun(debug=FLAGS.debug)
1946
+ else:
1947
+ target.run(debug=FLAGS.debug, rerun=FLAGS.rerun)
1948
+
1949
+ app.run(_main)
@@ -101,7 +101,7 @@ class EvaluationTest(unittest.TestCase):
101
101
  self.assertEqual(s.dir, os.path.join(s.root_dir, s.id))
102
102
  self.assertEqual(s.hash, s.clone().hash)
103
103
  # Test persistent hash.
104
- self.assertEqual(s.hash, '436dc80c')
104
+ self.assertEqual(s.hash, 'ae86c703')
105
105
  self.assertEqual(
106
106
  s.hash, s.clone(override={'max_workers': 2, 'lm.timeout': 20}).hash
107
107
  )
@@ -210,7 +210,7 @@ class EvaluationTest(unittest.TestCase):
210
210
  s.result,
211
211
  dict(
212
212
  experiment_setup=dict(
213
- id='Evaluation@f1aa5126',
213
+ id='Evaluation@0fade07d',
214
214
  dir=s.dir,
215
215
  model='StaticSequence',
216
216
  prompt_template='{{example.question}}',
@@ -221,6 +221,14 @@ class EvaluationTest(unittest.TestCase):
221
221
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
222
222
  ),
223
223
  metrics=dict(total=2, failures=1, failure_rate=0.5),
224
+ usage=dict(
225
+ total_prompt_tokens=774,
226
+ total_completion_tokens=25,
227
+ num_usages=2,
228
+ average_prompt_tokens=387,
229
+ average_completion_tokens=12,
230
+ average_total_tokens=399,
231
+ ),
224
232
  ),
225
233
  )
226
234
  self.assertTrue(
@@ -285,8 +293,11 @@ class EvaluationTest(unittest.TestCase):
285
293
  s = eval_set(
286
294
  'run_filter_test', pg.oneof(['call', 'query']),
287
295
  schema_fn=answer_schema(), lm=lm)
296
+ result = s.run(
297
+ filter=lambda x: x.method == 'query', dryrun=True, summary=False
298
+ )
288
299
  self.assertEqual(
289
- s.run(filter=lambda x: x.method == 'query', dryrun=True, summary=False),
300
+ result,
290
301
  {
291
302
  s.children[0].id: None,
292
303
  s.children[1].id: dict(
@@ -302,7 +313,8 @@ class EvaluationTest(unittest.TestCase):
302
313
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
303
314
  ),
304
315
  metrics=dict(total=2, failures=0, failure_rate=0.0),
305
- )
316
+ usage=s.children[1].result.usage,
317
+ ),
306
318
  },
307
319
  )
308
320
 
@@ -336,7 +348,6 @@ class EvaluationTest(unittest.TestCase):
336
348
 
337
349
  summary = s.run(verbose=True)
338
350
  self.assertEqual(len(summary.evaluations), 2)
339
-
340
351
  self.assertEqual(
341
352
  s.result,
342
353
  {
@@ -353,6 +364,7 @@ class EvaluationTest(unittest.TestCase):
353
364
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
354
365
  ),
355
366
  metrics=dict(total=2, failures=1, failure_rate=0.5),
367
+ usage=s.children[0].result.usage,
356
368
  ),
357
369
  s.children[1].id: dict(
358
370
  experiment_setup=dict(
@@ -367,6 +379,7 @@ class EvaluationTest(unittest.TestCase):
367
379
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
368
380
  ),
369
381
  metrics=dict(total=2, failures=1, failure_rate=0.5),
382
+ usage=s.children[1].result.usage,
370
383
  ),
371
384
  },
372
385
  )
@@ -459,7 +472,7 @@ class SuiteTest(unittest.TestCase):
459
472
  lm=lm
460
473
  )
461
474
  # Test for persistent hash.
462
- self.assertEqual(s.hash, 'bbfdc7a8')
475
+ self.assertEqual(s.hash, '26e6cc25')
463
476
  s.run()
464
477
  expected = {
465
478
  s.children[0].id: dict(
@@ -475,6 +488,7 @@ class SuiteTest(unittest.TestCase):
475
488
  use_cache=True, num_queries=2, num_hits=0, num_updates=2
476
489
  ),
477
490
  metrics=dict(total=2, failures=1, failure_rate=0.5),
491
+ usage=s.children[0].result.usage,
478
492
  ),
479
493
  s.children[1].id: {
480
494
  s.children[1]
@@ -492,6 +506,7 @@ class SuiteTest(unittest.TestCase):
492
506
  use_cache=True, num_queries=4, num_hits=1, num_updates=3
493
507
  ),
494
508
  metrics=dict(total=2, failures=2, failure_rate=1.0),
509
+ usage=s.children[1].children[0].result.usage,
495
510
  ),
496
511
  s.children[1]
497
512
  .children[2]
@@ -511,6 +526,7 @@ class SuiteTest(unittest.TestCase):
511
526
  num_updates=2,
512
527
  ),
513
528
  metrics=dict(total=2, failures=1, failure_rate=0.5),
529
+ usage=s.children[1].children[2].result.usage,
514
530
  ),
515
531
  },
516
532
  }
@@ -682,5 +698,17 @@ class SummaryTest(unittest.TestCase):
682
698
  self.assertTrue(pg.io.path_exists(summary_file))
683
699
 
684
700
 
701
+ class AppRunTest(unittest.TestCase):
702
+
703
+ def test_app_run(self):
704
+ lm = fake.StaticSequence(['two', 'Solution(final_answer=2)'])
705
+ try:
706
+ base.app_run(
707
+ eval_set('app_run_test', 'query', schema_fn=answer_schema(), lm=lm)
708
+ )
709
+ except SystemExit:
710
+ pass
711
+
712
+
685
713
  if __name__ == '__main__':
686
714
  unittest.main()
@@ -86,9 +86,26 @@ class Matching(base.Evaluation):
86
86
  self._matches = []
87
87
  self._mismatches = []
88
88
 
89
- def audit(self, example: Any, output: Any, message: lf.Message) -> None:
89
+ def audit_processed(
90
+ self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
91
+ ) -> None:
90
92
  groundtruth = self.groundtruth(example)
91
93
  answer = self.answer(output, example)
94
+
95
+ if dryrun:
96
+ lf.console.write('')
97
+ lf.console.write(
98
+ str(groundtruth),
99
+ title='GROUDTRUTH',
100
+ color='green',
101
+ )
102
+ lf.console.write('')
103
+ lf.console.write(
104
+ str(answer),
105
+ title='ANSWER',
106
+ color='blue',
107
+ )
108
+
92
109
  if self.match(answer, groundtruth):
93
110
  self._matches.append((example, output, message))
94
111
  else:
@@ -103,7 +103,7 @@ class MatchingTest(unittest.TestCase):
103
103
  s.result,
104
104
  dict(
105
105
  experiment_setup=dict(
106
- id='MyTask@acd56a61',
106
+ id='MyTask@739a174b',
107
107
  dir=s.dir,
108
108
  model='StaticSequence',
109
109
  prompt_template='{{example.question}}',
@@ -125,6 +125,7 @@ class MatchingTest(unittest.TestCase):
125
125
  num_mismatches=1,
126
126
  mismatch_rate=0.25,
127
127
  ),
128
+ usage=s.result.usage,
128
129
  ),
129
130
  )
130
131
  self.assertTrue(
@@ -61,8 +61,18 @@ class Scoring(base.Evaluation):
61
61
  super()._reset()
62
62
  self._scored = []
63
63
 
64
- def audit(self, example: Any, output: Any, message: lf.Message) -> None:
64
+ def audit_processed(
65
+ self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
66
+ ) -> None:
65
67
  score = self.score(example, output)
68
+
69
+ if dryrun:
70
+ lf.console.write('')
71
+ lf.console.write(
72
+ str(score),
73
+ title='SCORE',
74
+ color='blue',
75
+ )
66
76
  self._scored.append((example, output, score, message))
67
77
 
68
78
  @abc.abstractmethod
@@ -81,7 +81,7 @@ class ScoringTest(unittest.TestCase):
81
81
  s.result,
82
82
  dict(
83
83
  experiment_setup=dict(
84
- id='ConstraintFollowing@a44d8b89',
84
+ id='ConstraintFollowing@5c88a5eb',
85
85
  dir=s.dir,
86
86
  model='StaticSequence',
87
87
  prompt_template='{{example}}',
@@ -102,6 +102,7 @@ class ScoringTest(unittest.TestCase):
102
102
  score_rate=1.0,
103
103
  avg_score=0.5,
104
104
  ),
105
+ usage=s.result.usage,
105
106
  ),
106
107
  )
107
108
  self.assertTrue(
@@ -24,6 +24,9 @@ from langfun.core import console
24
24
  from langfun.core import message as message_lib
25
25
  import pyglove as pg
26
26
 
27
+ TOKENS_PER_REQUEST = 250 # Estimated num tokens for a single request
28
+ DEFAULT_MAX_CONCURRENCY = 1 # Use this as max concurrency if no RPM or TPM data
29
+
27
30
 
28
31
  class LMSample(pg.Object):
29
32
  """Response candidate."""
@@ -604,3 +607,14 @@ class LanguageModel(component.Component):
604
607
  f'score: {r.score}',
605
608
  color='blue',
606
609
  )
610
+
611
+ def rate_to_max_concurrency(
612
+ self, requests_per_min: float = 0, tokens_per_min: float = 0
613
+ ) -> int:
614
+ """Converts a rate to a max concurrency."""
615
+ if tokens_per_min > 0:
616
+ return max(int(tokens_per_min / TOKENS_PER_REQUEST / 60), 1)
617
+ elif requests_per_min > 0:
618
+ return max(int(requests_per_min / 60), 1) # Max concurrency can't be zero
619
+ else:
620
+ return DEFAULT_MAX_CONCURRENCY # Default of 1
@@ -394,6 +394,38 @@ class LanguageModelTest(unittest.TestCase):
394
394
  with self.assertRaises(NotImplementedError):
395
395
  MockModel().score('hi', ['1', '2'])
396
396
 
397
+ def test_rate_to_max_concurrency_no_rpm_no_tpm(self) -> None:
398
+ lm = MockModel()
399
+ self.assertEqual(
400
+ lm_lib.DEFAULT_MAX_CONCURRENCY,
401
+ lm.rate_to_max_concurrency(requests_per_min=0, tokens_per_min=0),
402
+ )
403
+ self.assertEqual(
404
+ lm_lib.DEFAULT_MAX_CONCURRENCY,
405
+ lm.rate_to_max_concurrency(requests_per_min=-1, tokens_per_min=-1),
406
+ )
407
+
408
+ def test_rate_to_max_concurrency_only_rpm_specified_uses_rpm(self) -> None:
409
+ lm = MockModel()
410
+ test_rpm = 1e4
411
+ self.assertEqual(
412
+ lm.rate_to_max_concurrency(requests_per_min=test_rpm),
413
+ int(test_rpm / 60)
414
+ )
415
+
416
+ def test_rate_to_max_concurrency_tpm_specified_uses_tpm(self) -> None:
417
+ lm = MockModel()
418
+ test_tpm = 1e7
419
+ self.assertEqual(
420
+ lm.rate_to_max_concurrency(requests_per_min=1, tokens_per_min=test_tpm),
421
+ int(test_tpm / lm_lib.TOKENS_PER_REQUEST / 60),
422
+ )
423
+
424
+ def test_rate_to_max_concurrency_small_rate_returns_one(self) -> None:
425
+ lm = MockModel()
426
+ self.assertEqual(lm.rate_to_max_concurrency(requests_per_min=1), 1)
427
+ self.assertEqual(lm.rate_to_max_concurrency(tokens_per_min=1), 1)
428
+
397
429
 
398
430
  if __name__ == '__main__':
399
431
  unittest.main()