langfun 0.1.2.dev202501010804__py3-none-any.whl → 0.1.2.dev202501060804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/__init__.py +0 -4
- langfun/core/eval/matching.py +2 -2
- langfun/core/eval/scoring.py +6 -2
- langfun/core/eval/v2/checkpointing.py +106 -72
- langfun/core/eval/v2/checkpointing_test.py +108 -3
- langfun/core/eval/v2/eval_test_helper.py +56 -0
- langfun/core/eval/v2/evaluation.py +25 -4
- langfun/core/eval/v2/evaluation_test.py +11 -0
- langfun/core/eval/v2/example.py +11 -1
- langfun/core/eval/v2/example_test.py +16 -2
- langfun/core/eval/v2/experiment.py +83 -19
- langfun/core/eval/v2/experiment_test.py +121 -3
- langfun/core/eval/v2/reporting.py +67 -20
- langfun/core/eval/v2/reporting_test.py +119 -2
- langfun/core/eval/v2/runners.py +7 -4
- langfun/core/llms/__init__.py +23 -24
- langfun/core/llms/anthropic.py +12 -0
- langfun/core/llms/cache/in_memory.py +6 -0
- langfun/core/llms/cache/in_memory_test.py +5 -0
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +46 -310
- langfun/core/llms/google_genai_test.py +9 -204
- langfun/core/llms/openai.py +23 -37
- langfun/core/llms/vertexai.py +28 -348
- langfun/core/llms/vertexai_test.py +6 -166
- {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/METADATA +7 -13
- {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/RECORD +31 -31
- {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/WHEEL +1 -1
- langfun/core/repr_utils.py +0 -204
- langfun/core/repr_utils_test.py +0 -90
- {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/top_level.txt +0 -0
@@ -11,223 +11,28 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""Tests for
|
14
|
+
"""Tests for Google GenAI models."""
|
15
15
|
|
16
16
|
import os
|
17
17
|
import unittest
|
18
|
-
from unittest import mock
|
19
|
-
|
20
|
-
from google import generativeai as genai
|
21
|
-
import langfun.core as lf
|
22
|
-
from langfun.core import modalities as lf_modalities
|
23
18
|
from langfun.core.llms import google_genai
|
24
|
-
import pyglove as pg
|
25
|
-
|
26
|
-
|
27
|
-
example_image = (
|
28
|
-
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
|
29
|
-
b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
|
30
|
-
b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
|
31
|
-
b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
|
32
|
-
b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
|
33
|
-
b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
|
34
|
-
b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
|
35
|
-
b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
|
36
|
-
)
|
37
|
-
|
38
|
-
|
39
|
-
def mock_get_model(model_name, *args, **kwargs):
|
40
|
-
del args, kwargs
|
41
|
-
if 'gemini' in model_name:
|
42
|
-
method = 'generateContent'
|
43
|
-
elif 'chat' in model_name:
|
44
|
-
method = 'generateMessage'
|
45
|
-
else:
|
46
|
-
method = 'generateText'
|
47
|
-
return pg.Dict(supported_generation_methods=[method])
|
48
|
-
|
49
|
-
|
50
|
-
def mock_generate_text(*, model, prompt, **kwargs):
|
51
|
-
return pg.Dict(
|
52
|
-
candidates=[pg.Dict(output=f'{prompt} to {model} with {kwargs}')]
|
53
|
-
)
|
54
|
-
|
55
|
-
|
56
|
-
def mock_chat(*, model, messages, **kwargs):
|
57
|
-
return pg.Dict(
|
58
|
-
candidates=[pg.Dict(content=f'{messages} to {model} with {kwargs}')]
|
59
|
-
)
|
60
|
-
|
61
|
-
|
62
|
-
def mock_generate_content(content, generation_config, **kwargs):
|
63
|
-
del kwargs
|
64
|
-
c = generation_config
|
65
|
-
return genai.types.GenerateContentResponse(
|
66
|
-
done=True,
|
67
|
-
iterator=None,
|
68
|
-
chunks=[],
|
69
|
-
result=pg.Dict(
|
70
|
-
prompt_feedback=pg.Dict(block_reason=None),
|
71
|
-
candidates=[
|
72
|
-
pg.Dict(
|
73
|
-
content=pg.Dict(
|
74
|
-
parts=[
|
75
|
-
pg.Dict(
|
76
|
-
text=(
|
77
|
-
f'This is a response to {content[0]} with '
|
78
|
-
f'n={c.candidate_count}, '
|
79
|
-
f'temperature={c.temperature}, '
|
80
|
-
f'top_p={c.top_p}, '
|
81
|
-
f'top_k={c.top_k}, '
|
82
|
-
f'max_tokens={c.max_output_tokens}, '
|
83
|
-
f'stop={c.stop_sequences}.'
|
84
|
-
)
|
85
|
-
)
|
86
|
-
]
|
87
|
-
),
|
88
|
-
),
|
89
|
-
],
|
90
|
-
),
|
91
|
-
)
|
92
19
|
|
93
20
|
|
94
21
|
class GenAITest(unittest.TestCase):
|
95
|
-
"""Tests for
|
96
|
-
|
97
|
-
def test_content_from_message_text_only(self):
|
98
|
-
text = 'This is a beautiful day'
|
99
|
-
model = google_genai.GeminiPro()
|
100
|
-
chunks = model._content_from_message(lf.UserMessage(text))
|
101
|
-
self.assertEqual(chunks, [text])
|
102
|
-
|
103
|
-
def test_content_from_message_mm(self):
|
104
|
-
message = lf.UserMessage(
|
105
|
-
'This is an <<[[image]]>>, what is it?',
|
106
|
-
image=lf_modalities.Image.from_bytes(example_image),
|
107
|
-
)
|
22
|
+
"""Tests for GenAI model."""
|
108
23
|
|
109
|
-
|
110
|
-
with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
|
111
|
-
google_genai.GeminiPro()._content_from_message(message)
|
112
|
-
|
113
|
-
model = google_genai.GeminiProVision()
|
114
|
-
chunks = model._content_from_message(message)
|
115
|
-
self.maxDiff = None
|
116
|
-
self.assertEqual(
|
117
|
-
chunks,
|
118
|
-
[
|
119
|
-
'This is an',
|
120
|
-
genai.types.BlobDict(mime_type='image/png', data=example_image),
|
121
|
-
', what is it?',
|
122
|
-
],
|
123
|
-
)
|
124
|
-
|
125
|
-
def test_response_to_result_text_only(self):
|
126
|
-
response = genai.types.GenerateContentResponse(
|
127
|
-
done=True,
|
128
|
-
iterator=None,
|
129
|
-
chunks=[],
|
130
|
-
result=pg.Dict(
|
131
|
-
prompt_feedback=pg.Dict(block_reason=None),
|
132
|
-
candidates=[
|
133
|
-
pg.Dict(
|
134
|
-
content=pg.Dict(
|
135
|
-
parts=[pg.Dict(text='This is response 1.')]
|
136
|
-
),
|
137
|
-
),
|
138
|
-
pg.Dict(
|
139
|
-
content=pg.Dict(parts=[pg.Dict(text='This is response 2.')])
|
140
|
-
),
|
141
|
-
],
|
142
|
-
),
|
143
|
-
)
|
144
|
-
model = google_genai.GeminiProVision()
|
145
|
-
result = model._response_to_result(response)
|
146
|
-
self.assertEqual(
|
147
|
-
result,
|
148
|
-
lf.LMSamplingResult([
|
149
|
-
lf.LMSample(lf.AIMessage('This is response 1.'), score=0.0),
|
150
|
-
lf.LMSample(lf.AIMessage('This is response 2.'), score=0.0),
|
151
|
-
]),
|
152
|
-
)
|
153
|
-
|
154
|
-
def test_model_hub(self):
|
155
|
-
orig_get_model = genai.get_model
|
156
|
-
genai.get_model = mock_get_model
|
157
|
-
|
158
|
-
model = google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro')
|
159
|
-
self.assertIsNotNone(model)
|
160
|
-
self.assertIs(google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro'), model)
|
161
|
-
|
162
|
-
genai.get_model = orig_get_model
|
163
|
-
|
164
|
-
def test_api_key_check(self):
|
24
|
+
def test_basics(self):
|
165
25
|
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
166
|
-
_ = google_genai.
|
26
|
+
_ = google_genai.GeminiPro1_5().api_endpoint
|
27
|
+
|
28
|
+
self.assertIsNotNone(google_genai.GeminiPro1_5(api_key='abc').api_endpoint)
|
167
29
|
|
168
|
-
self.assertTrue(google_genai.GeminiPro(api_key='abc')._api_initialized)
|
169
30
|
os.environ['GOOGLE_API_KEY'] = 'abc'
|
170
|
-
|
31
|
+
lm = google_genai.GeminiPro1_5()
|
32
|
+
self.assertIsNotNone(lm.api_endpoint)
|
33
|
+
self.assertTrue(lm.model_id.startswith('GenAI('))
|
171
34
|
del os.environ['GOOGLE_API_KEY']
|
172
35
|
|
173
|
-
def test_call(self):
|
174
|
-
with mock.patch(
|
175
|
-
'google.generativeai.GenerativeModel.generate_content',
|
176
|
-
) as mock_generate:
|
177
|
-
orig_get_model = genai.get_model
|
178
|
-
genai.get_model = mock_get_model
|
179
|
-
mock_generate.side_effect = mock_generate_content
|
180
|
-
|
181
|
-
lm = google_genai.GeminiPro(api_key='test_key')
|
182
|
-
self.maxDiff = None
|
183
|
-
self.assertEqual(
|
184
|
-
lm('hello', temperature=2.0, top_k=20, max_tokens=1024).text,
|
185
|
-
(
|
186
|
-
'This is a response to hello with n=1, temperature=2.0, '
|
187
|
-
'top_p=None, top_k=20, max_tokens=1024, stop=None.'
|
188
|
-
),
|
189
|
-
)
|
190
|
-
genai.get_model = orig_get_model
|
191
|
-
|
192
|
-
def test_call_with_legacy_completion_model(self):
|
193
|
-
orig_get_model = genai.get_model
|
194
|
-
genai.get_model = mock_get_model
|
195
|
-
orig_generate_text = getattr(genai, 'generate_text', None)
|
196
|
-
if orig_generate_text is not None:
|
197
|
-
genai.generate_text = mock_generate_text
|
198
|
-
|
199
|
-
lm = google_genai.Palm2(api_key='test_key')
|
200
|
-
self.maxDiff = None
|
201
|
-
self.assertEqual(
|
202
|
-
lm('hello', temperature=2.0, top_k=20).text,
|
203
|
-
(
|
204
|
-
"hello to models/text-bison-001 with {'temperature': 2.0, "
|
205
|
-
"'top_k': 20, 'top_p': None, 'candidate_count': 1, "
|
206
|
-
"'max_output_tokens': None, 'stop_sequences': None}"
|
207
|
-
),
|
208
|
-
)
|
209
|
-
genai.generate_text = orig_generate_text
|
210
|
-
genai.get_model = orig_get_model
|
211
|
-
|
212
|
-
def test_call_with_legacy_chat_model(self):
|
213
|
-
orig_get_model = genai.get_model
|
214
|
-
genai.get_model = mock_get_model
|
215
|
-
orig_chat = getattr(genai, 'chat', None)
|
216
|
-
if orig_chat is not None:
|
217
|
-
genai.chat = mock_chat
|
218
|
-
|
219
|
-
lm = google_genai.Palm2_IT(api_key='test_key')
|
220
|
-
self.maxDiff = None
|
221
|
-
self.assertEqual(
|
222
|
-
lm('hello', temperature=2.0, top_k=20).text,
|
223
|
-
(
|
224
|
-
"hello to models/chat-bison-001 with {'temperature': 2.0, "
|
225
|
-
"'top_k': 20, 'top_p': None, 'candidate_count': 1}"
|
226
|
-
),
|
227
|
-
)
|
228
|
-
genai.chat = orig_chat
|
229
|
-
genai.get_model = orig_get_model
|
230
|
-
|
231
36
|
|
232
37
|
if __name__ == '__main__':
|
233
38
|
unittest.main()
|
langfun/core/llms/openai.py
CHANGED
@@ -32,6 +32,13 @@ SUPPORTED_MODELS_AND_SETTINGS = {
|
|
32
32
|
# o1 (preview) models.
|
33
33
|
# Pricing in US dollars, from https://openai.com/api/pricing/
|
34
34
|
# as of 2024-10-10.
|
35
|
+
'o1': pg.Dict(
|
36
|
+
in_service=True,
|
37
|
+
rpm=10000,
|
38
|
+
tpm=5000000,
|
39
|
+
cost_per_1k_input_tokens=0.015,
|
40
|
+
cost_per_1k_output_tokens=0.06,
|
41
|
+
),
|
35
42
|
'o1-preview': pg.Dict(
|
36
43
|
in_service=True,
|
37
44
|
rpm=10000,
|
@@ -255,25 +262,17 @@ SUPPORTED_MODELS_AND_SETTINGS = {
|
|
255
262
|
),
|
256
263
|
# GPT-3.5 models
|
257
264
|
'text-davinci-003': pg.Dict(
|
258
|
-
in_service=False,
|
259
|
-
rpm=_DEFAULT_RPM,
|
260
|
-
tpm=_DEFAULT_TPM
|
265
|
+
in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM
|
261
266
|
),
|
262
267
|
'text-davinci-002': pg.Dict(
|
263
|
-
in_service=False,
|
264
|
-
rpm=_DEFAULT_RPM,
|
265
|
-
tpm=_DEFAULT_TPM
|
268
|
+
in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM
|
266
269
|
),
|
267
270
|
'code-davinci-002': pg.Dict(
|
268
|
-
in_service=False,
|
269
|
-
rpm=_DEFAULT_RPM,
|
270
|
-
tpm=_DEFAULT_TPM
|
271
|
+
in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM
|
271
272
|
),
|
272
273
|
# GPT-3 instruction-tuned models (Deprecated)
|
273
274
|
'text-curie-001': pg.Dict(
|
274
|
-
in_service=False,
|
275
|
-
rpm=_DEFAULT_RPM,
|
276
|
-
tpm=_DEFAULT_TPM
|
275
|
+
in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM
|
277
276
|
),
|
278
277
|
'text-babbage-001': pg.Dict(
|
279
278
|
in_service=False,
|
@@ -290,32 +289,12 @@ SUPPORTED_MODELS_AND_SETTINGS = {
|
|
290
289
|
rpm=_DEFAULT_RPM,
|
291
290
|
tpm=_DEFAULT_TPM,
|
292
291
|
),
|
293
|
-
'curie': pg.Dict(
|
294
|
-
|
295
|
-
|
296
|
-
tpm=_DEFAULT_TPM
|
297
|
-
),
|
298
|
-
'babbage': pg.Dict(
|
299
|
-
in_service=False,
|
300
|
-
rpm=_DEFAULT_RPM,
|
301
|
-
tpm=_DEFAULT_TPM
|
302
|
-
),
|
303
|
-
'ada': pg.Dict(
|
304
|
-
in_service=False,
|
305
|
-
rpm=_DEFAULT_RPM,
|
306
|
-
tpm=_DEFAULT_TPM
|
307
|
-
),
|
292
|
+
'curie': pg.Dict(in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
293
|
+
'babbage': pg.Dict(in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
294
|
+
'ada': pg.Dict(in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
308
295
|
# GPT-3 base models that are still in service.
|
309
|
-
'babbage-002': pg.Dict(
|
310
|
-
|
311
|
-
rpm=_DEFAULT_RPM,
|
312
|
-
tpm=_DEFAULT_TPM
|
313
|
-
),
|
314
|
-
'davinci-002': pg.Dict(
|
315
|
-
in_service=True,
|
316
|
-
rpm=_DEFAULT_RPM,
|
317
|
-
tpm=_DEFAULT_TPM
|
318
|
-
),
|
296
|
+
'babbage-002': pg.Dict(in_service=True, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
297
|
+
'davinci-002': pg.Dict(in_service=True, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
319
298
|
}
|
320
299
|
|
321
300
|
|
@@ -569,6 +548,13 @@ class OpenAI(rest.REST):
|
|
569
548
|
)
|
570
549
|
|
571
550
|
|
551
|
+
class GptO1(OpenAI):
|
552
|
+
"""GPT-O1."""
|
553
|
+
|
554
|
+
model = 'o1'
|
555
|
+
multimodal = True
|
556
|
+
|
557
|
+
|
572
558
|
class GptO1Preview(OpenAI):
|
573
559
|
"""GPT-O1."""
|
574
560
|
model = 'o1-preview'
|