langfun 0.0.2.dev20240319__py3-none-any.whl → 0.0.2.dev20240429__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +2 -0
- langfun/core/__init__.py +1 -0
- langfun/core/coding/python/correction.py +0 -7
- langfun/core/component.py +6 -0
- langfun/core/component_test.py +1 -0
- langfun/core/eval/__init__.py +2 -0
- langfun/core/eval/base.py +240 -37
- langfun/core/eval/base_test.py +52 -18
- langfun/core/eval/matching.py +26 -9
- langfun/core/eval/matching_test.py +3 -4
- langfun/core/eval/scoring.py +15 -6
- langfun/core/eval/scoring_test.py +2 -2
- langfun/core/langfunc.py +0 -5
- langfun/core/langfunc_test.py +6 -4
- langfun/core/language_model.py +124 -24
- langfun/core/language_model_test.py +249 -26
- langfun/core/llms/__init__.py +24 -5
- langfun/core/llms/anthropic.py +263 -0
- langfun/core/llms/anthropic_test.py +167 -0
- langfun/core/llms/cache/in_memory_test.py +37 -28
- langfun/core/llms/fake.py +31 -22
- langfun/core/llms/fake_test.py +122 -11
- langfun/core/llms/{gemini.py → google_genai.py} +117 -15
- langfun/core/llms/{gemini_test.py → google_genai_test.py} +83 -15
- langfun/core/llms/groq.py +260 -0
- langfun/core/llms/groq_test.py +170 -0
- langfun/core/llms/llama_cpp.py +3 -1
- langfun/core/llms/openai.py +97 -79
- langfun/core/llms/openai_test.py +285 -59
- langfun/core/modalities/video.py +5 -2
- langfun/core/structured/__init__.py +3 -0
- langfun/core/structured/completion_test.py +2 -2
- langfun/core/structured/function_generation.py +245 -0
- langfun/core/structured/function_generation_test.py +329 -0
- langfun/core/structured/mapping.py +59 -3
- langfun/core/structured/mapping_test.py +17 -0
- langfun/core/structured/parsing.py +2 -1
- langfun/core/structured/parsing_test.py +18 -13
- langfun/core/structured/prompting.py +27 -6
- langfun/core/structured/prompting_test.py +79 -12
- langfun/core/structured/schema.py +25 -22
- langfun/core/structured/schema_generation.py +2 -3
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +42 -27
- langfun/core/template.py +125 -10
- langfun/core/template_test.py +75 -0
- langfun/core/templates/selfplay_test.py +6 -2
- {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/METADATA +3 -2
- {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/RECORD +52 -46
- {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/top_level.txt +0 -0
langfun/core/llms/fake_test.py
CHANGED
@@ -25,7 +25,24 @@ class EchoTest(unittest.TestCase):
|
|
25
25
|
def test_sample(self):
|
26
26
|
lm = fakelm.Echo()
|
27
27
|
self.assertEqual(
|
28
|
-
lm.sample(['hi']),
|
28
|
+
lm.sample(['hi']),
|
29
|
+
[
|
30
|
+
lf.LMSamplingResult(
|
31
|
+
[
|
32
|
+
lf.LMSample(
|
33
|
+
lf.AIMessage(
|
34
|
+
'hi',
|
35
|
+
score=1.0,
|
36
|
+
logprobs=None,
|
37
|
+
usage=lf.LMSamplingUsage(2, 2, 4),
|
38
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
39
|
+
),
|
40
|
+
score=1.0,
|
41
|
+
logprobs=None,
|
42
|
+
)
|
43
|
+
],
|
44
|
+
lf.LMSamplingUsage(2, 2, 4))
|
45
|
+
]
|
29
46
|
)
|
30
47
|
|
31
48
|
def test_call(self):
|
@@ -34,8 +51,8 @@ class EchoTest(unittest.TestCase):
|
|
34
51
|
with contextlib.redirect_stdout(string_io):
|
35
52
|
self.assertEqual(lm('hi'), 'hi')
|
36
53
|
debug_info = string_io.getvalue()
|
37
|
-
self.assertIn('[0] LM INFO
|
38
|
-
self.assertIn('[0] PROMPT SENT TO LM
|
54
|
+
self.assertIn('[0] LM INFO', debug_info)
|
55
|
+
self.assertIn('[0] PROMPT SENT TO LM', debug_info)
|
39
56
|
self.assertIn('[0] LM RESPONSE', debug_info)
|
40
57
|
|
41
58
|
def test_score(self):
|
@@ -53,11 +70,45 @@ class StaticResponseTest(unittest.TestCase):
|
|
53
70
|
lm = fakelm.StaticResponse(canned_response)
|
54
71
|
self.assertEqual(
|
55
72
|
lm.sample(['hi']),
|
56
|
-
[
|
73
|
+
[
|
74
|
+
lf.LMSamplingResult(
|
75
|
+
[
|
76
|
+
lf.LMSample(
|
77
|
+
lf.AIMessage(
|
78
|
+
canned_response,
|
79
|
+
score=1.0,
|
80
|
+
logprobs=None,
|
81
|
+
usage=lf.LMSamplingUsage(2, 38, 40),
|
82
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
83
|
+
),
|
84
|
+
score=1.0,
|
85
|
+
logprobs=None,
|
86
|
+
)
|
87
|
+
],
|
88
|
+
usage=lf.LMSamplingUsage(2, 38, 40)
|
89
|
+
)
|
90
|
+
],
|
57
91
|
)
|
58
92
|
self.assertEqual(
|
59
93
|
lm.sample(['Tell me a joke.']),
|
60
|
-
[
|
94
|
+
[
|
95
|
+
lf.LMSamplingResult(
|
96
|
+
[
|
97
|
+
lf.LMSample(
|
98
|
+
lf.AIMessage(
|
99
|
+
canned_response,
|
100
|
+
score=1.0,
|
101
|
+
logprobs=None,
|
102
|
+
usage=lf.LMSamplingUsage(15, 38, 53),
|
103
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
104
|
+
),
|
105
|
+
score=1.0,
|
106
|
+
logprobs=None,
|
107
|
+
)
|
108
|
+
],
|
109
|
+
usage=lf.LMSamplingUsage(15, 38, 53)
|
110
|
+
)
|
111
|
+
],
|
61
112
|
)
|
62
113
|
|
63
114
|
def test_call(self):
|
@@ -69,8 +120,8 @@ class StaticResponseTest(unittest.TestCase):
|
|
69
120
|
self.assertEqual(lm('hi'), canned_response)
|
70
121
|
|
71
122
|
debug_info = string_io.getvalue()
|
72
|
-
self.assertIn('[0] LM INFO
|
73
|
-
self.assertIn('[0] PROMPT SENT TO LM
|
123
|
+
self.assertIn('[0] LM INFO', debug_info)
|
124
|
+
self.assertIn('[0] PROMPT SENT TO LM', debug_info)
|
74
125
|
self.assertIn('[0] LM RESPONSE', debug_info)
|
75
126
|
|
76
127
|
|
@@ -85,8 +136,38 @@ class StaticMappingTest(unittest.TestCase):
|
|
85
136
|
self.assertEqual(
|
86
137
|
lm.sample(['Hi', 'How are you?']),
|
87
138
|
[
|
88
|
-
lf.LMSamplingResult(
|
89
|
-
|
139
|
+
lf.LMSamplingResult(
|
140
|
+
[
|
141
|
+
lf.LMSample(
|
142
|
+
lf.AIMessage(
|
143
|
+
'Hello',
|
144
|
+
score=1.0,
|
145
|
+
logprobs=None,
|
146
|
+
usage=lf.LMSamplingUsage(2, 5, 7),
|
147
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
148
|
+
),
|
149
|
+
score=1.0,
|
150
|
+
logprobs=None,
|
151
|
+
)
|
152
|
+
],
|
153
|
+
usage=lf.LMSamplingUsage(2, 5, 7)
|
154
|
+
),
|
155
|
+
lf.LMSamplingResult(
|
156
|
+
[
|
157
|
+
lf.LMSample(
|
158
|
+
lf.AIMessage(
|
159
|
+
'I am fine, how about you?',
|
160
|
+
score=1.0,
|
161
|
+
logprobs=None,
|
162
|
+
usage=lf.LMSamplingUsage(12, 25, 37),
|
163
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
164
|
+
),
|
165
|
+
score=1.0,
|
166
|
+
logprobs=None,
|
167
|
+
)
|
168
|
+
],
|
169
|
+
usage=lf.LMSamplingUsage(12, 25, 37)
|
170
|
+
)
|
90
171
|
]
|
91
172
|
)
|
92
173
|
with self.assertRaises(KeyError):
|
@@ -104,8 +185,38 @@ class StaticSequenceTest(unittest.TestCase):
|
|
104
185
|
self.assertEqual(
|
105
186
|
lm.sample(['Hi', 'How are you?']),
|
106
187
|
[
|
107
|
-
lf.LMSamplingResult(
|
108
|
-
|
188
|
+
lf.LMSamplingResult(
|
189
|
+
[
|
190
|
+
lf.LMSample(
|
191
|
+
lf.AIMessage(
|
192
|
+
'Hello',
|
193
|
+
score=1.0,
|
194
|
+
logprobs=None,
|
195
|
+
usage=lf.LMSamplingUsage(2, 5, 7),
|
196
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
197
|
+
),
|
198
|
+
score=1.0,
|
199
|
+
logprobs=None,
|
200
|
+
)
|
201
|
+
],
|
202
|
+
usage=lf.LMSamplingUsage(2, 5, 7)
|
203
|
+
),
|
204
|
+
lf.LMSamplingResult(
|
205
|
+
[
|
206
|
+
lf.LMSample(
|
207
|
+
lf.AIMessage(
|
208
|
+
'I am fine, how about you?',
|
209
|
+
score=1.0,
|
210
|
+
logprobs=None,
|
211
|
+
usage=lf.LMSamplingUsage(12, 25, 37),
|
212
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
213
|
+
),
|
214
|
+
score=1.0,
|
215
|
+
logprobs=None,
|
216
|
+
)
|
217
|
+
],
|
218
|
+
usage=lf.LMSamplingUsage(12, 25, 37)
|
219
|
+
)
|
109
220
|
]
|
110
221
|
)
|
111
222
|
with self.assertRaises(IndexError):
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Gemini models exposed through Google Generative AI APIs."""
|
15
15
|
|
16
|
+
import abc
|
16
17
|
import functools
|
17
18
|
import os
|
18
19
|
from typing import Annotated, Any, Literal
|
@@ -20,14 +21,20 @@ from typing import Annotated, Any, Literal
|
|
20
21
|
import google.generativeai as genai
|
21
22
|
import langfun.core as lf
|
22
23
|
from langfun.core import modalities as lf_modalities
|
24
|
+
import pyglove as pg
|
23
25
|
|
24
26
|
|
25
27
|
@lf.use_init_args(['model'])
|
26
|
-
class
|
27
|
-
"""Language
|
28
|
+
class GenAI(lf.LanguageModel):
|
29
|
+
"""Language models provided by Google GenAI."""
|
28
30
|
|
29
31
|
model: Annotated[
|
30
|
-
Literal[
|
32
|
+
Literal[
|
33
|
+
'gemini-pro',
|
34
|
+
'gemini-pro-vision',
|
35
|
+
'text-bison-001',
|
36
|
+
'chat-bison-001',
|
37
|
+
],
|
31
38
|
'Model name.',
|
32
39
|
]
|
33
40
|
|
@@ -35,7 +42,8 @@ class Gemini(lf.LanguageModel):
|
|
35
42
|
str | None,
|
36
43
|
(
|
37
44
|
'API key. If None, the key will be read from environment variable '
|
38
|
-
"'GOOGLE_API_KEY'."
|
45
|
+
"'GOOGLE_API_KEY'. "
|
46
|
+
'Get an API key at https://ai.google.dev/tutorials/setup'
|
39
47
|
),
|
40
48
|
] = None
|
41
49
|
|
@@ -43,6 +51,9 @@ class Gemini(lf.LanguageModel):
|
|
43
51
|
False
|
44
52
|
)
|
45
53
|
|
54
|
+
# Set the default max concurrency to 8 workers.
|
55
|
+
max_concurrency = 8
|
56
|
+
|
46
57
|
def _on_bound(self):
|
47
58
|
super()._on_bound()
|
48
59
|
self.__dict__.pop('_api_initialized', None)
|
@@ -67,7 +78,11 @@ class Gemini(lf.LanguageModel):
|
|
67
78
|
return [
|
68
79
|
m.name.lstrip('models/')
|
69
80
|
for m in genai.list_models()
|
70
|
-
if
|
81
|
+
if (
|
82
|
+
'generateContent' in m.supported_generation_methods
|
83
|
+
or 'generateText' in m.supported_generation_methods
|
84
|
+
or 'generateMessage' in m.supported_generation_methods
|
85
|
+
)
|
71
86
|
]
|
72
87
|
|
73
88
|
@property
|
@@ -80,11 +95,6 @@ class Gemini(lf.LanguageModel):
|
|
80
95
|
"""Returns a string to identify the resource for rate control."""
|
81
96
|
return self.model_id
|
82
97
|
|
83
|
-
@property
|
84
|
-
def max_concurrency(self) -> int:
|
85
|
-
"""Max concurrent requests."""
|
86
|
-
return 8
|
87
|
-
|
88
98
|
def _generation_config(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
89
99
|
"""Creates generation config from langfun sampling options."""
|
90
100
|
return genai.GenerationConfig(
|
@@ -117,7 +127,7 @@ class Gemini(lf.LanguageModel):
|
|
117
127
|
return chunks
|
118
128
|
|
119
129
|
def _response_to_result(
|
120
|
-
self, response: genai.types.GenerateContentResponse
|
130
|
+
self, response: genai.types.GenerateContentResponse | pg.Dict
|
121
131
|
) -> lf.LMSamplingResult:
|
122
132
|
"""Parses generative response into message."""
|
123
133
|
samples = []
|
@@ -149,17 +159,97 @@ class Gemini(lf.LanguageModel):
|
|
149
159
|
return self._response_to_result(response)
|
150
160
|
|
151
161
|
|
162
|
+
class _LegacyGenerativeModel(pg.Object):
|
163
|
+
"""Base for legacy GenAI generative model."""
|
164
|
+
|
165
|
+
model: str
|
166
|
+
|
167
|
+
def generate_content(
|
168
|
+
self,
|
169
|
+
input_content: list[str | genai.types.BlobDict],
|
170
|
+
generation_config: genai.GenerationConfig,
|
171
|
+
) -> pg.Dict:
|
172
|
+
"""Generate content."""
|
173
|
+
segments = []
|
174
|
+
for s in input_content:
|
175
|
+
if not isinstance(s, str):
|
176
|
+
raise ValueError(f'Unsupported modality: {s!r}')
|
177
|
+
segments.append(s)
|
178
|
+
return self.generate(' '.join(segments), generation_config)
|
179
|
+
|
180
|
+
@abc.abstractmethod
|
181
|
+
def generate(
|
182
|
+
self, prompt: str, generation_config: genai.GenerationConfig) -> pg.Dict:
|
183
|
+
"""Generate response based on prompt."""
|
184
|
+
|
185
|
+
|
186
|
+
class _LegacyCompletionModel(_LegacyGenerativeModel):
|
187
|
+
"""Legacy GenAI completion model."""
|
188
|
+
|
189
|
+
def generate(
|
190
|
+
self, prompt: str, generation_config: genai.GenerationConfig
|
191
|
+
) -> pg.Dict:
|
192
|
+
completion: genai.types.Completion = genai.generate_text(
|
193
|
+
model=f'models/{self.model}',
|
194
|
+
prompt=prompt,
|
195
|
+
temperature=generation_config.temperature,
|
196
|
+
top_k=generation_config.top_k,
|
197
|
+
top_p=generation_config.top_p,
|
198
|
+
candidate_count=generation_config.candidate_count,
|
199
|
+
max_output_tokens=generation_config.max_output_tokens,
|
200
|
+
stop_sequences=generation_config.stop_sequences,
|
201
|
+
)
|
202
|
+
return pg.Dict(
|
203
|
+
candidates=[
|
204
|
+
pg.Dict(content=pg.Dict(parts=[pg.Dict(text=c['output'])]))
|
205
|
+
for c in completion.candidates
|
206
|
+
]
|
207
|
+
)
|
208
|
+
|
209
|
+
|
210
|
+
class _LegacyChatModel(_LegacyGenerativeModel):
|
211
|
+
"""Legacy GenAI chat model."""
|
212
|
+
|
213
|
+
def generate(
|
214
|
+
self, prompt: str, generation_config: genai.GenerationConfig
|
215
|
+
) -> pg.Dict:
|
216
|
+
response: genai.types.ChatResponse = genai.chat(
|
217
|
+
model=f'models/{self.model}',
|
218
|
+
messages=prompt,
|
219
|
+
temperature=generation_config.temperature,
|
220
|
+
top_k=generation_config.top_k,
|
221
|
+
top_p=generation_config.top_p,
|
222
|
+
candidate_count=generation_config.candidate_count,
|
223
|
+
)
|
224
|
+
return pg.Dict(
|
225
|
+
candidates=[
|
226
|
+
pg.Dict(content=pg.Dict(parts=[pg.Dict(text=c['content'])]))
|
227
|
+
for c in response.candidates
|
228
|
+
]
|
229
|
+
)
|
230
|
+
|
231
|
+
|
152
232
|
class _ModelHub:
|
153
233
|
"""Google Generative AI model hub."""
|
154
234
|
|
155
235
|
def __init__(self):
|
156
236
|
self._model_cache = {}
|
157
237
|
|
158
|
-
def get(
|
238
|
+
def get(
|
239
|
+
self, model_name: str
|
240
|
+
) -> genai.GenerativeModel | _LegacyGenerativeModel:
|
159
241
|
"""Gets a generative model by model id."""
|
160
242
|
model = self._model_cache.get(model_name, None)
|
161
243
|
if model is None:
|
162
|
-
|
244
|
+
model_info = genai.get_model(f'models/{model_name}')
|
245
|
+
if 'generateContent' in model_info.supported_generation_methods:
|
246
|
+
model = genai.GenerativeModel(model_name)
|
247
|
+
elif 'generateText' in model_info.supported_generation_methods:
|
248
|
+
model = _LegacyCompletionModel(model_name)
|
249
|
+
elif 'generateMessage' in model_info.supported_generation_methods:
|
250
|
+
model = _LegacyChatModel(model_name)
|
251
|
+
else:
|
252
|
+
raise ValueError(f'Unsupported model: {model_name!r}')
|
163
253
|
self._model_cache[model_name] = model
|
164
254
|
return model
|
165
255
|
|
@@ -172,14 +262,26 @@ _GOOGLE_GENAI_MODEL_HUB = _ModelHub()
|
|
172
262
|
#
|
173
263
|
|
174
264
|
|
175
|
-
class GeminiPro(
|
265
|
+
class GeminiPro(GenAI):
|
176
266
|
"""Gemini Pro model."""
|
177
267
|
|
178
268
|
model = 'gemini-pro'
|
179
269
|
|
180
270
|
|
181
|
-
class GeminiProVision(
|
271
|
+
class GeminiProVision(GenAI):
|
182
272
|
"""Gemini Pro vision model."""
|
183
273
|
|
184
274
|
model = 'gemini-pro-vision'
|
185
275
|
multimodal = True
|
276
|
+
|
277
|
+
|
278
|
+
class Palm2(GenAI):
|
279
|
+
"""PaLM2 model."""
|
280
|
+
|
281
|
+
model = 'text-bison-001'
|
282
|
+
|
283
|
+
|
284
|
+
class Palm2_IT(GenAI): # pylint: disable=invalid-name
|
285
|
+
"""PaLM2 instruction-tuned model."""
|
286
|
+
|
287
|
+
model = 'chat-bison-001'
|
@@ -20,7 +20,7 @@ from unittest import mock
|
|
20
20
|
from google import generativeai as genai
|
21
21
|
import langfun.core as lf
|
22
22
|
from langfun.core import modalities as lf_modalities
|
23
|
-
from langfun.core.llms import
|
23
|
+
from langfun.core.llms import google_genai
|
24
24
|
import pyglove as pg
|
25
25
|
|
26
26
|
|
@@ -36,6 +36,29 @@ example_image = (
|
|
36
36
|
)
|
37
37
|
|
38
38
|
|
39
|
+
def mock_get_model(model_name, *args, **kwargs):
|
40
|
+
del args, kwargs
|
41
|
+
if 'gemini' in model_name:
|
42
|
+
method = 'generateContent'
|
43
|
+
elif 'chat' in model_name:
|
44
|
+
method = 'generateMessage'
|
45
|
+
else:
|
46
|
+
method = 'generateText'
|
47
|
+
return pg.Dict(supported_generation_methods=[method])
|
48
|
+
|
49
|
+
|
50
|
+
def mock_generate_text(*, model, prompt, **kwargs):
|
51
|
+
return pg.Dict(
|
52
|
+
candidates=[pg.Dict(output=f'{prompt} to {model} with {kwargs}')]
|
53
|
+
)
|
54
|
+
|
55
|
+
|
56
|
+
def mock_chat(*, model, messages, **kwargs):
|
57
|
+
return pg.Dict(
|
58
|
+
candidates=[pg.Dict(content=f'{messages} to {model} with {kwargs}')]
|
59
|
+
)
|
60
|
+
|
61
|
+
|
39
62
|
def mock_generate_content(content, generation_config, **kwargs):
|
40
63
|
del kwargs
|
41
64
|
c = generation_config
|
@@ -68,12 +91,12 @@ def mock_generate_content(content, generation_config, **kwargs):
|
|
68
91
|
)
|
69
92
|
|
70
93
|
|
71
|
-
class
|
72
|
-
"""Tests for
|
94
|
+
class GenAITest(unittest.TestCase):
|
95
|
+
"""Tests for Google GenAI model."""
|
73
96
|
|
74
97
|
def test_content_from_message_text_only(self):
|
75
98
|
text = 'This is a beautiful day'
|
76
|
-
model =
|
99
|
+
model = google_genai.GeminiPro()
|
77
100
|
chunks = model._content_from_message(lf.UserMessage(text))
|
78
101
|
self.assertEqual(chunks, [text])
|
79
102
|
|
@@ -85,9 +108,9 @@ class GeminiTest(unittest.TestCase):
|
|
85
108
|
|
86
109
|
# Non-multimodal model.
|
87
110
|
with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
|
88
|
-
|
111
|
+
google_genai.GeminiPro()._content_from_message(message)
|
89
112
|
|
90
|
-
model =
|
113
|
+
model = google_genai.GeminiProVision()
|
91
114
|
chunks = model._content_from_message(message)
|
92
115
|
self.maxDiff = None
|
93
116
|
self.assertEqual(
|
@@ -118,7 +141,7 @@ class GeminiTest(unittest.TestCase):
|
|
118
141
|
],
|
119
142
|
),
|
120
143
|
)
|
121
|
-
model =
|
144
|
+
model = google_genai.GeminiProVision()
|
122
145
|
result = model._response_to_result(response)
|
123
146
|
self.assertEqual(
|
124
147
|
result,
|
@@ -129,34 +152,79 @@ class GeminiTest(unittest.TestCase):
|
|
129
152
|
)
|
130
153
|
|
131
154
|
def test_model_hub(self):
|
132
|
-
|
155
|
+
orig_get_model = genai.get_model
|
156
|
+
genai.get_model = mock_get_model
|
157
|
+
|
158
|
+
model = google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro')
|
133
159
|
self.assertIsNotNone(model)
|
134
|
-
self.assertIs(
|
160
|
+
self.assertIs(google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro'), model)
|
161
|
+
|
162
|
+
genai.get_model = orig_get_model
|
135
163
|
|
136
164
|
def test_api_key_check(self):
|
137
165
|
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
138
|
-
_ =
|
166
|
+
_ = google_genai.GeminiPro()._api_initialized
|
139
167
|
|
140
|
-
self.assertTrue(
|
168
|
+
self.assertTrue(google_genai.GeminiPro(api_key='abc')._api_initialized)
|
141
169
|
os.environ['GOOGLE_API_KEY'] = 'abc'
|
142
|
-
self.assertTrue(
|
170
|
+
self.assertTrue(google_genai.GeminiPro()._api_initialized)
|
143
171
|
del os.environ['GOOGLE_API_KEY']
|
144
172
|
|
145
173
|
def test_call(self):
|
146
174
|
with mock.patch(
|
147
|
-
'google.generativeai.
|
175
|
+
'google.generativeai.GenerativeModel.generate_content',
|
148
176
|
) as mock_generate:
|
177
|
+
orig_get_model = genai.get_model
|
178
|
+
genai.get_model = mock_get_model
|
149
179
|
mock_generate.side_effect = mock_generate_content
|
150
180
|
|
151
|
-
lm =
|
181
|
+
lm = google_genai.GeminiPro(api_key='test_key')
|
152
182
|
self.maxDiff = None
|
153
183
|
self.assertEqual(
|
154
|
-
lm('hello', temperature=2.0, top_k=20).text,
|
184
|
+
lm('hello', temperature=2.0, top_k=20, max_tokens=1024).text,
|
155
185
|
(
|
156
186
|
'This is a response to hello with n=1, temperature=2.0, '
|
157
187
|
'top_p=None, top_k=20, max_tokens=1024, stop=None.'
|
158
188
|
),
|
159
189
|
)
|
190
|
+
genai.get_model = orig_get_model
|
191
|
+
|
192
|
+
def test_call_with_legacy_completion_model(self):
|
193
|
+
orig_get_model = genai.get_model
|
194
|
+
genai.get_model = mock_get_model
|
195
|
+
orig_generate_text = genai.generate_text
|
196
|
+
genai.generate_text = mock_generate_text
|
197
|
+
|
198
|
+
lm = google_genai.Palm2(api_key='test_key')
|
199
|
+
self.maxDiff = None
|
200
|
+
self.assertEqual(
|
201
|
+
lm('hello', temperature=2.0, top_k=20).text,
|
202
|
+
(
|
203
|
+
"hello to models/text-bison-001 with {'temperature': 2.0, "
|
204
|
+
"'top_k': 20, 'top_p': None, 'candidate_count': 1, "
|
205
|
+
"'max_output_tokens': None, 'stop_sequences': None}"
|
206
|
+
),
|
207
|
+
)
|
208
|
+
genai.get_model = orig_get_model
|
209
|
+
genai.generate_text = orig_generate_text
|
210
|
+
|
211
|
+
def test_call_with_legacy_chat_model(self):
|
212
|
+
orig_get_model = genai.get_model
|
213
|
+
genai.get_model = mock_get_model
|
214
|
+
orig_chat = genai.chat
|
215
|
+
genai.chat = mock_chat
|
216
|
+
|
217
|
+
lm = google_genai.Palm2_IT(api_key='test_key')
|
218
|
+
self.maxDiff = None
|
219
|
+
self.assertEqual(
|
220
|
+
lm('hello', temperature=2.0, top_k=20).text,
|
221
|
+
(
|
222
|
+
"hello to models/chat-bison-001 with {'temperature': 2.0, "
|
223
|
+
"'top_k': 20, 'top_p': None, 'candidate_count': 1}"
|
224
|
+
),
|
225
|
+
)
|
226
|
+
genai.get_model = orig_get_model
|
227
|
+
genai.chat = orig_chat
|
160
228
|
|
161
229
|
|
162
230
|
if __name__ == '__main__':
|