langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501150804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. langfun/__init__.py +20 -2
  2. langfun/core/__init__.py +16 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -21
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +63 -2
  18. langfun/core/component_test.py +53 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +17 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +16 -1
  24. langfun/core/eval/base.py +622 -174
  25. langfun/core/eval/base_test.py +200 -54
  26. langfun/core/eval/matching.py +63 -76
  27. langfun/core/eval/matching_test.py +17 -8
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +26 -26
  31. langfun/core/eval/scoring_test.py +19 -2
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +4 -17
  55. langfun/core/langfunc_test.py +22 -6
  56. langfun/core/language_model.py +577 -39
  57. langfun/core/language_model_test.py +470 -56
  58. langfun/core/llms/__init__.py +87 -16
  59. langfun/core/llms/anthropic.py +312 -87
  60. langfun/core/llms/anthropic_test.py +71 -3
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +53 -2
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +11 -7
  69. langfun/core/llms/fake_test.py +14 -0
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -202
  74. langfun/core/llms/groq.py +160 -144
  75. langfun/core/llms/groq_test.py +31 -137
  76. langfun/core/llms/llama_cpp.py +15 -42
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +395 -203
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +30 -395
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -26
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +12 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +7 -6
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +60 -27
  112. langfun/core/structured/function_generation_test.py +72 -2
  113. langfun/core/structured/mapping.py +97 -47
  114. langfun/core/structured/mapping_test.py +90 -2
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +53 -9
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
  119. langfun/core/structured/schema.py +204 -97
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_test.py +130 -29
  122. langfun/core/structured/scoring.py +125 -19
  123. langfun/core/structured/scoring_test.py +30 -0
  124. langfun/core/structured/tokenization.py +64 -0
  125. langfun/core/structured/tokenization_test.py +48 -0
  126. langfun/core/template.py +115 -1
  127. langfun/core/template_test.py +71 -1
  128. langfun/core/templates/conversation.py +9 -0
  129. langfun/core/templates/conversation_test.py +4 -3
  130. langfun/core/templates/selfplay_test.py +10 -2
  131. langfun-0.1.2.dev202501150804.dist-info/METADATA +225 -0
  132. langfun-0.1.2.dev202501150804.dist-info/RECORD +153 -0
  133. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/WHEEL +1 -1
  134. langfun/core/coding/python/errors.py +0 -108
  135. langfun/core/coding/python/errors_test.py +0 -99
  136. langfun/core/coding/python/permissions.py +0 -90
  137. langfun/core/coding/python/permissions_test.py +0 -86
  138. langfun/core/structured/prompting.py +0 -238
  139. langfun/core/text_formatting.py +0 -162
  140. langfun/core/text_formatting_test.py +0 -47
  141. langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
  142. langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
  143. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/LICENSE +0 -0
  144. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/top_level.txt +0 -0
langfun/core/llms/fake.py CHANGED
@@ -21,9 +21,13 @@ import langfun.core as lf
21
21
  class Fake(lf.LanguageModel):
22
22
  """The base class for all fake language models."""
23
23
 
24
- def _score(self, prompt: lf.Message, completions: list[lf.Message]):
24
+ def _score(self, prompt: lf.Message| list[lf.Message],
25
+ completions: list[lf.Message]):
25
26
  return [lf.LMScoringResult(score=-i * 1.0) for i in range(len(completions))]
26
27
 
28
+ def _tokenize(self, prompt: lf.Message) -> list[tuple[str | bytes, int]]:
29
+ return [(w, i) for i, w in enumerate(prompt.text.split(' '))]
30
+
27
31
  def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
28
32
  results = []
29
33
  for prompt in prompts:
@@ -57,12 +61,12 @@ class StaticResponse(Fake):
57
61
  """Language model that always gives the same canned response."""
58
62
 
59
63
  response: Annotated[
60
- str,
64
+ str | lf.Message,
61
65
  'A canned response that will be returned regardless of the prompt.'
62
66
  ]
63
67
 
