langfun 0.1.2.dev202501050804__py3-none-any.whl → 0.1.2.dev202501070804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,223 +11,28 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Tests for Gemini models."""
14
+ """Tests for Google GenAI models."""
15
15
 
16
16
  import os
17
17
  import unittest
18
- from unittest import mock
19
-
20
- from google import generativeai as genai
21
- import langfun.core as lf
22
- from langfun.core import modalities as lf_modalities
23
18
  from langfun.core.llms import google_genai
24
- import pyglove as pg
25
-
26
-
27
- example_image = (
28
- b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
29
- b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
30
- b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
31
- b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
32
- b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
33
- b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
34
- b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
35
- b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
36
- )
37
-
38
-
39
- def mock_get_model(model_name, *args, **kwargs):
40
- del args, kwargs
41
- if 'gemini' in model_name:
42
- method = 'generateContent'
43
- elif 'chat' in model_name:
44
- method = 'generateMessage'
45
- else:
46
- method = 'generateText'
47
- return pg.Dict(supported_generation_methods=[method])
48
-
49
-
50
- def mock_generate_text(*, model, prompt, **kwargs):
51
- return pg.Dict(
52
- candidates=[pg.Dict(output=f'{prompt} to {model} with {kwargs}')]
53
- )
54
-
55
-
56
- def mock_chat(*, model, messages, **kwargs):
57
- return pg.Dict(
58
- candidates=[pg.Dict(content=f'{messages} to {model} with {kwargs}')]
59
- )
60
-
61
-
62
- def mock_generate_content(content, generation_config, **kwargs):
63
- del kwargs
64
- c = generation_config
65
- return genai.types.GenerateContentResponse(
66
- done=True,
67
- iterator=None,
68
- chunks=[],
69
- result=pg.Dict(
70
- prompt_feedback=pg.Dict(block_reason=None),
71
- candidates=[
72
- pg.Dict(
73
- content=pg.Dict(
74
- parts=[
75
- pg.Dict(
76
- text=(
77
- f'This is a response to {content[0]} with '
78
- f'n={c.candidate_count}, '
79
- f'temperature={c.temperature}, '
80
- f'top_p={c.top_p}, '
81
- f'top_k={c.top_k}, '
82
- f'max_tokens={c.max_output_tokens}, '
83
- f'stop={c.stop_sequences}.'
84
- )
85
- )
86
- ]
87
- ),
88
- ),
89
- ],
90
- ),
91
- )
92
19
 
93
20
 
94
21
  class GenAITest(unittest.TestCase):
95
- """Tests for Google GenAI model."""
96
-
97
- def test_content_from_message_text_only(self):
98
- text = 'This is a beautiful day'
99
- model = google_genai.GeminiPro()
100
- chunks = model._content_from_message(lf.UserMessage(text))
101
- self.assertEqual(chunks, [text])
102
-
103
- def test_content_from_message_mm(self):
104
- message = lf.UserMessage(
105
- 'This is an <<[[image]]>>, what is it?',
106
- image=lf_modalities.Image.from_bytes(example_image),
107
- )
22
+ """Tests for GenAI model."""
108
23
 
109
- # Non-multimodal model.
110
- with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
111
- google_genai.GeminiPro()._content_from_message(message)
112
-
113
- model = google_genai.GeminiProVision()
114
- chunks = model._content_from_message(message)
115
- self.maxDiff = None
116
- self.assertEqual(
117
- chunks,
118
- [
119
- 'This is an',
120
- genai.types.BlobDict(mime_type='image/png', data=example_image),
121
- ', what is it?',
122
- ],
123
- )
124
-
125
- def test_response_to_result_text_only(self):
126
- response = genai.types.GenerateContentResponse(
127
- done=True,
128
- iterator=None,
129
- chunks=[],
130
- result=pg.Dict(
131
- prompt_feedback=pg.Dict(block_reason=None),
132
- candidates=[
133
- pg.Dict(
134
- content=pg.Dict(
135
- parts=[pg.Dict(text='This is response 1.')]
136
- ),
137
- ),
138
- pg.Dict(
139
- content=pg.Dict(parts=[pg.Dict(text='This is response 2.')])
140
- ),
141
- ],
142
- ),
143
- )
144
- model = google_genai.GeminiProVision()
145
- result = model._response_to_result(response)
146
- self.assertEqual(
147
- result,
148
- lf.LMSamplingResult([
149
- lf.LMSample(lf.AIMessage('This is response 1.'), score=0.0),
150
- lf.LMSample(lf.AIMessage('This is response 2.'), score=0.0),
151
- ]),
152
- )
153
-
154
- def test_model_hub(self):
155
- orig_get_model = genai.get_model
156
- genai.get_model = mock_get_model
157
-
158
- model = google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro')
159
- self.assertIsNotNone(model)
160
- self.assertIs(google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro'), model)
161
-
162
- genai.get_model = orig_get_model
163
-
164
- def test_api_key_check(self):
24
+ def test_basics(self):
165
25
  with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
166
- _ = google_genai.GeminiPro()._api_initialized
26
+ _ = google_genai.GeminiPro1_5().api_endpoint
27
+
28
+ self.assertIsNotNone(google_genai.GeminiPro1_5(api_key='abc').api_endpoint)
167
29
 
168
- self.assertTrue(google_genai.GeminiPro(api_key='abc')._api_initialized)
169
30
  os.environ['GOOGLE_API_KEY'] = 'abc'
170
- self.assertTrue(google_genai.GeminiPro()._api_initialized)
31
+ lm = google_genai.GeminiPro1_5()
32
+ self.assertIsNotNone(lm.api_endpoint)
33
+ self.assertTrue(lm.model_id.startswith('GenAI('))
171
34
  del os.environ['GOOGLE_API_KEY']
172
35
 
173
- def test_call(self):
174
- with mock.patch(
175
- 'google.generativeai.GenerativeModel.generate_content',
176
- ) as mock_generate:
177
- orig_get_model = genai.get_model
178
- genai.get_model = mock_get_model
179
- mock_generate.side_effect = mock_generate_content
180
-
181
- lm = google_genai.GeminiPro(api_key='test_key')
182
- self.maxDiff = None
183
- self.assertEqual(
184
- lm('hello', temperature=2.0, top_k=20, max_tokens=1024).text,
185
- (
186
- 'This is a response to hello with n=1, temperature=2.0, '
187
- 'top_p=None, top_k=20, max_tokens=1024, stop=None.'
188
- ),
189
- )
190
- genai.get_model = orig_get_model
191
-
192
- def test_call_with_legacy_completion_model(self):
193
- orig_get_model = genai.get_model
194
- genai.get_model = mock_get_model
195
- orig_generate_text = getattr(genai, 'generate_text', None)
196
- if orig_generate_text is not None:
197
- genai.generate_text = mock_generate_text
198
-
199
- lm = google_genai.Palm2(api_key='test_key')
200
- self.maxDiff = None
201
- self.assertEqual(
202
- lm('hello', temperature=2.0, top_k=20).text,
203
- (
204
- "hello to models/text-bison-001 with {'temperature': 2.0, "
205
- "'top_k': 20, 'top_p': None, 'candidate_count': 1, "
206
- "'max_output_tokens': None, 'stop_sequences': None}"
207
- ),
208
- )
209
- genai.generate_text = orig_generate_text
210
- genai.get_model = orig_get_model
211
-
212
- def test_call_with_legacy_chat_model(self):
213
- orig_get_model = genai.get_model
214
- genai.get_model = mock_get_model
215
- orig_chat = getattr(genai, 'chat', None)
216
- if orig_chat is not None:
217
- genai.chat = mock_chat
218
-
219
- lm = google_genai.Palm2_IT(api_key='test_key')
220
- self.maxDiff = None
221
- self.assertEqual(
222
- lm('hello', temperature=2.0, top_k=20).text,
223
- (
224
- "hello to models/chat-bison-001 with {'temperature': 2.0, "
225
- "'top_k': 20, 'top_p': None, 'candidate_count': 1}"
226
- ),
227
- )
228
- genai.chat = orig_chat
229
- genai.get_model = orig_get_model
230
-
231
36
 
232
37
  if __name__ == '__main__':
233
38
  unittest.main()
@@ -553,26 +553,31 @@ class GptO1(OpenAI):
553
553
 
554
554
  model = 'o1'
555
555
  multimodal = True
556
+ timeout = None
556
557
 
557
558
 
558
559
  class GptO1Preview(OpenAI):
559
560
  """GPT-O1."""
560
561
  model = 'o1-preview'
562
+ timeout = None
561
563
 
562
564
 
563
565
  class GptO1Preview_20240912(OpenAI): # pylint: disable=invalid-name
564
566
  """GPT O1."""
565
567
  model = 'o1-preview-2024-09-12'
568
+ timeout = None
566
569
 
567
570
 
568
571
  class GptO1Mini(OpenAI):
569
572
  """GPT O1-mini."""
570
573
  model = 'o1-mini'
574
+ timeout = None
571
575
 
572
576
 
573
577
  class GptO1Mini_20240912(OpenAI): # pylint: disable=invalid-name
574
578
  """GPT O1-mini."""
575
579
  model = 'o1-mini-2024-09-12'
580
+ timeout = None
576
581
 
577
582
 
578
583
  class Gpt4(OpenAI):