langfun 0.0.2.dev20240109__tar.gz → 0.0.2.dev20240111__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/PKG-INFO +3 -2
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/python/execution.py +47 -35
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/python/execution_test.py +15 -6
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/python/generation_test.py +2 -1
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/langfunc.py +6 -4
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/langfunc_test.py +2 -2
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/llms/__init__.py +5 -0
- langfun-0.0.2.dev20240111/langfun/core/llms/gemini.py +190 -0
- langfun-0.0.2.dev20240111/langfun/core/llms/gemini_test.py +163 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/message.py +1 -1
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/message_test.py +3 -2
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/modality.py +5 -10
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/modality_test.py +29 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/structured/completion.py +15 -22
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/structured/completion_test.py +34 -23
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/structured/description.py +3 -5
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/structured/description_test.py +18 -16
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/structured/mapping.py +16 -8
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/structured/parsing.py +5 -18
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/structured/parsing_test.py +2 -3
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/structured/prompting.py +23 -23
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/structured/prompting_test.py +212 -56
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun.egg-info/PKG-INFO +3 -2
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun.egg-info/SOURCES.txt +2 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun.egg-info/requires.txt +2 -1
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/LICENSE +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/README.md +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/__init__.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/__init__.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/__init__.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/python/__init__.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/python/correction.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/python/correction_test.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/python/errors.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/python/errors_test.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/python/generation.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/python/parsing.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/python/parsing_test.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/python/permissions.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/python/permissions_test.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/component.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/component_test.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/concurrent.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/concurrent_test.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/console.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/console_test.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/eval/__init__.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/eval/base.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/eval/base_test.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/eval/matching.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/eval/matching_test.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/eval/scoring.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/eval/scoring_test.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/language_model.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/language_model_test.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/llms/cache/__init__.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/llms/cache/base.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/llms/cache/in_memory.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/llms/cache/in_memory_test.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/llms/fake.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/llms/fake_test.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/llms/llama_cpp.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/llms/llama_cpp_test.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/llms/openai.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/llms/openai_test.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/memories/__init__.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/memories/conversation_history.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/memories/conversation_history_test.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/memory.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/modalities/__init__.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/modalities/image.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/modalities/image_test.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/natural_language.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/natural_language_test.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/sampling.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/sampling_test.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/structured/__init__.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/structured/mapping_test.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/structured/schema.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/structured/schema_test.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/subscription.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/subscription_test.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/template.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/template_test.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/templates/__init__.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/templates/completion.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/templates/completion_test.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/templates/conversation.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/templates/conversation_test.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/templates/demonstration.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/templates/demonstration_test.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/templates/selfplay.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/templates/selfplay_test.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/text_formatting.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/text_formatting_test.py +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun.egg-info/dependency_links.txt +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun.egg-info/top_level.txt +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/setup.cfg +0 -0
- {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/setup.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: langfun
|
3
|
-
Version: 0.0.2.
|
3
|
+
Version: 0.0.2.dev20240111
|
4
4
|
Summary: Langfun: Language as Functions.
|
5
5
|
Home-page: https://github.com/google/langfun
|
6
6
|
Author: Langfun Authors
|
@@ -21,9 +21,10 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
21
21
|
Classifier: Topic :: Software Development :: Libraries
|
22
22
|
Description-Content-Type: text/markdown
|
23
23
|
License-File: LICENSE
|
24
|
+
Requires-Dist: google-generativeai>=0.3.2
|
24
25
|
Requires-Dist: jinja2>=3.1.2
|
25
26
|
Requires-Dist: openai==0.27.2
|
26
|
-
Requires-Dist: pyglove>=0.4.5.
|
27
|
+
Requires-Dist: pyglove>=0.4.5.dev20240109
|
27
28
|
Requires-Dist: requests>=2.31.0
|
28
29
|
Requires-Dist: termcolor==1.1.0
|
29
30
|
Requires-Dist: tqdm>=4.64.1
|
{langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/python/execution.py
RENAMED
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
import ast
|
17
17
|
import contextlib
|
18
|
+
import io
|
18
19
|
import multiprocessing
|
19
20
|
from typing import Any, Callable
|
20
21
|
|
@@ -24,6 +25,9 @@ from langfun.core.coding.python import permissions
|
|
24
25
|
import pyglove as pg
|
25
26
|
|
26
27
|
|
28
|
+
# Key in returned dict that captures stdout.
|
29
|
+
STDOUT_KEY = '__stdout__'
|
30
|
+
|
27
31
|
# Key in the returned dict that represents the final result.
|
28
32
|
RESULT_KEY = '__result__'
|
29
33
|
_TLS_CODE_RUN_CONTEXT = '__code_run_context__'
|
@@ -86,45 +90,51 @@ def evaluate(
|
|
86
90
|
code, code_block = parsing.PythonCodeParser().parse(code, permission)
|
87
91
|
global_vars, orig_global_vars = ctx, ctx.copy()
|
88
92
|
|
89
|
-
|
90
|
-
|
91
|
-
|
93
|
+
# No code.
|
94
|
+
if not code_block.body:
|
95
|
+
return {} if outputs_intermediate else None
|
92
96
|
|
93
|
-
|
94
|
-
|
95
|
-
|
97
|
+
stdout = io.StringIO()
|
98
|
+
with contextlib.redirect_stdout(stdout):
|
99
|
+
if hasattr(code_block.body[-1], 'value'):
|
100
|
+
last_expr = code_block.body.pop() # pytype: disable=attribute-error
|
101
|
+
result_vars = [RESULT_KEY]
|
96
102
|
|
97
|
-
|
103
|
+
if isinstance(last_expr, ast.Assign):
|
104
|
+
for name_node in last_expr.targets:
|
105
|
+
result_vars.append(name_node.id)
|
98
106
|
|
99
|
-
|
100
|
-
# Execute the lines before the last expression.
|
101
|
-
# NOTE(daiyip): Only a `globals` dict is specified here, which will also
|
102
|
-
# be used to output intermediate values by `exec`. We do not specify a
|
103
|
-
# separate `locals` dict here, for - "If exec gets two separate objects as
|
104
|
-
# globals and locals, the code will be executed as if it were embedded in
|
105
|
-
# a class definition." - as the Python document explains. The outcome is
|
106
|
-
# that new functions defined in the code block could not be called by
|
107
|
-
# other newly defined functions.
|
108
|
-
# Refer to https://stackoverflow.com/questions/
|
109
|
-
# 73940751/why-cant-i-call-a-function-from-another-function-using-exec
|
110
|
-
# for more details.
|
111
|
-
exec(compile(code_block, '', mode='exec'), global_vars) # pylint: disable=exec-used
|
112
|
-
|
113
|
-
# Evaluate the last expression.
|
114
|
-
result = eval( # pylint: disable=eval-used
|
115
|
-
compile(last_expr, '', mode='eval'), global_vars
|
116
|
-
)
|
117
|
-
except Exception as e:
|
118
|
-
raise errors.CodeError(code, e) from e
|
107
|
+
last_expr = ast.Expression(last_expr.value) # pytype: disable=attribute-error
|
119
108
|
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
109
|
+
try:
|
110
|
+
# Execute the lines before the last expression.
|
111
|
+
# NOTE(daiyip): Only a `globals` dict is specified here, which will also
|
112
|
+
# be used to output intermediate values by `exec`. We do not specify a
|
113
|
+
# separate `locals` dict here, for - "If exec gets two separate objects
|
114
|
+
# as globals and locals, the code will be executed as if it were
|
115
|
+
# embedded in a class definition." - as the Python document explains.
|
116
|
+
# The outcome is that new functions defined in the code block could not
|
117
|
+
# be called by other newly defined functions.
|
118
|
+
# Refer to https://stackoverflow.com/questions/
|
119
|
+
# 73940751/why-cant-i-call-a-function-from-another-function-using-exec
|
120
|
+
# for more details.
|
121
|
+
exec(compile(code_block, '', mode='exec'), global_vars) # pylint: disable=exec-used
|
122
|
+
|
123
|
+
# Evaluate the last expression.
|
124
|
+
result = eval( # pylint: disable=eval-used
|
125
|
+
compile(last_expr, '', mode='eval'), global_vars
|
126
|
+
)
|
127
|
+
except Exception as e:
|
128
|
+
raise errors.CodeError(code, e) from e
|
129
|
+
|
130
|
+
for result_var in result_vars:
|
131
|
+
global_vars[result_var] = result
|
132
|
+
else:
|
133
|
+
try:
|
134
|
+
exec(compile(code_block, '', mode='exec'), global_vars) # pylint: disable=exec-used
|
135
|
+
except Exception as e:
|
136
|
+
raise errors.CodeError(code, e) from e
|
137
|
+
global_vars[RESULT_KEY] = list(global_vars.values())[-1]
|
128
138
|
|
129
139
|
if outputs_intermediate:
|
130
140
|
outputs = {}
|
@@ -133,6 +143,8 @@ def evaluate(
|
|
133
143
|
continue
|
134
144
|
if k not in orig_global_vars or v is not orig_global_vars[k]:
|
135
145
|
outputs[k] = v
|
146
|
+
# Add stdout to outputs.
|
147
|
+
outputs[STDOUT_KEY] = stdout.getvalue()
|
136
148
|
return outputs
|
137
149
|
return global_vars[RESULT_KEY]
|
138
150
|
|
{langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/python/execution_test.py
RENAMED
@@ -36,7 +36,7 @@ class EvaluateTest(unittest.TestCase):
|
|
36
36
|
global_vars=dict(z=3),
|
37
37
|
outputs_intermediate=True,
|
38
38
|
),
|
39
|
-
dict(p=2 + 0 + 3, __result__=2 + 0 + 3),
|
39
|
+
dict(p=2 + 0 + 3, __result__=2 + 0 + 3, __stdout__=''),
|
40
40
|
)
|
41
41
|
|
42
42
|
def test_basics(self):
|
@@ -45,17 +45,19 @@ class EvaluateTest(unittest.TestCase):
|
|
45
45
|
"""
|
46
46
|
x = 1
|
47
47
|
y = x + 1
|
48
|
+
print(y)
|
48
49
|
z = x + y
|
49
50
|
""",
|
50
51
|
outputs_intermediate=True,
|
51
52
|
),
|
52
|
-
dict(x=1, y=2, z=3, __result__=3),
|
53
|
+
dict(x=1, y=2, z=3, __result__=3, __stdout__='2\n'),
|
53
54
|
)
|
54
55
|
self.assertEqual(
|
55
56
|
execution.evaluate(
|
56
57
|
"""
|
57
58
|
x = 1
|
58
59
|
y = x + 1
|
60
|
+
print(y)
|
59
61
|
z = x + y
|
60
62
|
""",
|
61
63
|
),
|
@@ -75,9 +77,10 @@ class EvaluateTest(unittest.TestCase):
|
|
75
77
|
global_vars=dict(pg=pg),
|
76
78
|
outputs_intermediate=True,
|
77
79
|
)
|
78
|
-
self.assertEqual(list(ret.keys()), ['A', '__result__'])
|
80
|
+
self.assertEqual(list(ret.keys()), ['A', '__result__', '__stdout__'])
|
79
81
|
self.assertTrue(issubclass(ret['A'], pg.Object))
|
80
82
|
self.assertIs(ret['__result__'], ret['A'])
|
83
|
+
self.assertEqual(ret['__stdout__'], '')
|
81
84
|
|
82
85
|
def test_function_def(self):
|
83
86
|
ret = execution.evaluate(
|
@@ -91,7 +94,9 @@ class EvaluateTest(unittest.TestCase):
|
|
91
94
|
permission=permissions.CodePermission.ALL,
|
92
95
|
outputs_intermediate=True,
|
93
96
|
)
|
94
|
-
self.assertEqual(
|
97
|
+
self.assertEqual(
|
98
|
+
list(ret.keys()), ['foo', 'bar', '__result__', '__stdout__']
|
99
|
+
)
|
95
100
|
self.assertTrue(inspect.isfunction(ret['foo']))
|
96
101
|
self.assertTrue(inspect.isfunction(ret['bar']))
|
97
102
|
self.assertIs(ret['__result__'], ret['bar'])
|
@@ -110,7 +115,9 @@ class EvaluateTest(unittest.TestCase):
|
|
110
115
|
permission=permissions.CodePermission.ALL,
|
111
116
|
outputs_intermediate=True,
|
112
117
|
)
|
113
|
-
self.assertEqual(
|
118
|
+
self.assertEqual(
|
119
|
+
list(ret.keys()), ['foo', 'bar', '__result__', '__stdout__']
|
120
|
+
)
|
114
121
|
self.assertEqual(ret['__result__'], 3)
|
115
122
|
|
116
123
|
def test_complex(self):
|
@@ -131,7 +138,9 @@ class EvaluateTest(unittest.TestCase):
|
|
131
138
|
global_vars=dict(pg=pg),
|
132
139
|
outputs_intermediate=True,
|
133
140
|
)
|
134
|
-
self.assertEqual(
|
141
|
+
self.assertEqual(
|
142
|
+
list(ret.keys()), ['A', 'foo', 'k', '__result__', '__stdout__']
|
143
|
+
)
|
135
144
|
self.assertTrue(issubclass(ret['A'], pg.Object))
|
136
145
|
self.assertTrue(inspect.isfunction(ret['foo']))
|
137
146
|
self.assertIsInstance(ret['k'], pg.Object)
|
@@ -41,9 +41,10 @@ class PythonCodeTest(unittest.TestCase):
|
|
41
41
|
generation.PythonCode("""
|
42
42
|
x = 1
|
43
43
|
y = x + 1
|
44
|
+
print(y)
|
44
45
|
z = x + y
|
45
46
|
""").eval(),
|
46
|
-
dict(x=1, y=2, z=3, __result__=3),
|
47
|
+
dict(x=1, y=2, z=3, __result__=3, __stdout__='2\n'),
|
47
48
|
)
|
48
49
|
|
49
50
|
def test_call(self):
|
@@ -326,18 +326,20 @@ class LangFunc(
|
|
326
326
|
return lm_output
|
327
327
|
|
328
328
|
@classmethod
|
329
|
-
def from_value(
|
329
|
+
def from_value(
|
330
|
+
cls, value: Union[str, template_lib.Template], **kwargs
|
331
|
+
) -> 'LangFunc':
|
330
332
|
"""Create a LangFunc object from a string or template."""
|
331
333
|
if isinstance(value, LangFunc):
|
332
334
|
return value
|
333
335
|
if isinstance(value, template_lib.Template):
|
334
|
-
lfun = LangFunc(value.template_str)
|
336
|
+
lfun = LangFunc(value.template_str, **kwargs)
|
335
337
|
# So lfun could acccess all attributes from value.
|
336
338
|
lfun.sym_setparent(value)
|
337
339
|
return lfun
|
338
340
|
if isinstance(value, str):
|
339
|
-
return LangFunc(template_str=value)
|
340
|
-
|
341
|
+
return LangFunc(template_str=value, **kwargs)
|
342
|
+
return LangFunc('{{input}}', input=value, **kwargs)
|
341
343
|
|
342
344
|
|
343
345
|
# Register converter from str to LangFunc, therefore we can always
|
@@ -64,8 +64,8 @@ class BasicTest(unittest.TestCase):
|
|
64
64
|
l3 = LangFunc.from_value(c.l)
|
65
65
|
self.assertEqual(l3.render(), '1 + 2')
|
66
66
|
|
67
|
-
|
68
|
-
|
67
|
+
l4 = LangFunc.from_value(1)
|
68
|
+
self.assertEqual(l4.render(), '1')
|
69
69
|
|
70
70
|
|
71
71
|
class LangFuncCallTest(unittest.TestCase):
|
@@ -23,6 +23,11 @@ from langfun.core.llms.fake import StaticMapping
|
|
23
23
|
from langfun.core.llms.fake import StaticResponse
|
24
24
|
from langfun.core.llms.fake import StaticSequence
|
25
25
|
|
26
|
+
# Gemini models.
|
27
|
+
from langfun.core.llms.gemini import Gemini
|
28
|
+
from langfun.core.llms.gemini import GeminiPro
|
29
|
+
from langfun.core.llms.gemini import GeminiProVision
|
30
|
+
|
26
31
|
# OpenAI models.
|
27
32
|
from langfun.core.llms.openai import OpenAI
|
28
33
|
|
@@ -0,0 +1,190 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Gemini models exposed through Google Generative AI APIs."""
|
15
|
+
|
16
|
+
import functools
|
17
|
+
import os
|
18
|
+
from typing import Annotated, Any, Literal
|
19
|
+
|
20
|
+
import google.generativeai as genai
|
21
|
+
import langfun.core as lf
|
22
|
+
from langfun.core import modalities as lf_modalities
|
23
|
+
|
24
|
+
|
25
|
+
@lf.use_init_args(['model'])
|
26
|
+
class Gemini(lf.LanguageModel):
|
27
|
+
"""Language model served on VertexAI."""
|
28
|
+
|
29
|
+
model: Annotated[
|
30
|
+
Literal['gemini-pro', 'gemini-pro-vision', ''],
|
31
|
+
'Model name.',
|
32
|
+
]
|
33
|
+
|
34
|
+
api_key: Annotated[
|
35
|
+
str | None,
|
36
|
+
(
|
37
|
+
'API key. If None, the key will be read from environment variable '
|
38
|
+
"'GOOGLE_API_KEY'."
|
39
|
+
),
|
40
|
+
] = None
|
41
|
+
|
42
|
+
multimodal: Annotated[bool, 'Whether this model has multimodal support.'] = (
|
43
|
+
False
|
44
|
+
)
|
45
|
+
|
46
|
+
def _on_bound(self):
|
47
|
+
super()._on_bound()
|
48
|
+
self.__dict__.pop('_api_initialized', None)
|
49
|
+
|
50
|
+
@functools.cached_property
|
51
|
+
def _api_initialized(self):
|
52
|
+
api_key = self.api_key or os.environ.get('GOOGLE_API_KEY', None)
|
53
|
+
if not api_key:
|
54
|
+
raise ValueError(
|
55
|
+
'Please specify `api_key` during `__init__` or set environment '
|
56
|
+
'variable `GOOGLE_API_KEY` with your Google Cloud API key. '
|
57
|
+
'Check out '
|
58
|
+
'https://cloud.google.com/api-keys/docs/create-manage-api-keys '
|
59
|
+
'for more details.'
|
60
|
+
)
|
61
|
+
genai.configure(api_key=api_key)
|
62
|
+
return True
|
63
|
+
|
64
|
+
@classmethod
|
65
|
+
def dir(cls) -> list[str]:
|
66
|
+
"""Lists generative models."""
|
67
|
+
return [
|
68
|
+
m.name.lstrip('models/')
|
69
|
+
for m in genai.list_models()
|
70
|
+
if 'generateContent' in m.supported_generation_methods
|
71
|
+
]
|
72
|
+
|
73
|
+
@property
|
74
|
+
def model_id(self) -> str:
|
75
|
+
"""Returns a string to identify the model."""
|
76
|
+
return self.model
|
77
|
+
|
78
|
+
@property
|
79
|
+
def resource_id(self) -> str:
|
80
|
+
"""Returns a string to identify the resource for rate control."""
|
81
|
+
return self.model_id
|
82
|
+
|
83
|
+
@property
|
84
|
+
def max_concurrency(self) -> int:
|
85
|
+
"""Max concurrent requests."""
|
86
|
+
return 8
|
87
|
+
|
88
|
+
def _generation_config(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
89
|
+
"""Creates generation config from langfun sampling options."""
|
90
|
+
return genai.GenerationConfig(
|
91
|
+
candidate_count=options.n,
|
92
|
+
temperature=options.temperature,
|
93
|
+
top_p=options.top_p,
|
94
|
+
top_k=options.top_k,
|
95
|
+
max_output_tokens=options.max_tokens,
|
96
|
+
stop_sequences=options.stop,
|
97
|
+
)
|
98
|
+
|
99
|
+
def _content_from_message(
|
100
|
+
self, prompt: lf.Message
|
101
|
+
) -> list[str | genai.types.BlobDict]:
|
102
|
+
"""Gets Evergreen formatted content from langfun message."""
|
103
|
+
formatted = lf.UserMessage(prompt.text)
|
104
|
+
formatted.source = prompt
|
105
|
+
|
106
|
+
chunks = []
|
107
|
+
for lf_chunk in formatted.chunk():
|
108
|
+
if isinstance(lf_chunk, str):
|
109
|
+
chunk = lf_chunk
|
110
|
+
elif self.multimodal and isinstance(lf_chunk, lf_modalities.Image):
|
111
|
+
chunk = genai.types.BlobDict(
|
112
|
+
data=lf_chunk.to_bytes(), mime_type=f'image/{lf_chunk.image_format}'
|
113
|
+
)
|
114
|
+
else:
|
115
|
+
raise ValueError(f'Unsupported modality: {lf_chunk!r}')
|
116
|
+
chunks.append(chunk)
|
117
|
+
return chunks
|
118
|
+
|
119
|
+
def _response_to_result(
|
120
|
+
self, response: genai.types.GenerateContentResponse
|
121
|
+
) -> lf.LMSamplingResult:
|
122
|
+
"""Parses generative response into message."""
|
123
|
+
samples = []
|
124
|
+
for candidate in response.candidates:
|
125
|
+
chunks = []
|
126
|
+
for part in candidate.content.parts:
|
127
|
+
# TODO(daiyip): support multi-modal parts when they are available via
|
128
|
+
# Gemini API.
|
129
|
+
if hasattr(part, 'text'):
|
130
|
+
chunks.append(part.text)
|
131
|
+
samples.append(lf.LMSample(lf.AIMessage.from_chunks(chunks), score=0.0))
|
132
|
+
return lf.LMSamplingResult(samples)
|
133
|
+
|
134
|
+
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
135
|
+
assert self._api_initialized, 'Vertex AI API is not initialized.'
|
136
|
+
return lf.concurrent_execute(
|
137
|
+
self._sample_single,
|
138
|
+
prompts,
|
139
|
+
executor=self.resource_id,
|
140
|
+
max_workers=self.max_concurrency,
|
141
|
+
# NOTE(daiyip): Vertex has its own policy on handling
|
142
|
+
# with rate limit, so we do not retry on errors.
|
143
|
+
retry_on_errors=None,
|
144
|
+
)
|
145
|
+
|
146
|
+
def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
147
|
+
"""Samples a single prompt."""
|
148
|
+
model = _GOOGLE_GENAI_MODEL_HUB.get(self.model)
|
149
|
+
input_content = self._content_from_message(prompt)
|
150
|
+
response = model.generate_content(
|
151
|
+
input_content,
|
152
|
+
generation_config=self._generation_config(self.sampling_options),
|
153
|
+
)
|
154
|
+
return self._response_to_result(response)
|
155
|
+
|
156
|
+
|
157
|
+
class _ModelHub:
|
158
|
+
"""Google Generative AI model hub."""
|
159
|
+
|
160
|
+
def __init__(self):
|
161
|
+
self._model_cache = {}
|
162
|
+
|
163
|
+
def get(self, model_name: str) -> genai.GenerativeModel:
|
164
|
+
"""Gets a generative model by model id."""
|
165
|
+
model = self._model_cache.get(model_name, None)
|
166
|
+
if model is None:
|
167
|
+
model = genai.GenerativeModel(model_name)
|
168
|
+
self._model_cache[model_name] = model
|
169
|
+
return model
|
170
|
+
|
171
|
+
|
172
|
+
_GOOGLE_GENAI_MODEL_HUB = _ModelHub()
|
173
|
+
|
174
|
+
|
175
|
+
#
|
176
|
+
# Public Gemini models.
|
177
|
+
#
|
178
|
+
|
179
|
+
|
180
|
+
class GeminiPro(Gemini):
|
181
|
+
"""Gemini Pro model."""
|
182
|
+
|
183
|
+
model = 'gemini-pro'
|
184
|
+
|
185
|
+
|
186
|
+
class GeminiProVision(Gemini):
|
187
|
+
"""Gemini Pro vision model."""
|
188
|
+
|
189
|
+
model = 'gemini-pro-vision'
|
190
|
+
multimodal = True
|
@@ -0,0 +1,163 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Tests for Gemini models."""
|
15
|
+
|
16
|
+
import os
|
17
|
+
import unittest
|
18
|
+
from unittest import mock
|
19
|
+
|
20
|
+
from google import generativeai as genai
|
21
|
+
import langfun.core as lf
|
22
|
+
from langfun.core import modalities as lf_modalities
|
23
|
+
from langfun.core.llms import gemini
|
24
|
+
import pyglove as pg
|
25
|
+
|
26
|
+
|
27
|
+
example_image = (
|
28
|
+
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
|
29
|
+
b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
|
30
|
+
b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
|
31
|
+
b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
|
32
|
+
b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
|
33
|
+
b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
|
34
|
+
b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
|
35
|
+
b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
|
36
|
+
)
|
37
|
+
|
38
|
+
|
39
|
+
def mock_generate_content(content, generation_config, **kwargs):
|
40
|
+
del kwargs
|
41
|
+
c = generation_config
|
42
|
+
return genai.types.GenerateContentResponse(
|
43
|
+
done=True,
|
44
|
+
iterator=None,
|
45
|
+
chunks=[],
|
46
|
+
result=pg.Dict(
|
47
|
+
prompt_feedback=pg.Dict(block_reason=None),
|
48
|
+
candidates=[
|
49
|
+
pg.Dict(
|
50
|
+
content=pg.Dict(
|
51
|
+
parts=[
|
52
|
+
pg.Dict(
|
53
|
+
text=(
|
54
|
+
f'This is a response to {content[0]} with '
|
55
|
+
f'n={c.candidate_count}, '
|
56
|
+
f'temperature={c.temperature}, '
|
57
|
+
f'top_p={c.top_p}, '
|
58
|
+
f'top_k={c.top_k}, '
|
59
|
+
f'max_tokens={c.max_output_tokens}, '
|
60
|
+
f'stop={c.stop_sequences}.'
|
61
|
+
)
|
62
|
+
)
|
63
|
+
]
|
64
|
+
),
|
65
|
+
),
|
66
|
+
],
|
67
|
+
),
|
68
|
+
)
|
69
|
+
|
70
|
+
|
71
|
+
class GeminiTest(unittest.TestCase):
|
72
|
+
"""Tests for Evergreen language model."""
|
73
|
+
|
74
|
+
def test_content_from_message_text_only(self):
|
75
|
+
text = 'This is a beautiful day'
|
76
|
+
model = gemini.GeminiPro()
|
77
|
+
chunks = model._content_from_message(lf.UserMessage(text))
|
78
|
+
self.assertEqual(chunks, [text])
|
79
|
+
|
80
|
+
def test_content_from_message_mm(self):
|
81
|
+
message = lf.UserMessage(
|
82
|
+
'This is an {{image}}, what is it?',
|
83
|
+
image=lf_modalities.Image.from_bytes(example_image),
|
84
|
+
)
|
85
|
+
|
86
|
+
# Non-multimodal model.
|
87
|
+
with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
|
88
|
+
gemini.GeminiPro()._content_from_message(message)
|
89
|
+
|
90
|
+
model = gemini.GeminiProVision()
|
91
|
+
chunks = model._content_from_message(message)
|
92
|
+
self.maxDiff = None
|
93
|
+
self.assertEqual(
|
94
|
+
chunks,
|
95
|
+
[
|
96
|
+
'This is an',
|
97
|
+
genai.types.BlobDict(mime_type='image/png', data=example_image),
|
98
|
+
', what is it?',
|
99
|
+
],
|
100
|
+
)
|
101
|
+
|
102
|
+
def test_response_to_result_text_only(self):
|
103
|
+
response = genai.types.GenerateContentResponse(
|
104
|
+
done=True,
|
105
|
+
iterator=None,
|
106
|
+
chunks=[],
|
107
|
+
result=pg.Dict(
|
108
|
+
prompt_feedback=pg.Dict(block_reason=None),
|
109
|
+
candidates=[
|
110
|
+
pg.Dict(
|
111
|
+
content=pg.Dict(
|
112
|
+
parts=[pg.Dict(text='This is response 1.')]
|
113
|
+
),
|
114
|
+
),
|
115
|
+
pg.Dict(
|
116
|
+
content=pg.Dict(parts=[pg.Dict(text='This is response 2.')])
|
117
|
+
),
|
118
|
+
],
|
119
|
+
),
|
120
|
+
)
|
121
|
+
model = gemini.GeminiProVision()
|
122
|
+
result = model._response_to_result(response)
|
123
|
+
self.assertEqual(
|
124
|
+
result,
|
125
|
+
lf.LMSamplingResult([
|
126
|
+
lf.LMSample(lf.AIMessage('This is response 1.'), score=0.0),
|
127
|
+
lf.LMSample(lf.AIMessage('This is response 2.'), score=0.0),
|
128
|
+
]),
|
129
|
+
)
|
130
|
+
|
131
|
+
def test_model_hub(self):
|
132
|
+
model = gemini._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro')
|
133
|
+
self.assertIsNotNone(model)
|
134
|
+
self.assertIs(gemini._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro'), model)
|
135
|
+
|
136
|
+
def test_api_key_check(self):
|
137
|
+
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
138
|
+
_ = gemini.GeminiPro()._api_initialized
|
139
|
+
|
140
|
+
self.assertTrue(gemini.GeminiPro(api_key='abc')._api_initialized)
|
141
|
+
os.environ['GOOGLE_API_KEY'] = 'abc'
|
142
|
+
self.assertTrue(gemini.GeminiPro()._api_initialized)
|
143
|
+
del os.environ['GOOGLE_API_KEY']
|
144
|
+
|
145
|
+
def test_call(self):
|
146
|
+
with mock.patch(
|
147
|
+
'google.generativeai.generative_models.GenerativeModel.generate_content'
|
148
|
+
) as mock_generate:
|
149
|
+
mock_generate.side_effect = mock_generate_content
|
150
|
+
|
151
|
+
lm = gemini.GeminiPro(api_key='test_key')
|
152
|
+
self.maxDiff = None
|
153
|
+
self.assertEqual(
|
154
|
+
lm('hello', temperature=2.0, top_k=20).text,
|
155
|
+
(
|
156
|
+
'This is a response to hello with n=1, temperature=2.0, '
|
157
|
+
'top_p=None, top_k=20, max_tokens=1024, stop=None.'
|
158
|
+
),
|
159
|
+
)
|
160
|
+
|
161
|
+
|
162
|
+
if __name__ == '__main__':
|
163
|
+
unittest.main()
|
@@ -195,7 +195,7 @@ class Message(natural_language.NaturalLanguageFormattable, pg.Object):
|
|
195
195
|
if key_path == Message.PATH_TEXT:
|
196
196
|
return self.text
|
197
197
|
else:
|
198
|
-
v = self.metadata.sym_get(key_path, default)
|
198
|
+
v = self.metadata.sym_get(key_path, default, use_inferred=True)
|
199
199
|
return v.value if isinstance(v, pg.Ref) else v
|
200
200
|
|
201
201
|
#
|
@@ -110,16 +110,17 @@ class MessageTest(unittest.TestCase):
|
|
110
110
|
def test_get(self):
|
111
111
|
|
112
112
|
class A(pg.Object):
|
113
|
-
|
113
|
+
p: int
|
114
114
|
|
115
115
|
# Create a symbolic object and assign it to a container, so we could test
|
116
116
|
# pg.Ref.
|
117
|
-
a = A()
|
117
|
+
a = A(1)
|
118
118
|
d = pg.Dict(x=a)
|
119
119
|
|
120
120
|
m = message.UserMessage('hi', x=pg.Ref(a), y=dict(z=[0, 1, 2]))
|
121
121
|
self.assertEqual(m.get('text'), 'hi')
|
122
122
|
self.assertIs(m.get('x'), a)
|
123
|
+
self.assertIs(m.get('x.p'), 1)
|
123
124
|
self.assertEqual(m.get('y'), dict(z=[0, 1, 2]))
|
124
125
|
self.assertEqual(m.get('y.z'), [0, 1, 2])
|
125
126
|
self.assertEqual(m.get('y.z[0]'), 0)
|