langfun 0.1.2.dev202501050804__py3-none-any.whl → 0.1.2.dev202501070804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/eval/v2/reporting.py +7 -2
- langfun/core/language_model.py +4 -1
- langfun/core/language_model_test.py +15 -0
- langfun/core/llms/__init__.py +21 -26
- langfun/core/llms/cache/in_memory.py +6 -0
- langfun/core/llms/cache/in_memory_test.py +5 -0
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +46 -320
- langfun/core/llms/google_genai_test.py +9 -204
- langfun/core/llms/openai.py +5 -0
- langfun/core/llms/vertexai.py +26 -357
- langfun/core/llms/vertexai_test.py +6 -166
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501070804.dist-info}/METADATA +7 -13
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501070804.dist-info}/RECORD +18 -16
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501070804.dist-info}/WHEEL +1 -1
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501070804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501070804.dist-info}/top_level.txt +0 -0
@@ -11,223 +11,28 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""Tests for
|
14
|
+
"""Tests for Google GenAI models."""
|
15
15
|
|
16
16
|
import os
|
17
17
|
import unittest
|
18
|
-
from unittest import mock
|
19
|
-
|
20
|
-
from google import generativeai as genai
|
21
|
-
import langfun.core as lf
|
22
|
-
from langfun.core import modalities as lf_modalities
|
23
18
|
from langfun.core.llms import google_genai
|
24
|
-
import pyglove as pg
|
25
|
-
|
26
|
-
|
27
|
-
example_image = (
|
28
|
-
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
|
29
|
-
b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
|
30
|
-
b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
|
31
|
-
b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
|
32
|
-
b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
|
33
|
-
b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
|
34
|
-
b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
|
35
|
-
b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
|
36
|
-
)
|
37
|
-
|
38
|
-
|
39
|
-
def mock_get_model(model_name, *args, **kwargs):
|
40
|
-
del args, kwargs
|
41
|
-
if 'gemini' in model_name:
|
42
|
-
method = 'generateContent'
|
43
|
-
elif 'chat' in model_name:
|
44
|
-
method = 'generateMessage'
|
45
|
-
else:
|
46
|
-
method = 'generateText'
|
47
|
-
return pg.Dict(supported_generation_methods=[method])
|
48
|
-
|
49
|
-
|
50
|
-
def mock_generate_text(*, model, prompt, **kwargs):
|
51
|
-
return pg.Dict(
|
52
|
-
candidates=[pg.Dict(output=f'{prompt} to {model} with {kwargs}')]
|
53
|
-
)
|
54
|
-
|
55
|
-
|
56
|
-
def mock_chat(*, model, messages, **kwargs):
|
57
|
-
return pg.Dict(
|
58
|
-
candidates=[pg.Dict(content=f'{messages} to {model} with {kwargs}')]
|
59
|
-
)
|
60
|
-
|
61
|
-
|
62
|
-
def mock_generate_content(content, generation_config, **kwargs):
|
63
|
-
del kwargs
|
64
|
-
c = generation_config
|
65
|
-
return genai.types.GenerateContentResponse(
|
66
|
-
done=True,
|
67
|
-
iterator=None,
|
68
|
-
chunks=[],
|
69
|
-
result=pg.Dict(
|
70
|
-
prompt_feedback=pg.Dict(block_reason=None),
|
71
|
-
candidates=[
|
72
|
-
pg.Dict(
|
73
|
-
content=pg.Dict(
|
74
|
-
parts=[
|
75
|
-
pg.Dict(
|
76
|
-
text=(
|
77
|
-
f'This is a response to {content[0]} with '
|
78
|
-
f'n={c.candidate_count}, '
|
79
|
-
f'temperature={c.temperature}, '
|
80
|
-
f'top_p={c.top_p}, '
|
81
|
-
f'top_k={c.top_k}, '
|
82
|
-
f'max_tokens={c.max_output_tokens}, '
|
83
|
-
f'stop={c.stop_sequences}.'
|
84
|
-
)
|
85
|
-
)
|
86
|
-
]
|
87
|
-
),
|
88
|
-
),
|
89
|
-
],
|
90
|
-
),
|
91
|
-
)
|
92
19
|
|
93
20
|
|
94
21
|
class GenAITest(unittest.TestCase):
|
95
|
-
"""Tests for
|
96
|
-
|
97
|
-
def test_content_from_message_text_only(self):
|
98
|
-
text = 'This is a beautiful day'
|
99
|
-
model = google_genai.GeminiPro()
|
100
|
-
chunks = model._content_from_message(lf.UserMessage(text))
|
101
|
-
self.assertEqual(chunks, [text])
|
102
|
-
|
103
|
-
def test_content_from_message_mm(self):
|
104
|
-
message = lf.UserMessage(
|
105
|
-
'This is an <<[[image]]>>, what is it?',
|
106
|
-
image=lf_modalities.Image.from_bytes(example_image),
|
107
|
-
)
|
22
|
+
"""Tests for GenAI model."""
|
108
23
|
|
109
|
-
|
110
|
-
with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
|
111
|
-
google_genai.GeminiPro()._content_from_message(message)
|
112
|
-
|
113
|
-
model = google_genai.GeminiProVision()
|
114
|
-
chunks = model._content_from_message(message)
|
115
|
-
self.maxDiff = None
|
116
|
-
self.assertEqual(
|
117
|
-
chunks,
|
118
|
-
[
|
119
|
-
'This is an',
|
120
|
-
genai.types.BlobDict(mime_type='image/png', data=example_image),
|
121
|
-
', what is it?',
|
122
|
-
],
|
123
|
-
)
|
124
|
-
|
125
|
-
def test_response_to_result_text_only(self):
|
126
|
-
response = genai.types.GenerateContentResponse(
|
127
|
-
done=True,
|
128
|
-
iterator=None,
|
129
|
-
chunks=[],
|
130
|
-
result=pg.Dict(
|
131
|
-
prompt_feedback=pg.Dict(block_reason=None),
|
132
|
-
candidates=[
|
133
|
-
pg.Dict(
|
134
|
-
content=pg.Dict(
|
135
|
-
parts=[pg.Dict(text='This is response 1.')]
|
136
|
-
),
|
137
|
-
),
|
138
|
-
pg.Dict(
|
139
|
-
content=pg.Dict(parts=[pg.Dict(text='This is response 2.')])
|
140
|
-
),
|
141
|
-
],
|
142
|
-
),
|
143
|
-
)
|
144
|
-
model = google_genai.GeminiProVision()
|
145
|
-
result = model._response_to_result(response)
|
146
|
-
self.assertEqual(
|
147
|
-
result,
|
148
|
-
lf.LMSamplingResult([
|
149
|
-
lf.LMSample(lf.AIMessage('This is response 1.'), score=0.0),
|
150
|
-
lf.LMSample(lf.AIMessage('This is response 2.'), score=0.0),
|
151
|
-
]),
|
152
|
-
)
|
153
|
-
|
154
|
-
def test_model_hub(self):
|
155
|
-
orig_get_model = genai.get_model
|
156
|
-
genai.get_model = mock_get_model
|
157
|
-
|
158
|
-
model = google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro')
|
159
|
-
self.assertIsNotNone(model)
|
160
|
-
self.assertIs(google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro'), model)
|
161
|
-
|
162
|
-
genai.get_model = orig_get_model
|
163
|
-
|
164
|
-
def test_api_key_check(self):
|
24
|
+
def test_basics(self):
|
165
25
|
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
166
|
-
_ = google_genai.
|
26
|
+
_ = google_genai.GeminiPro1_5().api_endpoint
|
27
|
+
|
28
|
+
self.assertIsNotNone(google_genai.GeminiPro1_5(api_key='abc').api_endpoint)
|
167
29
|
|
168
|
-
self.assertTrue(google_genai.GeminiPro(api_key='abc')._api_initialized)
|
169
30
|
os.environ['GOOGLE_API_KEY'] = 'abc'
|
170
|
-
|
31
|
+
lm = google_genai.GeminiPro1_5()
|
32
|
+
self.assertIsNotNone(lm.api_endpoint)
|
33
|
+
self.assertTrue(lm.model_id.startswith('GenAI('))
|
171
34
|
del os.environ['GOOGLE_API_KEY']
|
172
35
|
|
173
|
-
def test_call(self):
|
174
|
-
with mock.patch(
|
175
|
-
'google.generativeai.GenerativeModel.generate_content',
|
176
|
-
) as mock_generate:
|
177
|
-
orig_get_model = genai.get_model
|
178
|
-
genai.get_model = mock_get_model
|
179
|
-
mock_generate.side_effect = mock_generate_content
|
180
|
-
|
181
|
-
lm = google_genai.GeminiPro(api_key='test_key')
|
182
|
-
self.maxDiff = None
|
183
|
-
self.assertEqual(
|
184
|
-
lm('hello', temperature=2.0, top_k=20, max_tokens=1024).text,
|
185
|
-
(
|
186
|
-
'This is a response to hello with n=1, temperature=2.0, '
|
187
|
-
'top_p=None, top_k=20, max_tokens=1024, stop=None.'
|
188
|
-
),
|
189
|
-
)
|
190
|
-
genai.get_model = orig_get_model
|
191
|
-
|
192
|
-
def test_call_with_legacy_completion_model(self):
|
193
|
-
orig_get_model = genai.get_model
|
194
|
-
genai.get_model = mock_get_model
|
195
|
-
orig_generate_text = getattr(genai, 'generate_text', None)
|
196
|
-
if orig_generate_text is not None:
|
197
|
-
genai.generate_text = mock_generate_text
|
198
|
-
|
199
|
-
lm = google_genai.Palm2(api_key='test_key')
|
200
|
-
self.maxDiff = None
|
201
|
-
self.assertEqual(
|
202
|
-
lm('hello', temperature=2.0, top_k=20).text,
|
203
|
-
(
|
204
|
-
"hello to models/text-bison-001 with {'temperature': 2.0, "
|
205
|
-
"'top_k': 20, 'top_p': None, 'candidate_count': 1, "
|
206
|
-
"'max_output_tokens': None, 'stop_sequences': None}"
|
207
|
-
),
|
208
|
-
)
|
209
|
-
genai.generate_text = orig_generate_text
|
210
|
-
genai.get_model = orig_get_model
|
211
|
-
|
212
|
-
def test_call_with_legacy_chat_model(self):
|
213
|
-
orig_get_model = genai.get_model
|
214
|
-
genai.get_model = mock_get_model
|
215
|
-
orig_chat = getattr(genai, 'chat', None)
|
216
|
-
if orig_chat is not None:
|
217
|
-
genai.chat = mock_chat
|
218
|
-
|
219
|
-
lm = google_genai.Palm2_IT(api_key='test_key')
|
220
|
-
self.maxDiff = None
|
221
|
-
self.assertEqual(
|
222
|
-
lm('hello', temperature=2.0, top_k=20).text,
|
223
|
-
(
|
224
|
-
"hello to models/chat-bison-001 with {'temperature': 2.0, "
|
225
|
-
"'top_k': 20, 'top_p': None, 'candidate_count': 1}"
|
226
|
-
),
|
227
|
-
)
|
228
|
-
genai.chat = orig_chat
|
229
|
-
genai.get_model = orig_get_model
|
230
|
-
|
231
36
|
|
232
37
|
if __name__ == '__main__':
|
233
38
|
unittest.main()
|
langfun/core/llms/openai.py
CHANGED
@@ -553,26 +553,31 @@ class GptO1(OpenAI):
|
|
553
553
|
|
554
554
|
model = 'o1'
|
555
555
|
multimodal = True
|
556
|
+
timeout = None
|
556
557
|
|
557
558
|
|
558
559
|
class GptO1Preview(OpenAI):
|
559
560
|
"""GPT-O1."""
|
560
561
|
model = 'o1-preview'
|
562
|
+
timeout = None
|
561
563
|
|
562
564
|
|
563
565
|
class GptO1Preview_20240912(OpenAI): # pylint: disable=invalid-name
|
564
566
|
"""GPT O1."""
|
565
567
|
model = 'o1-preview-2024-09-12'
|
568
|
+
timeout = None
|
566
569
|
|
567
570
|
|
568
571
|
class GptO1Mini(OpenAI):
|
569
572
|
"""GPT O1-mini."""
|
570
573
|
model = 'o1-mini'
|
574
|
+
timeout = None
|
571
575
|
|
572
576
|
|
573
577
|
class GptO1Mini_20240912(OpenAI): # pylint: disable=invalid-name
|
574
578
|
"""GPT O1-mini."""
|
575
579
|
model = 'o1-mini-2024-09-12'
|
580
|
+
timeout = None
|
576
581
|
|
577
582
|
|
578
583
|
class Gpt4(OpenAI):
|