langfun 0.1.2.dev202502110804__py3-none-any.whl → 0.1.2.dev202502120804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -78,6 +78,33 @@ def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
78
78
  class GeminiTest(unittest.TestCase):
79
79
  """Tests for Vertex model with REST API."""
80
80
 
81
+ def test_dir(self):
82
+ self.assertIn('gemini-1.5-pro', gemini.Gemini.dir())
83
+
84
+ def test_estimate_cost(self):
85
+ model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
86
+ self.assertEqual(
87
+ model.estimate_cost(
88
+ lf.LMSamplingUsage(
89
+ prompt_tokens=100_000,
90
+ completion_tokens=1000,
91
+ total_tokens=101_000,
92
+ )
93
+ ),
94
+ 0.13
95
+ )
96
+ # Prompt length greater than 128K.
97
+ self.assertEqual(
98
+ model.estimate_cost(
99
+ lf.LMSamplingUsage(
100
+ prompt_tokens=200_000,
101
+ completion_tokens=1000,
102
+ total_tokens=201_000,
103
+ )
104
+ ),
105
+ 0.51
106
+ )
107
+
81
108
  def test_content_from_message_text_only(self):
82
109
  text = 'This is a beautiful day'
83
110
  model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
@@ -89,13 +116,6 @@ class GeminiTest(unittest.TestCase):
89
116
  message = lf.UserMessage(
90
117
  'This is an <<[[image]]>>, what is it?', image=image
91
118
  )
92
-
93
- # Non-multimodal model.
94
- with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
95
- gemini.Gemini(
96
- 'gemini-1.0-pro', api_endpoint=''
97
- )._content_from_message(message)
98
-
99
119
  model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
100
120
  content = model._content_from_message(message)
