langfun 0.0.2.dev20240330__py3-none-any.whl → 0.0.2.dev20240511__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/__init__.py +7 -0
- langfun/core/__init__.py +1 -0
- langfun/core/coding/python/correction.py +0 -7
- langfun/core/component.py +6 -0
- langfun/core/component_test.py +1 -0
- langfun/core/eval/__init__.py +15 -0
- langfun/core/eval/base.py +665 -95
- langfun/core/eval/base_test.py +224 -53
- langfun/core/eval/matching.py +48 -30
- langfun/core/eval/matching_test.py +25 -3
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +19 -10
- langfun/core/eval/scoring_test.py +21 -3
- langfun/core/langfunc.py +1 -22
- langfun/core/langfunc_test.py +10 -4
- langfun/core/language_model.py +130 -24
- langfun/core/language_model_test.py +249 -26
- langfun/core/llms/__init__.py +27 -2
- langfun/core/llms/anthropic.py +263 -0
- langfun/core/llms/anthropic_test.py +167 -0
- langfun/core/llms/cache/in_memory_test.py +37 -28
- langfun/core/llms/fake.py +34 -25
- langfun/core/llms/fake_test.py +122 -11
- langfun/core/llms/google_genai.py +8 -0
- langfun/core/llms/google_genai_test.py +8 -3
- langfun/core/llms/groq.py +260 -0
- langfun/core/llms/groq_test.py +170 -0
- langfun/core/llms/llama_cpp.py +3 -1
- langfun/core/llms/openai.py +100 -81
- langfun/core/llms/openai_test.py +287 -60
- langfun/core/llms/vertexai.py +291 -0
- langfun/core/llms/vertexai_test.py +233 -0
- langfun/core/modalities/image.py +1 -3
- langfun/core/modalities/mime.py +6 -0
- langfun/core/modalities/video.py +6 -5
- langfun/core/structured/__init__.py +5 -0
- langfun/core/structured/completion_test.py +2 -2
- langfun/core/structured/function_generation.py +245 -0
- langfun/core/structured/function_generation_test.py +329 -0
- langfun/core/structured/mapping.py +61 -3
- langfun/core/structured/mapping_test.py +17 -0
- langfun/core/structured/parsing_test.py +18 -13
- langfun/core/structured/prompting.py +61 -12
- langfun/core/structured/prompting_test.py +122 -12
- langfun/core/structured/schema.py +38 -6
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +36 -7
- langfun/core/structured/scoring.py +4 -1
- langfun/core/structured/scoring_test.py +6 -0
- langfun/core/template.py +147 -11
- langfun/core/template_test.py +75 -0
- langfun/core/templates/selfplay_test.py +6 -2
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/METADATA +3 -2
- langfun-0.0.2.dev20240511.dist-info/RECORD +112 -0
- langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/top_level.txt +0 -0
langfun/core/eval/base.py
CHANGED
@@ -18,12 +18,14 @@ import collections
|
|
18
18
|
import dataclasses
|
19
19
|
import functools
|
20
20
|
import hashlib
|
21
|
+
import html
|
21
22
|
import inspect
|
22
23
|
import io
|
23
24
|
import os
|
24
25
|
import re
|
25
26
|
import threading
|
26
27
|
import time
|
28
|
+
import types
|
27
29
|
from typing import Annotated, Any, Callable, Iterator, Literal, Optional, Sequence, Type, Union
|
28
30
|
|
29
31
|
import langfun.core as lf
|
@@ -38,7 +40,8 @@ class Evaluable(lf.Component):
|
|
38
40
|
|
39
41
|
EXPERIMENT_JSON = 'experiment.json'
|
40
42
|
RESULT_JSON = 'result.json'
|
41
|
-
|
43
|
+
OOP_FAILURES_JSON = 'oop_failures.json'
|
44
|
+
NON_OOP_FAILURES_JSON = 'non_oop_failures.json'
|
42
45
|
INDEX_HTML = 'index.html'
|
43
46
|
SUMMARY_HTML = 'summary.html'
|
44
47
|
|
@@ -356,7 +359,7 @@ class Evaluable(lf.Component):
|
|
356
359
|
color='yellow')
|
357
360
|
|
358
361
|
for node in self.nonleaf_nodes:
|
359
|
-
node._result = {c.id: c.result for c in node.
|
362
|
+
node._result = {c.id: c.result for c in node.leaf_nodes} # pylint: disable=protected-access
|
360
363
|
if should_save:
|
361
364
|
node.save(result=False, report=False)
|
362
365
|
|
@@ -538,14 +541,24 @@ class Evaluable(lf.Component):
|
|
538
541
|
f'<div style="color: {text_color}; white-space: pre-wrap;'
|
539
542
|
'padding: 10px; border: 1px solid; margin-top: 10px">'
|
540
543
|
)
|
541
|
-
s.write(m.text)
|
544
|
+
s.write(html.escape(m.get('formatted_text', m.text)))
|
542
545
|
if m.result is not None:
|
543
546
|
s.write(
|
544
547
|
'<div style="color: magenta; white-space: pre-wrap;'
|
545
548
|
'padding: 10px; border: 1px solid; margin: 10px">'
|
546
549
|
)
|
547
|
-
s.write(pg.format(m.result))
|
550
|
+
s.write(html.escape(pg.format(m.result)))
|
548
551
|
s.write('</div>')
|
552
|
+
if 'usage' in m.metadata and m.usage is not None:
|
553
|
+
s.write(
|
554
|
+
'<div style="background-color: #EEEEEE; color: black; '
|
555
|
+
'white-space: pre-wrap; padding: 10px; border: 0px solid; '
|
556
|
+
'margin: 10px">'
|
557
|
+
f'prompt: {m.usage.prompt_tokens} tokens, '
|
558
|
+
f'response: {m.usage.completion_tokens} tokens, '
|
559
|
+
f'total: {m.usage.total_tokens} tokens'
|
560
|
+
'</div>'
|
561
|
+
)
|
549
562
|
s.write('</div>')
|
550
563
|
|
551
564
|
@classmethod
|
@@ -586,7 +599,6 @@ class _LeafNode:
|
|
586
599
|
@pg.use_init_args(['children'])
|
587
600
|
class Suite(Evaluable):
|
588
601
|
"""Evaluation suite."""
|
589
|
-
|
590
602
|
children: Annotated[list[Evaluable], 'Child evaluation sets or suites.']
|
591
603
|
|
592
604
|
# Use empty ID as suite is just a container of child evaluations.
|
@@ -741,10 +753,12 @@ class Evaluation(Evaluable):
|
|
741
753
|
|
742
754
|
# Constants.
|
743
755
|
CACHE_JSON = 'cache.json'
|
744
|
-
|
756
|
+
OOP_FAILURES_HTML = 'oop_failures.html'
|
757
|
+
NON_OOP_FAILURES_HTML = 'non_oop_failures.html'
|
745
758
|
|
746
759
|
@functools.cached_property
|
747
760
|
def hash(self) -> str:
|
761
|
+
"""Returns the semantic-based hash of the evaluation."""
|
748
762
|
if self.is_deterministic:
|
749
763
|
identity = pg.format(self._identifiers(), compact=True)
|
750
764
|
else:
|
@@ -793,6 +807,10 @@ class Evaluation(Evaluable):
|
|
793
807
|
"""Returns the complete rate."""
|
794
808
|
return self.num_completed / self.num_examples
|
795
809
|
|
810
|
+
#
|
811
|
+
# Properties on failures.
|
812
|
+
#
|
813
|
+
|
796
814
|
@property
|
797
815
|
def failures(self) -> list[tuple[Any, Exception]]:
|
798
816
|
"""Returns the failed examples and their errors."""
|
@@ -803,6 +821,15 @@ class Evaluation(Evaluable):
|
|
803
821
|
"""Returns the number of failed examples."""
|
804
822
|
return len(self.failures)
|
805
823
|
|
824
|
+
@functools.cached_property
|
825
|
+
def failure_breakdown(self) -> dict[str, int]:
|
826
|
+
"""Returns the breakdown of failures."""
|
827
|
+
breakdown = collections.defaultdict(int)
|
828
|
+
for _, error in self.failures:
|
829
|
+
breakdown[_error_key(error)] += 1
|
830
|
+
sorted_items = sorted(breakdown.items(), key=lambda x: x[1], reverse=True)
|
831
|
+
return pg.Dict({x[0]: x[1] for x in sorted_items})
|
832
|
+
|
806
833
|
@property
|
807
834
|
def failure_rate(self) -> float:
|
808
835
|
"""Returns the failure rate in range [0, 1]."""
|
@@ -810,17 +837,76 @@ class Evaluation(Evaluable):
|
|
810
837
|
return 0.0
|
811
838
|
return self.num_failures / self.num_completed
|
812
839
|
|
840
|
+
@functools.cached_property
|
841
|
+
def oop_failures(self) -> list[tuple[Any, lf_structured.MappingError]]:
|
842
|
+
"""Returns the OOP failures."""
|
843
|
+
return [item for item in self.failures
|
844
|
+
if isinstance(item[1], lf_structured.MappingError)]
|
845
|
+
|
846
|
+
@property
|
847
|
+
def num_oop_failures(self) -> int:
|
848
|
+
"""Returns the number of OOP failures."""
|
849
|
+
return len(self.oop_failures)
|
850
|
+
|
851
|
+
@property
|
852
|
+
def oop_failure_rate(self) -> float:
|
853
|
+
"""Returns the OOP failure rate in range [0, 1]."""
|
854
|
+
if self.num_completed == 0:
|
855
|
+
return 0.0
|
856
|
+
return self.num_oop_failures / self.num_completed
|
857
|
+
|
858
|
+
@functools.cached_property
|
859
|
+
def non_oop_failures(self) -> list[tuple[Any, Exception]]:
|
860
|
+
"""Returns the OOP failures."""
|
861
|
+
return [item for item in self.failures
|
862
|
+
if not isinstance(item[1], lf_structured.MappingError)]
|
863
|
+
|
864
|
+
@property
|
865
|
+
def num_non_oop_failures(self) -> int:
|
866
|
+
"""Returns the number of non-OOP failures."""
|
867
|
+
return len(self.non_oop_failures)
|
868
|
+
|
869
|
+
@property
|
870
|
+
def non_oop_failure_rate(self) -> float:
|
871
|
+
"""Returns the non-OOP failure rate in range [0, 1]."""
|
872
|
+
if self.num_completed == 0:
|
873
|
+
return 0.0
|
874
|
+
return self.num_non_oop_failures / self.num_completed
|
875
|
+
|
876
|
+
#
|
877
|
+
# Properties on usage.
|
878
|
+
#
|
879
|
+
|
880
|
+
@property
|
881
|
+
def has_usage(self) -> bool:
|
882
|
+
"""Returns True if token usage is enabled."""
|
883
|
+
return self._num_usages > 0
|
884
|
+
|
885
|
+
@property
|
886
|
+
def average_prompt_tokens(self) -> int:
|
887
|
+
"""Returns the average prompt tokens."""
|
888
|
+
if not self.has_usage:
|
889
|
+
return 0
|
890
|
+
return self._total_prompt_tokens // self._num_usages
|
891
|
+
|
892
|
+
@property
|
893
|
+
def average_completion_tokens(self) -> int:
|
894
|
+
"""Returns the average completion tokens."""
|
895
|
+
if not self.has_usage:
|
896
|
+
return 0
|
897
|
+
return self._total_completion_tokens // self._num_usages
|
898
|
+
|
899
|
+
@property
|
900
|
+
def average_total_tokens(self) -> int:
|
901
|
+
"""Returns the average total tokens."""
|
902
|
+
return self.average_prompt_tokens + self.average_completion_tokens
|
903
|
+
|
813
904
|
@functools.cached_property
|
814
905
|
def schema(self) -> lf_structured.Schema | None:
|
815
906
|
"""Schema."""
|
816
907
|
if self.schema_fn is None:
|
817
908
|
return None
|
818
909
|
|
819
|
-
kwargs = {}
|
820
|
-
# Allow schema to be a function based on current evaluation.
|
821
|
-
if 'evaluation' in self.schema_fn.__signature__.arg_names:
|
822
|
-
kwargs['evaluation'] = self
|
823
|
-
|
824
910
|
schema = self._call_schema_fn()
|
825
911
|
fewshot_examples = None
|
826
912
|
if isinstance(schema, tuple):
|
@@ -861,7 +947,11 @@ class Evaluation(Evaluable):
|
|
861
947
|
'Encountered: {annotation!r}.'
|
862
948
|
)
|
863
949
|
self._maybe_adjust_schema_for_completion(annotation)
|
864
|
-
|
950
|
+
schema = lf_structured.Schema.from_value(annotation)
|
951
|
+
# NOTE(daiyip): add references to the dependent classes of the returned type
|
952
|
+
# to prevent unused subclasses get garbage collected by Python.
|
953
|
+
setattr(schema, '__dependencies__', schema.class_dependencies())
|
954
|
+
return schema
|
865
955
|
|
866
956
|
def _maybe_adjust_schema_for_completion(self, cls):
|
867
957
|
if (self.completion_prompt_field is None
|
@@ -938,12 +1028,25 @@ class Evaluation(Evaluable):
|
|
938
1028
|
self._failures = []
|
939
1029
|
self._num_completed = 0
|
940
1030
|
|
1031
|
+
self._total_prompt_tokens = 0
|
1032
|
+
self._total_completion_tokens = 0
|
1033
|
+
self._num_usages = 0
|
1034
|
+
self.__dict__.pop('oop_failures', None)
|
1035
|
+
self.__dict__.pop('non_oop_failures', None)
|
1036
|
+
|
1037
|
+
@property
|
1038
|
+
def oop_failures_link(self) -> str | None:
|
1039
|
+
"""Returns the link to the OOP failures page."""
|
1040
|
+
if self.dir is None:
|
1041
|
+
return None
|
1042
|
+
return self.link(os.path.join(self.dir, Evaluation.OOP_FAILURES_HTML))
|
1043
|
+
|
941
1044
|
@property
|
942
|
-
def
|
943
|
-
"""Returns the link to
|
1045
|
+
def non_oop_failures_link(self) -> str | None:
|
1046
|
+
"""Returns the link to then non-OOP failures page."""
|
944
1047
|
if self.dir is None:
|
945
1048
|
return None
|
946
|
-
return self.link(os.path.join(self.dir, Evaluation.
|
1049
|
+
return self.link(os.path.join(self.dir, Evaluation.NON_OOP_FAILURES_HTML))
|
947
1050
|
|
948
1051
|
def _dryrun(
|
949
1052
|
self,
|
@@ -953,11 +1056,11 @@ class Evaluation(Evaluable):
|
|
953
1056
|
verbose: bool,
|
954
1057
|
**kwargs,
|
955
1058
|
) -> None:
|
956
|
-
# Set the example for dryrun.
|
957
|
-
example = example or self.examples[0]
|
958
|
-
|
959
1059
|
# We make a copy to avoid pollute the state of current object.
|
960
|
-
copy = self.clone()
|
1060
|
+
copy: Evaluation = self.clone()
|
1061
|
+
|
1062
|
+
# Set the example for dryrun.
|
1063
|
+
example = example or copy.examples[0]
|
961
1064
|
copy.__dict__['examples'] = [example]
|
962
1065
|
|
963
1066
|
# We set the symbolic parent of the cloned to access contextual information
|
@@ -972,24 +1075,35 @@ class Evaluation(Evaluable):
|
|
972
1075
|
color='green',
|
973
1076
|
)
|
974
1077
|
|
975
|
-
|
976
|
-
output_message = copy.process(example, **(self.additional_args or {}))
|
977
|
-
if self.schema is None:
|
978
|
-
output = output_message.text
|
979
|
-
else:
|
980
|
-
output = output_message.result
|
1078
|
+
error, output_message = None, None
|
981
1079
|
|
982
|
-
|
1080
|
+
try:
|
1081
|
+
with lf.use_settings(debug=debug):
|
1082
|
+
output_message = copy.process(example, **(self.additional_args or {}))
|
1083
|
+
if self.schema is None:
|
1084
|
+
output = output_message.text
|
1085
|
+
else:
|
1086
|
+
output = output_message.result
|
1087
|
+
|
1088
|
+
if verbose:
|
1089
|
+
lf.console.write('')
|
1090
|
+
lf.console.write(
|
1091
|
+
str(output),
|
1092
|
+
title='OUTPUT',
|
1093
|
+
color='blue',
|
1094
|
+
)
|
1095
|
+
except lf_structured.MappingError as e:
|
983
1096
|
lf.console.write('')
|
984
1097
|
lf.console.write(
|
985
|
-
str(
|
986
|
-
title='
|
987
|
-
color='
|
1098
|
+
str(e),
|
1099
|
+
title='ERROR',
|
1100
|
+
color='red',
|
988
1101
|
)
|
1102
|
+
error = e
|
1103
|
+
|
1104
|
+
copy.audit(example, output_message, error, dryrun=True)
|
1105
|
+
result = copy.finalize()
|
989
1106
|
|
990
|
-
# Audit the result.
|
991
|
-
copy.audit(example, output, output_message)
|
992
|
-
result = copy.summarize()
|
993
1107
|
if verbose:
|
994
1108
|
lf.console.write('')
|
995
1109
|
lf.console.write(
|
@@ -1012,6 +1126,9 @@ class Evaluation(Evaluable):
|
|
1012
1126
|
**kwargs,
|
1013
1127
|
) -> None:
|
1014
1128
|
# Setup examples.
|
1129
|
+
# Reset examples so it could be read from the input functor.
|
1130
|
+
self.__dict__.pop('examples', None)
|
1131
|
+
|
1015
1132
|
if end is None:
|
1016
1133
|
end = len(self.examples)
|
1017
1134
|
examples = self.examples[start:end]
|
@@ -1036,18 +1153,19 @@ class Evaluation(Evaluable):
|
|
1036
1153
|
status_fn=self._status,
|
1037
1154
|
):
|
1038
1155
|
if error is not None:
|
1039
|
-
|
1040
|
-
|
1041
|
-
|
1042
|
-
|
1043
|
-
|
1156
|
+
message = (
|
1157
|
+
error.lm_response
|
1158
|
+
if isinstance(error, lf_structured.MappingError)
|
1159
|
+
else None
|
1160
|
+
)
|
1161
|
+
self.audit(example, message, error)
|
1044
1162
|
finally:
|
1045
1163
|
# Save cache upon completion or interruption.
|
1046
1164
|
if self.dir and self.cache:
|
1047
1165
|
self.cache.save()
|
1048
1166
|
|
1049
1167
|
# Summarize result.
|
1050
|
-
self._result = self.
|
1168
|
+
self._result = self.finalize()
|
1051
1169
|
if verbose:
|
1052
1170
|
lf.console.write(
|
1053
1171
|
str(self.result),
|
@@ -1061,7 +1179,7 @@ class Evaluation(Evaluable):
|
|
1061
1179
|
|
1062
1180
|
def process(self, example: Any, **kwargs) -> lf.Message:
|
1063
1181
|
"""Process an example and returns its output."""
|
1064
|
-
prompt = self.prompt
|
1182
|
+
prompt = lf.Template.from_value(self.prompt, example=example)
|
1065
1183
|
if self.method == 'call':
|
1066
1184
|
return lf_structured.call(
|
1067
1185
|
prompt,
|
@@ -1089,7 +1207,9 @@ class Evaluation(Evaluable):
|
|
1089
1207
|
else:
|
1090
1208
|
assert self.method == 'complete', self.method
|
1091
1209
|
assert isinstance(self.schema.spec, pg.typing.Object), self.schema
|
1092
|
-
|
1210
|
+
# TODO(daiyip): Currently multi-modal inputs within the prompt for
|
1211
|
+
# completion is not supported.
|
1212
|
+
input_value = self.schema.spec.cls.partial(prompt.render().text)
|
1093
1213
|
return lf_structured.complete(
|
1094
1214
|
input_value,
|
1095
1215
|
lm=self.lm,
|
@@ -1103,13 +1223,13 @@ class Evaluation(Evaluable):
|
|
1103
1223
|
def _status(self, progress: lf.concurrent.Progress) -> dict[str, Any]:
|
1104
1224
|
return {
|
1105
1225
|
'Model': self.lm.model_id,
|
1106
|
-
'Succeeded':
|
1107
|
-
progress.success_rate
|
1226
|
+
'Succeeded': '%s (%d/%d)' % (
|
1227
|
+
self._format_rate(progress.success_rate),
|
1108
1228
|
progress.succeeded,
|
1109
1229
|
progress.completed,
|
1110
1230
|
),
|
1111
|
-
'Failed':
|
1112
|
-
progress.failure_rate
|
1231
|
+
'Failed': '%s (%d/%d)' % (
|
1232
|
+
self._format_rate(progress.failure_rate),
|
1113
1233
|
progress.failed,
|
1114
1234
|
progress.completed,
|
1115
1235
|
),
|
@@ -1119,21 +1239,20 @@ class Evaluation(Evaluable):
|
|
1119
1239
|
assert self.result is not None
|
1120
1240
|
m = self.result.metrics
|
1121
1241
|
return (
|
1122
|
-
|
1123
|
-
f' Failures=%.{self.report_precision}f%% (%d/%d)'
|
1242
|
+
'COMPLETED(%s): Successes=%s(%d/%d) Failures=%s (%d/%d)'
|
1124
1243
|
% (
|
1125
1244
|
run_status,
|
1126
|
-
(1 - m.failure_rate)
|
1245
|
+
self._format_rate(1 - m.failure_rate),
|
1127
1246
|
m.total - m.failures,
|
1128
1247
|
m.total,
|
1129
|
-
m.failure_rate
|
1248
|
+
self._format_rate(m.failure_rate),
|
1130
1249
|
m.failures,
|
1131
1250
|
m.total,
|
1132
1251
|
)
|
1133
1252
|
)
|
1134
1253
|
|
1135
|
-
def
|
1136
|
-
"""
|
1254
|
+
def finalize(self) -> pg.Dict:
|
1255
|
+
"""Finalizes the evaluation result."""
|
1137
1256
|
if self.cache:
|
1138
1257
|
cache_stats = dict(
|
1139
1258
|
use_cache=True,
|
@@ -1143,6 +1262,19 @@ class Evaluation(Evaluable):
|
|
1143
1262
|
)
|
1144
1263
|
else:
|
1145
1264
|
cache_stats = dict(use_cache=False)
|
1265
|
+
|
1266
|
+
if self.has_usage:
|
1267
|
+
usage = pg.Dict(
|
1268
|
+
total_prompt_tokens=self._total_prompt_tokens,
|
1269
|
+
total_completion_tokens=self._total_completion_tokens,
|
1270
|
+
num_usages=self._num_usages,
|
1271
|
+
average_prompt_tokens=self.average_prompt_tokens,
|
1272
|
+
average_completion_tokens=self.average_completion_tokens,
|
1273
|
+
average_total_tokens=self.average_total_tokens,
|
1274
|
+
)
|
1275
|
+
else:
|
1276
|
+
usage = None
|
1277
|
+
|
1146
1278
|
result = pg.Dict(
|
1147
1279
|
experiment_setup=pg.Dict(
|
1148
1280
|
id=self.id,
|
@@ -1157,11 +1289,18 @@ class Evaluation(Evaluable):
|
|
1157
1289
|
total=self.num_completed,
|
1158
1290
|
failures=self.num_failures,
|
1159
1291
|
failure_rate=self.failure_rate,
|
1292
|
+
oop_failures=self.num_oop_failures,
|
1293
|
+
oop_failure_rate=self.oop_failure_rate,
|
1294
|
+
non_oop_failures=self.num_non_oop_failures,
|
1295
|
+
non_oop_failure_rate=self.non_oop_failure_rate,
|
1296
|
+
failure_breakdown=self.failure_breakdown,
|
1160
1297
|
),
|
1298
|
+
usage=usage,
|
1161
1299
|
)
|
1162
1300
|
return result
|
1163
1301
|
|
1164
|
-
def
|
1302
|
+
def summary_card(self) -> str:
|
1303
|
+
"""Returns summary card in HTML."""
|
1165
1304
|
s = io.StringIO()
|
1166
1305
|
definition = _html_repr(self, compact=False, escape=True)
|
1167
1306
|
s.write('<div><table><tr><td>')
|
@@ -1176,37 +1315,141 @@ class Evaluation(Evaluable):
|
|
1176
1315
|
s.write(
|
1177
1316
|
f'<a target="_blank" title="{definition}" '
|
1178
1317
|
f'href="{self.index_link}">{self.hash}</a>'
|
1318
|
+
f' [<a href="{self.link(self.dir)}">dir</a>]'
|
1179
1319
|
'</td></tr><tr><td>'
|
1180
1320
|
)
|
1181
|
-
self.
|
1321
|
+
self._render_summary_metrics(s)
|
1322
|
+
|
1323
|
+
# Summarize average usage.
|
1324
|
+
if self.result.usage is not None:
|
1325
|
+
self._render_summary_usage(s)
|
1326
|
+
|
1182
1327
|
s.write('</td></tr></table></div>')
|
1183
1328
|
return s.getvalue()
|
1184
1329
|
|
1185
|
-
def
|
1330
|
+
def _render_summary_usage(self, s: io.StringIO) -> None:
|
1331
|
+
"""Renders usage in HTML."""
|
1332
|
+
usage = self.result.usage
|
1333
|
+
total = usage.total_prompt_tokens + usage.total_completion_tokens
|
1334
|
+
s.write(
|
1335
|
+
' <a title="'
|
1336
|
+
f'# of usages: {usage.num_usages}
'
|
1337
|
+
f'total prompt: {usage.total_prompt_tokens}
'
|
1338
|
+
f'total response: {usage.total_completion_tokens}
'
|
1339
|
+
f'avg prompt: {usage.average_prompt_tokens}
'
|
1340
|
+
f'avg response: {usage.average_completion_tokens}'
|
1341
|
+
f'" style="color:gray">({total} tokens)</a>'
|
1342
|
+
)
|
1343
|
+
|
1344
|
+
def _render_summary_metrics(self, s: io.StringIO) -> None:
|
1186
1345
|
"""Renders metrics in HTML."""
|
1187
1346
|
assert self.result is not None
|
1188
1347
|
m = self.result.metrics
|
1348
|
+
|
1349
|
+
# OOP failures.
|
1350
|
+
oop_failure_title = f'OOP failures ({m.oop_failures}/{m.total})'
|
1351
|
+
if m.oop_failures:
|
1352
|
+
oop_failure_title += '
'
|
1353
|
+
for name, count in m.failure_breakdown.items():
|
1354
|
+
if name.startswith('MappingError'):
|
1355
|
+
oop_failure_title += '
%s: %s (%d/%d)' % (
|
1356
|
+
name.removeprefix('MappingError.'),
|
1357
|
+
self._format_rate(count / m.total),
|
1358
|
+
count,
|
1359
|
+
m.total,
|
1360
|
+
)
|
1361
|
+
|
1362
|
+
extra_style = ''
|
1363
|
+
if m.oop_failure_rate > 0.1 and m.oop_failures > 3:
|
1364
|
+
extra_style = ';font-weight:bold'
|
1189
1365
|
s.write(
|
1190
|
-
'<a title="
|
1366
|
+
'<a title="%s" href="%s" style="color:magenta%s">%s</a>'
|
1191
1367
|
% (
|
1192
|
-
|
1193
|
-
|
1194
|
-
|
1195
|
-
|
1368
|
+
oop_failure_title,
|
1369
|
+
self.oop_failures_link,
|
1370
|
+
extra_style,
|
1371
|
+
self._format_rate(m.oop_failure_rate),
|
1196
1372
|
)
|
1197
1373
|
)
|
1374
|
+
s.write(' | ')
|
1375
|
+
|
1376
|
+
# Non-OOP failures.
|
1377
|
+
non_oop_failure_title = f'Non-OOP failures ({m.non_oop_failures}/{m.total})'
|
1378
|
+
if m.non_oop_failures:
|
1379
|
+
non_oop_failure_title += '
'
|
1380
|
+
for name, count in m.failure_breakdown.items():
|
1381
|
+
if not name.startswith('MappingError'):
|
1382
|
+
non_oop_failure_title += '
%s: %s (%d/%d)' % (
|
1383
|
+
name,
|
1384
|
+
self._format_rate(count / m.total),
|
1385
|
+
count,
|
1386
|
+
m.total,
|
1387
|
+
)
|
1198
1388
|
|
1199
|
-
|
1389
|
+
extra_style = ';font-weight:bold' if m.non_oop_failures > 0 else ''
|
1390
|
+
s.write(
|
1391
|
+
'<a title="%s" href="%s" style="color:red%s">%s</a>'
|
1392
|
+
% (
|
1393
|
+
non_oop_failure_title,
|
1394
|
+
self.non_oop_failures_link,
|
1395
|
+
extra_style,
|
1396
|
+
self._format_rate(m.non_oop_failure_rate),
|
1397
|
+
)
|
1398
|
+
)
|
1399
|
+
|
1400
|
+
def _format_rate(self, rate: float) -> str:
|
1401
|
+
"""Formats a rate."""
|
1402
|
+
return f'%.{self.report_precision}f%% ' % (rate * 100)
|
1403
|
+
|
1404
|
+
def audit(
|
1405
|
+
self,
|
1406
|
+
example: Any,
|
1407
|
+
message: lf.Message | None,
|
1408
|
+
error: Exception | None = None,
|
1409
|
+
dryrun: bool = False,
|
1410
|
+
) -> None:
|
1200
1411
|
"""Audits the example against the output. Subclasses should override.
|
1201
1412
|
|
1202
1413
|
Args:
|
1203
1414
|
example: The input object.
|
1204
|
-
output: The output from LM. For `lf.call`, if `schema_fn` is not provided,
|
1205
|
-
it will be the raw LM response string. Otherwise it will be the
|
1206
|
-
structured output from the LM.
|
1207
1415
|
message: The entire message returned by the LM, which could be used to
|
1208
|
-
trace the LM input, response and parsed structure.
|
1416
|
+
trace the LM input, response and parsed structure. If error is raised
|
1417
|
+
before LLM could return a response, None will be its value.
|
1418
|
+
error: The exception during processing the example.
|
1419
|
+
dryrun: Whether or not audition takes place during dryrun.
|
1209
1420
|
"""
|
1421
|
+
if error is not None:
|
1422
|
+
self._failures.append((example, error))
|
1423
|
+
|
1424
|
+
# Invalid cache of num_oop_failures.
|
1425
|
+
self.__dict__.pop('oop_failures', None)
|
1426
|
+
self.__dict__.pop('non_oop_failures', None)
|
1427
|
+
self.__dict__.pop('failure_breakdown', None)
|
1428
|
+
|
1429
|
+
if isinstance(error, lf_structured.MappingError):
|
1430
|
+
message = error.lm_response
|
1431
|
+
else:
|
1432
|
+
assert message is not None
|
1433
|
+
output = message.text if self.schema is None else message.result
|
1434
|
+
self.audit_processed(example, output, message, dryrun=dryrun)
|
1435
|
+
|
1436
|
+
# Audit usage.
|
1437
|
+
if message is not None:
|
1438
|
+
self.audit_usage(message, dryrun=dryrun)
|
1439
|
+
self._num_completed += 1
|
1440
|
+
|
1441
|
+
def audit_usage(self, message: lf.Message, dryrun: bool = False) -> None:
|
1442
|
+
del dryrun
|
1443
|
+
for m in message.trace():
|
1444
|
+
if m.metadata.get('usage', None) is not None:
|
1445
|
+
self._total_prompt_tokens += m.usage.prompt_tokens
|
1446
|
+
self._total_completion_tokens += m.usage.completion_tokens
|
1447
|
+
self._num_usages += 1
|
1448
|
+
|
1449
|
+
def audit_processed(
|
1450
|
+
self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
|
1451
|
+
) -> None:
|
1452
|
+
"""Audits a successfully processed example. Subclass should override."""
|
1210
1453
|
|
1211
1454
|
def save(
|
1212
1455
|
self, definition: bool = True, result: bool = True, report: bool = True
|
@@ -1229,16 +1472,26 @@ class Evaluation(Evaluable):
|
|
1229
1472
|
# Save failures.
|
1230
1473
|
pg.save(
|
1231
1474
|
[
|
1232
|
-
pg.Dict(
|
1233
|
-
|
1234
|
-
)
|
1235
|
-
for input, error in self.failures
|
1475
|
+
pg.Dict(input=input, error=_format_error(error))
|
1476
|
+
for input, error in self.oop_failures
|
1236
1477
|
],
|
1237
|
-
os.path.join(self.dir, Evaluation.
|
1478
|
+
os.path.join(self.dir, Evaluation.OOP_FAILURES_JSON),
|
1238
1479
|
)
|
1239
1480
|
pg.save(
|
1240
|
-
self._html([self._render_result, self.
|
1241
|
-
os.path.join(self.dir, Evaluation.
|
1481
|
+
self._html([self._render_result, self._render_oop_failures]),
|
1482
|
+
os.path.join(self.dir, Evaluation.OOP_FAILURES_HTML),
|
1483
|
+
file_format='txt',
|
1484
|
+
)
|
1485
|
+
pg.save(
|
1486
|
+
[
|
1487
|
+
pg.Dict(input=input, error=_format_error(error))
|
1488
|
+
for input, error in self.non_oop_failures
|
1489
|
+
],
|
1490
|
+
os.path.join(self.dir, Evaluation.NON_OOP_FAILURES_JSON),
|
1491
|
+
)
|
1492
|
+
pg.save(
|
1493
|
+
self._html([self._render_result, self._render_non_oop_failures]),
|
1494
|
+
os.path.join(self.dir, Evaluation.NON_OOP_FAILURES_HTML),
|
1242
1495
|
file_format='txt',
|
1243
1496
|
)
|
1244
1497
|
|
@@ -1250,8 +1503,11 @@ class Evaluation(Evaluable):
|
|
1250
1503
|
'<td>Prompt</td>'
|
1251
1504
|
'<td>Schema</td>'
|
1252
1505
|
'<td>Additional Args</td>'
|
1253
|
-
'<td>Failures</td>'
|
1254
1506
|
)
|
1507
|
+
if self.result.usage is not None:
|
1508
|
+
s.write('<td>Usage</td>')
|
1509
|
+
s.write('<td>OOP Failures</td>')
|
1510
|
+
s.write('<td>Non-OOP Failures</td>')
|
1255
1511
|
|
1256
1512
|
def _render_result_row(self, s: io.StringIO) -> None:
|
1257
1513
|
s.write(
|
@@ -1276,13 +1532,32 @@ class Evaluation(Evaluable):
|
|
1276
1532
|
'<td style="color:purple" '
|
1277
1533
|
f'{_html_repr(self.additional_args, compact=False)}</td>'
|
1278
1534
|
)
|
1279
|
-
#
|
1535
|
+
# Usage.
|
1536
|
+
if self.result.usage is not None:
|
1537
|
+
s.write('<td>')
|
1538
|
+
self._render_summary_usage(s)
|
1539
|
+
s.write('</td>')
|
1540
|
+
|
1541
|
+
# OOP failures.
|
1542
|
+
s.write(
|
1543
|
+
'<td><span style="color:magenta">%s</span>%s</td>'
|
1544
|
+
% (
|
1545
|
+
self._format_rate(self.oop_failure_rate),
|
1546
|
+
'<a href="%s">(%d/%d)</a>'
|
1547
|
+
% (self.oop_failures_link,
|
1548
|
+
self.num_oop_failures,
|
1549
|
+
self.num_completed),
|
1550
|
+
)
|
1551
|
+
)
|
1552
|
+
# Non-OOP failures.
|
1280
1553
|
s.write(
|
1281
|
-
'<td><span style="color:
|
1554
|
+
'<td><span style="color:red">%s</span>%s</td>'
|
1282
1555
|
% (
|
1283
|
-
|
1556
|
+
self._format_rate(self.non_oop_failure_rate),
|
1284
1557
|
'<a href="%s">(%d/%d)</a>'
|
1285
|
-
% (self.
|
1558
|
+
% (self.non_oop_failures_link,
|
1559
|
+
self.num_non_oop_failures,
|
1560
|
+
self.num_completed),
|
1286
1561
|
)
|
1287
1562
|
)
|
1288
1563
|
|
@@ -1296,24 +1571,77 @@ class Evaluation(Evaluable):
|
|
1296
1571
|
else:
|
1297
1572
|
return 'cyan'
|
1298
1573
|
|
1299
|
-
def
|
1574
|
+
def _render_oop_failures(self, s: io.StringIO) -> None:
|
1575
|
+
self._render_failures(s, '^MappingError.*', error_color='magenta')
|
1576
|
+
|
1577
|
+
def _render_non_oop_failures(self, s: io.StringIO) -> None:
|
1578
|
+
self._render_failures(s, '^(?!MappingError).*', error_color='red')
|
1579
|
+
|
1580
|
+
def _render_failures(
|
1581
|
+
self, s: io.StringIO, error_regex: str, error_color: str) -> None:
|
1300
1582
|
"""Formats the failed cases into html."""
|
1583
|
+
# Failure summary.
|
1301
1584
|
s.write(
|
1302
|
-
'<h2>
|
1585
|
+
'<h2> Error Summary </h2>'
|
1303
1586
|
'<div style="white-space:pre">\n'
|
1304
1587
|
'<table style="border:1px solid">'
|
1305
|
-
'<tr class="header"><td>
|
1588
|
+
'<tr class="header"><td>Error type</td><td>Stats</td></tr>'
|
1306
1589
|
)
|
1590
|
+
error_regex = re.compile(error_regex)
|
1591
|
+
if self.result.metrics.failure_breakdown:
|
1592
|
+
for name, count in self.result.metrics.failure_breakdown.items():
|
1593
|
+
if not error_regex.match(name):
|
1594
|
+
continue
|
1595
|
+
|
1596
|
+
link = f'<a href="#{name}">{name}</a>'
|
1597
|
+
error_rate = self._format_rate(count / self.result.metrics.total)
|
1598
|
+
stats = (f'<span style="color:{error_color}">{error_rate} '
|
1599
|
+
f'({count}/{self.result.metrics.total})</span>')
|
1600
|
+
s.write(f'<tr><td>{link}</td><td>{stats})</td></tr>')
|
1601
|
+
s.write(
|
1602
|
+
'</table></div>'
|
1603
|
+
'<h2> Failed Cases </h2>'
|
1604
|
+
'<div style="white-space:pre">'
|
1605
|
+
)
|
1606
|
+
# Failure details by error type.
|
1607
|
+
failures_by_error = collections.defaultdict(list)
|
1608
|
+
for example, error in self.failures:
|
1609
|
+
error_name = _error_key(error)
|
1610
|
+
if error_regex.match(error_name):
|
1611
|
+
failures_by_error[error_name].append((example, error))
|
1612
|
+
|
1613
|
+
for error_key, failures in failures_by_error.items():
|
1614
|
+
s.write(
|
1615
|
+
f'<h3 id="{error_key}"><a href="#{error_key}">{error_key}</a> '
|
1616
|
+
f'(count={len(failures)})</h3>'
|
1617
|
+
'<table style="border:1px solid">'
|
1618
|
+
'<tr class="header"><td>No.</td><td>Input</td>'
|
1619
|
+
'<td>LM invocation</td><td>Error</td></tr>'
|
1620
|
+
)
|
1621
|
+
for i, (example, error) in enumerate(failures):
|
1622
|
+
lm_response = None
|
1623
|
+
if isinstance(error, lf.structured.MappingError):
|
1624
|
+
lm_response = error.lm_response
|
1625
|
+
error = error.cause
|
1626
|
+
|
1627
|
+
bgcolor = 'white' if i % 2 == 0 else '#DDDDDD'
|
1628
|
+
s.write(f'<tr style="background-color: {bgcolor}"><td>{i + 1}</td>')
|
1629
|
+
s.write('<td style="color:green;white-space:pre-wrap">')
|
1630
|
+
s.write(pg.format(example, verbose=False))
|
1631
|
+
s.write('</td><td>')
|
1632
|
+
if lm_response is not None:
|
1633
|
+
self._render_message(lm_response, s)
|
1634
|
+
s.write(f'</td><td style="color:{error_color};white-space:pre">')
|
1635
|
+
s.write(_format_error(error))
|
1636
|
+
s.write('</td></tr>')
|
1637
|
+
s.write('</table>')
|
1638
|
+
s.write('</div>')
|
1307
1639
|
|
1308
|
-
|
1309
|
-
|
1310
|
-
|
1311
|
-
|
1312
|
-
|
1313
|
-
error_str = lf.text_formatting.decolored(str(error))
|
1314
|
-
s.write(f'<td style="color:red;white-space:pre">{error_str}</td>')
|
1315
|
-
s.write('</tr>')
|
1316
|
-
s.write('</table></div>')
|
1640
|
+
@classmethod
|
1641
|
+
def visualize(cls, evaluations: list['Evaluation']) -> str | None:
|
1642
|
+
"""Visualize the a list of evaluations of this task in HTML."""
|
1643
|
+
del evaluations
|
1644
|
+
return None
|
1317
1645
|
|
1318
1646
|
|
1319
1647
|
@pg.functor()
|
@@ -1374,8 +1702,8 @@ class Summary(pg.Object):
|
|
1374
1702
|
Type[lf.LanguageModel],
|
1375
1703
|
tuple[lf.LanguageModel | Type[lf.LanguageModel], ...],
|
1376
1704
|
] = lf.LanguageModel,
|
1377
|
-
method: Union[str, tuple[str], None] = None,
|
1378
|
-
schema_fn: Union[pg.Functor, tuple[pg.Functor], None] = None,
|
1705
|
+
method: Union[str, tuple[str, ...], None] = None,
|
1706
|
+
schema_fn: Union[pg.Functor, tuple[pg.Functor, ...], None] = None,
|
1379
1707
|
completed: bool | None = None,
|
1380
1708
|
pivot_field: str | None = None,
|
1381
1709
|
) -> 'Summary':
|
@@ -1466,7 +1794,7 @@ class Summary(pg.Object):
|
|
1466
1794
|
if e is None:
|
1467
1795
|
s.write('<span style="color: gray">N/A<span>')
|
1468
1796
|
else:
|
1469
|
-
s.write(e.
|
1797
|
+
s.write(e.summary_card())
|
1470
1798
|
s.write('</td>')
|
1471
1799
|
s.write('</tr>')
|
1472
1800
|
s.write('</table>')
|
@@ -1541,13 +1869,22 @@ class Summary(pg.Object):
|
|
1541
1869
|
s.write('<html><body>')
|
1542
1870
|
for task in sorted(self.tasks(), key=lambda cls: cls.__name__):
|
1543
1871
|
table_id = task.__name__.lower()
|
1872
|
+
evaluations = self.select(task=task).evaluations
|
1873
|
+
table = Summary.Table.from_evaluations(evaluations, pivot_field)
|
1544
1874
|
s.write('<div>')
|
1545
|
-
s.write(
|
1546
|
-
|
1547
|
-
|
1548
|
-
table = Summary.Table.from_evaluations(
|
1549
|
-
self.select(task=task).evaluations, pivot_field
|
1875
|
+
s.write(
|
1876
|
+
f'<a id="{table_id}" href="#{table_id}">'
|
1877
|
+
f'<h2>{task.__name__}</h2></a>'
|
1550
1878
|
)
|
1879
|
+
|
1880
|
+
# Allow users to plugin visualization code (e.g. matplot) in the summary
|
1881
|
+
# page.
|
1882
|
+
visual_part = task.visualize(evaluations)
|
1883
|
+
if visual_part:
|
1884
|
+
s.write(visual_part)
|
1885
|
+
|
1886
|
+
s.write(f'<h4 style="color:gray">{len(evaluations)} experiments</h4>')
|
1887
|
+
s.write('<hr/>')
|
1551
1888
|
s.write(table.html())
|
1552
1889
|
s.write('</div>')
|
1553
1890
|
s.write('</body></html>')
|
@@ -1556,8 +1893,36 @@ class Summary(pg.Object):
|
|
1556
1893
|
def _repr_html_(self) -> str:
|
1557
1894
|
return self.html()
|
1558
1895
|
|
1896
|
+
def json(
|
1897
|
+
self,
|
1898
|
+
) -> dict[
|
1899
|
+
str, # Task name
|
1900
|
+
list[pg.Dict], # List of pg.Dict with `experiment` and `metrics`.
|
1901
|
+
]:
|
1902
|
+
"""Returns the JSON representation of the summary."""
|
1903
|
+
task_results = {}
|
1904
|
+
for task in sorted(self.tasks(), key=lambda cls: cls.__name__):
|
1905
|
+
results = []
|
1906
|
+
for entry in self.select(task=task).evaluations:
|
1907
|
+
results.append(
|
1908
|
+
pg.Dict(
|
1909
|
+
id=entry.id,
|
1910
|
+
experiment=entry,
|
1911
|
+
dir=entry.dir,
|
1912
|
+
metrics=entry.result.metrics if entry.result else None,
|
1913
|
+
usage=entry.result.usage if entry.result else None,
|
1914
|
+
)
|
1915
|
+
)
|
1916
|
+
task_results[task.__name__] = results
|
1917
|
+
return task_results
|
1918
|
+
|
1559
1919
|
def save(self, file: str, pivot_field: str | None = None) -> None:
|
1560
1920
|
pg.save(self.html(pivot_field), file, file_format='txt')
|
1921
|
+
if file.endswith('.html'):
|
1922
|
+
json_file = file.replace('.html', '.json')
|
1923
|
+
else:
|
1924
|
+
json_file = os.path.join(file, '.json')
|
1925
|
+
pg.save(self.json(), json_file)
|
1561
1926
|
|
1562
1927
|
@classmethod
|
1563
1928
|
def from_dirs(
|
@@ -1694,6 +2059,21 @@ class Summary(pg.Object):
|
|
1694
2059
|
return result.join()
|
1695
2060
|
|
1696
2061
|
|
2062
|
+
def _format_error(error: Exception):
|
2063
|
+
"""Formats an error into a string."""
|
2064
|
+
return (f'({error.__class__.__name__}) '
|
2065
|
+
+ lf.text_formatting.decolored(str(error)))
|
2066
|
+
|
2067
|
+
|
2068
|
+
def _error_key(error: Exception) -> str:
|
2069
|
+
"""Returns the key for an error."""
|
2070
|
+
error_names = []
|
2071
|
+
while error is not None:
|
2072
|
+
error_names.append(error.__class__.__name__)
|
2073
|
+
error = getattr(error, 'cause', None)
|
2074
|
+
return '.'.join(error_names)
|
2075
|
+
|
2076
|
+
|
1697
2077
|
def _html_repr(value: Any, compact: bool = True, escape: bool = False) -> str:
|
1698
2078
|
"""Formats prompt in HTML."""
|
1699
2079
|
if type(value) is lf.Template: # pylint: disable=unidiomatic-typecheck
|
@@ -1768,3 +2148,193 @@ def monitor_async(
|
|
1768
2148
|
scan_interval=scan_interval,
|
1769
2149
|
refresh_when_stop=refresh_when_stop,
|
1770
2150
|
)
|
2151
|
+
|
2152
|
+
|
2153
|
+
#
|
2154
|
+
# Named evaluations and experiments support.
|
2155
|
+
#
|
2156
|
+
|
2157
|
+
|
2158
|
+
class _NamedEvaluationRegistry:
|
2159
|
+
"""Named evaluation registry."""
|
2160
|
+
|
2161
|
+
def __init__(self):
|
2162
|
+
self._registry = {}
|
2163
|
+
|
2164
|
+
def names(self) -> list[str]:
|
2165
|
+
"""Returns all registered names."""
|
2166
|
+
return sorted(self._registry.keys())
|
2167
|
+
|
2168
|
+
def get(self, name: str) -> Type[Evaluable]:
|
2169
|
+
"""Gets an evaluation by name."""
|
2170
|
+
if name not in self._registry:
|
2171
|
+
raise ValueError(
|
2172
|
+
f'Evaluation {name!r} not found. '
|
2173
|
+
'Did you forget to import the module that registers it?'
|
2174
|
+
)
|
2175
|
+
return self._registry[name]
|
2176
|
+
|
2177
|
+
def register(
|
2178
|
+
self,
|
2179
|
+
name: str,
|
2180
|
+
experiment_cls: Type[Evaluable],
|
2181
|
+
):
|
2182
|
+
"""Register an experiment class."""
|
2183
|
+
self._registry[name] = experiment_cls
|
2184
|
+
|
2185
|
+
|
2186
|
+
_eval_registry = _NamedEvaluationRegistry()
|
2187
|
+
|
2188
|
+
|
2189
|
+
def registered_names() -> list[str]:
|
2190
|
+
"""Returns all registered names."""
|
2191
|
+
return _eval_registry.names()
|
2192
|
+
|
2193
|
+
|
2194
|
+
def get_evaluation(evaluation: str | Evaluable) -> Evaluable:
|
2195
|
+
"""Gets an evaluation experiment by name."""
|
2196
|
+
if isinstance(evaluation, str):
|
2197
|
+
return _eval_registry.get(evaluation)()
|
2198
|
+
return evaluation
|
2199
|
+
|
2200
|
+
|
2201
|
+
def register(name: str):
|
2202
|
+
"""Decorator to create a named evaluation class."""
|
2203
|
+
|
2204
|
+
def _register(func_or_cls: Type[Evaluation] | types.FunctionType):
|
2205
|
+
if inspect.isfunction(func_or_cls):
|
2206
|
+
e = func_or_cls()
|
2207
|
+
if not isinstance(e, Evaluable):
|
2208
|
+
raise TypeError(
|
2209
|
+
f'The return value of `{func_or_cls}` should be an instance of '
|
2210
|
+
'`lf.eval.Evaluable` subclass.'
|
2211
|
+
)
|
2212
|
+
|
2213
|
+
class GeneratedSuite(Suite):
|
2214
|
+
# NOTE(daiyip): Delay serialization key registration for generated
|
2215
|
+
# class.
|
2216
|
+
auto_register = False
|
2217
|
+
children = e.children if isinstance(e, Suite) else [e]
|
2218
|
+
|
2219
|
+
cls = GeneratedSuite
|
2220
|
+
cls.__name__ = func_or_cls.__name__
|
2221
|
+
cls.__doc__ = func_or_cls.__doc__
|
2222
|
+
cls.__qualname__ = func_or_cls.__qualname__
|
2223
|
+
cls.__module__ = getattr(func_or_cls, '__module__', 'wrapper')
|
2224
|
+
cls.register_for_deserialization(cls.__type_name__)
|
2225
|
+
|
2226
|
+
elif issubclass(func_or_cls, Evaluable):
|
2227
|
+
cls = func_or_cls
|
2228
|
+
else:
|
2229
|
+
raise ValueError(f'Unsupported type: {type(func_or_cls)}')
|
2230
|
+
|
2231
|
+
_eval_registry.register(name, cls)
|
2232
|
+
return cls
|
2233
|
+
|
2234
|
+
return _register
|
2235
|
+
|
2236
|
+
|
2237
|
+
def get(
|
2238
|
+
root_dir: str,
|
2239
|
+
evaluations: list[str | Evaluable],
|
2240
|
+
filter: Union[ # pylint: disable=redefined-builtin
|
2241
|
+
str, # Regex to filter evaluation based on ID.
|
2242
|
+
Callable[[Evaluable], bool], # Custom filter function.
|
2243
|
+
None # No filtering (Default).
|
2244
|
+
] = None, # pylint: disable=bad-whitespace
|
2245
|
+
patches: list[Union[
|
2246
|
+
str, # String-based PyGlove patcher.
|
2247
|
+
pg.patching.Patcher, # PyGlove patcher object.
|
2248
|
+
Callable[[pg.KeyPath, Any, Any], Any], # PyGlove rebind function.
|
2249
|
+
]] | None = None, # pylint: disable=bad-whitespace
|
2250
|
+
) -> Suite:
|
2251
|
+
"""Gets a suite from a list of patched evaluations.
|
2252
|
+
|
2253
|
+
Args:
|
2254
|
+
root_dir: The root directory of the experiment.
|
2255
|
+
evaluations: A list of evaluations to be included in the suite.
|
2256
|
+
filter: A regular expression (str) for selecting sub-experiments of matched
|
2257
|
+
IDs, or a filter function to filter the evaluations.
|
2258
|
+
patches: A list of patches to be applied to the suite. Each element can be
|
2259
|
+
a string (for string-based patcher), a `pg.patching.Patcher` object, or
|
2260
|
+
a rebind function (e.g. `pg.rebind`). See `lf.eval.patch_*` for more
|
2261
|
+
details.
|
2262
|
+
|
2263
|
+
Returns:
|
2264
|
+
A suite of selected `lf.eval.Evaluation` objects.
|
2265
|
+
"""
|
2266
|
+
evaluations = [get_evaluation(e) for e in evaluations]
|
2267
|
+
suite = Suite(evaluations, root_dir=root_dir)
|
2268
|
+
if patches:
|
2269
|
+
suite = pg.patch(suite, patches)
|
2270
|
+
|
2271
|
+
if isinstance(filter, str):
|
2272
|
+
regex = re.compile(filter)
|
2273
|
+
filter = lambda x: bool(regex.match(x.id))
|
2274
|
+
|
2275
|
+
if filter:
|
2276
|
+
suite = Suite(
|
2277
|
+
[leaf for leaf in suite.leaf_nodes if filter(leaf)], root_dir=root_dir)
|
2278
|
+
return suite
|
2279
|
+
|
2280
|
+
|
2281
|
+
def run(
|
2282
|
+
root_dir: str,
|
2283
|
+
evaluations: list[str | Evaluable],
|
2284
|
+
filter: Union[ # pylint: disable=redefined-builtin
|
2285
|
+
str, # Regex to filter evaluation based on ID.
|
2286
|
+
Callable[[Evaluable], bool], # Custom filter function.
|
2287
|
+
None # No filtering (Default).
|
2288
|
+
] = None, # pylint: disable=bad-whitespace
|
2289
|
+
patches: list[Union[
|
2290
|
+
str, # String-based PyGlove patcher.
|
2291
|
+
pg.patching.Patcher, # PyGlove patcher object.
|
2292
|
+
Callable[[pg.KeyPath, Any, Any], Any], # PyGlove rebind function.
|
2293
|
+
]] | None = None, # pylint: disable=bad-whitespace
|
2294
|
+
mode: Literal['run', 'rerun', 'dryrun', 'noop'] = 'run',
|
2295
|
+
debug: bool = False,
|
2296
|
+
print_definition: bool = False,
|
2297
|
+
**kwargs,
|
2298
|
+
) -> Suite:
|
2299
|
+
"""Run selected evaluations with patching.
|
2300
|
+
|
2301
|
+
Args:
|
2302
|
+
root_dir: The root directory of the experiment.
|
2303
|
+
evaluations: A list of evaluations to be included in the suite.
|
2304
|
+
filter: A regular expression (str) for selecting sub-experiments of matched
|
2305
|
+
IDs, or a filter function to filter the evaluations.
|
2306
|
+
patches: A list of patches to be applied to the suite. Each element can be
|
2307
|
+
a string (for string-based patcher), a `pg.patching.Patcher` object, or
|
2308
|
+
a rebind function (e.g. `pg.rebind`). See `lf.eval.patch_*` for more
|
2309
|
+
details.
|
2310
|
+
mode: The mode to run the suite. "run" to run the suite, with reusing
|
2311
|
+
existing results if available; "rerun" to rerun all evaluations even if
|
2312
|
+
there are existing results; "dryrun" to dryrun the suite; and "noop"
|
2313
|
+
to do nothing.
|
2314
|
+
debug: Whether to run in debug mode.
|
2315
|
+
print_definition: Whether to print the experiment definition.
|
2316
|
+
**kwargs: Additional arguments to be passed to dryrun/run the suite.
|
2317
|
+
|
2318
|
+
Returns:
|
2319
|
+
A suite of selected `lf.eval.Evaluation` objects.
|
2320
|
+
"""
|
2321
|
+
suite = get(root_dir, evaluations, patches=patches, filter=filter)
|
2322
|
+
if print_definition:
|
2323
|
+
lf.console.write(
|
2324
|
+
pg.format(
|
2325
|
+
suite,
|
2326
|
+
compact=False,
|
2327
|
+
verbose=False,
|
2328
|
+
hide_default_values=True,
|
2329
|
+
python_format=True,
|
2330
|
+
),
|
2331
|
+
title='[EXPERIMENT DEFINITION]',
|
2332
|
+
color='blue',
|
2333
|
+
)
|
2334
|
+
|
2335
|
+
if mode == 'run':
|
2336
|
+
rerun = mode == 'rerun'
|
2337
|
+
suite.run(debug=debug, rerun=rerun, **kwargs)
|
2338
|
+
elif mode == 'dryrun':
|
2339
|
+
suite.dryrun(debug=debug, **kwargs)
|
2340
|
+
return suite
|