langfun 0.1.2.dev202501050804__py3-none-any.whl → 0.1.2.dev202501090804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. langfun/core/__init__.py +0 -5
  2. langfun/core/coding/python/correction.py +4 -3
  3. langfun/core/coding/python/errors.py +10 -9
  4. langfun/core/coding/python/execution.py +23 -12
  5. langfun/core/coding/python/execution_test.py +21 -2
  6. langfun/core/coding/python/generation.py +18 -9
  7. langfun/core/concurrent.py +2 -3
  8. langfun/core/console.py +8 -3
  9. langfun/core/eval/base.py +2 -3
  10. langfun/core/eval/v2/reporting.py +15 -6
  11. langfun/core/language_model.py +7 -4
  12. langfun/core/language_model_test.py +15 -0
  13. langfun/core/llms/__init__.py +25 -26
  14. langfun/core/llms/cache/in_memory.py +6 -0
  15. langfun/core/llms/cache/in_memory_test.py +5 -0
  16. langfun/core/llms/deepseek.py +261 -0
  17. langfun/core/llms/deepseek_test.py +438 -0
  18. langfun/core/llms/gemini.py +507 -0
  19. langfun/core/llms/gemini_test.py +195 -0
  20. langfun/core/llms/google_genai.py +46 -320
  21. langfun/core/llms/google_genai_test.py +9 -204
  22. langfun/core/llms/openai.py +5 -0
  23. langfun/core/llms/vertexai.py +31 -359
  24. langfun/core/llms/vertexai_test.py +6 -166
  25. langfun/core/structured/mapping.py +13 -13
  26. langfun/core/structured/mapping_test.py +2 -2
  27. langfun/core/structured/schema.py +16 -8
  28. {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/METADATA +19 -14
  29. {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/RECORD +32 -30
  30. {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/WHEEL +1 -1
  31. langfun/core/text_formatting.py +0 -168
  32. langfun/core/text_formatting_test.py +0 -65
  33. {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/LICENSE +0 -0
  34. {langfun-0.1.2.dev202501050804.dist-info → langfun-0.1.2.dev202501090804.dist-info}/top_level.txt +0 -0
@@ -553,26 +553,31 @@ class GptO1(OpenAI):
553
553
 
554
554
  model = 'o1'
555
555
  multimodal = True
556
+ timeout = None
556
557
 
557
558
 
558
559
  class GptO1Preview(OpenAI):
559
560
  """GPT-O1."""
560
561
  model = 'o1-preview'
562
+ timeout = None
561
563
 
562
564
 
563
565
  class GptO1Preview_20240912(OpenAI): # pylint: disable=invalid-name
564
566
  """GPT O1."""
565
567
  model = 'o1-preview-2024-09-12'
568
+ timeout = None
566
569
 
567
570
 
568
571
  class GptO1Mini(OpenAI):
569
572
  """GPT O1-mini."""
570
573
  model = 'o1-mini'
574
+ timeout = None
571
575
 
572
576
 
573
577
  class GptO1Mini_20240912(OpenAI): # pylint: disable=invalid-name
574
578
  """GPT O1-mini."""
575
579
  model = 'o1-mini-2024-09-12'
580
+ timeout = None
576
581
 
577
582
 
578
583
  class Gpt4(OpenAI):
@@ -13,14 +13,12 @@
13
13
  # limitations under the License.
14
14
  """Vertex AI generative models."""
15
15
 
16
- import base64
17
16
  import functools
18
17
  import os
19
18
  from typing import Annotated, Any
20
19
 
21
20
  import langfun.core as lf
22
- from langfun.core import modalities as lf_modalities
23
- from langfun.core.llms import rest
21
+ from langfun.core.llms import gemini
24
22
  import pyglove as pg
25
23
 
26
24
  try:
@@ -38,114 +36,11 @@ except ImportError:
38
36
  Credentials = Any
39
37
 
40
38
 
41
- # https://cloud.google.com/vertex-ai/generative-ai/pricing
42
- # describes that the average number of characters per token is about 4.
43
- AVGERAGE_CHARS_PER_TOKEN = 4
44
-
45
-
46
- # Price in US dollars,
47
- # from https://cloud.google.com/vertex-ai/generative-ai/pricing
48
- # as of 2024-10-10.
49
- SUPPORTED_MODELS_AND_SETTINGS = {
50
- 'gemini-1.5-pro-001': pg.Dict(
51
- rpm=100,
52
- cost_per_1k_input_chars=0.0003125,
53
- cost_per_1k_output_chars=0.00125,
54
- ),
55
- 'gemini-1.5-pro-002': pg.Dict(
56
- rpm=100,
57
- cost_per_1k_input_chars=0.0003125,
58
- cost_per_1k_output_chars=0.00125,
59
- ),
60
- 'gemini-1.5-flash-002': pg.Dict(
61
- rpm=500,
62
- cost_per_1k_input_chars=0.00001875,
63
- cost_per_1k_output_chars=0.000075,
64
- ),
65
- 'gemini-1.5-flash-001': pg.Dict(
66
- rpm=500,
67
- cost_per_1k_input_chars=0.00001875,
68
- cost_per_1k_output_chars=0.000075,
69
- ),
70
- 'gemini-1.5-pro': pg.Dict(
71
- rpm=100,
72
- cost_per_1k_input_chars=0.0003125,
73
- cost_per_1k_output_chars=0.00125,
74
- ),
75
- 'gemini-1.5-flash': pg.Dict(
76
- rpm=500,
77
- cost_per_1k_input_chars=0.00001875,
78
- cost_per_1k_output_chars=0.000075,
79
- ),
80
- 'gemini-1.5-pro-preview-0514': pg.Dict(
81
- rpm=50,
82
- cost_per_1k_input_chars=0.0003125,
83
- cost_per_1k_output_chars=0.00125,
84
- ),
85
- 'gemini-1.5-pro-preview-0409': pg.Dict(
86
- rpm=50,
87
- cost_per_1k_input_chars=0.0003125,
88
- cost_per_1k_output_chars=0.00125,
89
- ),
90
- 'gemini-1.5-flash-preview-0514': pg.Dict(
91
- rpm=200,
92
- cost_per_1k_input_chars=0.00001875,
93
- cost_per_1k_output_chars=0.000075,
94
- ),
95
- 'gemini-1.0-pro': pg.Dict(
96
- rpm=300,
97
- cost_per_1k_input_chars=0.000125,
98
- cost_per_1k_output_chars=0.000375,
99
- ),
100
- 'gemini-1.0-pro-vision': pg.Dict(
101
- rpm=100,
102
- cost_per_1k_input_chars=0.000125,
103
- cost_per_1k_output_chars=0.000375,
104
- ),
105
- # TODO(sharatsharat): Update costs when published
106
- 'gemini-exp-1206': pg.Dict(
107
- rpm=20,
108
- cost_per_1k_input_chars=0.000,
109
- cost_per_1k_output_chars=0.000,
110
- ),
111
- # TODO(sharatsharat): Update costs when published
112
- 'gemini-2.0-flash-exp': pg.Dict(
113
- rpm=10,
114
- cost_per_1k_input_chars=0.000,
115
- cost_per_1k_output_chars=0.000,
116
- ),
117
- # TODO(yifenglu): Update costs when published
118
- 'gemini-2.0-flash-thinking-exp-1219': pg.Dict(
119
- rpm=10,
120
- cost_per_1k_input_chars=0.000,
121
- cost_per_1k_output_chars=0.000,
122
- ),
123
- # TODO(chengrun): Set a more appropriate rpm for endpoint.
124
- 'vertexai-endpoint': pg.Dict(
125
- rpm=20,
126
- cost_per_1k_input_chars=0.0000125,
127
- cost_per_1k_output_chars=0.0000375,
128
- ),
129
- }
130
-
131
-
132
39
  @lf.use_init_args(['model'])
133
40
  @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
134
- class VertexAI(rest.REST):
41
+ class VertexAI(gemini.Gemini):
135
42
  """Language model served on VertexAI with REST API."""
136
43
 
137
- model: pg.typing.Annotated[
138
- pg.typing.Enum(
139
- pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
140
- ),
141
- (
142
- 'Vertex AI model name with REST API support. See '
143
- 'https://cloud.google.com/vertex-ai/generative-ai/docs/'
144
- 'model-reference/inference#supported-models'
145
- ' for details.'
146
- ),
147
- ]
148
-
149
44
  project: Annotated[
150
45
  str | None,
151
46
  (
@@ -170,11 +65,6 @@ class VertexAI(rest.REST):
170
65
  ),
171
66
  ] = None
172
67
 
173
- supported_modalities: Annotated[
174
- list[str],
175
- 'A list of MIME types for supported modalities'
176
- ] = []
177
-
178
68
  def _on_bound(self):
179
69
  super()._on_bound()
180
70
  if google_auth is None:
@@ -200,6 +90,8 @@ class VertexAI(rest.REST):
200
90
  )
201
91
 
202
92
  self._project = project
93
+ self._location = location
94
+
203
95
  credentials = self.credentials
204
96
  if credentials is None:
205
97
  # Use default credentials.
@@ -209,31 +101,9 @@ class VertexAI(rest.REST):
209
101
  self._credentials = credentials
210
102
 
211
103
  @property
212
- def max_concurrency(self) -> int:
213
- """Returns the maximum number of concurrent requests."""
214
- return self.rate_to_max_concurrency(
215
- requests_per_min=SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm,
216
- tokens_per_min=0,
217
- )
218
-
219
- def estimate_cost(
220
- self,
221
- num_input_tokens: int,
222
- num_output_tokens: int
223
- ) -> float | None:
224
- """Estimate the cost based on usage."""
225
- cost_per_1k_input_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
226
- 'cost_per_1k_input_chars', None
227
- )
228
- cost_per_1k_output_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
229
- 'cost_per_1k_output_chars', None
230
- )
231
- if cost_per_1k_output_chars is None or cost_per_1k_input_chars is None:
232
- return None
233
- return (
234
- cost_per_1k_input_chars * num_input_tokens
235
- + cost_per_1k_output_chars * num_output_tokens
236
- ) * AVGERAGE_CHARS_PER_TOKEN / 1000
104
+ def model_id(self) -> str:
105
+ """Returns a string to identify the model."""
106
+ return f'VertexAI({self.model})'
237
107
 
238
108
  @functools.cached_property
239
109
  def _session(self):
@@ -244,277 +114,79 @@ class VertexAI(rest.REST):
244
114
  s.headers.update(self.headers or {})
245
115
  return s
246
116
 
247
- @property
248
- def headers(self):
249
- return {
250
- 'Content-Type': 'application/json; charset=utf-8',
251
- }
252
-
253
117
  @property
254
118
  def api_endpoint(self) -> str:
119
+ assert self._api_initialized
255
120
  return (
256
- f'https://{self.location}-aiplatform.googleapis.com/v1/projects/'
257
- f'{self.project}/locations/{self.location}/publishers/google/'
121
+ f'https://{self._location}-aiplatform.googleapis.com/v1/projects/'
122
+ f'{self._project}/locations/{self._location}/publishers/google/'
258
123
  f'models/{self.model}:generateContent'
259
124
  )
260
125
 
261
- def request(
262
- self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
263
- ) -> dict[str, Any]:
264
- request = dict(
265
- generationConfig=self._generation_config(prompt, sampling_options)
266
- )
267
- request['contents'] = [self._content_from_message(prompt)]
268
- return request
269
-
270
- def _generation_config(
271
- self, prompt: lf.Message, options: lf.LMSamplingOptions
272
- ) -> dict[str, Any]:
273
- """Returns a dict as generation config for prompt and LMSamplingOptions."""
274
- config = dict(
275
- temperature=options.temperature,
276
- maxOutputTokens=options.max_tokens,
277
- candidateCount=options.n,
278
- topK=options.top_k,
279
- topP=options.top_p,
280
- stopSequences=options.stop,
281
- seed=options.random_seed,
282
- responseLogprobs=options.logprobs,
283
- logprobs=options.top_logprobs,
284
- )
285
126
 
286
- if json_schema := prompt.metadata.get('json_schema'):
287
- if not isinstance(json_schema, dict):
288
- raise ValueError(
289
- f'`json_schema` must be a dict, got {json_schema!r}.'
290
- )
291
- json_schema = pg.to_json(json_schema)
292
- config['responseSchema'] = json_schema
293
- config['responseMimeType'] = 'application/json'
294
- prompt.metadata.formatted_text = (
295
- prompt.text
296
- + '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
297
- + pg.to_json_str(json_schema, json_indent=2)
298
- )
299
- return config
300
-
301
- def _content_from_message(self, prompt: lf.Message) -> dict[str, Any]:
302
- """Gets generation content from langfun message."""
303
- parts = []
304
- for lf_chunk in prompt.chunk():
305
- if isinstance(lf_chunk, str):
306
- parts.append({'text': lf_chunk})
307
- elif isinstance(lf_chunk, lf_modalities.Mime):
308
- try:
309
- modalities = lf_chunk.make_compatible(
310
- self.supported_modalities + ['text/plain']
311
- )
312
- if isinstance(modalities, lf_modalities.Mime):
313
- modalities = [modalities]
314
- for modality in modalities:
315
- if modality.is_text:
316
- parts.append({'text': modality.to_text()})
317
- else:
318
- parts.append({
319
- 'inlineData': {
320
- 'data': base64.b64encode(modality.to_bytes()).decode(),
321
- 'mimeType': modality.mime_type,
322
- }
323
- })
324
- except lf.ModalityError as e:
325
- raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
326
- else:
327
- raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
328
- return dict(role='user', parts=parts)
329
-
330
- def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
331
- messages = [
332
- self._message_from_content_parts(candidate['content']['parts'])
333
- for candidate in json['candidates']
334
- ]
335
- usage = json['usageMetadata']
336
- input_tokens = usage['promptTokenCount']
337
- output_tokens = usage['candidatesTokenCount']
338
- return lf.LMSamplingResult(
339
- [lf.LMSample(message) for message in messages],
340
- usage=lf.LMSamplingUsage(
341
- prompt_tokens=input_tokens,
342
- completion_tokens=output_tokens,
343
- total_tokens=input_tokens + output_tokens,
344
- estimated_cost=self.estimate_cost(
345
- num_input_tokens=input_tokens,
346
- num_output_tokens=output_tokens,
347
- ),
348
- ),
349
- )
127
+ class VertexAIGeminiFlash2_0ThinkingExp_20241219(VertexAI): # pylint: disable=invalid-name
128
+ """Vertex AI Gemini Flash 2.0 Thinking model launched on 12/19/2024."""
129
+
130
+ api_version = 'v1alpha'
131
+ model = 'gemini-2.0-flash-thinking-exp-1219'
132
+ timeout = None
133
+
350
134
 
351
- def _message_from_content_parts(
352
- self, parts: list[dict[str, Any]]
353
- ) -> lf.Message:
354
- """Converts Vertex AI's content parts protocol to message."""
355
- chunks = []
356
- for part in parts:
357
- if text_part := part.get('text'):
358
- chunks.append(text_part)
359
- else:
360
- raise ValueError(f'Unsupported part: {part}')
361
- return lf.AIMessage.from_chunks(chunks)
362
-
363
-
364
- IMAGE_TYPES = [
365
- 'image/png',
366
- 'image/jpeg',
367
- 'image/webp',
368
- 'image/heic',
369
- 'image/heif',
370
- ]
371
-
372
- AUDIO_TYPES = [
373
- 'audio/aac',
374
- 'audio/flac',
375
- 'audio/mp3',
376
- 'audio/m4a',
377
- 'audio/mpeg',
378
- 'audio/mpga',
379
- 'audio/mp4',
380
- 'audio/opus',
381
- 'audio/pcm',
382
- 'audio/wav',
383
- 'audio/webm',
384
- ]
385
-
386
- VIDEO_TYPES = [
387
- 'video/mov',
388
- 'video/mpeg',
389
- 'video/mpegps',
390
- 'video/mpg',
391
- 'video/mp4',
392
- 'video/webm',
393
- 'video/wmv',
394
- 'video/x-flv',
395
- 'video/3gpp',
396
- 'video/quicktime',
397
- ]
398
-
399
- DOCUMENT_TYPES = [
400
- 'application/pdf',
401
- 'text/plain',
402
- 'text/csv',
403
- 'text/html',
404
- 'text/xml',
405
- 'text/x-script.python',
406
- 'application/json',
407
- ]
408
-
409
-
410
- class VertexAIGemini2_0(VertexAI): # pylint: disable=invalid-name
411
- """Vertex AI Gemini 2.0 model."""
412
-
413
- supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
414
- DOCUMENT_TYPES + IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES
415
- )
416
-
417
-
418
- class VertexAIGeminiFlash2_0Exp(VertexAIGemini2_0): # pylint: disable=invalid-name
135
+ class VertexAIGeminiFlash2_0Exp(VertexAI): # pylint: disable=invalid-name
419
136
  """Vertex AI Gemini 2.0 Flash model."""
