langfun 0.0.2.dev20240327__tar.gz → 0.0.2.dev20240330__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/PKG-INFO +2 -2
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/eval/base.py +39 -18
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/eval/base_test.py +4 -9
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/eval/matching_test.py +2 -4
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/eval/scoring_test.py +1 -2
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/llms/__init__.py +5 -3
- langfun-0.0.2.dev20240327/langfun/core/llms/gemini.py → langfun-0.0.2.dev20240330/langfun/core/llms/google_genai.py +117 -15
- langfun-0.0.2.dev20240327/langfun/core/llms/gemini_test.py → langfun-0.0.2.dev20240330/langfun/core/llms/google_genai_test.py +76 -13
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/schema.py +21 -20
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/schema_test.py +38 -21
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun.egg-info/PKG-INFO +2 -2
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun.egg-info/SOURCES.txt +2 -2
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun.egg-info/requires.txt +1 -1
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/LICENSE +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/README.md +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/__init__.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/__init__.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/coding/__init__.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/coding/python/__init__.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/coding/python/correction.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/coding/python/correction_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/coding/python/errors.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/coding/python/errors_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/coding/python/execution.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/coding/python/execution_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/coding/python/generation.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/coding/python/generation_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/coding/python/parsing.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/coding/python/parsing_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/coding/python/permissions.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/coding/python/permissions_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/component.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/component_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/concurrent.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/concurrent_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/console.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/console_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/eval/__init__.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/eval/matching.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/eval/scoring.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/langfunc.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/langfunc_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/language_model.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/language_model_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/llms/cache/__init__.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/llms/cache/base.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/llms/cache/in_memory.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/llms/cache/in_memory_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/llms/fake.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/llms/fake_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/llms/llama_cpp.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/llms/llama_cpp_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/llms/openai.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/llms/openai_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/memories/__init__.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/memories/conversation_history.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/memories/conversation_history_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/memory.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/message.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/message_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/modalities/__init__.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/modalities/image.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/modalities/image_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/modalities/mime.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/modalities/mime_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/modalities/video.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/modalities/video_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/modality.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/modality_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/natural_language.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/natural_language_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/sampling.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/sampling_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/__init__.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/completion.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/completion_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/description.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/description_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/mapping.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/mapping_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/parsing.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/parsing_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/prompting.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/prompting_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/schema_generation.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/schema_generation_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/scoring.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/scoring_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/subscription.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/subscription_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/template.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/template_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/templates/__init__.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/templates/completion.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/templates/completion_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/templates/conversation.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/templates/conversation_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/templates/demonstration.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/templates/demonstration_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/templates/selfplay.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/templates/selfplay_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/text_formatting.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/text_formatting_test.py +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun.egg-info/dependency_links.txt +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun.egg-info/top_level.txt +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/setup.cfg +0 -0
- {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/setup.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: langfun
|
3
|
-
Version: 0.0.2.
|
3
|
+
Version: 0.0.2.dev20240330
|
4
4
|
Summary: Langfun: Language as Functions.
|
5
5
|
Home-page: https://github.com/google/langfun
|
6
6
|
Author: Langfun Authors
|
@@ -24,7 +24,7 @@ License-File: LICENSE
|
|
24
24
|
Requires-Dist: google-generativeai>=0.3.2
|
25
25
|
Requires-Dist: jinja2>=3.1.2
|
26
26
|
Requires-Dist: openai==0.27.2
|
27
|
-
Requires-Dist: pyglove>=0.4.5.
|
27
|
+
Requires-Dist: pyglove>=0.4.5.dev20240323
|
28
28
|
Requires-Dist: python-magic>=0.4.27
|
29
29
|
Requires-Dist: requests>=2.31.0
|
30
30
|
Requires-Dist: termcolor==1.1.0
|
@@ -27,6 +27,7 @@ import time
|
|
27
27
|
from typing import Annotated, Any, Callable, Iterator, Literal, Optional, Sequence, Type, Union
|
28
28
|
|
29
29
|
import langfun.core as lf
|
30
|
+
import langfun.core.coding as lf_coding
|
30
31
|
from langfun.core.llms.cache import in_memory
|
31
32
|
import langfun.core.structured as lf_structured
|
32
33
|
import pyglove as pg
|
@@ -41,14 +42,6 @@ class Evaluable(lf.Component):
|
|
41
42
|
INDEX_HTML = 'index.html'
|
42
43
|
SUMMARY_HTML = 'summary.html'
|
43
44
|
|
44
|
-
id: Annotated[
|
45
|
-
str,
|
46
|
-
(
|
47
|
-
'The ID of the evaluation, which should be unique across all '
|
48
|
-
'evaluations.'
|
49
|
-
),
|
50
|
-
]
|
51
|
-
|
52
45
|
root_dir: Annotated[
|
53
46
|
str | None,
|
54
47
|
(
|
@@ -61,6 +54,18 @@ class Evaluable(lf.Component):
|
|
61
54
|
int, 'Number of decimals when reporting precision.'
|
62
55
|
] = lf.contextual(default=1)
|
63
56
|
|
57
|
+
@property
|
58
|
+
@abc.abstractmethod
|
59
|
+
def id(self) -> str:
|
60
|
+
"""Returns the ID of the task.
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
Evaluation task ID. Different evaluation task should have their unique
|
64
|
+
task IDs, for each task will be stored in sub-directoreis identified by
|
65
|
+
their IDs. For suites, the ID could be an empty string as they will not
|
66
|
+
produce sub-directories
|
67
|
+
"""
|
68
|
+
|
64
69
|
@property
|
65
70
|
def dir(self) -> str | None:
|
66
71
|
"""Returns the directory for saving results and details."""
|
@@ -578,12 +583,15 @@ class _LeafNode:
|
|
578
583
|
progress_bar: int | None = None
|
579
584
|
|
580
585
|
|
581
|
-
@pg.use_init_args(['
|
586
|
+
@pg.use_init_args(['children'])
|
582
587
|
class Suite(Evaluable):
|
583
588
|
"""Evaluation suite."""
|
584
589
|
|
585
590
|
children: Annotated[list[Evaluable], 'Child evaluation sets or suites.']
|
586
591
|
|
592
|
+
# Use empty ID as suite is just a container of child evaluations.
|
593
|
+
id: str = ''
|
594
|
+
|
587
595
|
__kwargs__: Annotated[
|
588
596
|
Any,
|
589
597
|
(
|
@@ -841,8 +849,10 @@ class Evaluation(Evaluable):
|
|
841
849
|
kwargs['evaluation'] = self
|
842
850
|
return self.schema_fn(**kwargs)
|
843
851
|
|
844
|
-
def _formalize_schema(self, annotation) -> lf_structured.Schema:
|
852
|
+
def _formalize_schema(self, annotation) -> lf_structured.Schema | None:
|
845
853
|
"""Formalizes schema from annotation."""
|
854
|
+
if annotation in (str, None):
|
855
|
+
return None
|
846
856
|
if self.method == 'complete':
|
847
857
|
if not hasattr(annotation, '__schema__'):
|
848
858
|
raise TypeError(
|
@@ -883,6 +893,14 @@ class Evaluation(Evaluable):
|
|
883
893
|
completion_examples.append(ex)
|
884
894
|
return completion_examples
|
885
895
|
|
896
|
+
@property
|
897
|
+
def id(self) -> str:
|
898
|
+
"""Returns the ID of this evaluation."""
|
899
|
+
id_prefix = self.__class__.__name__
|
900
|
+
if not self.is_deterministic:
|
901
|
+
return id_prefix
|
902
|
+
return f'{id_prefix}@{self.hash}'
|
903
|
+
|
886
904
|
@functools.cached_property
|
887
905
|
def children(self) -> list['Evaluation']:
|
888
906
|
"""Returns the trials as child evaluations if this evaluation is a space."""
|
@@ -892,7 +910,6 @@ class Evaluation(Evaluable):
|
|
892
910
|
for i, child in enumerate(pg.iter(self)):
|
893
911
|
child.sym_setparent(self)
|
894
912
|
child.sym_setpath(self.sym_path + f'children[{i}]')
|
895
|
-
child.rebind(id=f'{self.id}@{child.hash}', skip_notification=True)
|
896
913
|
children.append(child)
|
897
914
|
return children
|
898
915
|
|
@@ -1004,7 +1021,11 @@ class Evaluation(Evaluable):
|
|
1004
1021
|
self._reset()
|
1005
1022
|
|
1006
1023
|
def _process(example: Any):
|
1007
|
-
|
1024
|
+
# NOTE(daiyip): set the `input` symbol of the globals to None, so LLM
|
1025
|
+
# generated code with calls to `input` will raise an error, thus not
|
1026
|
+
# blocking the evaluation.
|
1027
|
+
with lf_coding.context(input=None):
|
1028
|
+
return self.process(example, **(self.additional_args or {}))
|
1008
1029
|
|
1009
1030
|
try:
|
1010
1031
|
for example, message, error in lf.concurrent_map(
|
@@ -1015,10 +1036,7 @@ class Evaluation(Evaluable):
|
|
1015
1036
|
status_fn=self._status,
|
1016
1037
|
):
|
1017
1038
|
if error is not None:
|
1018
|
-
|
1019
|
-
self._failures.append((example, str(error)))
|
1020
|
-
except Exception as e: # pylint: disable=broad-exception-caught
|
1021
|
-
self._failures.append((example, str(e)))
|
1039
|
+
self._failures.append((example, str(error)))
|
1022
1040
|
else:
|
1023
1041
|
output = message.text if self.schema is None else message.result
|
1024
1042
|
self.audit(example, output, message)
|
@@ -1521,9 +1539,12 @@ class Summary(pg.Object):
|
|
1521
1539
|
pivot_field = pivot_field or self.pivot_field
|
1522
1540
|
s = io.StringIO()
|
1523
1541
|
s.write('<html><body>')
|
1524
|
-
for task in self.tasks():
|
1542
|
+
for task in sorted(self.tasks(), key=lambda cls: cls.__name__):
|
1543
|
+
table_id = task.__name__.lower()
|
1525
1544
|
s.write('<div>')
|
1526
|
-
s.write(f'<
|
1545
|
+
s.write(f'<a id="{table_id}"')
|
1546
|
+
s.write(f'<h2><a href="#{table_id}">{task.__name__}</a></h2>')
|
1547
|
+
s.write('</a>')
|
1527
1548
|
table = Summary.Table.from_evaluations(
|
1528
1549
|
self.select(task=task).evaluations, pivot_field
|
1529
1550
|
)
|
@@ -70,8 +70,7 @@ def eval_set(
|
|
70
70
|
"""Creates an evaluation object for testing."""
|
71
71
|
tmp_dir = tempfile.gettempdir()
|
72
72
|
return cls(
|
73
|
-
|
74
|
-
root_dir=tmp_dir,
|
73
|
+
root_dir=os.path.join(tmp_dir, eval_id),
|
75
74
|
inputs=base.as_inputs([
|
76
75
|
pg.Dict(question='Compute 1 + 1'),
|
77
76
|
pg.Dict(question='Compute 1 + 2'),
|
@@ -210,7 +209,7 @@ class EvaluationTest(unittest.TestCase):
|
|
210
209
|
s.result,
|
211
210
|
dict(
|
212
211
|
experiment_setup=dict(
|
213
|
-
id='
|
212
|
+
id='Evaluation@17915dc6',
|
214
213
|
dir=s.dir,
|
215
214
|
model='StaticSequence',
|
216
215
|
prompt_template='{{example.question}}',
|
@@ -302,7 +301,6 @@ class EvaluationTest(unittest.TestCase):
|
|
302
301
|
'3',
|
303
302
|
])
|
304
303
|
s = base.Evaluation(
|
305
|
-
id='search_space_test',
|
306
304
|
root_dir=tempfile.gettempdir(),
|
307
305
|
inputs=base.as_inputs([
|
308
306
|
pg.Dict(question='Compute 1 + 1'),
|
@@ -439,7 +437,6 @@ class SuiteTest(unittest.TestCase):
|
|
439
437
|
'3',
|
440
438
|
] * 5)
|
441
439
|
s = base.Suite(
|
442
|
-
'suite_run_test',
|
443
440
|
[
|
444
441
|
eval_set('run_test_1', 'query', schema_fn=answer_schema()),
|
445
442
|
# A suite of search space. Two of the sub-experiments are identical,
|
@@ -548,7 +545,6 @@ class SummaryTest(unittest.TestCase):
|
|
548
545
|
def _eval_set(self, root_dir):
|
549
546
|
return base.Suite(id='select_test', children=[
|
550
547
|
TaskA(
|
551
|
-
id='task_a',
|
552
548
|
inputs=base.as_inputs([
|
553
549
|
pg.Dict(question='Compute 1 + 1'),
|
554
550
|
]),
|
@@ -569,7 +565,6 @@ class SummaryTest(unittest.TestCase):
|
|
569
565
|
max_workers=1,
|
570
566
|
),
|
571
567
|
TaskB(
|
572
|
-
id='task_b',
|
573
568
|
inputs=base.as_inputs([
|
574
569
|
pg.Dict(question='Compute 1 + 1'),
|
575
570
|
]),
|
@@ -650,10 +645,10 @@ class SummaryTest(unittest.TestCase):
|
|
650
645
|
len(base.Summary.from_dirs(root_dir)), 2 * 2 * 2 * 2 + 2 * 1 * 1 * 2
|
651
646
|
)
|
652
647
|
self.assertEqual(
|
653
|
-
len(base.Summary.from_dirs(root_dir, '
|
648
|
+
len(base.Summary.from_dirs(root_dir, 'TaskB')), 2 * 1 * 1 * 2
|
654
649
|
)
|
655
650
|
self.assertEqual(
|
656
|
-
len(base.Summary.from_dirs(root_dir, ('
|
651
|
+
len(base.Summary.from_dirs(root_dir, ('TaskA'))), 2 * 2 * 2 * 2
|
657
652
|
)
|
658
653
|
|
659
654
|
def test_monitor(self):
|
@@ -65,10 +65,8 @@ def eval_set(
|
|
65
65
|
use_cache: bool = True,
|
66
66
|
):
|
67
67
|
"""Creates an evaluation object for testing."""
|
68
|
-
tmp_dir = tempfile.gettempdir()
|
69
68
|
return MyTask(
|
70
|
-
|
71
|
-
root_dir=tmp_dir,
|
69
|
+
root_dir=os.path.join(tempfile.gettempdir(), eval_id),
|
72
70
|
inputs=base.as_inputs([
|
73
71
|
pg.Dict(question='Compute 1 + 1', groundtruth=2),
|
74
72
|
pg.Dict(question='Compute 1 + 2', groundtruth=3),
|
@@ -105,7 +103,7 @@ class MatchingTest(unittest.TestCase):
|
|
105
103
|
s.result,
|
106
104
|
dict(
|
107
105
|
experiment_setup=dict(
|
108
|
-
id='
|
106
|
+
id='MyTask@3d87f97f',
|
109
107
|
dir=s.dir,
|
110
108
|
model='StaticSequence',
|
111
109
|
prompt_template='{{example.question}}',
|
@@ -43,7 +43,6 @@ def constrained_by_upperbound(upper_bound: int):
|
|
43
43
|
|
44
44
|
|
45
45
|
class ConstraintFollowing(scoring.Scoring):
|
46
|
-
id = 'constraint_following'
|
47
46
|
inputs = constrained_by_upperbound(1)
|
48
47
|
prompt = '{{example}}'
|
49
48
|
method = 'query'
|
@@ -82,7 +81,7 @@ class ScoringTest(unittest.TestCase):
|
|
82
81
|
s.result,
|
83
82
|
dict(
|
84
83
|
experiment_setup=dict(
|
85
|
-
id='
|
84
|
+
id='ConstraintFollowing@9e51bb9e',
|
86
85
|
dir=s.dir,
|
87
86
|
model='StaticSequence',
|
88
87
|
prompt_template='{{example}}',
|
@@ -25,9 +25,11 @@ from langfun.core.llms.fake import StaticResponse
|
|
25
25
|
from langfun.core.llms.fake import StaticSequence
|
26
26
|
|
27
27
|
# Gemini models.
|
28
|
-
from langfun.core.llms.
|
29
|
-
from langfun.core.llms.
|
30
|
-
from langfun.core.llms.
|
28
|
+
from langfun.core.llms.google_genai import GenAI
|
29
|
+
from langfun.core.llms.google_genai import GeminiPro
|
30
|
+
from langfun.core.llms.google_genai import GeminiProVision
|
31
|
+
from langfun.core.llms.google_genai import Palm2
|
32
|
+
from langfun.core.llms.google_genai import Palm2_IT
|
31
33
|
|
32
34
|
# OpenAI models.
|
33
35
|
from langfun.core.llms.openai import OpenAI
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Gemini models exposed through Google Generative AI APIs."""
|
15
15
|
|
16
|
+
import abc
|
16
17
|
import functools
|
17
18
|
import os
|
18
19
|
from typing import Annotated, Any, Literal
|
@@ -20,14 +21,20 @@ from typing import Annotated, Any, Literal
|
|
20
21
|
import google.generativeai as genai
|
21
22
|
import langfun.core as lf
|
22
23
|
from langfun.core import modalities as lf_modalities
|
24
|
+
import pyglove as pg
|
23
25
|
|
24
26
|
|
25
27
|
@lf.use_init_args(['model'])
|
26
|
-
class
|
27
|
-
"""Language
|
28
|
+
class GenAI(lf.LanguageModel):
|
29
|
+
"""Language models provided by Google GenAI."""
|
28
30
|
|
29
31
|
model: Annotated[
|
30
|
-
Literal[
|
32
|
+
Literal[
|
33
|
+
'gemini-pro',
|
34
|
+
'gemini-pro-vision',
|
35
|
+
'text-bison-001',
|
36
|
+
'chat-bison-001',
|
37
|
+
],
|
31
38
|
'Model name.',
|
32
39
|
]
|
33
40
|
|
@@ -35,7 +42,8 @@ class Gemini(lf.LanguageModel):
|
|
35
42
|
str | None,
|
36
43
|
(
|
37
44
|
'API key. If None, the key will be read from environment variable '
|
38
|
-
"'GOOGLE_API_KEY'."
|
45
|
+
"'GOOGLE_API_KEY'. "
|
46
|
+
'Get an API key at https://ai.google.dev/tutorials/setup'
|
39
47
|
),
|
40
48
|
] = None
|
41
49
|
|
@@ -43,6 +51,9 @@ class Gemini(lf.LanguageModel):
|
|
43
51
|
False
|
44
52
|
)
|
45
53
|
|
54
|
+
# Set the default max concurrency to 8 workers.
|
55
|
+
max_concurrency = 8
|
56
|
+
|
46
57
|
def _on_bound(self):
|
47
58
|
super()._on_bound()
|
48
59
|
self.__dict__.pop('_api_initialized', None)
|
@@ -67,7 +78,11 @@ class Gemini(lf.LanguageModel):
|
|
67
78
|
return [
|
68
79
|
m.name.lstrip('models/')
|
69
80
|
for m in genai.list_models()
|
70
|
-
if
|
81
|
+
if (
|
82
|
+
'generateContent' in m.supported_generation_methods
|
83
|
+
or 'generateText' in m.supported_generation_methods
|
84
|
+
or 'generateMessage' in m.supported_generation_methods
|
85
|
+
)
|
71
86
|
]
|
72
87
|
|
73
88
|
@property
|
@@ -80,11 +95,6 @@ class Gemini(lf.LanguageModel):
|
|
80
95
|
"""Returns a string to identify the resource for rate control."""
|
81
96
|
return self.model_id
|
82
97
|
|
83
|
-
@property
|
84
|
-
def max_concurrency(self) -> int:
|
85
|
-
"""Max concurrent requests."""
|
86
|
-
return 8
|
87
|
-
|
88
98
|
def _generation_config(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
89
99
|
"""Creates generation config from langfun sampling options."""
|
90
100
|
return genai.GenerationConfig(
|
@@ -117,7 +127,7 @@ class Gemini(lf.LanguageModel):
|
|
117
127
|
return chunks
|
118
128
|
|
119
129
|
def _response_to_result(
|
120
|
-
self, response: genai.types.GenerateContentResponse
|
130
|
+
self, response: genai.types.GenerateContentResponse | pg.Dict
|
121
131
|
) -> lf.LMSamplingResult:
|
122
132
|
"""Parses generative response into message."""
|
123
133
|
samples = []
|
@@ -149,17 +159,97 @@ class Gemini(lf.LanguageModel):
|
|
149
159
|
return self._response_to_result(response)
|
150
160
|
|
151
161
|
|
162
|
+
class _LegacyGenerativeModel(pg.Object):
|
163
|
+
"""Base for legacy GenAI generative model."""
|
164
|
+
|
165
|
+
model: str
|
166
|
+
|
167
|
+
def generate_content(
|
168
|
+
self,
|
169
|
+
input_content: list[str | genai.types.BlobDict],
|
170
|
+
generation_config: genai.GenerationConfig,
|
171
|
+
) -> pg.Dict:
|
172
|
+
"""Generate content."""
|
173
|
+
segments = []
|
174
|
+
for s in input_content:
|
175
|
+
if not isinstance(s, str):
|
176
|
+
raise ValueError(f'Unsupported modality: {s!r}')
|
177
|
+
segments.append(s)
|
178
|
+
return self.generate(' '.join(segments), generation_config)
|
179
|
+
|
180
|
+
@abc.abstractmethod
|
181
|
+
def generate(
|
182
|
+
self, prompt: str, generation_config: genai.GenerationConfig) -> pg.Dict:
|
183
|
+
"""Generate response based on prompt."""
|
184
|
+
|
185
|
+
|
186
|
+
class _LegacyCompletionModel(_LegacyGenerativeModel):
|
187
|
+
"""Legacy GenAI completion model."""
|
188
|
+
|
189
|
+
def generate(
|
190
|
+
self, prompt: str, generation_config: genai.GenerationConfig
|
191
|
+
) -> pg.Dict:
|
192
|
+
completion: genai.types.Completion = genai.generate_text(
|
193
|
+
model=f'models/{self.model}',
|
194
|
+
prompt=prompt,
|
195
|
+
temperature=generation_config.temperature,
|
196
|
+
top_k=generation_config.top_k,
|
197
|
+
top_p=generation_config.top_p,
|
198
|
+
candidate_count=generation_config.candidate_count,
|
199
|
+
max_output_tokens=generation_config.max_output_tokens,
|
200
|
+
stop_sequences=generation_config.stop_sequences,
|
201
|
+
)
|
202
|
+
return pg.Dict(
|
203
|
+
candidates=[
|
204
|
+
pg.Dict(content=pg.Dict(parts=[pg.Dict(text=c['output'])]))
|
205
|
+
for c in completion.candidates
|
206
|
+
]
|
207
|
+
)
|
208
|
+
|
209
|
+
|
210
|
+
class _LegacyChatModel(_LegacyGenerativeModel):
|
211
|
+
"""Legacy GenAI chat model."""
|
212
|
+
|
213
|
+
def generate(
|
214
|
+
self, prompt: str, generation_config: genai.GenerationConfig
|
215
|
+
) -> pg.Dict:
|
216
|
+
response: genai.types.ChatResponse = genai.chat(
|
217
|
+
model=f'models/{self.model}',
|
218
|
+
messages=prompt,
|
219
|
+
temperature=generation_config.temperature,
|
220
|
+
top_k=generation_config.top_k,
|
221
|
+
top_p=generation_config.top_p,
|
222
|
+
candidate_count=generation_config.candidate_count,
|
223
|
+
)
|
224
|
+
return pg.Dict(
|
225
|
+
candidates=[
|
226
|
+
pg.Dict(content=pg.Dict(parts=[pg.Dict(text=c['content'])]))
|
227
|
+
for c in response.candidates
|
228
|
+
]
|
229
|
+
)
|
230
|
+
|
231
|
+
|
152
232
|
class _ModelHub:
|
153
233
|
"""Google Generative AI model hub."""
|
154
234
|
|
155
235
|
def __init__(self):
|
156
236
|
self._model_cache = {}
|
157
237
|
|
158
|
-
def get(
|
238
|
+
def get(
|
239
|
+
self, model_name: str
|
240
|
+
) -> genai.GenerativeModel | _LegacyGenerativeModel:
|
159
241
|
"""Gets a generative model by model id."""
|
160
242
|
model = self._model_cache.get(model_name, None)
|
161
243
|
if model is None:
|
162
|
-
|
244
|
+
model_info = genai.get_model(f'models/{model_name}')
|
245
|
+
if 'generateContent' in model_info.supported_generation_methods:
|
246
|
+
model = genai.GenerativeModel(model_name)
|
247
|
+
elif 'generateText' in model_info.supported_generation_methods:
|
248
|
+
model = _LegacyCompletionModel(model_name)
|
249
|
+
elif 'generateMessage' in model_info.supported_generation_methods:
|
250
|
+
model = _LegacyChatModel(model_name)
|
251
|
+
else:
|
252
|
+
raise ValueError(f'Unsupported model: {model_name!r}')
|
163
253
|
self._model_cache[model_name] = model
|
164
254
|
return model
|
165
255
|
|
@@ -172,14 +262,26 @@ _GOOGLE_GENAI_MODEL_HUB = _ModelHub()
|
|
172
262
|
#
|
173
263
|
|
174
264
|
|
175
|
-
class GeminiPro(
|
265
|
+
class GeminiPro(GenAI):
|
176
266
|
"""Gemini Pro model."""
|
177
267
|
|
178
268
|
model = 'gemini-pro'
|
179
269
|
|
180
270
|
|
181
|
-
class GeminiProVision(
|
271
|
+
class GeminiProVision(GenAI):
|
182
272
|
"""Gemini Pro vision model."""
|
183
273
|
|
184
274
|
model = 'gemini-pro-vision'
|
185
275
|
multimodal = True
|
276
|
+
|
277
|
+
|
278
|
+
class Palm2(GenAI):
|
279
|
+
"""PaLM2 model."""
|
280
|
+
|
281
|
+
model = 'text-bison-001'
|
282
|
+
|
283
|
+
|
284
|
+
class Palm2_IT(GenAI): # pylint: disable=invalid-name
|
285
|
+
"""PaLM2 instruction-tuned model."""
|
286
|
+
|
287
|
+
model = 'chat-bison-001'
|
@@ -20,7 +20,7 @@ from unittest import mock
|
|
20
20
|
from google import generativeai as genai
|
21
21
|
import langfun.core as lf
|
22
22
|
from langfun.core import modalities as lf_modalities
|
23
|
-
from langfun.core.llms import
|
23
|
+
from langfun.core.llms import google_genai
|
24
24
|
import pyglove as pg
|
25
25
|
|
26
26
|
|
@@ -36,6 +36,29 @@ example_image = (
|
|
36
36
|
)
|
37
37
|
|
38
38
|
|
39
|
+
def mock_get_model(model_name, *args, **kwargs):
|
40
|
+
del args, kwargs
|
41
|
+
if 'gemini' in model_name:
|
42
|
+
method = 'generateContent'
|
43
|
+
elif 'chat' in model_name:
|
44
|
+
method = 'generateMessage'
|
45
|
+
else:
|
46
|
+
method = 'generateText'
|
47
|
+
return pg.Dict(supported_generation_methods=[method])
|
48
|
+
|
49
|
+
|
50
|
+
def mock_generate_text(*, model, prompt, **kwargs):
|
51
|
+
return pg.Dict(
|
52
|
+
candidates=[pg.Dict(output=f'{prompt} to {model} with {kwargs}')]
|
53
|
+
)
|
54
|
+
|
55
|
+
|
56
|
+
def mock_chat(*, model, messages, **kwargs):
|
57
|
+
return pg.Dict(
|
58
|
+
candidates=[pg.Dict(content=f'{messages} to {model} with {kwargs}')]
|
59
|
+
)
|
60
|
+
|
61
|
+
|
39
62
|
def mock_generate_content(content, generation_config, **kwargs):
|
40
63
|
del kwargs
|
41
64
|
c = generation_config
|
@@ -68,12 +91,12 @@ def mock_generate_content(content, generation_config, **kwargs):
|
|
68
91
|
)
|
69
92
|
|
70
93
|
|
71
|
-
class
|
72
|
-
"""Tests for
|
94
|
+
class GenAITest(unittest.TestCase):
|
95
|
+
"""Tests for Google GenAI model."""
|
73
96
|
|
74
97
|
def test_content_from_message_text_only(self):
|
75
98
|
text = 'This is a beautiful day'
|
76
|
-
model =
|
99
|
+
model = google_genai.GeminiPro()
|
77
100
|
chunks = model._content_from_message(lf.UserMessage(text))
|
78
101
|
self.assertEqual(chunks, [text])
|
79
102
|
|
@@ -85,9 +108,9 @@ class GeminiTest(unittest.TestCase):
|
|
85
108
|
|
86
109
|
# Non-multimodal model.
|
87
110
|
with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
|
88
|
-
|
111
|
+
google_genai.GeminiPro()._content_from_message(message)
|
89
112
|
|
90
|
-
model =
|
113
|
+
model = google_genai.GeminiProVision()
|
91
114
|
chunks = model._content_from_message(message)
|
92
115
|
self.maxDiff = None
|
93
116
|
self.assertEqual(
|
@@ -118,7 +141,7 @@ class GeminiTest(unittest.TestCase):
|
|
118
141
|
],
|
119
142
|
),
|
120
143
|
)
|
121
|
-
model =
|
144
|
+
model = google_genai.GeminiProVision()
|
122
145
|
result = model._response_to_result(response)
|
123
146
|
self.assertEqual(
|
124
147
|
result,
|
@@ -129,26 +152,28 @@ class GeminiTest(unittest.TestCase):
|
|
129
152
|
)
|
130
153
|
|
131
154
|
def test_model_hub(self):
|
132
|
-
model =
|
155
|
+
model = google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro')
|
133
156
|
self.assertIsNotNone(model)
|
134
|
-
self.assertIs(
|
157
|
+
self.assertIs(google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro'), model)
|
135
158
|
|
136
159
|
def test_api_key_check(self):
|
137
160
|
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
138
|
-
_ =
|
161
|
+
_ = google_genai.GeminiPro()._api_initialized
|
139
162
|
|
140
|
-
self.assertTrue(
|
163
|
+
self.assertTrue(google_genai.GeminiPro(api_key='abc')._api_initialized)
|
141
164
|
os.environ['GOOGLE_API_KEY'] = 'abc'
|
142
|
-
self.assertTrue(
|
165
|
+
self.assertTrue(google_genai.GeminiPro()._api_initialized)
|
143
166
|
del os.environ['GOOGLE_API_KEY']
|
144
167
|
|
145
168
|
def test_call(self):
|
146
169
|
with mock.patch(
|
147
170
|
'google.generativeai.generative_models.GenerativeModel.generate_content'
|
148
171
|
) as mock_generate:
|
172
|
+
orig_get_model = genai.get_model
|
173
|
+
genai.get_model = mock_get_model
|
149
174
|
mock_generate.side_effect = mock_generate_content
|
150
175
|
|
151
|
-
lm =
|
176
|
+
lm = google_genai.GeminiPro(api_key='test_key')
|
152
177
|
self.maxDiff = None
|
153
178
|
self.assertEqual(
|
154
179
|
lm('hello', temperature=2.0, top_k=20).text,
|
@@ -157,6 +182,44 @@ class GeminiTest(unittest.TestCase):
|
|
157
182
|
'top_p=None, top_k=20, max_tokens=1024, stop=None.'
|
158
183
|
),
|
159
184
|
)
|
185
|
+
genai.get_model = orig_get_model
|
186
|
+
|
187
|
+
def test_call_with_legacy_completion_model(self):
|
188
|
+
orig_get_model = genai.get_model
|
189
|
+
genai.get_model = mock_get_model
|
190
|
+
orig_generate_text = genai.generate_text
|
191
|
+
genai.generate_text = mock_generate_text
|
192
|
+
|
193
|
+
lm = google_genai.Palm2(api_key='test_key')
|
194
|
+
self.maxDiff = None
|
195
|
+
self.assertEqual(
|
196
|
+
lm('hello', temperature=2.0, top_k=20).text,
|
197
|
+
(
|
198
|
+
"hello to models/text-bison-001 with {'temperature': 2.0, "
|
199
|
+
"'top_k': 20, 'top_p': None, 'candidate_count': 1, "
|
200
|
+
"'max_output_tokens': 1024, 'stop_sequences': None}"
|
201
|
+
),
|
202
|
+
)
|
203
|
+
genai.get_model = orig_get_model
|
204
|
+
genai.generate_text = orig_generate_text
|
205
|
+
|
206
|
+
def test_call_with_legacy_chat_model(self):
|
207
|
+
orig_get_model = genai.get_model
|
208
|
+
genai.get_model = mock_get_model
|
209
|
+
orig_chat = genai.chat
|
210
|
+
genai.chat = mock_chat
|
211
|
+
|
212
|
+
lm = google_genai.Palm2_IT(api_key='test_key')
|
213
|
+
self.maxDiff = None
|
214
|
+
self.assertEqual(
|
215
|
+
lm('hello', temperature=2.0, top_k=20).text,
|
216
|
+
(
|
217
|
+
"hello to models/chat-bison-001 with {'temperature': 2.0, "
|
218
|
+
"'top_k': 20, 'top_p': None, 'candidate_count': 1}"
|
219
|
+
),
|
220
|
+
)
|
221
|
+
genai.get_model = orig_get_model
|
222
|
+
genai.chat = orig_chat
|
160
223
|
|
161
224
|
|
162
225
|
if __name__ == '__main__':
|