101
121
  self.assertEqual(
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  """Language models from Google GenAI."""
15
15
 
16
+ import functools
16
17
  import os
17
18
  from typing import Annotated, Literal
18
19
 
@@ -26,6 +27,20 @@ import pyglove as pg
26
27
  class GenAI(gemini.Gemini):
27
28
  """Language models provided by Google GenAI."""
28
29
 
30
+ model: pg.typing.Annotated[
31
+ pg.typing.Enum(
32
+ pg.MISSING_VALUE,
33
+ [
34
+ m.model_id for m in gemini.SUPPORTED_MODELS
35
+ if m.provider == 'Google GenAI' or (
36
+ isinstance(m.provider, pg.hyper.OneOf)
37
+ and 'Google GenAI' in m.provider.candidates
38
+ )
39
+ ]
40
+ ),
41
+ 'The name of the model to use.',
42
+ ]
43
+
29
44
  api_key: Annotated[
30
45
  str | None,
31
46
  (
@@ -40,10 +55,11 @@ class GenAI(gemini.Gemini):
40
55
  'The API version to use.'
41
56
  ] = 'v1beta'
42
57
 
43
- @property
44
- def model_id(self) -> str:
45
- """Returns a string to identify the model."""
46
- return f'GenAI({self.model})'
58
+ @functools.cached_property
59
+ def model_info(self) -> lf.ModelInfo:
60
+ return super().model_info.clone(
61
+ override=dict(provider='Google GenAI')
62
+ )
47
63
 
48
64
  @property
49
65
  def api_endpoint(self) -> str:
@@ -63,91 +79,105 @@ class GenAI(gemini.Gemini):
63
79
  )
64
80
 
65
81
 
66
- class Gemini2Flash(GenAI): # pylint: disable=invalid-name
67
- """Gemini Flash 2.0 model launched on 02/05/2025."""
68
-
69
- api_version = 'v1beta'
70
- model = 'gemini-2.0-flash'
82
+ # pylint: disable=invalid-name
71
83
 
84
+ #
85
+ # Experimental models.
86
+ #
72
87
 
73
- class Gemini2ProExp_20250205(GenAI): # pylint: disable=invalid-name
74
- """Gemini Flash 2.0 Pro model launched on 02/05/2025."""
75
88
 
76
- api_version = 'v1beta'
89
+ class Gemini2ProExp_20250205(GenAI):
90
+ """Gemini 2.0 Pro experimental model launched on 02/05/2025."""
77
91
  model = 'gemini-2.0-pro-exp-02-05'
78
92
 
79
93
 
80
- class Gemini2FlashThinkingExp_20250121(GenAI): # pylint: disable=invalid-name
81
- """Gemini Flash 2.0 Thinking model launched on 01/21/2025."""
82
-
94
+ class Gemini2FlashThinkingExp_20250121(GenAI):
95
+ """Gemini 2.0 Flash Thinking model launched on 01/21/2025."""
83
96
  api_version = 'v1beta'
84
97
  model = 'gemini-2.0-flash-thinking-exp-01-21'
85
98
  timeout = None
86
99
 
87
100
 
88
- class GeminiFlash2_0ThinkingExp_20241219(GenAI): # pylint: disable=invalid-name
89
- """Gemini Flash 2.0 Thinking model launched on 12/19/2024."""
101
+ class GeminiExp_20241206(GenAI):
102
+ """Gemini Experimental model launched on 12/06/2024."""
103
+ model = 'gemini-exp-1206'
90
104
 
91
- api_version = 'v1beta'
92
- model = 'gemini-2.0-flash-thinking-exp-1219'
93
- timeout = None
94
105
 
106
+ #
107
+ # Production models.
108
+ #
95
109
 
96
- class GeminiFlash2_0Exp(GenAI): # pylint: disable=invalid-name
97
- """Gemini Flash 2.0 model launched on 12/11/2024."""
98
110
 
99
- model = 'gemini-2.0-flash-exp'
111
+ class Gemini2Flash(GenAI):
112
+ """Gemini 2.0 Flash model (latest stable)."""
113
+ model = 'gemini-2.0-flash'
100
114
 
101
115
 
102
- class GeminiExp_20241206(GenAI): # pylint: disable=invalid-name
103
- """Gemini Experimental model launched on 12/06/2024."""
116
+ class Gemini2Flash_001(GenAI):
117
+ """Gemini 2.0 Flash model launched on 02/05/2025."""
118
+ model = 'gemini-2.0-flash-001'
104
119
 
105
- model = 'gemini-exp-1206'
120
+
121
+ class Gemini2FlashLitePreview_20250205(GenAI):
122
+ """Gemini 2.0 Flash lite preview model launched on 02/05/2025."""
123
+ model = 'gemini-2.0-flash-lite-preview-02-05'
106
124
 
107
125
 
108
- class GeminiExp_20241114(GenAI): # pylint: disable=invalid-name
109
- """Gemini Experimental model launched on 11/14/2024."""
126
+ class Gemini15Pro(GenAI):
127
+ """Gemini 1.5 Pro latest stable model."""
128
+ model = 'gemini-1.5-pro'
110
129
 
111
- model = 'gemini-exp-1114'
112
130
 
131
+ class Gemini15Pro_002(GenAI):
132
+ """Gemini 1.5 Pro stable version 002."""
133
+ model = 'gemini-1.5-pro-002'
113
134
 
114
- class GeminiPro1_5(GenAI): # pylint: disable=invalid-name
115
- """Gemini Pro latest model."""
116
135
 
117
- model = 'gemini-1.5-pro-latest'
136
+ class Gemini15Pro_001(GenAI):
137
+ """Gemini 1.5 Pro stable version 001."""
138
+ model = 'gemini-1.5-pro-001'
118
139
 
119
140
 
120
- class GeminiPro1_5_002(GenAI): # pylint: disable=invalid-name
121
- """Gemini Pro latest model."""
141
+ class Gemini15Flash(GenAI):
142
+ """Gemini 1.5 Flash latest model."""
143
+ model = 'gemini-1.5-flash'
122
144
 
123
- model = 'gemini-1.5-pro-002'
124
145
 
146
+ class Gemini15Flash_002(GenAI):
147
+ """Gemini 1.5 Flash model stable version 002."""
148
+ model = 'gemini-1.5-flash-002'
125
149
 
126
- class GeminiPro1_5_001(GenAI): # pylint: disable=invalid-name
127
- """Gemini Pro latest model."""
128
150
 
129
- model = 'gemini-1.5-pro-001'
151
+ class Gemini15Flash_001(GenAI):
152
+ """Gemini 1.5 Flash model stable version 001."""
153
+ model = 'gemini-1.5-flash-001'
130
154
 
131
155
 
132
- class GeminiFlash1_5(GenAI): # pylint: disable=invalid-name
133
- """Gemini Flash latest model."""
156
+ class Gemini15Flash8B(GenAI):
157
+ """Gemini 1.5 Flash 8B modle (latest stable)."""
158
+ model = 'gemini-1.5-flash-8b'
134
159
 
135
- model = 'gemini-1.5-flash-latest'
136
160
 
161
+ class Gemini15Flash8B_001(GenAI):
162
+ """Gemini 1.5 Flash 8B model (version 001)."""
163
+ model = 'gemini-1.5-flash-8b-001'
137
164
 
138
- class GeminiFlash1_5_002(GenAI): # pylint: disable=invalid-name
139
- """Gemini Flash 1.5 model stable version 002."""
140
165
 
141
- model = 'gemini-1.5-flash-002'
166
+ # For backward compatibility.
167
+ GeminiPro1_5 = Gemini15Pro
168
+ GeminiFlash1_5 = Gemini15Flash
142
169
 
170
+ # pylint: enable=invalid-name
143
171
 
144
- class GeminiFlash1_5_001(GenAI): # pylint: disable=invalid-name
145
- """Gemini Flash 1.5 model stable version 001."""
146
172
 
147
- model = 'gemini-1.5-flash-001'
173
+ def _genai_model(model: str, *args, **kwargs) -> GenAI:
174
+ model = model.removeprefix('google_genai://')
175
+ return GenAI(model=model, *args, **kwargs)
148
176
 
149
177
 
150
- class GeminiPro1(GenAI): # pylint: disable=invalid-name
151
- """Gemini 1.0 Pro model."""
178
+ def _register_genai_models():
179
+ """Register GenAI models."""
180
+ for m in gemini.SUPPORTED_MODELS:
181
+ lf.LanguageModel.register('google_genai://' + m.model_id, _genai_model)
152
182
 
153
- model = 'gemini-1.0-pro'
183
+ _register_genai_models()
@@ -15,6 +15,7 @@
15
15
 
16
16
  import os
17
17
  import unittest
18
+ import langfun.core as lf
18
19
  from langfun.core.llms import google_genai
19
20
 
20
21
 
@@ -23,16 +24,22 @@ class GenAITest(unittest.TestCase):
23
24
 
24
25
  def test_basics(self):
25
26
  with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
26
- _ = google_genai.GeminiPro1_5().api_endpoint
27
+ _ = google_genai.Gemini15Pro().api_endpoint
27
28
 
28
- self.assertIsNotNone(google_genai.GeminiPro1_5(api_key='abc').api_endpoint)
29
+ self.assertIsNotNone(google_genai.Gemini15Pro(api_key='abc').api_endpoint)
29
30
 
30
31
  os.environ['GOOGLE_API_KEY'] = 'abc'
31
- lm = google_genai.GeminiPro1_5()
32
+ lm = google_genai.Gemini15Pro_001()
32
33
  self.assertIsNotNone(lm.api_endpoint)
33
- self.assertTrue(lm.model_id.startswith('GenAI('))
34
+ self.assertEqual(lm.model_id, 'gemini-1.5-pro-001')
35
+ self.assertEqual(lm.resource_id, 'google_genai://gemini-1.5-pro-001')
34
36
  del os.environ['GOOGLE_API_KEY']
35
37
 
38
+ def test_lm_get(self):
39
+ self.assertIsInstance(
40
+ lf.LanguageModel.get('google_genai://gemini-1.5-pro'),
41
+ google_genai.GenAI,
42
+ )
36
43
 
37
44
  if __name__ == '__main__':
38
45
  unittest.main()