64
68
  def _response_from(self, prompt: lf.Message) -> lf.Message:
65
- return lf.AIMessage(self.response)
69
+ return lf.AIMessage.from_value(self.response)
66
70
 
67
71
 
68
72
  @lf.use_init_args(['mapping'])
@@ -70,12 +74,12 @@ class StaticMapping(Fake):
70
74
  """A static mapping from prompt to response."""
71
75
 
72
76
  mapping: Annotated[
73
- dict[str, str],
77
+ dict[str, str | lf.Message],
74
78
  'A mapping from prompt to response.'
75
79
  ]
76
80
 
77
81
  def _response_from(self, prompt: lf.Message) -> lf.Message:
78
- return lf.AIMessage(self.mapping[prompt])
82
+ return lf.AIMessage.from_value(self.mapping[prompt])
79
83
 
80
84
 
81
85
  @lf.use_init_args(['sequence'])
@@ -83,7 +87,7 @@ class StaticSequence(Fake):
83
87
  """A static sequence of responses to use."""
84
88
 
85
89
  sequence: Annotated[
86
- list[str],
90
+ list[str | lf.Message],
87
91
  'A sequence of strings as the response.'
88
92
  ]
89
93
 
@@ -92,6 +96,6 @@ class StaticSequence(Fake):
92
96
  self._pos = 0
93
97
 
94
98
  def _response_from(self, prompt: lf.Message) -> lf.Message:
95
- r = lf.AIMessage(self.sequence[self._pos])
99
+ r = lf.AIMessage.from_value(self.sequence[self._pos])
96
100
  self._pos += 1
97
101
  return r
@@ -34,6 +34,7 @@ class EchoTest(unittest.TestCase):
34
34
  'hi',
35
35
  score=1.0,
36
36
  logprobs=None,
37
+ is_cached=False,
37
38
  usage=lf.LMSamplingUsage(2, 2, 4),
38
39
  tags=[lf.Message.TAG_LM_RESPONSE],
39
40
  ),
@@ -62,6 +63,13 @@ class EchoTest(unittest.TestCase):
62
63
  [lf.LMScoringResult(0.0), lf.LMScoringResult(-1.0)],
63
64
  )
64
65
 
66
+ def test_tokenize(self):
67
+ lm = fakelm.Echo()
68
+ self.assertEqual(
69
+ lm.tokenize('hi'),
70
+ [('hi', 0)]
71
+ )
72
+
65
73
 
66
74
  class StaticResponseTest(unittest.TestCase):
67
75
 
@@ -78,6 +86,7 @@ class StaticResponseTest(unittest.TestCase):
78
86
  canned_response,
79
87
  score=1.0,
80
88
  logprobs=None,
89
+ is_cached=False,
81
90
  usage=lf.LMSamplingUsage(2, 38, 40),
82
91
  tags=[lf.Message.TAG_LM_RESPONSE],
83
92
  ),
@@ -99,6 +108,7 @@ class StaticResponseTest(unittest.TestCase):
99
108
  canned_response,
100
109
  score=1.0,
101
110
  logprobs=None,
111
+ is_cached=False,
102
112
  usage=lf.LMSamplingUsage(15, 38, 53),
103
113
  tags=[lf.Message.TAG_LM_RESPONSE],
104
114
  ),
@@ -143,6 +153,7 @@ class StaticMappingTest(unittest.TestCase):
143
153
  'Hello',
144
154
  score=1.0,
145
155
  logprobs=None,
156
+ is_cached=False,
146
157
  usage=lf.LMSamplingUsage(2, 5, 7),
147
158
  tags=[lf.Message.TAG_LM_RESPONSE],
148
159
  ),
@@ -159,6 +170,7 @@ class StaticMappingTest(unittest.TestCase):
159
170
  'I am fine, how about you?',
160
171
  score=1.0,
161
172
  logprobs=None,
173
+ is_cached=False,
162
174
  usage=lf.LMSamplingUsage(12, 25, 37),
163
175
  tags=[lf.Message.TAG_LM_RESPONSE],
