langfun 0.1.2.dev202501010804__py3-none-any.whl → 0.1.2.dev202501060804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. langfun/core/__init__.py +0 -4
  2. langfun/core/eval/matching.py +2 -2
  3. langfun/core/eval/scoring.py +6 -2
  4. langfun/core/eval/v2/checkpointing.py +106 -72
  5. langfun/core/eval/v2/checkpointing_test.py +108 -3
  6. langfun/core/eval/v2/eval_test_helper.py +56 -0
  7. langfun/core/eval/v2/evaluation.py +25 -4
  8. langfun/core/eval/v2/evaluation_test.py +11 -0
  9. langfun/core/eval/v2/example.py +11 -1
  10. langfun/core/eval/v2/example_test.py +16 -2
  11. langfun/core/eval/v2/experiment.py +83 -19
  12. langfun/core/eval/v2/experiment_test.py +121 -3
  13. langfun/core/eval/v2/reporting.py +67 -20
  14. langfun/core/eval/v2/reporting_test.py +119 -2
  15. langfun/core/eval/v2/runners.py +7 -4
  16. langfun/core/llms/__init__.py +23 -24
  17. langfun/core/llms/anthropic.py +12 -0
  18. langfun/core/llms/cache/in_memory.py +6 -0
  19. langfun/core/llms/cache/in_memory_test.py +5 -0
  20. langfun/core/llms/gemini.py +507 -0
  21. langfun/core/llms/gemini_test.py +195 -0
  22. langfun/core/llms/google_genai.py +46 -310
  23. langfun/core/llms/google_genai_test.py +9 -204
  24. langfun/core/llms/openai.py +23 -37
  25. langfun/core/llms/vertexai.py +28 -348
  26. langfun/core/llms/vertexai_test.py +6 -166
  27. {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/METADATA +7 -13
  28. {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/RECORD +31 -31
  29. {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/WHEEL +1 -1
  30. langfun/core/repr_utils.py +0 -204
  31. langfun/core/repr_utils_test.py +0 -90
  32. {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/LICENSE +0 -0
  33. {langfun-0.1.2.dev202501010804.dist-info → langfun-0.1.2.dev202501060804.dist-info}/top_level.txt +0 -0
@@ -11,223 +11,28 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Tests for Gemini models."""
14
+ """Tests for Google GenAI models."""
15
15
 
16
16
  import os
17
17
  import unittest
18
- from unittest import mock
19
-
20
- from google import generativeai as genai
21
- import langfun.core as lf
22
- from langfun.core import modalities as lf_modalities
23
18
  from langfun.core.llms import google_genai
24
- import pyglove as pg
25
-
26
-
27
- example_image = (
28
- b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
29
- b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
30
- b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
31
- b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
32
- b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
33
- b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
34
- b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
35
- b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
36
- )
37
-
38
-
39
- def mock_get_model(model_name, *args, **kwargs):
40
- del args, kwargs
41
- if 'gemini' in model_name:
42
- method = 'generateContent'
43
- elif 'chat' in model_name:
44
- method = 'generateMessage'
45
- else:
46
- method = 'generateText'
47
- return pg.Dict(supported_generation_methods=[method])
48
-
49
-
50
- def mock_generate_text(*, model, prompt, **kwargs):
51
- return pg.Dict(
52
- candidates=[pg.Dict(output=f'{prompt} to {model} with {kwargs}')]
53
- )
54
-
55
-
56
- def mock_chat(*, model, messages, **kwargs):
57
- return pg.Dict(
58
- candidates=[pg.Dict(content=f'{messages} to {model} with {kwargs}')]
59
- )
60
-
61
-
62
- def mock_generate_content(content, generation_config, **kwargs):
63
- del kwargs
64
- c = generation_config
65
- return genai.types.GenerateContentResponse(
66
- done=True,
67
- iterator=None,
68
- chunks=[],
69
- result=pg.Dict(
70
- prompt_feedback=pg.Dict(block_reason=None),
71
- candidates=[
72
- pg.Dict(
73
- content=pg.Dict(
74
- parts=[
75
- pg.Dict(
76
- text=(
77
- f'This is a response to {content[0]} with '
78
- f'n={c.candidate_count}, '
79
- f'temperature={c.temperature}, '
80
- f'top_p={c.top_p}, '
81
- f'top_k={c.top_k}, '
82
- f'max_tokens={c.max_output_tokens}, '
83
- f'stop={c.stop_sequences}.'
84
- )
85
- )
86
- ]
87
- ),
88
- ),
89
- ],
90
- ),
91
- )
92
19
 
93
20
 
94
21
  class GenAITest(unittest.TestCase):
95
- """Tests for Google GenAI model."""
96
-
97
- def test_content_from_message_text_only(self):
98
- text = 'This is a beautiful day'
99
- model = google_genai.GeminiPro()
100
- chunks = model._content_from_message(lf.UserMessage(text))
101
- self.assertEqual(chunks, [text])
102
-
103
- def test_content_from_message_mm(self):
104
- message = lf.UserMessage(
105
- 'This is an <<[[image]]>>, what is it?',
106
- image=lf_modalities.Image.from_bytes(example_image),
107
- )
22
+ """Tests for GenAI model."""
108
23
 
109
- # Non-multimodal model.
110
- with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
111
- google_genai.GeminiPro()._content_from_message(message)
112
-
113
- model = google_genai.GeminiProVision()
114
- chunks = model._content_from_message(message)
115
- self.maxDiff = None
116
- self.assertEqual(
117
- chunks,
118
- [
119
- 'This is an',
120
- genai.types.BlobDict(mime_type='image/png', data=example_image),
121
- ', what is it?',
122
- ],
123
- )
124
-
125
- def test_response_to_result_text_only(self):
126
- response = genai.types.GenerateContentResponse(
127
- done=True,
128
- iterator=None,
129
- chunks=[],
130
- result=pg.Dict(
131
- prompt_feedback=pg.Dict(block_reason=None),
132
- candidates=[
133
- pg.Dict(
134
- content=pg.Dict(
135
- parts=[pg.Dict(text='This is response 1.')]
136
- ),
137
- ),
138
- pg.Dict(
139
- content=pg.Dict(parts=[pg.Dict(text='This is response 2.')])
140
- ),
141
- ],
142
- ),
143
- )
144
- model = google_genai.GeminiProVision()
145
- result = model._response_to_result(response)
146
- self.assertEqual(
147
- result,
148
- lf.LMSamplingResult([
149
- lf.LMSample(lf.AIMessage('This is response 1.'), score=0.0),
150
- lf.LMSample(lf.AIMessage('This is response 2.'), score=0.0),
151
- ]),
152
- )
153
-
154
- def test_model_hub(self):
155
- orig_get_model = genai.get_model
156
- genai.get_model = mock_get_model
157
-
158
- model = google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro')
159
- self.assertIsNotNone(model)
160
- self.assertIs(google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro'), model)
161
-
162
- genai.get_model = orig_get_model
163
-
164
- def test_api_key_check(self):
24
+ def test_basics(self):
165
25
  with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
166
- _ = google_genai.GeminiPro()._api_initialized
26
+ _ = google_genai.GeminiPro1_5().api_endpoint
27
+
28
+ self.assertIsNotNone(google_genai.GeminiPro1_5(api_key='abc').api_endpoint)
167
29
 
168
- self.assertTrue(google_genai.GeminiPro(api_key='abc')._api_initialized)
169
30
  os.environ['GOOGLE_API_KEY'] = 'abc'
170
- self.assertTrue(google_genai.GeminiPro()._api_initialized)
31
+ lm = google_genai.GeminiPro1_5()
32
+ self.assertIsNotNone(lm.api_endpoint)
33
+ self.assertTrue(lm.model_id.startswith('GenAI('))
171
34
  del os.environ['GOOGLE_API_KEY']
172
35
 
173
- def test_call(self):
174
- with mock.patch(
175
- 'google.generativeai.GenerativeModel.generate_content',
176
- ) as mock_generate:
177
- orig_get_model = genai.get_model
178
- genai.get_model = mock_get_model
179
- mock_generate.side_effect = mock_generate_content
180
-
181
- lm = google_genai.GeminiPro(api_key='test_key')
182
- self.maxDiff = None
183
- self.assertEqual(
184
- lm('hello', temperature=2.0, top_k=20, max_tokens=1024).text,
185
- (
186
- 'This is a response to hello with n=1, temperature=2.0, '
187
- 'top_p=None, top_k=20, max_tokens=1024, stop=None.'
188
- ),
189
- )
190
- genai.get_model = orig_get_model
191
-
192
- def test_call_with_legacy_completion_model(self):
193
- orig_get_model = genai.get_model
194
- genai.get_model = mock_get_model
195
- orig_generate_text = getattr(genai, 'generate_text', None)
196
- if orig_generate_text is not None:
197
- genai.generate_text = mock_generate_text
198
-
199
- lm = google_genai.Palm2(api_key='test_key')
200
- self.maxDiff = None
201
- self.assertEqual(
202
- lm('hello', temperature=2.0, top_k=20).text,
203
- (
204
- "hello to models/text-bison-001 with {'temperature': 2.0, "
205
- "'top_k': 20, 'top_p': None, 'candidate_count': 1, "
206
- "'max_output_tokens': None, 'stop_sequences': None}"
207
- ),
208
- )
209
- genai.generate_text = orig_generate_text
210
- genai.get_model = orig_get_model
211
-
212
- def test_call_with_legacy_chat_model(self):
213
- orig_get_model = genai.get_model
214
- genai.get_model = mock_get_model
215
- orig_chat = getattr(genai, 'chat', None)
216
- if orig_chat is not None:
217
- genai.chat = mock_chat
218
-
219
- lm = google_genai.Palm2_IT(api_key='test_key')
220
- self.maxDiff = None
221
- self.assertEqual(
222
- lm('hello', temperature=2.0, top_k=20).text,
223
- (
224
- "hello to models/chat-bison-001 with {'temperature': 2.0, "
225
- "'top_k': 20, 'top_p': None, 'candidate_count': 1}"
226
- ),
227
- )
228
- genai.chat = orig_chat
229
- genai.get_model = orig_get_model
230
-
231
36
 
232
37
  if __name__ == '__main__':
233
38
  unittest.main()
@@ -32,6 +32,13 @@ SUPPORTED_MODELS_AND_SETTINGS = {
32
32
  # o1 (preview) models.
33
33
  # Pricing in US dollars, from https://openai.com/api/pricing/
34
34
  # as of 2024-10-10.
35
+ 'o1': pg.Dict(
36
+ in_service=True,
37
+ rpm=10000,
38
+ tpm=5000000,
39
+ cost_per_1k_input_tokens=0.015,
40
+ cost_per_1k_output_tokens=0.06,
41
+ ),
35
42
  'o1-preview': pg.Dict(
36
43
  in_service=True,
37
44
  rpm=10000,
@@ -255,25 +262,17 @@ SUPPORTED_MODELS_AND_SETTINGS = {
255
262
  ),
256
263
  # GPT-3.5 models
257
264
  'text-davinci-003': pg.Dict(
258
- in_service=False,
259
- rpm=_DEFAULT_RPM,
260
- tpm=_DEFAULT_TPM
265
+ in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM
261
266
  ),
262
267
  'text-davinci-002': pg.Dict(
263
- in_service=False,
264
- rpm=_DEFAULT_RPM,
265
- tpm=_DEFAULT_TPM
268
+ in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM
266
269
  ),
267
270
  'code-davinci-002': pg.Dict(
268
- in_service=False,
269
- rpm=_DEFAULT_RPM,
270
- tpm=_DEFAULT_TPM
271
+ in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM
271
272
  ),
272
273
  # GPT-3 instruction-tuned models (Deprecated)
273
274
  'text-curie-001': pg.Dict(
274
- in_service=False,
275
- rpm=_DEFAULT_RPM,
276
- tpm=_DEFAULT_TPM
275
+ in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM
277
276
  ),
278
277
  'text-babbage-001': pg.Dict(
279
278
  in_service=False,
@@ -290,32 +289,12 @@ SUPPORTED_MODELS_AND_SETTINGS = {
290
289
  rpm=_DEFAULT_RPM,
291
290
  tpm=_DEFAULT_TPM,
292
291
  ),
293
- 'curie': pg.Dict(
294
- in_service=False,
295
- rpm=_DEFAULT_RPM,
296
- tpm=_DEFAULT_TPM
297
- ),
298
- 'babbage': pg.Dict(
299
- in_service=False,
300
- rpm=_DEFAULT_RPM,
301
- tpm=_DEFAULT_TPM
302
- ),
303
- 'ada': pg.Dict(
304
- in_service=False,
305
- rpm=_DEFAULT_RPM,
306
- tpm=_DEFAULT_TPM
307
- ),
292
+ 'curie': pg.Dict(in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
293
+ 'babbage': pg.Dict(in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
294
+ 'ada': pg.Dict(in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
308
295
  # GPT-3 base models that are still in service.
309
- 'babbage-002': pg.Dict(
310
- in_service=True,
311
- rpm=_DEFAULT_RPM,
312
- tpm=_DEFAULT_TPM
313
- ),
314
- 'davinci-002': pg.Dict(
315
- in_service=True,
316
- rpm=_DEFAULT_RPM,
317
- tpm=_DEFAULT_TPM
318
- ),
296
+ 'babbage-002': pg.Dict(in_service=True, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
297
+ 'davinci-002': pg.Dict(in_service=True, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
319
298
  }
320
299
 
321
300
 
@@ -569,6 +548,13 @@ class OpenAI(rest.REST):
569
548
  )
570
549
 
571
550
 
551
+ class GptO1(OpenAI):
552
+ """GPT-O1."""
553
+
554
+ model = 'o1'
555
+ multimodal = True
556
+
557
+
572
558
  class GptO1Preview(OpenAI):
573
559
  """GPT-O1."""
574
560
  model = 'o1-preview'