420
137
 
421
138
  model = 'gemini-2.0-flash-exp'
422
139
 
423
140
 
424
- class VertexAIGeminiFlash2_0ThinkingExp(VertexAIGemini2_0): # pylint: disable=invalid-name
425
- """Vertex AI Gemini 2.0 Flash model."""
141
+ class VertexAIGeminiExp_20241206(VertexAI): # pylint: disable=invalid-name
142
+ """Vertex AI Gemini Experimental model launched on 12/06/2024."""
426
143
 
427
- model = 'gemini-2.0-flash-thinking-exp-1219'
144
+ model = 'gemini-exp-1206'
428
145
 
429
146
 
430
- class VertexAIGemini1_5(VertexAI): # pylint: disable=invalid-name
431
- """Vertex AI Gemini 1.5 model."""
147
+ class VertexAIGeminiExp_20241114(VertexAI): # pylint: disable=invalid-name
148
+ """Vertex AI Gemini Experimental model launched on 11/14/2024."""
432
149
 
433
- supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
434
- DOCUMENT_TYPES + IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES
435
- )
150
+ model = 'gemini-exp-1114'
436
151
 
437
152
 
438
- class VertexAIGeminiPro1_5(VertexAIGemini1_5): # pylint: disable=invalid-name
153
+ class VertexAIGeminiPro1_5(VertexAI): # pylint: disable=invalid-name
439
154
  """Vertex AI Gemini 1.5 Pro model."""
