langfun 0.0.2.dev20240319__py3-none-any.whl → 0.0.2.dev20240429__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. langfun/__init__.py +2 -0
  2. langfun/core/__init__.py +1 -0
  3. langfun/core/coding/python/correction.py +0 -7
  4. langfun/core/component.py +6 -0
  5. langfun/core/component_test.py +1 -0
  6. langfun/core/eval/__init__.py +2 -0
  7. langfun/core/eval/base.py +240 -37
  8. langfun/core/eval/base_test.py +52 -18
  9. langfun/core/eval/matching.py +26 -9
  10. langfun/core/eval/matching_test.py +3 -4
  11. langfun/core/eval/scoring.py +15 -6
  12. langfun/core/eval/scoring_test.py +2 -2
  13. langfun/core/langfunc.py +0 -5
  14. langfun/core/langfunc_test.py +6 -4
  15. langfun/core/language_model.py +124 -24
  16. langfun/core/language_model_test.py +249 -26
  17. langfun/core/llms/__init__.py +24 -5
  18. langfun/core/llms/anthropic.py +263 -0
  19. langfun/core/llms/anthropic_test.py +167 -0
  20. langfun/core/llms/cache/in_memory_test.py +37 -28
  21. langfun/core/llms/fake.py +31 -22
  22. langfun/core/llms/fake_test.py +122 -11
  23. langfun/core/llms/{gemini.py → google_genai.py} +117 -15
  24. langfun/core/llms/{gemini_test.py → google_genai_test.py} +83 -15
  25. langfun/core/llms/groq.py +260 -0
  26. langfun/core/llms/groq_test.py +170 -0
  27. langfun/core/llms/llama_cpp.py +3 -1
  28. langfun/core/llms/openai.py +97 -79
  29. langfun/core/llms/openai_test.py +285 -59
  30. langfun/core/modalities/video.py +5 -2
  31. langfun/core/structured/__init__.py +3 -0
  32. langfun/core/structured/completion_test.py +2 -2
  33. langfun/core/structured/function_generation.py +245 -0
  34. langfun/core/structured/function_generation_test.py +329 -0
  35. langfun/core/structured/mapping.py +59 -3
  36. langfun/core/structured/mapping_test.py +17 -0
  37. langfun/core/structured/parsing.py +2 -1
  38. langfun/core/structured/parsing_test.py +18 -13
  39. langfun/core/structured/prompting.py +27 -6
  40. langfun/core/structured/prompting_test.py +79 -12
  41. langfun/core/structured/schema.py +25 -22
  42. langfun/core/structured/schema_generation.py +2 -3
  43. langfun/core/structured/schema_generation_test.py +2 -2
  44. langfun/core/structured/schema_test.py +42 -27
  45. langfun/core/template.py +125 -10
  46. langfun/core/template_test.py +75 -0
  47. langfun/core/templates/selfplay_test.py +6 -2
  48. {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/METADATA +3 -2
  49. {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/RECORD +52 -46
  50. {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/LICENSE +0 -0
  51. {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/WHEEL +0 -0
  52. {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/top_level.txt +0 -0
@@ -25,7 +25,24 @@ class EchoTest(unittest.TestCase):
25
25
  def test_sample(self):
26
26
  lm = fakelm.Echo()
27
27
  self.assertEqual(
28
- lm.sample(['hi']), [lf.LMSamplingResult([lf.LMSample('hi', 1.0)])]
28
+ lm.sample(['hi']),
29
+ [
30
+ lf.LMSamplingResult(
31
+ [
32
+ lf.LMSample(
33
+ lf.AIMessage(
34
+ 'hi',
35
+ score=1.0,
36
+ logprobs=None,
37
+ usage=lf.LMSamplingUsage(2, 2, 4),
38
+ tags=[lf.Message.TAG_LM_RESPONSE],
39
+ ),
40
+ score=1.0,
41
+ logprobs=None,
42
+ )
43
+ ],
44
+ lf.LMSamplingUsage(2, 2, 4))
45
+ ]
29
46
  )
30
47
 
31
48
  def test_call(self):
@@ -34,8 +51,8 @@ class EchoTest(unittest.TestCase):
34
51
  with contextlib.redirect_stdout(string_io):
35
52
  self.assertEqual(lm('hi'), 'hi')
36
53
  debug_info = string_io.getvalue()
37
- self.assertIn('[0] LM INFO:', debug_info)
38
- self.assertIn('[0] PROMPT SENT TO LM:', debug_info)
54
+ self.assertIn('[0] LM INFO', debug_info)
55
+ self.assertIn('[0] PROMPT SENT TO LM', debug_info)
39
56
  self.assertIn('[0] LM RESPONSE', debug_info)
40
57
 
41
58
  def test_score(self):
@@ -53,11 +70,45 @@ class StaticResponseTest(unittest.TestCase):
53
70
  lm = fakelm.StaticResponse(canned_response)
54
71
  self.assertEqual(
55
72
  lm.sample(['hi']),
56
- [lf.LMSamplingResult([lf.LMSample(canned_response, 1.0)])],
73
+ [
74
+ lf.LMSamplingResult(
75
+ [
76
+ lf.LMSample(
77
+ lf.AIMessage(
78
+ canned_response,
79
+ score=1.0,
80
+ logprobs=None,
81
+ usage=lf.LMSamplingUsage(2, 38, 40),
82
+ tags=[lf.Message.TAG_LM_RESPONSE],
83
+ ),
84
+ score=1.0,
85
+ logprobs=None,
86
+ )
87
+ ],
88
+ usage=lf.LMSamplingUsage(2, 38, 40)
89
+ )
90
+ ],
57
91
  )
58
92
  self.assertEqual(
59
93
  lm.sample(['Tell me a joke.']),
60
- [lf.LMSamplingResult([lf.LMSample(canned_response, 1.0)])],
94
+ [
95
+ lf.LMSamplingResult(
96
+ [
97
+ lf.LMSample(
98
+ lf.AIMessage(
99
+ canned_response,
100
+ score=1.0,
101
+ logprobs=None,
102
+ usage=lf.LMSamplingUsage(15, 38, 53),
103
+ tags=[lf.Message.TAG_LM_RESPONSE],
104
+ ),
105
+ score=1.0,
106
+ logprobs=None,
107
+ )
108
+ ],
109
+ usage=lf.LMSamplingUsage(15, 38, 53)
110
+ )
111
+ ],
61
112
  )
62
113
 
63
114
  def test_call(self):
@@ -69,8 +120,8 @@ class StaticResponseTest(unittest.TestCase):
69
120
  self.assertEqual(lm('hi'), canned_response)
70
121
 
71
122
  debug_info = string_io.getvalue()
72
- self.assertIn('[0] LM INFO:', debug_info)
73
- self.assertIn('[0] PROMPT SENT TO LM:', debug_info)
123
+ self.assertIn('[0] LM INFO', debug_info)
124
+ self.assertIn('[0] PROMPT SENT TO LM', debug_info)
74
125
  self.assertIn('[0] LM RESPONSE', debug_info)
75
126
 
76
127
 
@@ -85,8 +136,38 @@ class StaticMappingTest(unittest.TestCase):
85
136
  self.assertEqual(
86
137
  lm.sample(['Hi', 'How are you?']),
87
138
  [
88
- lf.LMSamplingResult([lf.LMSample('Hello', 1.0)]),
89
- lf.LMSamplingResult([lf.LMSample('I am fine, how about you?', 1.0)])
139
+ lf.LMSamplingResult(
140
+ [
141
+ lf.LMSample(
142
+ lf.AIMessage(
143
+ 'Hello',
144
+ score=1.0,
145
+ logprobs=None,
146
+ usage=lf.LMSamplingUsage(2, 5, 7),
147
+ tags=[lf.Message.TAG_LM_RESPONSE],
148
+ ),
149
+ score=1.0,
150
+ logprobs=None,
151
+ )
152
+ ],
153
+ usage=lf.LMSamplingUsage(2, 5, 7)
154
+ ),
155
+ lf.LMSamplingResult(
156
+ [
157
+ lf.LMSample(
158
+ lf.AIMessage(
159
+ 'I am fine, how about you?',
160
+ score=1.0,
161
+ logprobs=None,
162
+ usage=lf.LMSamplingUsage(12, 25, 37),
163
+ tags=[lf.Message.TAG_LM_RESPONSE],
164
+ ),
165
+ score=1.0,
166
+ logprobs=None,
167
+ )
168
+ ],
169
+ usage=lf.LMSamplingUsage(12, 25, 37)
170
+ )
90
171
  ]
91
172
  )
92
173
  with self.assertRaises(KeyError):
@@ -104,8 +185,38 @@ class StaticSequenceTest(unittest.TestCase):
104
185
  self.assertEqual(
105
186
  lm.sample(['Hi', 'How are you?']),
106
187
  [
107
- lf.LMSamplingResult([lf.LMSample('Hello', 1.0)]),
108
- lf.LMSamplingResult([lf.LMSample('I am fine, how about you?', 1.0)])
188
+ lf.LMSamplingResult(
189
+ [
190
+ lf.LMSample(
191
+ lf.AIMessage(
192
+ 'Hello',
193
+ score=1.0,
194
+ logprobs=None,
195
+ usage=lf.LMSamplingUsage(2, 5, 7),
196
+ tags=[lf.Message.TAG_LM_RESPONSE],
197
+ ),
198
+ score=1.0,
199
+ logprobs=None,
200
+ )
201
+ ],
202
+ usage=lf.LMSamplingUsage(2, 5, 7)
203
+ ),
204
+ lf.LMSamplingResult(
205
+ [
206
+ lf.LMSample(
207
+ lf.AIMessage(
208
+ 'I am fine, how about you?',
209
+ score=1.0,
210
+ logprobs=None,
211
+ usage=lf.LMSamplingUsage(12, 25, 37),
212
+ tags=[lf.Message.TAG_LM_RESPONSE],
213
+ ),
214
+ score=1.0,
215
+ logprobs=None,
216
+ )
217
+ ],
218
+ usage=lf.LMSamplingUsage(12, 25, 37)
219
+ )
109
220
  ]
110
221
  )
111
222
  with self.assertRaises(IndexError):
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  """Gemini models exposed through Google Generative AI APIs."""
15
15
 
16
+ import abc
16
17
  import functools
17
18
  import os
18
19
  from typing import Annotated, Any, Literal
@@ -20,14 +21,20 @@ from typing import Annotated, Any, Literal
20
21
  import google.generativeai as genai
21
22
  import langfun.core as lf
22
23
  from langfun.core import modalities as lf_modalities
24
+ import pyglove as pg
23
25
 
24
26
 
25
27
  @lf.use_init_args(['model'])
26
- class Gemini(lf.LanguageModel):
27
- """Language model served on VertexAI."""
28
+ class GenAI(lf.LanguageModel):
29
+ """Language models provided by Google GenAI."""
28
30
 
29
31
  model: Annotated[
30
- Literal['gemini-pro', 'gemini-pro-vision', ''],
32
+ Literal[
33
+ 'gemini-pro',
34
+ 'gemini-pro-vision',
35
+ 'text-bison-001',
36
+ 'chat-bison-001',
37
+ ],
31
38
  'Model name.',
32
39
  ]
33
40
 
@@ -35,7 +42,8 @@ class Gemini(lf.LanguageModel):
35
42
  str | None,
36
43
  (
37
44
  'API key. If None, the key will be read from environment variable '
38
- "'GOOGLE_API_KEY'."
45
+ "'GOOGLE_API_KEY'. "
46
+ 'Get an API key at https://ai.google.dev/tutorials/setup'
39
47
  ),
40
48
  ] = None
41
49
 
@@ -43,6 +51,9 @@ class Gemini(lf.LanguageModel):
43
51
  False
44
52
  )
45
53
 
54
+ # Set the default max concurrency to 8 workers.
55
+ max_concurrency = 8
56
+
46
57
  def _on_bound(self):
47
58
  super()._on_bound()
48
59
  self.__dict__.pop('_api_initialized', None)
@@ -67,7 +78,11 @@ class Gemini(lf.LanguageModel):
67
78
  return [
68
79
  m.name.lstrip('models/')
69
80
  for m in genai.list_models()
70
- if 'generateContent' in m.supported_generation_methods
81
+ if (
82
+ 'generateContent' in m.supported_generation_methods
83
+ or 'generateText' in m.supported_generation_methods
84
+ or 'generateMessage' in m.supported_generation_methods
85
+ )
71
86
  ]
72
87
 
73
88
  @property
@@ -80,11 +95,6 @@ class Gemini(lf.LanguageModel):
80
95
  """Returns a string to identify the resource for rate control."""
81
96
  return self.model_id
82
97
 
83
- @property
84
- def max_concurrency(self) -> int:
85
- """Max concurrent requests."""
86
- return 8
87
-
88
98
  def _generation_config(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
89
99
  """Creates generation config from langfun sampling options."""
90
100
  return genai.GenerationConfig(
@@ -117,7 +127,7 @@ class Gemini(lf.LanguageModel):
117
127
  return chunks
118
128
 
119
129
  def _response_to_result(
120
- self, response: genai.types.GenerateContentResponse
130
+ self, response: genai.types.GenerateContentResponse | pg.Dict
121
131
  ) -> lf.LMSamplingResult:
122
132
  """Parses generative response into message."""
123
133
  samples = []
@@ -149,17 +159,97 @@ class Gemini(lf.LanguageModel):
149
159
  return self._response_to_result(response)
150
160
 
151
161
 
162
+ class _LegacyGenerativeModel(pg.Object):
163
+ """Base for legacy GenAI generative model."""
164
+
165
+ model: str
166
+
167
+ def generate_content(
168
+ self,
169
+ input_content: list[str | genai.types.BlobDict],
170
+ generation_config: genai.GenerationConfig,
171
+ ) -> pg.Dict:
172
+ """Generate content."""
173
+ segments = []
174
+ for s in input_content:
175
+ if not isinstance(s, str):
176
+ raise ValueError(f'Unsupported modality: {s!r}')
177
+ segments.append(s)
178
+ return self.generate(' '.join(segments), generation_config)
179
+
180
+ @abc.abstractmethod
181
+ def generate(
182
+ self, prompt: str, generation_config: genai.GenerationConfig) -> pg.Dict:
183
+ """Generate response based on prompt."""
184
+
185
+
186
+ class _LegacyCompletionModel(_LegacyGenerativeModel):
187
+ """Legacy GenAI completion model."""
188
+
189
+ def generate(
190
+ self, prompt: str, generation_config: genai.GenerationConfig
191
+ ) -> pg.Dict:
192
+ completion: genai.types.Completion = genai.generate_text(
193
+ model=f'models/{self.model}',
194
+ prompt=prompt,
195
+ temperature=generation_config.temperature,
196
+ top_k=generation_config.top_k,
197
+ top_p=generation_config.top_p,
198
+ candidate_count=generation_config.candidate_count,
199
+ max_output_tokens=generation_config.max_output_tokens,
200
+ stop_sequences=generation_config.stop_sequences,
201
+ )
202
+ return pg.Dict(
203
+ candidates=[
204
+ pg.Dict(content=pg.Dict(parts=[pg.Dict(text=c['output'])]))
205
+ for c in completion.candidates
206
+ ]
207
+ )
208
+
209
+
210
+ class _LegacyChatModel(_LegacyGenerativeModel):
211
+ """Legacy GenAI chat model."""
212
+
213
+ def generate(
214
+ self, prompt: str, generation_config: genai.GenerationConfig
215
+ ) -> pg.Dict:
216
+ response: genai.types.ChatResponse = genai.chat(
217
+ model=f'models/{self.model}',
218
+ messages=prompt,
219
+ temperature=generation_config.temperature,
220
+ top_k=generation_config.top_k,
221
+ top_p=generation_config.top_p,
222
+ candidate_count=generation_config.candidate_count,
223
+ )
224
+ return pg.Dict(
225
+ candidates=[
226
+ pg.Dict(content=pg.Dict(parts=[pg.Dict(text=c['content'])]))
227
+ for c in response.candidates
228
+ ]
229
+ )
230
+
231
+
152
232
  class _ModelHub:
153
233
  """Google Generative AI model hub."""
154
234
 
155
235
  def __init__(self):
156
236
  self._model_cache = {}
157
237
 
158
- def get(self, model_name: str) -> genai.GenerativeModel:
238
+ def get(
239
+ self, model_name: str
240
+ ) -> genai.GenerativeModel | _LegacyGenerativeModel:
159
241
  """Gets a generative model by model id."""
160
242
  model = self._model_cache.get(model_name, None)
161
243
  if model is None:
162
- model = genai.GenerativeModel(model_name)
244
+ model_info = genai.get_model(f'models/{model_name}')
245
+ if 'generateContent' in model_info.supported_generation_methods:
246
+ model = genai.GenerativeModel(model_name)
247
+ elif 'generateText' in model_info.supported_generation_methods:
248
+ model = _LegacyCompletionModel(model_name)
249
+ elif 'generateMessage' in model_info.supported_generation_methods:
250
+ model = _LegacyChatModel(model_name)
251
+ else:
252
+ raise ValueError(f'Unsupported model: {model_name!r}')
163
253
  self._model_cache[model_name] = model
164
254
  return model
165
255
 
@@ -172,14 +262,26 @@ _GOOGLE_GENAI_MODEL_HUB = _ModelHub()
172
262
  #
173
263
 
174
264
 
175
- class GeminiPro(Gemini):
265
+ class GeminiPro(GenAI):
176
266
  """Gemini Pro model."""
177
267
 
178
268
  model = 'gemini-pro'
179
269
 
180
270
 
181
- class GeminiProVision(Gemini):
271
+ class GeminiProVision(GenAI):
182
272
  """Gemini Pro vision model."""
183
273
 
184
274
  model = 'gemini-pro-vision'
185
275
  multimodal = True
276
+
277
+
278
+ class Palm2(GenAI):
279
+ """PaLM2 model."""
280
+
281
+ model = 'text-bison-001'
282
+
283
+
284
+ class Palm2_IT(GenAI): # pylint: disable=invalid-name
285
+ """PaLM2 instruction-tuned model."""
286
+
287
+ model = 'chat-bison-001'
@@ -20,7 +20,7 @@ from unittest import mock
20
20
  from google import generativeai as genai
21
21
  import langfun.core as lf
22
22
  from langfun.core import modalities as lf_modalities
23
- from langfun.core.llms import gemini
23
+ from langfun.core.llms import google_genai
24
24
  import pyglove as pg
25
25
 
26
26
 
@@ -36,6 +36,29 @@ example_image = (
36
36
  )
37
37
 
38
38
 
39
+ def mock_get_model(model_name, *args, **kwargs):
40
+ del args, kwargs
41
+ if 'gemini' in model_name:
42
+ method = 'generateContent'
43
+ elif 'chat' in model_name:
44
+ method = 'generateMessage'
45
+ else:
46
+ method = 'generateText'
47
+ return pg.Dict(supported_generation_methods=[method])
48
+
49
+
50
+ def mock_generate_text(*, model, prompt, **kwargs):
51
+ return pg.Dict(
52
+ candidates=[pg.Dict(output=f'{prompt} to {model} with {kwargs}')]
53
+ )
54
+
55
+
56
+ def mock_chat(*, model, messages, **kwargs):
57
+ return pg.Dict(
58
+ candidates=[pg.Dict(content=f'{messages} to {model} with {kwargs}')]
59
+ )
60
+
61
+
39
62
  def mock_generate_content(content, generation_config, **kwargs):
40
63
  del kwargs
41
64
  c = generation_config
@@ -68,12 +91,12 @@ def mock_generate_content(content, generation_config, **kwargs):
68
91
  )
69
92
 
70
93
 
71
- class GeminiTest(unittest.TestCase):
72
- """Tests for Evergreen language model."""
94
+ class GenAITest(unittest.TestCase):
95
+ """Tests for Google GenAI model."""
73
96
 
74
97
  def test_content_from_message_text_only(self):
75
98
  text = 'This is a beautiful day'
76
- model = gemini.GeminiPro()
99
+ model = google_genai.GeminiPro()
77
100
  chunks = model._content_from_message(lf.UserMessage(text))
78
101
  self.assertEqual(chunks, [text])
79
102
 
@@ -85,9 +108,9 @@ class GeminiTest(unittest.TestCase):
85
108
 
86
109
  # Non-multimodal model.
87
110
  with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
88
- gemini.GeminiPro()._content_from_message(message)
111
+ google_genai.GeminiPro()._content_from_message(message)
89
112
 
90
- model = gemini.GeminiProVision()
113
+ model = google_genai.GeminiProVision()
91
114
  chunks = model._content_from_message(message)
92
115
  self.maxDiff = None
93
116
  self.assertEqual(
@@ -118,7 +141,7 @@ class GeminiTest(unittest.TestCase):
118
141
  ],
119
142
  ),
120
143
  )
121
- model = gemini.GeminiProVision()
144
+ model = google_genai.GeminiProVision()
122
145
  result = model._response_to_result(response)
123
146
  self.assertEqual(
124
147
  result,
@@ -129,34 +152,79 @@ class GeminiTest(unittest.TestCase):
129
152
  )
130
153
 
131
154
  def test_model_hub(self):
132
- model = gemini._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro')
155
+ orig_get_model = genai.get_model
156
+ genai.get_model = mock_get_model
157
+
158
+ model = google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro')
133
159
  self.assertIsNotNone(model)
134
- self.assertIs(gemini._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro'), model)
160
+ self.assertIs(google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro'), model)
161
+
162
+ genai.get_model = orig_get_model
135
163
 
136
164
  def test_api_key_check(self):
137
165
  with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
138
- _ = gemini.GeminiPro()._api_initialized
166
+ _ = google_genai.GeminiPro()._api_initialized
139
167
 
140
- self.assertTrue(gemini.GeminiPro(api_key='abc')._api_initialized)
168
+ self.assertTrue(google_genai.GeminiPro(api_key='abc')._api_initialized)
141
169
  os.environ['GOOGLE_API_KEY'] = 'abc'
142
- self.assertTrue(gemini.GeminiPro()._api_initialized)
170
+ self.assertTrue(google_genai.GeminiPro()._api_initialized)
143
171
  del os.environ['GOOGLE_API_KEY']
144
172
 
145
173
  def test_call(self):
146
174
  with mock.patch(
147
- 'google.generativeai.generative_models.GenerativeModel.generate_content'
175
+ 'google.generativeai.GenerativeModel.generate_content',
148
176
  ) as mock_generate:
177
+ orig_get_model = genai.get_model
178
+ genai.get_model = mock_get_model
149
179
  mock_generate.side_effect = mock_generate_content
150
180
 
151
- lm = gemini.GeminiPro(api_key='test_key')
181
+ lm = google_genai.GeminiPro(api_key='test_key')
152
182
  self.maxDiff = None
153
183
  self.assertEqual(
154
- lm('hello', temperature=2.0, top_k=20).text,
184
+ lm('hello', temperature=2.0, top_k=20, max_tokens=1024).text,
155
185
  (
156
186
  'This is a response to hello with n=1, temperature=2.0, '
157
187
  'top_p=None, top_k=20, max_tokens=1024, stop=None.'
158
188
  ),
159
189
  )
190
+ genai.get_model = orig_get_model
191
+
192
+ def test_call_with_legacy_completion_model(self):
193
+ orig_get_model = genai.get_model
194
+ genai.get_model = mock_get_model
195
+ orig_generate_text = genai.generate_text
196
+ genai.generate_text = mock_generate_text
197
+
198
+ lm = google_genai.Palm2(api_key='test_key')
199
+ self.maxDiff = None
200
+ self.assertEqual(
201
+ lm('hello', temperature=2.0, top_k=20).text,
202
+ (
203
+ "hello to models/text-bison-001 with {'temperature': 2.0, "
204
+ "'top_k': 20, 'top_p': None, 'candidate_count': 1, "
205
+ "'max_output_tokens': None, 'stop_sequences': None}"
206
+ ),
207
+ )
208
+ genai.get_model = orig_get_model
209
+ genai.generate_text = orig_generate_text
210
+
211
+ def test_call_with_legacy_chat_model(self):
212
+ orig_get_model = genai.get_model
213
+ genai.get_model = mock_get_model
214
+ orig_chat = genai.chat
215
+ genai.chat = mock_chat
216
+
217
+ lm = google_genai.Palm2_IT(api_key='test_key')
218
+ self.maxDiff = None
219
+ self.assertEqual(
220
+ lm('hello', temperature=2.0, top_k=20).text,
221
+ (
222
+ "hello to models/chat-bison-001 with {'temperature': 2.0, "
223
+ "'top_k': 20, 'top_p': None, 'candidate_count': 1}"
224
+ ),
225
+ )
226
+ genai.get_model = orig_get_model
227
+ genai.chat = orig_chat
160
228
 
161
229
 
162
230
  if __name__ == '__main__':