164
176
  ),
@@ -192,6 +204,7 @@ class StaticSequenceTest(unittest.TestCase):
192
204
  'Hello',
193
205
  score=1.0,
194
206
  logprobs=None,
207
+ is_cached=False,
195
208
  usage=lf.LMSamplingUsage(2, 5, 7),
196
209
  tags=[lf.Message.TAG_LM_RESPONSE],
197
210
  ),
@@ -208,6 +221,7 @@ class StaticSequenceTest(unittest.TestCase):
208
221
  'I am fine, how about you?',
209
222
  score=1.0,
210
223
  logprobs=None,
224
+ is_cached=False,
211
225
  usage=lf.LMSamplingUsage(12, 25, 37),
212
226
  tags=[lf.Message.TAG_LM_RESPONSE],
213
227
  ),
@@ -0,0 +1,507 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Gemini REST API (Shared by Google GenAI and Vertex AI)."""
15
+
16
+ import base64
17
+ from typing import Any
18
+
19
+ import langfun.core as lf
20
+ from langfun.core import modalities as lf_modalities
21
+ from langfun.core.llms import rest
22
+ import pyglove as pg
23
+
24
+ # Supported modalities.
25
+
26
+ IMAGE_TYPES = [
27
+ 'image/png',
28
+ 'image/jpeg',
29
+ 'image/webp',
30
+ 'image/heic',
31
+ 'image/heif',
32
+ ]
33
+
34
+ AUDIO_TYPES = [
35
+ 'audio/aac',
36
+ 'audio/flac',
37
+ 'audio/mp3',
38
+ 'audio/m4a',
39
+ 'audio/mpeg',
40
+ 'audio/mpga',
41
+ 'audio/mp4',
42
+ 'audio/opus',
43
+ 'audio/pcm',
44
+ 'audio/wav',
45
+ 'audio/webm',
46
+ ]
47
+
48
+ VIDEO_TYPES = [
49
+ 'video/mov',
50
+ 'video/mpeg',
51
+ 'video/mpegps',
52
+ 'video/mpg',
53
+ 'video/mp4',
54
+ 'video/webm',
55
+ 'video/wmv',
56
+ 'video/x-flv',
57
+ 'video/3gpp',
58
+ 'video/quicktime',
59
+ ]
60
+
61
+ DOCUMENT_TYPES = [
62
+ 'application/pdf',
63
+ 'text/plain',
64
+ 'text/csv',
65
+ 'text/html',
66
+ 'text/xml',
67
+ 'text/x-script.python',
68
+ 'application/json',
69
+ ]
70
+
71
+ TEXT_ONLY = []
72
+
73
+ ALL_MODALITIES = (
74
+ IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES + DOCUMENT_TYPES
75
+ )
76
+
77
+ SUPPORTED_MODELS_AND_SETTINGS = {
78
+ # For automatically rate control and cost estimation, we explicitly register
79
+ # supported models here. This may be inconvenient for new models, but it
80
+ # helps us to keep track of the models and their pricing.
81
+ # Models and RPM are from
82
+ # https://ai.google.dev/gemini-api/docs/models/gemini?_gl=1*114hbho*_up*MQ..&gclid=Cj0KCQiAst67BhCEARIsAKKdWOljBY5aQdNQ41zOPkXFCwymUfMNFl_7ukm1veAf75ZTD9qWFrFr11IaApL3EALw_wcB
83
+ # Pricing in US dollars, from https://ai.google.dev/pricing
84
+ # as of 2025-01-03.
85
+ # NOTE: Please update google_genai.py, vertexai.py, __init__.py when
86
+ # adding new models.
87
+ # !!! PLEASE KEEP MODELS SORTED BY RELEASE DATE !!!
88
+ 'gemini-2.0-flash-thinking-exp-1219': pg.Dict(
89
+ latest_update='2024-12-19',
90
+ experimental=True,
91
+ in_service=True,
92
+ supported_modalities=ALL_MODALITIES,
93
+ rpm_free=10,
94
+ tpm_free=4_000_000,
95
+ rpm_paid=0,
96
+ tpm_paid=0,
97
+ cost_per_1m_input_tokens_up_to_128k=0,
98
+ cost_per_1m_output_tokens_up_to_128k=0,
99
+ cost_per_1m_cached_tokens_up_to_128k=0,
100
+ cost_per_1m_input_tokens_longer_than_128k=0,
101
+ cost_per_1m_output_tokens_longer_than_128k=0,
102
+ cost_per_1m_cached_tokens_longer_than_128k=0,
103
+ ),
104
+ 'gemini-2.0-flash-exp': pg.Dict(
105
+ latest_update='2024-12-11',
106
+ experimental=True,
107
+ in_service=True,
108
+ supported_modalities=ALL_MODALITIES,
109
+ rpm_free=10,
110
+ tpm_free=4_000_000,
111
+ rpm_paid=0,
112
+ tpm_paid=0,
113
+ cost_per_1m_input_tokens_up_to_128k=0,
114
+ cost_per_1m_output_tokens_up_to_128k=0,
115
+ cost_per_1m_cached_tokens_up_to_128k=0,
116
+ cost_per_1m_input_tokens_longer_than_128k=0,
117
+ cost_per_1m_output_tokens_longer_than_128k=0,
118
+ cost_per_1m_cached_tokens_longer_than_128k=0,
119
+ ),
120
+ 'gemini-exp-1206': pg.Dict(
121
+ latest_update='2024-12-06',
122
+ experimental=True,
123
+ in_service=True,
124
+ supported_modalities=ALL_MODALITIES,
125
+ rpm_free=10,
126
+ tpm_free=4_000_000,
127
+ rpm_paid=0,
128
+ tpm_paid=0,
129
+ cost_per_1m_input_tokens_up_to_128k=0,
130
+ cost_per_1m_output_tokens_up_to_128k=0,
131
+ cost_per_1m_cached_tokens_up_to_128k=0,
132
+ cost_per_1m_input_tokens_longer_than_128k=0,
133
+ cost_per_1m_output_tokens_longer_than_128k=0,
134
+ cost_per_1m_cached_tokens_longer_than_128k=0,
135
+ ),
136
+ 'learnlm-1.5-pro-experimental': pg.Dict(
137
+ latest_update='2024-11-19',
138
+ experimental=True,
139
+ in_service=True,
140
+ supported_modalities=ALL_MODALITIES,
141
+ rpm_free=10,
142
+ tpm_free=4_000_000,
143
+ rpm_paid=0,
144
+ tpm_paid=0,
145
+ cost_per_1m_input_tokens_up_to_128k=0,
146
+ cost_per_1m_output_tokens_up_to_128k=0,
147
+ cost_per_1m_cached_tokens_up_to_128k=0,
148
+ cost_per_1m_input_tokens_longer_than_128k=0,
149
+ cost_per_1m_output_tokens_longer_than_128k=0,
150
+ cost_per_1m_cached_tokens_longer_than_128k=0,
151
+ ),
152
+ 'gemini-exp-1114': pg.Dict(
153
+ latest_update='2024-11-14',
154
+ experimental=True,
155
+ in_service=True,
156
+ supported_modalities=ALL_MODALITIES,
157
+ rpm_free=10,
158
+ tpm_free=4_000_000,
159
+ rpm_paid=0,
160
+ tpm_paid=0,
161
+ cost_per_1m_input_tokens_up_to_128k=0,
162
+ cost_per_1m_output_tokens_up_to_128k=0,
163
+ cost_per_1m_cached_tokens_up_to_128k=0,
164
+ cost_per_1m_input_tokens_longer_than_128k=0,
165
+ cost_per_1m_output_tokens_longer_than_128k=0,
166
+ cost_per_1m_cached_tokens_longer_than_128k=0,
167
+ ),
168
+ 'gemini-1.5-flash-latest': pg.Dict(
169
+ latest_update='2024-09-30',
170
+ in_service=True,
171
+ supported_modalities=ALL_MODALITIES,
172
+ rpm_free=15,
173
+ tpm_free=1_000_000,
174
+ rpm_paid=2000,
175
+ tpm_paid=4_000_000,
176
+ cost_per_1m_input_tokens_up_to_128k=0.075,
177
+ cost_per_1m_output_tokens_up_to_128k=0.3,
178
+ cost_per_1m_cached_tokens_up_to_128k=0.01875,
179
+ cost_per_1m_input_tokens_longer_than_128k=0.15,
180
+ cost_per_1m_output_tokens_longer_than_128k=0.6,
181
+ cost_per_1m_cached_tokens_longer_than_128k=0.0375,
182
+ ),
183
+ 'gemini-1.5-flash': pg.Dict(
184
+ latest_update='2024-09-30',
185
+ in_service=True,
186
+ supported_modalities=ALL_MODALITIES,
187
+ rpm_free=15,
188
+ tpm_free=1_000_000,
189
+ rpm_paid=2000,
190
+ tpm_paid=4_000_000,
191
+ cost_per_1m_input_tokens_up_to_128k=0.075,
192
+ cost_per_1m_output_tokens_up_to_128k=0.3,
193
+ cost_per_1m_cached_tokens_up_to_128k=0.01875,
194
+ cost_per_1m_input_tokens_longer_than_128k=0.15,
195
+ cost_per_1m_output_tokens_longer_than_128k=0.6,
196
+ cost_per_1m_cached_tokens_longer_than_128k=0.0375,
197
+ ),
198
+ 'gemini-1.5-flash-001': pg.Dict(
199
+ latest_update='2024-09-30',
200
+ in_service=True,
201
+ supported_modalities=ALL_MODALITIES,
202
+ rpm_free=15,
203
+ tpm_free=1_000_000,
204
+ rpm_paid=2000,
205
+ tpm_paid=4_000_000,
206
+ cost_per_1m_input_tokens_up_to_128k=0.075,
207
+ cost_per_1m_output_tokens_up_to_128k=0.3,
208
+ cost_per_1m_cached_tokens_up_to_128k=0.01875,
209
+ cost_per_1m_input_tokens_longer_than_128k=0.15,
210
+ cost_per_1m_output_tokens_longer_than_128k=0.6,
211
+ cost_per_1m_cached_tokens_longer_than_128k=0.0375,
212
+ ),
213
+ 'gemini-1.5-flash-002': pg.Dict(
214
+ latest_update='2024-09-30',
215
+ in_service=True,
216
+ supported_modalities=ALL_MODALITIES,
217
+ rpm_free=15,
218
+ tpm_free=1_000_000,
219
+ rpm_paid=2000,
220
+ tpm_paid=4_000_000,
221
+ cost_per_1m_input_tokens_up_to_128k=0.075,
222
+ cost_per_1m_output_tokens_up_to_128k=0.3,
223
+ cost_per_1m_cached_tokens_up_to_128k=0.01875,
224
+ cost_per_1m_input_tokens_longer_than_128k=0.15,
225
+ cost_per_1m_output_tokens_longer_than_128k=0.6,
226
+ cost_per_1m_cached_tokens_longer_than_128k=0.0375,
227
+ ),
228
+ 'gemini-1.5-flash-8b': pg.Dict(
229
+ latest_update='2024-10-30',
230
+ in_service=True,
231
+ supported_modalities=ALL_MODALITIES,
232
+ rpm_free=15,
233
+ tpm_free=1_000_000,
234
+ rpm_paid=4000,
235
+ tpm_paid=4_000_000,
236
+ cost_per_1m_input_tokens_up_to_128k=0.0375,
237
+ cost_per_1m_output_tokens_up_to_128k=0.15,
238
+ cost_per_1m_cached_tokens_up_to_128k=0.01,
239
+ cost_per_1m_input_tokens_longer_than_128k=0.075,
240
+ cost_per_1m_output_tokens_longer_than_128k=0.3,
241
+ cost_per_1m_cached_tokens_longer_than_128k=0.02,
242
+ ),
243
+ 'gemini-1.5-flash-8b-001': pg.Dict(
244
+ latest_update='2024-10-30',
245
+ in_service=True,
246
+ supported_modalities=ALL_MODALITIES,
247
+ rpm_free=15,
248
+ tpm_free=1_000_000,
249
+ rpm_paid=4000,
250
+ tpm_paid=4_000_000,
251
+ cost_per_1m_input_tokens_up_to_128k=0.0375,
252
+ cost_per_1m_output_tokens_up_to_128k=0.15,
253
+ cost_per_1m_cached_tokens_up_to_128k=0.01,
254
+ cost_per_1m_input_tokens_longer_than_128k=0.075,
255
+ cost_per_1m_output_tokens_longer_than_128k=0.3,
256
+ cost_per_1m_cached_tokens_longer_than_128k=0.02,
257
+ ),
258
+ 'gemini-1.5-pro-latest': pg.Dict(
259
+ latest_update='2024-09-30',
260
+ in_service=True,
261
+ supported_modalities=ALL_MODALITIES,
262
+ rpm_free=2,
263
+ tpm_free=32_000,
264
+ rpm_paid=1000,
265
+ tpm_paid=4_000_000,
266
+ cost_per_1m_input_tokens_up_to_128k=1.25,
267
+ cost_per_1m_output_tokens_up_to_128k=5.00,
268
+ cost_per_1m_cached_tokens_up_to_128k=0.3125,
269
+ cost_per_1m_input_tokens_longer_than_128k=2.5,
270
+ cost_per_1m_output_tokens_longer_than_128k=10.00,
271
+ cost_per_1m_cached_tokens_longer_than_128k=0.625,
272
+ ),
273
+ 'gemini-1.5-pro': pg.Dict(
274
+ latest_update='2024-09-30',
275
+ in_service=True,
276
+ supported_modalities=ALL_MODALITIES,
277
+ rpm_free=2,
278
+ tpm_free=32_000,
279
+ rpm_paid=1000,
280
+ tpm_paid=4_000_000,
281
+ cost_per_1m_input_tokens_up_to_128k=1.25,
282
+ cost_per_1m_output_tokens_up_to_128k=5.00,
283
+ cost_per_1m_cached_tokens_up_to_128k=0.3125,
284
+ cost_per_1m_input_tokens_longer_than_128k=2.5,
285
+ cost_per_1m_output_tokens_longer_than_128k=10.00,
286
+ cost_per_1m_cached_tokens_longer_than_128k=0.625,
287
+ ),
288
+ 'gemini-1.5-pro-001': pg.Dict(
289
+ latest_update='2024-09-30',
290
+ in_service=True,
291
+ supported_modalities=ALL_MODALITIES,
292
+ rpm_free=2,
293
+ tpm_free=32_000,
294
+ rpm_paid=1000,
295
+ tpm_paid=4_000_000,
296
+ cost_per_1m_input_tokens_up_to_128k=1.25,
297
+ cost_per_1m_output_tokens_up_to_128k=5.00,
298
+ cost_per_1m_cached_tokens_up_to_128k=0.3125,
299
+ cost_per_1m_input_tokens_longer_than_128k=2.5,
300
+ cost_per_1m_output_tokens_longer_than_128k=10.00,
301
+ cost_per_1m_cached_tokens_longer_than_128k=0.625,
302
+ ),
303
+ 'gemini-1.5-pro-002': pg.Dict(
304
+ latest_update='2024-09-30',
305
+ in_service=True,
306
+ supported_modalities=ALL_MODALITIES,
307
+ rpm_free=2,
308
+ tpm_free=32_000,
309
+ rpm_paid=1000,
310
+ tpm_paid=4_000_000,
311
+ cost_per_1m_input_tokens_up_to_128k=1.25,
312
+ cost_per_1m_output_tokens_up_to_128k=5.00,
313
+ cost_per_1m_cached_tokens_up_to_128k=0.3125,
314
+ cost_per_1m_input_tokens_longer_than_128k=2.5,
315
+ cost_per_1m_output_tokens_longer_than_128k=10.00,
316
+ cost_per_1m_cached_tokens_longer_than_128k=0.625,
317
+ ),
318
+ 'gemini-1.0-pro': pg.Dict(
319
+ in_service=False,
320
+ supported_modalities=TEXT_ONLY,
321
+ rpm_free=15,
322
+ tpm_free=32_000,
323
+ rpm_paid=360,
324
+ tpm_paid=120_000,
325
+ cost_per_1m_input_tokens_up_to_128k=0.5,
326
+ cost_per_1m_output_tokens_up_to_128k=1.5,
327
+ cost_per_1m_cached_tokens_up_to_128k=0,
328
+ cost_per_1m_input_tokens_longer_than_128k=0.5,
329
+ cost_per_1m_output_tokens_longer_than_128k=1.5,
330
+ cost_per_1m_cached_tokens_longer_than_128k=0,
331
+ ),
332
+ }
333
+
334
+
335
+ @pg.use_init_args(['model'])
336
+ class Gemini(rest.REST):
337
+ """Language models provided by Google GenAI."""
338
+
339
+ model: pg.typing.Annotated[
340
+ pg.typing.Enum(
341
+ pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
342
+ ),
343
+ 'The name of the model to use.',
344
+ ]
345
+
346
+ @property
347
+ def supported_modalities(self) -> list[str]:
348
+ """Returns the list of supported modalities."""
349
+ return SUPPORTED_MODELS_AND_SETTINGS[self.model].supported_modalities
350
+
351
+ @property
352
+ def max_concurrency(self) -> int:
353
+ """Returns the maximum number of concurrent requests."""
354
+ return self.rate_to_max_concurrency(
355
+ requests_per_min=max(
356
+ SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm_free,
357
+ SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm_paid
358
+ ),
359
+ tokens_per_min=max(
360
+ SUPPORTED_MODELS_AND_SETTINGS[self.model].tpm_free,
361
+ SUPPORTED_MODELS_AND_SETTINGS[self.model].tpm_paid,
362
+ ),
363
+ )
364
+
365
+ def estimate_cost(
366
+ self,
367
+ num_input_tokens: int,
368
+ num_output_tokens: int
369
+ ) -> float | None:
370
+ """Estimate the cost based on usage."""
371
+ entry = SUPPORTED_MODELS_AND_SETTINGS[self.model]
372
+ if num_input_tokens < 128_000:
373
+ cost_per_1m_input_tokens = entry.cost_per_1m_input_tokens_up_to_128k
374
+ cost_per_1m_output_tokens = entry.cost_per_1m_output_tokens_up_to_128k
375
+ else:
376
+ cost_per_1m_input_tokens = entry.cost_per_1m_input_tokens_longer_than_128k
377
+ cost_per_1m_output_tokens = (
378
+ entry.cost_per_1m_output_tokens_longer_than_128k
379
+ )
380
+ return (
381
+ cost_per_1m_input_tokens * num_input_tokens
382
+ + cost_per_1m_output_tokens * num_output_tokens
383
+ ) / 1000_1000
384
+
385
+ @property
386
+ def model_id(self) -> str:
387
+ """Returns a string to identify the model."""
388
+ return self.model
389
+
390
+ @classmethod
391
+ def dir(cls):
392
+ return [k for k, v in SUPPORTED_MODELS_AND_SETTINGS.items() if v.in_service]
393
+
394
+ @property
395
+ def headers(self):
396
+ return {
397
+ 'Content-Type': 'application/json; charset=utf-8',
398
+ }
399
+
400
+ def request(
401
+ self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
402
+ ) -> dict[str, Any]:
403
+ request = dict(
404
+ generationConfig=self._generation_config(prompt, sampling_options)
405
+ )
406
+ request['contents'] = [self._content_from_message(prompt)]
407
+ return request
408
+
409
+ def _generation_config(
410
+ self, prompt: lf.Message, options: lf.LMSamplingOptions
411
+ ) -> dict[str, Any]:
412
+ """Returns a dict as generation config for prompt and LMSamplingOptions."""
413
+ config = dict(
414
+ temperature=options.temperature,
415
+ maxOutputTokens=options.max_tokens,
416
+ candidateCount=options.n,
417
+ topK=options.top_k,
418
+ topP=options.top_p,
419
+ stopSequences=options.stop,
420
+ seed=options.random_seed,
421
+ responseLogprobs=options.logprobs,
422
+ logprobs=options.top_logprobs,
423
+ )
424
+
425
+ if json_schema := prompt.metadata.get('json_schema'):
426
+ if not isinstance(json_schema, dict):
427
+ raise ValueError(
428
+ f'`json_schema` must be a dict, got {json_schema!r}.'
429
+ )
430
+ json_schema = pg.to_json(json_schema)
431
+ config['responseSchema'] = json_schema
432
+ config['responseMimeType'] = 'application/json'
433
+ prompt.metadata.formatted_text = (
434
+ prompt.text
435
+ + '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
436
+ + pg.to_json_str(json_schema, json_indent=2)
437
+ )
438
+ return config
439
+
440
+ def _content_from_message(self, prompt: lf.Message) -> dict[str, Any]:
441
+ """Gets generation content from langfun message."""
442
+ parts = []
443
+ for lf_chunk in prompt.chunk():
444
+ if isinstance(lf_chunk, str):
445
+ parts.append({'text': lf_chunk})
446
+ elif isinstance(lf_chunk, lf_modalities.Mime):
447
+ try:
448
+ modalities = lf_chunk.make_compatible(
449
+ self.supported_modalities + ['text/plain']
450
+ )
451
+ if isinstance(modalities, lf_modalities.Mime):
452
+ modalities = [modalities]
453
+ for modality in modalities:
454
+ if modality.is_text:
455
+ parts.append({'text': modality.to_text()})
456
+ else:
457
+ parts.append({
458
+ 'inlineData': {
459
+ 'data': base64.b64encode(modality.to_bytes()).decode(),
460
+ 'mimeType': modality.mime_type,
461
+ }
462
+ })
463
+ except lf.ModalityError as e:
464
+ raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
465
+ else:
466
+ raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
467
+ return dict(role='user', parts=parts)
468
+
469
+ def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
470
+ messages = [
471
+ self._message_from_content_parts(candidate['content']['parts'])
472
+ for candidate in json['candidates']
473
+ ]
474
+ usage = json['usageMetadata']
475
+ input_tokens = usage['promptTokenCount']
476
+ output_tokens = usage['candidatesTokenCount']
477
+ return lf.LMSamplingResult(
478
+ [lf.LMSample(message) for message in messages],
479
+ usage=lf.LMSamplingUsage(
480
+ prompt_tokens=input_tokens,
481
+ completion_tokens=output_tokens,
482
+ total_tokens=input_tokens + output_tokens,
483
+ estimated_cost=self.estimate_cost(
484
+ num_input_tokens=input_tokens,
485
+ num_output_tokens=output_tokens,
486
+ ),
487
+ ),
488
+ )
489
+
490
+ def _message_from_content_parts(
491
+ self, parts: list[dict[str, Any]]
492
+ ) -> lf.Message:
493
+ """Converts Vertex AI's content parts protocol to message."""
494
+ chunks = []
495
+ thought_chunks = []
496
+ for part in parts:
497
+ if text_part := part.get('text'):
498
+ if part.get('thought'):
499
+ thought_chunks.append(text_part)
500
+ else:
501
+ chunks.append(text_part)
502
+ else:
503
+ raise ValueError(f'Unsupported part: {part}')
504
+ message = lf.AIMessage.from_chunks(chunks)
505
+ if thought_chunks:
506
+ message.set('thought', lf.AIMessage.from_chunks(thought_chunks))
507
+ return message