440
155
 
441
- model = 'gemini-1.5-pro'
156
+ model = 'gemini-1.5-pro-latest'
442
157
 
443
158
 
444
- class VertexAIGeminiPro1_5_002(VertexAIGemini1_5): # pylint: disable=invalid-name
159
+ class VertexAIGeminiPro1_5_002(VertexAI): # pylint: disable=invalid-name
445
160
  """Vertex AI Gemini 1.5 Pro model."""
446
161
 
447
162
  model = 'gemini-1.5-pro-002'
448
163
 
449
164
 
450
- class VertexAIGeminiPro1_5_001(VertexAIGemini1_5): # pylint: disable=invalid-name
165
+ class VertexAIGeminiPro1_5_001(VertexAI): # pylint: disable=invalid-name
451
166
  """Vertex AI Gemini 1.5 Pro model."""
452
167
 
453
168
  model = 'gemini-1.5-pro-001'
454
169
 
455
170
 
456
- class VertexAIGeminiPro1_5_0514(VertexAIGemini1_5): # pylint: disable=invalid-name
457
- """Vertex AI Gemini 1.5 Pro preview model."""
458
-
459
- model = 'gemini-1.5-pro-preview-0514'
460
-
461
-
462
- class VertexAIGeminiPro1_5_0409(VertexAIGemini1_5): # pylint: disable=invalid-name
463
- """Vertex AI Gemini 1.5 Pro preview model."""
464
-
465
- model = 'gemini-1.5-pro-preview-0409'
466
-
467
-
468
- class VertexAIGeminiFlash1_5(VertexAIGemini1_5): # pylint: disable=invalid-name
171
+ class VertexAIGeminiFlash1_5(VertexAI): # pylint: disable=invalid-name
469
172
  """Vertex AI Gemini 1.5 Flash model."""
