langfun 0.0.2.dev20240330__py3-none-any.whl → 0.0.2.dev20240429__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +2 -0
- langfun/core/__init__.py +1 -0
- langfun/core/coding/python/correction.py +0 -7
- langfun/core/component.py +6 -0
- langfun/core/component_test.py +1 -0
- langfun/core/eval/__init__.py +2 -0
- langfun/core/eval/base.py +202 -23
- langfun/core/eval/base_test.py +49 -10
- langfun/core/eval/matching.py +26 -9
- langfun/core/eval/matching_test.py +2 -1
- langfun/core/eval/scoring.py +15 -6
- langfun/core/eval/scoring_test.py +2 -1
- langfun/core/langfunc.py +0 -5
- langfun/core/langfunc_test.py +6 -4
- langfun/core/language_model.py +124 -24
- langfun/core/language_model_test.py +249 -26
- langfun/core/llms/__init__.py +19 -2
- langfun/core/llms/anthropic.py +263 -0
- langfun/core/llms/anthropic_test.py +167 -0
- langfun/core/llms/cache/in_memory_test.py +37 -28
- langfun/core/llms/fake.py +31 -22
- langfun/core/llms/fake_test.py +122 -11
- langfun/core/llms/google_genai_test.py +8 -3
- langfun/core/llms/groq.py +260 -0
- langfun/core/llms/groq_test.py +170 -0
- langfun/core/llms/llama_cpp.py +3 -1
- langfun/core/llms/openai.py +97 -79
- langfun/core/llms/openai_test.py +285 -59
- langfun/core/modalities/video.py +5 -2
- langfun/core/structured/__init__.py +3 -0
- langfun/core/structured/completion_test.py +2 -2
- langfun/core/structured/function_generation.py +245 -0
- langfun/core/structured/function_generation_test.py +329 -0
- langfun/core/structured/mapping.py +56 -2
- langfun/core/structured/mapping_test.py +17 -0
- langfun/core/structured/parsing_test.py +18 -13
- langfun/core/structured/prompting.py +27 -6
- langfun/core/structured/prompting_test.py +79 -12
- langfun/core/structured/schema.py +4 -2
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +4 -6
- langfun/core/template.py +125 -10
- langfun/core/template_test.py +75 -0
- langfun/core/templates/selfplay_test.py +6 -2
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/METADATA +3 -2
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/RECORD +49 -43
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/top_level.txt +0 -0
langfun/__init__.py
CHANGED
@@ -34,6 +34,7 @@ score = structured.score
|
|
34
34
|
generate_class = structured.generate_class
|
35
35
|
|
36
36
|
source_form = structured.source_form
|
37
|
+
function_gen = structured.function_gen
|
37
38
|
|
38
39
|
from langfun.core import eval # pylint: disable=redefined-builtin
|
39
40
|
from langfun.core import templates
|
@@ -54,6 +55,7 @@ Video = modalities.Video
|
|
54
55
|
PDF = modalities.PDF
|
55
56
|
|
56
57
|
# Error types.
|
58
|
+
MappingError = structured.MappingError
|
57
59
|
SchemaError = structured.SchemaError
|
58
60
|
JsonError = structured.JsonError
|
59
61
|
CodeError = coding.CodeError
|
langfun/core/__init__.py
CHANGED
@@ -99,6 +99,7 @@ from langfun.core.modality import ModalityRef
|
|
99
99
|
from langfun.core.language_model import LanguageModel
|
100
100
|
from langfun.core.language_model import LMSample
|
101
101
|
from langfun.core.language_model import LMSamplingOptions
|
102
|
+
from langfun.core.language_model import LMSamplingUsage
|
102
103
|
from langfun.core.language_model import LMSamplingResult
|
103
104
|
from langfun.core.language_model import LMScoringResult
|
104
105
|
from langfun.core.language_model import LMCache
|
@@ -12,7 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
"""Python code error correction."""
|
15
|
-
import re
|
16
15
|
from typing import Any
|
17
16
|
import langfun.core as lf
|
18
17
|
from langfun.core.coding.python import errors
|
@@ -31,11 +30,6 @@ class CorrectedCode(pg.Object):
|
|
31
30
|
corrected_code: str
|
32
31
|
|
33
32
|
|
34
|
-
def remove_docstrings(code):
|
35
|
-
pattern = re.compile(r"(def .+?:\s*?)('''|\"\"\")((.|\s)*?)(\2)", re.DOTALL)
|
36
|
-
return pattern.sub(r"\1", code)
|
37
|
-
|
38
|
-
|
39
33
|
def run_with_correction(
|
40
34
|
code: str,
|
41
35
|
error: str | None = None,
|
@@ -86,7 +80,6 @@ def run_with_correction(
|
|
86
80
|
# pytype: enable=import-error
|
87
81
|
# pylint: enable=g-import-not-at-top
|
88
82
|
|
89
|
-
code = remove_docstrings(code)
|
90
83
|
if max_attempts == 0:
|
91
84
|
result = execution.run(
|
92
85
|
code,
|
langfun/core/component.py
CHANGED
@@ -210,6 +210,12 @@ def get_contextual_override(var_name: str) -> ContextualOverride | None:
|
|
210
210
|
return _get_scoped_value(_global_tls, _CONTEXT_OVERRIDES, var_name)
|
211
211
|
|
212
212
|
|
213
|
+
def all_contextual_values() -> dict[str, Any]:
|
214
|
+
"""Returns all contextual values provided from `lf.context` in scope."""
|
215
|
+
overrides = getattr(_global_tls, _CONTEXT_OVERRIDES, {})
|
216
|
+
return {k: v.value for k, v in overrides.items()}
|
217
|
+
|
218
|
+
|
213
219
|
@contextlib.contextmanager
|
214
220
|
def _contextual_scope(
|
215
221
|
tls: threading.local, tls_key, **variables
|
langfun/core/component_test.py
CHANGED
@@ -84,6 +84,7 @@ class ComponentContextTest(unittest.TestCase):
|
|
84
84
|
lf.get_contextual_override('y'),
|
85
85
|
lf.ContextualOverride(3, cascade=False, override_attrs=False),
|
86
86
|
)
|
87
|
+
self.assertEqual(lf.all_contextual_values(), dict(x=3, y=3, z=3))
|
87
88
|
|
88
89
|
# Member attributes take precedence over `lf.context`.
|
89
90
|
self.assertEqual(a1.x, 1)
|
langfun/core/eval/__init__.py
CHANGED
@@ -16,6 +16,8 @@
|
|
16
16
|
# pylint: disable=g-importing-member
|
17
17
|
# pylint: disable=g-bad-import-order
|
18
18
|
|
19
|
+
from langfun.core.eval.base import app_run
|
20
|
+
|
19
21
|
from langfun.core.eval.base import Evaluable
|
20
22
|
from langfun.core.eval.base import Evaluation
|
21
23
|
from langfun.core.eval.base import Suite
|
langfun/core/eval/base.py
CHANGED
@@ -26,6 +26,8 @@ import threading
|
|
26
26
|
import time
|
27
27
|
from typing import Annotated, Any, Callable, Iterator, Literal, Optional, Sequence, Type, Union
|
28
28
|
|
29
|
+
from absl import app
|
30
|
+
from absl import flags
|
29
31
|
import langfun.core as lf
|
30
32
|
import langfun.core.coding as lf_coding
|
31
33
|
from langfun.core.llms.cache import in_memory
|
@@ -538,7 +540,7 @@ class Evaluable(lf.Component):
|
|
538
540
|
f'<div style="color: {text_color}; white-space: pre-wrap;'
|
539
541
|
'padding: 10px; border: 1px solid; margin-top: 10px">'
|
540
542
|
)
|
541
|
-
s.write(m.text)
|
543
|
+
s.write(m.get('formatted_text', m.text))
|
542
544
|
if m.result is not None:
|
543
545
|
s.write(
|
544
546
|
'<div style="color: magenta; white-space: pre-wrap;'
|
@@ -546,6 +548,16 @@ class Evaluable(lf.Component):
|
|
546
548
|
)
|
547
549
|
s.write(pg.format(m.result))
|
548
550
|
s.write('</div>')
|
551
|
+
if 'usage' in m.metadata:
|
552
|
+
s.write(
|
553
|
+
'<div style="background-color: #EEEEEE; color: black; '
|
554
|
+
'white-space: pre-wrap; padding: 10px; border: 0px solid; '
|
555
|
+
'margin: 10px">'
|
556
|
+
f'prompt: {m.usage.prompt_tokens} tokens, '
|
557
|
+
f'response: {m.usage.completion_tokens} tokens, '
|
558
|
+
f'total: {m.usage.total_tokens} tokens'
|
559
|
+
'</div>'
|
560
|
+
)
|
549
561
|
s.write('</div>')
|
550
562
|
|
551
563
|
@classmethod
|
@@ -810,17 +822,36 @@ class Evaluation(Evaluable):
|
|
810
822
|
return 0.0
|
811
823
|
return self.num_failures / self.num_completed
|
812
824
|
|
825
|
+
@property
|
826
|
+
def has_usage(self) -> bool:
|
827
|
+
"""Returns True if token usage is enabled."""
|
828
|
+
return self._num_usages > 0
|
829
|
+
|
830
|
+
@property
|
831
|
+
def average_prompt_tokens(self) -> int:
|
832
|
+
"""Returns the average prompt tokens."""
|
833
|
+
if not self.has_usage:
|
834
|
+
return 0
|
835
|
+
return self._total_prompt_tokens // self._num_usages
|
836
|
+
|
837
|
+
@property
|
838
|
+
def average_completion_tokens(self) -> int:
|
839
|
+
"""Returns the average completion tokens."""
|
840
|
+
if not self.has_usage:
|
841
|
+
return 0
|
842
|
+
return self._total_completion_tokens // self._num_usages
|
843
|
+
|
844
|
+
@property
|
845
|
+
def average_total_tokens(self) -> int:
|
846
|
+
"""Returns the average total tokens."""
|
847
|
+
return self.average_prompt_tokens + self.average_completion_tokens
|
848
|
+
|
813
849
|
@functools.cached_property
|
814
850
|
def schema(self) -> lf_structured.Schema | None:
|
815
851
|
"""Schema."""
|
816
852
|
if self.schema_fn is None:
|
817
853
|
return None
|
818
854
|
|
819
|
-
kwargs = {}
|
820
|
-
# Allow schema to be a function based on current evaluation.
|
821
|
-
if 'evaluation' in self.schema_fn.__signature__.arg_names:
|
822
|
-
kwargs['evaluation'] = self
|
823
|
-
|
824
855
|
schema = self._call_schema_fn()
|
825
856
|
fewshot_examples = None
|
826
857
|
if isinstance(schema, tuple):
|
@@ -861,7 +892,11 @@ class Evaluation(Evaluable):
|
|
861
892
|
'Encountered: {annotation!r}.'
|
862
893
|
)
|
863
894
|
self._maybe_adjust_schema_for_completion(annotation)
|
864
|
-
|
895
|
+
schema = lf_structured.Schema.from_value(annotation)
|
896
|
+
# NOTE(daiyip): add references to the dependent classes of the returned type
|
897
|
+
# to prevent unused subclasses get garbage collected by Python.
|
898
|
+
setattr(schema, '__dependencies__', schema.class_dependencies())
|
899
|
+
return schema
|
865
900
|
|
866
901
|
def _maybe_adjust_schema_for_completion(self, cls):
|
867
902
|
if (self.completion_prompt_field is None
|
@@ -938,6 +973,10 @@ class Evaluation(Evaluable):
|
|
938
973
|
self._failures = []
|
939
974
|
self._num_completed = 0
|
940
975
|
|
976
|
+
self._total_prompt_tokens = 0
|
977
|
+
self._total_completion_tokens = 0
|
978
|
+
self._num_usages = 0
|
979
|
+
|
941
980
|
@property
|
942
981
|
def failures_link(self) -> str | None:
|
943
982
|
"""Returns the link to the failures page."""
|
@@ -957,7 +996,7 @@ class Evaluation(Evaluable):
|
|
957
996
|
example = example or self.examples[0]
|
958
997
|
|
959
998
|
# We make a copy to avoid pollute the state of current object.
|
960
|
-
copy = self.clone()
|
999
|
+
copy: Evaluation = self.clone()
|
961
1000
|
copy.__dict__['examples'] = [example]
|
962
1001
|
|
963
1002
|
# We set the symbolic parent of the cloned to access contextual information
|
@@ -987,9 +1026,9 @@ class Evaluation(Evaluable):
|
|
987
1026
|
color='blue',
|
988
1027
|
)
|
989
1028
|
|
990
|
-
|
991
|
-
copy.audit(example, output, output_message)
|
1029
|
+
copy.audit(example, output_message, None, dryrun=True)
|
992
1030
|
result = copy.summarize()
|
1031
|
+
|
993
1032
|
if verbose:
|
994
1033
|
lf.console.write('')
|
995
1034
|
lf.console.write(
|
@@ -1036,11 +1075,12 @@ class Evaluation(Evaluable):
|
|
1036
1075
|
status_fn=self._status,
|
1037
1076
|
):
|
1038
1077
|
if error is not None:
|
1039
|
-
|
1040
|
-
|
1041
|
-
|
1042
|
-
|
1043
|
-
|
1078
|
+
message = (
|
1079
|
+
error.lm_response
|
1080
|
+
if isinstance(error, lf_structured.MappingError)
|
1081
|
+
else None
|
1082
|
+
)
|
1083
|
+
self.audit(example, message, error)
|
1044
1084
|
finally:
|
1045
1085
|
# Save cache upon completion or interruption.
|
1046
1086
|
if self.dir and self.cache:
|
@@ -1143,6 +1183,19 @@ class Evaluation(Evaluable):
|
|
1143
1183
|
)
|
1144
1184
|
else:
|
1145
1185
|
cache_stats = dict(use_cache=False)
|
1186
|
+
|
1187
|
+
if self.has_usage:
|
1188
|
+
usage = pg.Dict(
|
1189
|
+
total_prompt_tokens=self._total_prompt_tokens,
|
1190
|
+
total_completion_tokens=self._total_completion_tokens,
|
1191
|
+
num_usages=self._num_usages,
|
1192
|
+
average_prompt_tokens=self.average_prompt_tokens,
|
1193
|
+
average_completion_tokens=self.average_completion_tokens,
|
1194
|
+
average_total_tokens=self.average_total_tokens,
|
1195
|
+
)
|
1196
|
+
else:
|
1197
|
+
usage = None
|
1198
|
+
|
1146
1199
|
result = pg.Dict(
|
1147
1200
|
experiment_setup=pg.Dict(
|
1148
1201
|
id=self.id,
|
@@ -1158,6 +1211,7 @@ class Evaluation(Evaluable):
|
|
1158
1211
|
failures=self.num_failures,
|
1159
1212
|
failure_rate=self.failure_rate,
|
1160
1213
|
),
|
1214
|
+
usage=usage,
|
1161
1215
|
)
|
1162
1216
|
return result
|
1163
1217
|
|
@@ -1179,9 +1233,28 @@ class Evaluation(Evaluable):
|
|
1179
1233
|
'</td></tr><tr><td>'
|
1180
1234
|
)
|
1181
1235
|
self._render_metric(s)
|
1236
|
+
|
1237
|
+
# Summarize average usage.
|
1238
|
+
if self.result.usage is not None:
|
1239
|
+
self._render_usage(s)
|
1240
|
+
|
1182
1241
|
s.write('</td></tr></table></div>')
|
1183
1242
|
return s.getvalue()
|
1184
1243
|
|
1244
|
+
def _render_usage(self, s: io.StringIO) -> None:
|
1245
|
+
"""Renders usage in HTML."""
|
1246
|
+
usage = self.result.usage
|
1247
|
+
total = usage.total_prompt_tokens + usage.total_completion_tokens
|
1248
|
+
s.write(
|
1249
|
+
' <a title="'
|
1250
|
+
f'# of usages: {usage.num_usages}
'
|
1251
|
+
f'total prompt: {usage.total_prompt_tokens}
'
|
1252
|
+
f'total response: {usage.total_completion_tokens}
'
|
1253
|
+
f'avg prompt: {usage.average_prompt_tokens}
'
|
1254
|
+
f'avg response: {usage.average_completion_tokens}'
|
1255
|
+
f'" style="color:gray">({total} tokens)</a>'
|
1256
|
+
)
|
1257
|
+
|
1185
1258
|
def _render_metric(self, s: io.StringIO) -> None:
|
1186
1259
|
"""Renders metrics in HTML."""
|
1187
1260
|
assert self.result is not None
|
@@ -1196,17 +1269,48 @@ class Evaluation(Evaluable):
|
|
1196
1269
|
)
|
1197
1270
|
)
|
1198
1271
|
|
1199
|
-
def audit(
|
1272
|
+
def audit(
|
1273
|
+
self,
|
1274
|
+
example: Any,
|
1275
|
+
message: lf.Message | None,
|
1276
|
+
error: Exception | None = None,
|
1277
|
+
dryrun: bool = False,
|
1278
|
+
) -> None:
|
1200
1279
|
"""Audits the example against the output. Subclasses should override.
|
1201
1280
|
|
1202
1281
|
Args:
|
1203
1282
|
example: The input object.
|
1204
|
-
output: The output from LM. For `lf.call`, if `schema_fn` is not provided,
|
1205
|
-
it will be the raw LM response string. Otherwise it will be the
|
1206
|
-
structured output from the LM.
|
1207
1283
|
message: The entire message returned by the LM, which could be used to
|
1208
|
-
trace the LM input, response and parsed structure.
|
1284
|
+
trace the LM input, response and parsed structure. If error is raised
|
1285
|
+
before LLM could return a response, None will be its value.
|
1286
|
+
error: The exception during processing the example.
|
1287
|
+
dryrun: Whether or not audition takes place during dryrun.
|
1209
1288
|
"""
|
1289
|
+
if error is not None:
|
1290
|
+
self._failures.append((example, str(error)))
|
1291
|
+
if isinstance(error, lf_structured.MappingError):
|
1292
|
+
message = error.lm_response
|
1293
|
+
else:
|
1294
|
+
assert message is not None
|
1295
|
+
output = message.text if self.schema is None else message.result
|
1296
|
+
self.audit_processed(example, output, message, dryrun=dryrun)
|
1297
|
+
|
1298
|
+
# Audit usage.
|
1299
|
+
if message is not None:
|
1300
|
+
self.audit_usage(message, dryrun=dryrun)
|
1301
|
+
self._num_completed += 1
|
1302
|
+
|
1303
|
+
def audit_usage(self, message: lf.Message, dryrun: bool = False) -> None:
|
1304
|
+
for m in message.trace():
|
1305
|
+
if 'usage' in m.metadata:
|
1306
|
+
self._total_prompt_tokens += m.usage.prompt_tokens
|
1307
|
+
self._total_completion_tokens += m.usage.completion_tokens
|
1308
|
+
self._num_usages += 1
|
1309
|
+
|
1310
|
+
def audit_processed(
|
1311
|
+
self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
|
1312
|
+
) -> None:
|
1313
|
+
"""Audits a successfully processed example. Subclass should override."""
|
1210
1314
|
|
1211
1315
|
def save(
|
1212
1316
|
self, definition: bool = True, result: bool = True, report: bool = True
|
@@ -1250,8 +1354,10 @@ class Evaluation(Evaluable):
|
|
1250
1354
|
'<td>Prompt</td>'
|
1251
1355
|
'<td>Schema</td>'
|
1252
1356
|
'<td>Additional Args</td>'
|
1253
|
-
'<td>Failures</td>'
|
1254
1357
|
)
|
1358
|
+
if self.result.usage is not None:
|
1359
|
+
s.write('<td>Usage</td>')
|
1360
|
+
s.write('<td>Failures</td>')
|
1255
1361
|
|
1256
1362
|
def _render_result_row(self, s: io.StringIO) -> None:
|
1257
1363
|
s.write(
|
@@ -1276,6 +1382,12 @@ class Evaluation(Evaluable):
|
|
1276
1382
|
'<td style="color:purple" '
|
1277
1383
|
f'{_html_repr(self.additional_args, compact=False)}</td>'
|
1278
1384
|
)
|
1385
|
+
# Usage.
|
1386
|
+
if self.result.usage is not None:
|
1387
|
+
s.write('<td>')
|
1388
|
+
self._render_usage(s)
|
1389
|
+
s.write('</td>')
|
1390
|
+
|
1279
1391
|
# Failures.
|
1280
1392
|
s.write(
|
1281
1393
|
'<td><span style="color:orange">%s</span>%s</td>'
|
@@ -1374,8 +1486,8 @@ class Summary(pg.Object):
|
|
1374
1486
|
Type[lf.LanguageModel],
|
1375
1487
|
tuple[lf.LanguageModel | Type[lf.LanguageModel], ...],
|
1376
1488
|
] = lf.LanguageModel,
|
1377
|
-
method: Union[str, tuple[str], None] = None,
|
1378
|
-
schema_fn: Union[pg.Functor, tuple[pg.Functor], None] = None,
|
1489
|
+
method: Union[str, tuple[str, ...], None] = None,
|
1490
|
+
schema_fn: Union[pg.Functor, tuple[pg.Functor, ...], None] = None,
|
1379
1491
|
completed: bool | None = None,
|
1380
1492
|
pivot_field: str | None = None,
|
1381
1493
|
) -> 'Summary':
|
@@ -1556,8 +1668,35 @@ class Summary(pg.Object):
|
|
1556
1668
|
def _repr_html_(self) -> str:
|
1557
1669
|
return self.html()
|
1558
1670
|
|
1671
|
+
def json(
|
1672
|
+
self,
|
1673
|
+
) -> dict[
|
1674
|
+
str, # Task name
|
1675
|
+
list[pg.Dict], # List of pg.Dict with `experiment` and `metrics`.
|
1676
|
+
]:
|
1677
|
+
"""Returns the JSON representation of the summary."""
|
1678
|
+
task_results = {}
|
1679
|
+
for task in sorted(self.tasks(), key=lambda cls: cls.__name__):
|
1680
|
+
results = []
|
1681
|
+
for entry in self.select(task=task).evaluations:
|
1682
|
+
results.append(
|
1683
|
+
pg.Dict(
|
1684
|
+
id=entry.id,
|
1685
|
+
experiment=entry,
|
1686
|
+
dir=entry.dir,
|
1687
|
+
metrics=entry.result.metrics if entry.result else None,
|
1688
|
+
)
|
1689
|
+
)
|
1690
|
+
task_results[task.__name__] = results
|
1691
|
+
return task_results
|
1692
|
+
|
1559
1693
|
def save(self, file: str, pivot_field: str | None = None) -> None:
|
1560
1694
|
pg.save(self.html(pivot_field), file, file_format='txt')
|
1695
|
+
if file.endswith('.html'):
|
1696
|
+
json_file = file.replace('.html', '.json')
|
1697
|
+
else:
|
1698
|
+
json_file = os.path.join(file, '.json')
|
1699
|
+
pg.save(self.json(), json_file)
|
1561
1700
|
|
1562
1701
|
@classmethod
|
1563
1702
|
def from_dirs(
|
@@ -1768,3 +1907,43 @@ def monitor_async(
|
|
1768
1907
|
scan_interval=scan_interval,
|
1769
1908
|
refresh_when_stop=refresh_when_stop,
|
1770
1909
|
)
|
1910
|
+
|
1911
|
+
|
1912
|
+
def app_run(target: Evaluable):
|
1913
|
+
"""Runs the target evaluation as an absl app.
|
1914
|
+
|
1915
|
+
Args:
|
1916
|
+
target: An Langfun evaluable object.
|
1917
|
+
"""
|
1918
|
+
flags.DEFINE_string(
|
1919
|
+
'root_dir', None, 'Root directory for running the evaluation.'
|
1920
|
+
)
|
1921
|
+
|
1922
|
+
flags.DEFINE_bool(
|
1923
|
+
'dryrun', False, 'If True, dryrun the experiment instead of running it.'
|
1924
|
+
)
|
1925
|
+
|
1926
|
+
flags.DEFINE_bool(
|
1927
|
+
'debug', False, 'If True, output prompt and response to the console.'
|
1928
|
+
)
|
1929
|
+
|
1930
|
+
flags.DEFINE_bool(
|
1931
|
+
'rerun',
|
1932
|
+
False,
|
1933
|
+
'If True, rerun the experiment even a cached result is found.',
|
1934
|
+
)
|
1935
|
+
|
1936
|
+
FLAGS = flags.FLAGS # pylint: disable=invalid-name
|
1937
|
+
|
1938
|
+
def _main(argv):
|
1939
|
+
if len(argv) > 1:
|
1940
|
+
raise app.UsageError('Too many command-line arguments.')
|
1941
|
+
|
1942
|
+
if FLAGS.root_dir:
|
1943
|
+
target.rebind(root_dir=FLAGS.root_dir, raise_on_no_change=False)
|
1944
|
+
if FLAGS.dryrun:
|
1945
|
+
target.dryrun(debug=FLAGS.debug)
|
1946
|
+
else:
|
1947
|
+
target.run(debug=FLAGS.debug, rerun=FLAGS.rerun)
|
1948
|
+
|
1949
|
+
app.run(_main)
|
langfun/core/eval/base_test.py
CHANGED
@@ -101,7 +101,7 @@ class EvaluationTest(unittest.TestCase):
|
|
101
101
|
self.assertEqual(s.dir, os.path.join(s.root_dir, s.id))
|
102
102
|
self.assertEqual(s.hash, s.clone().hash)
|
103
103
|
# Test persistent hash.
|
104
|
-
self.assertEqual(s.hash, '
|
104
|
+
self.assertEqual(s.hash, 'ae86c703')
|
105
105
|
self.assertEqual(
|
106
106
|
s.hash, s.clone(override={'max_workers': 2, 'lm.timeout': 20}).hash
|
107
107
|
)
|
@@ -194,6 +194,7 @@ class EvaluationTest(unittest.TestCase):
|
|
194
194
|
cache_seed=0,
|
195
195
|
score=1.0,
|
196
196
|
logprobs=None,
|
197
|
+
usage=lf.LMSamplingUsage(387, 24, 411),
|
197
198
|
tags=['lm-response', 'lm-output', 'transformed'],
|
198
199
|
),
|
199
200
|
)
|
@@ -209,7 +210,7 @@ class EvaluationTest(unittest.TestCase):
|
|
209
210
|
s.result,
|
210
211
|
dict(
|
211
212
|
experiment_setup=dict(
|
212
|
-
id='Evaluation@
|
213
|
+
id='Evaluation@0fade07d',
|
213
214
|
dir=s.dir,
|
214
215
|
model='StaticSequence',
|
215
216
|
prompt_template='{{example.question}}',
|
@@ -220,6 +221,14 @@ class EvaluationTest(unittest.TestCase):
|
|
220
221
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
221
222
|
),
|
222
223
|
metrics=dict(total=2, failures=1, failure_rate=0.5),
|
224
|
+
usage=dict(
|
225
|
+
total_prompt_tokens=774,
|
226
|
+
total_completion_tokens=25,
|
227
|
+
num_usages=2,
|
228
|
+
average_prompt_tokens=387,
|
229
|
+
average_completion_tokens=12,
|
230
|
+
average_total_tokens=399,
|
231
|
+
),
|
223
232
|
),
|
224
233
|
)
|
225
234
|
self.assertTrue(
|
@@ -228,13 +237,23 @@ class EvaluationTest(unittest.TestCase):
|
|
228
237
|
os.path.exists(os.path.join(s.dir, base.Evaluation.RESULT_JSON)))
|
229
238
|
self.assertTrue(
|
230
239
|
os.path.exists(os.path.join(s.dir, base.Evaluation.CACHE_JSON)))
|
231
|
-
self.assertTrue(
|
232
|
-
os.path.exists(os.path.join(s.root_dir, base.Evaluation.SUMMARY_HTML))
|
233
|
-
)
|
234
240
|
self.assertTrue(
|
235
241
|
os.path.exists(os.path.join(s.dir, base.Evaluation.INDEX_HTML)))
|
236
242
|
self.assertTrue(
|
237
243
|
os.path.exists(os.path.join(s.dir, base.Evaluation.FAILURES_HTML)))
|
244
|
+
self.assertTrue(
|
245
|
+
os.path.exists(os.path.join(s.root_dir, base.Evaluation.SUMMARY_HTML))
|
246
|
+
)
|
247
|
+
# Check summary JSON.
|
248
|
+
summary_json = os.path.join(
|
249
|
+
s.root_dir, base.Evaluation.SUMMARY_HTML.replace('.html', '.json')
|
250
|
+
)
|
251
|
+
self.assertTrue(os.path.exists(summary_json))
|
252
|
+
summary = pg.load(summary_json, force_dict=True)
|
253
|
+
self.assertIn('Evaluation', summary)
|
254
|
+
self.assertEqual(len(summary['Evaluation']), 1)
|
255
|
+
self.assertIsNotNone(summary['Evaluation'][0].experiment)
|
256
|
+
self.assertIsNotNone(summary['Evaluation'][0].metrics)
|
238
257
|
|
239
258
|
def test_run_wihtout_save(self):
|
240
259
|
lm = fake.StaticSequence([
|
@@ -274,8 +293,11 @@ class EvaluationTest(unittest.TestCase):
|
|
274
293
|
s = eval_set(
|
275
294
|
'run_filter_test', pg.oneof(['call', 'query']),
|
276
295
|
schema_fn=answer_schema(), lm=lm)
|
296
|
+
result = s.run(
|
297
|
+
filter=lambda x: x.method == 'query', dryrun=True, summary=False
|
298
|
+
)
|
277
299
|
self.assertEqual(
|
278
|
-
|
300
|
+
result,
|
279
301
|
{
|
280
302
|
s.children[0].id: None,
|
281
303
|
s.children[1].id: dict(
|
@@ -291,7 +313,8 @@ class EvaluationTest(unittest.TestCase):
|
|
291
313
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
292
314
|
),
|
293
315
|
metrics=dict(total=2, failures=0, failure_rate=0.0),
|
294
|
-
|
316
|
+
usage=s.children[1].result.usage,
|
317
|
+
),
|
295
318
|
},
|
296
319
|
)
|
297
320
|
|
@@ -321,11 +344,10 @@ class EvaluationTest(unittest.TestCase):
|
|
321
344
|
s.children[0].dir, os.path.join(s.root_dir, s.children[0].id)
|
322
345
|
)
|
323
346
|
# Test persistent hash.
|
324
|
-
self.assertEqual(s.hash, '
|
347
|
+
self.assertEqual(s.hash, 'b66a4e88')
|
325
348
|
|
326
349
|
summary = s.run(verbose=True)
|
327
350
|
self.assertEqual(len(summary.evaluations), 2)
|
328
|
-
|
329
351
|
self.assertEqual(
|
330
352
|
s.result,
|
331
353
|
{
|
@@ -342,6 +364,7 @@ class EvaluationTest(unittest.TestCase):
|
|
342
364
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
343
365
|
),
|
344
366
|
metrics=dict(total=2, failures=1, failure_rate=0.5),
|
367
|
+
usage=s.children[0].result.usage,
|
345
368
|
),
|
346
369
|
s.children[1].id: dict(
|
347
370
|
experiment_setup=dict(
|
@@ -356,6 +379,7 @@ class EvaluationTest(unittest.TestCase):
|
|
356
379
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
357
380
|
),
|
358
381
|
metrics=dict(total=2, failures=1, failure_rate=0.5),
|
382
|
+
usage=s.children[1].result.usage,
|
359
383
|
),
|
360
384
|
},
|
361
385
|
)
|
@@ -448,7 +472,7 @@ class SuiteTest(unittest.TestCase):
|
|
448
472
|
lm=lm
|
449
473
|
)
|
450
474
|
# Test for persistent hash.
|
451
|
-
self.assertEqual(s.hash, '
|
475
|
+
self.assertEqual(s.hash, '26e6cc25')
|
452
476
|
s.run()
|
453
477
|
expected = {
|
454
478
|
s.children[0].id: dict(
|
@@ -464,6 +488,7 @@ class SuiteTest(unittest.TestCase):
|
|
464
488
|
use_cache=True, num_queries=2, num_hits=0, num_updates=2
|
465
489
|
),
|
466
490
|
metrics=dict(total=2, failures=1, failure_rate=0.5),
|
491
|
+
usage=s.children[0].result.usage,
|
467
492
|
),
|
468
493
|
s.children[1].id: {
|
469
494
|
s.children[1]
|
@@ -481,6 +506,7 @@ class SuiteTest(unittest.TestCase):
|
|
481
506
|
use_cache=True, num_queries=4, num_hits=1, num_updates=3
|
482
507
|
),
|
483
508
|
metrics=dict(total=2, failures=2, failure_rate=1.0),
|
509
|
+
usage=s.children[1].children[0].result.usage,
|
484
510
|
),
|
485
511
|
s.children[1]
|
486
512
|
.children[2]
|
@@ -500,6 +526,7 @@ class SuiteTest(unittest.TestCase):
|
|
500
526
|
num_updates=2,
|
501
527
|
),
|
502
528
|
metrics=dict(total=2, failures=1, failure_rate=0.5),
|
529
|
+
usage=s.children[1].children[2].result.usage,
|
503
530
|
),
|
504
531
|
},
|
505
532
|
}
|
@@ -671,5 +698,17 @@ class SummaryTest(unittest.TestCase):
|
|
671
698
|
self.assertTrue(pg.io.path_exists(summary_file))
|
672
699
|
|
673
700
|
|
701
|
+
class AppRunTest(unittest.TestCase):
|
702
|
+
|
703
|
+
def test_app_run(self):
|
704
|
+
lm = fake.StaticSequence(['two', 'Solution(final_answer=2)'])
|
705
|
+
try:
|
706
|
+
base.app_run(
|
707
|
+
eval_set('app_run_test', 'query', schema_fn=answer_schema(), lm=lm)
|
708
|
+
)
|
709
|
+
except SystemExit:
|
710
|
+
pass
|
711
|
+
|
712
|
+
|
674
713
|
if __name__ == '__main__':
|
675
714
|
unittest.main()
|
langfun/core/eval/matching.py
CHANGED
@@ -86,9 +86,26 @@ class Matching(base.Evaluation):
|
|
86
86
|
self._matches = []
|
87
87
|
self._mismatches = []
|
88
88
|
|
89
|
-
def
|
89
|
+
def audit_processed(
|
90
|
+
self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
|
91
|
+
) -> None:
|
90
92
|
groundtruth = self.groundtruth(example)
|
91
93
|
answer = self.answer(output, example)
|
94
|
+
|
95
|
+
if dryrun:
|
96
|
+
lf.console.write('')
|
97
|
+
lf.console.write(
|
98
|
+
str(groundtruth),
|
99
|
+
title='GROUDTRUTH',
|
100
|
+
color='green',
|
101
|
+
)
|
102
|
+
lf.console.write('')
|
103
|
+
lf.console.write(
|
104
|
+
str(answer),
|
105
|
+
title='ANSWER',
|
106
|
+
color='blue',
|
107
|
+
)
|
108
|
+
|
92
109
|
if self.match(answer, groundtruth):
|
93
110
|
self._matches.append((example, output, message))
|
94
111
|
else:
|
@@ -155,19 +172,16 @@ class Matching(base.Evaluation):
|
|
155
172
|
super().save(definition, result, report)
|
156
173
|
|
157
174
|
if result:
|
158
|
-
|
159
|
-
def force_dict(v):
|
160
|
-
return pg.object_utils.json_conversion.strip_types(pg.to_json(v))
|
161
|
-
|
162
175
|
# Save matches.
|
163
176
|
pg.save(
|
164
177
|
[
|
165
|
-
|
166
|
-
# within functors which could be deserialized.
|
167
|
-
pg.Dict(input=input, output=force_dict(output))
|
178
|
+
pg.Dict(input=input, output=output)
|
168
179
|
for input, output, _ in self.matches
|
169
180
|
],
|
170
181
|
os.path.join(self.dir, Matching.MATCHES_JSON),
|
182
|
+
# We force the input and output to be dict so it does not depend on
|
183
|
+
# the downstream to serialize.
|
184
|
+
force_dict=True,
|
171
185
|
)
|
172
186
|
|
173
187
|
# Save mismatches.
|
@@ -175,10 +189,13 @@ class Matching(base.Evaluation):
|
|
175
189
|
[
|
176
190
|
# We force the output to be dict as its type may be defined
|
177
191
|
# within functors which could be deserialized.
|
178
|
-
pg.Dict(input=input, output=
|
192
|
+
pg.Dict(input=input, output=output)
|
179
193
|
for input, output, _ in self.mismatches
|
180
194
|
],
|
181
195
|
os.path.join(self.dir, Matching.MISMATCHES_JSON),
|
196
|
+
# We force the input and output to be dict so it does not depend on
|
197
|
+
# the downstream to serialize.
|
198
|
+
force_dict=True,
|
182
199
|
)
|
183
200
|
|
184
201
|
if report:
|