470
173
 
471
174
  model = 'gemini-1.5-flash'
472
175
 
473
176
 
474
- class VertexAIGeminiFlash1_5_002(VertexAIGemini1_5): # pylint: disable=invalid-name
177
+ class VertexAIGeminiFlash1_5_002(VertexAI): # pylint: disable=invalid-name
475
178
  """Vertex AI Gemini 1.5 Flash model."""
476
179
 
477
180
  model = 'gemini-1.5-flash-002'
478
181
 
479
182
 
480
- class VertexAIGeminiFlash1_5_001(VertexAIGemini1_5): # pylint: disable=invalid-name
183
+ class VertexAIGeminiFlash1_5_001(VertexAI): # pylint: disable=invalid-name
481
184
  """Vertex AI Gemini 1.5 Flash model."""
482
185
 
483
186
  model = 'gemini-1.5-flash-001'
484
187
 
485
188
 
486
- class VertexAIGeminiFlash1_5_0514(VertexAIGemini1_5): # pylint: disable=invalid-name
487
- """Vertex AI Gemini 1.5 Flash preview model."""
488
-
489
- model = 'gemini-1.5-flash-preview-0514'
490
-
491
-
492
189
  class VertexAIGeminiPro1(VertexAI): # pylint: disable=invalid-name
493
190
  """Vertex AI Gemini 1.0 Pro model."""
494
191
 
495
192
  model = 'gemini-1.0-pro'
496
-
497
-
498
- class VertexAIGeminiPro1Vision(VertexAI): # pylint: disable=invalid-name
499
- """Vertex AI Gemini 1.0 Pro Vision model."""
500
-
501
- model = 'gemini-1.0-pro-vision'
502
- supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
503
- IMAGE_TYPES + VIDEO_TYPES
504
- )
505
-
506
-
507
- class VertexAIEndpoint(VertexAI): # pylint: disable=invalid-name
508
- """Vertex AI Endpoint model."""
509
-
510
- model = 'vertexai-endpoint'
511
-
512
- endpoint: Annotated[str, 'Vertex AI Endpoint ID.']
513
-
514
- @property
515
- def api_endpoint(self) -> str:
516
- return (
517
- f'https://{self.location}-aiplatform.googleapis.com/v1/projects/'
518
- f'{self.project}/locations/{self.location}/'
519
- f'endpoints/{self.endpoint}:generateContent'
520
- )