langfun 0.1.2.dev202502110804__py3-none-any.whl → 0.1.2.dev202502120804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/__init__.py +6 -2
- langfun/core/language_model.py +365 -22
- langfun/core/language_model_test.py +123 -35
- langfun/core/llms/__init__.py +50 -57
- langfun/core/llms/anthropic.py +434 -163
- langfun/core/llms/anthropic_test.py +20 -1
- langfun/core/llms/deepseek.py +90 -51
- langfun/core/llms/deepseek_test.py +15 -16
- langfun/core/llms/fake.py +6 -0
- langfun/core/llms/gemini.py +480 -390
- langfun/core/llms/gemini_test.py +27 -7
- langfun/core/llms/google_genai.py +80 -50
- langfun/core/llms/google_genai_test.py +11 -4
- langfun/core/llms/groq.py +268 -167
- langfun/core/llms/groq_test.py +9 -3
- langfun/core/llms/openai.py +839 -328
- langfun/core/llms/openai_compatible.py +3 -18
- langfun/core/llms/openai_compatible_test.py +20 -5
- langfun/core/llms/openai_test.py +14 -4
- langfun/core/llms/rest.py +11 -6
- langfun/core/llms/vertexai.py +238 -240
- langfun/core/llms/vertexai_test.py +35 -8
- {langfun-0.1.2.dev202502110804.dist-info → langfun-0.1.2.dev202502120804.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202502110804.dist-info → langfun-0.1.2.dev202502120804.dist-info}/RECORD +27 -27
- {langfun-0.1.2.dev202502110804.dist-info → langfun-0.1.2.dev202502120804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202502110804.dist-info → langfun-0.1.2.dev202502120804.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202502110804.dist-info → langfun-0.1.2.dev202502120804.dist-info}/top_level.txt +0 -0
langfun/core/llms/gemini_test.py
CHANGED
@@ -78,6 +78,33 @@ def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
|
|
78
78
|
class GeminiTest(unittest.TestCase):
|
79
79
|
"""Tests for Vertex model with REST API."""
|
80
80
|
|
81
|
+
def test_dir(self):
|
82
|
+
self.assertIn('gemini-1.5-pro', gemini.Gemini.dir())
|
83
|
+
|
84
|
+
def test_estimate_cost(self):
|
85
|
+
model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
|
86
|
+
self.assertEqual(
|
87
|
+
model.estimate_cost(
|
88
|
+
lf.LMSamplingUsage(
|
89
|
+
prompt_tokens=100_000,
|
90
|
+
completion_tokens=1000,
|
91
|
+
total_tokens=101_000,
|
92
|
+
)
|
93
|
+
),
|
94
|
+
0.13
|
95
|
+
)
|
96
|
+
# Prompt length greater than 128K.
|
97
|
+
self.assertEqual(
|
98
|
+
model.estimate_cost(
|
99
|
+
lf.LMSamplingUsage(
|
100
|
+
prompt_tokens=200_000,
|
101
|
+
completion_tokens=1000,
|
102
|
+
total_tokens=201_000,
|
103
|
+
)
|
104
|
+
),
|
105
|
+
0.51
|
106
|
+
)
|
107
|
+
|
81
108
|
def test_content_from_message_text_only(self):
|
82
109
|
text = 'This is a beautiful day'
|
83
110
|
model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
|
@@ -89,13 +116,6 @@ class GeminiTest(unittest.TestCase):
|
|
89
116
|
message = lf.UserMessage(
|
90
117
|
'This is an <<[[image]]>>, what is it?', image=image
|
91
118
|
)
|
92
|
-
|
93
|
-
# Non-multimodal model.
|
94
|
-
with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
|
95
|
-
gemini.Gemini(
|
96
|
-
'gemini-1.0-pro', api_endpoint=''
|
97
|
-
)._content_from_message(message)
|
98
|
-
|
99
119
|
model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
|
100
120
|
content = model._content_from_message(message)
|
101
121
|
self.assertEqual(
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Language models from Google GenAI."""
|
15
15
|
|
16
|
+
import functools
|
16
17
|
import os
|
17
18
|
from typing import Annotated, Literal
|
18
19
|
|
@@ -26,6 +27,20 @@ import pyglove as pg
|
|
26
27
|
class GenAI(gemini.Gemini):
|
27
28
|
"""Language models provided by Google GenAI."""
|
28
29
|
|
30
|
+
model: pg.typing.Annotated[
|
31
|
+
pg.typing.Enum(
|
32
|
+
pg.MISSING_VALUE,
|
33
|
+
[
|
34
|
+
m.model_id for m in gemini.SUPPORTED_MODELS
|
35
|
+
if m.provider == 'Google GenAI' or (
|
36
|
+
isinstance(m.provider, pg.hyper.OneOf)
|
37
|
+
and 'Google GenAI' in m.provider.candidates
|
38
|
+
)
|
39
|
+
]
|
40
|
+
),
|
41
|
+
'The name of the model to use.',
|
42
|
+
]
|
43
|
+
|
29
44
|
api_key: Annotated[
|
30
45
|
str | None,
|
31
46
|
(
|
@@ -40,10 +55,11 @@ class GenAI(gemini.Gemini):
|
|
40
55
|
'The API version to use.'
|
41
56
|
] = 'v1beta'
|
42
57
|
|
43
|
-
@
|
44
|
-
def
|
45
|
-
|
46
|
-
|
58
|
+
@functools.cached_property
|
59
|
+
def model_info(self) -> lf.ModelInfo:
|
60
|
+
return super().model_info.clone(
|
61
|
+
override=dict(provider='Google GenAI')
|
62
|
+
)
|
47
63
|
|
48
64
|
@property
|
49
65
|
def api_endpoint(self) -> str:
|
@@ -63,91 +79,105 @@ class GenAI(gemini.Gemini):
|
|
63
79
|
)
|
64
80
|
|
65
81
|
|
66
|
-
|
67
|
-
"""Gemini Flash 2.0 model launched on 02/05/2025."""
|
68
|
-
|
69
|
-
api_version = 'v1beta'
|
70
|
-
model = 'gemini-2.0-flash'
|
82
|
+
# pylint: disable=invalid-name
|
71
83
|
|
84
|
+
#
|
85
|
+
# Experimental models.
|
86
|
+
#
|
72
87
|
|
73
|
-
class Gemini2ProExp_20250205(GenAI): # pylint: disable=invalid-name
|
74
|
-
"""Gemini Flash 2.0 Pro model launched on 02/05/2025."""
|
75
88
|
|
76
|
-
|
89
|
+
class Gemini2ProExp_20250205(GenAI):
|
90
|
+
"""Gemini 2.0 Pro experimental model launched on 02/05/2025."""
|
77
91
|
model = 'gemini-2.0-pro-exp-02-05'
|
78
92
|
|
79
93
|
|
80
|
-
class Gemini2FlashThinkingExp_20250121(GenAI):
|
81
|
-
"""Gemini
|
82
|
-
|
94
|
+
class Gemini2FlashThinkingExp_20250121(GenAI):
|
95
|
+
"""Gemini 2.0 Flash Thinking model launched on 01/21/2025."""
|
83
96
|
api_version = 'v1beta'
|
84
97
|
model = 'gemini-2.0-flash-thinking-exp-01-21'
|
85
98
|
timeout = None
|
86
99
|
|
87
100
|
|
88
|
-
class
|
89
|
-
"""Gemini
|
101
|
+
class GeminiExp_20241206(GenAI):
|
102
|
+
"""Gemini Experimental model launched on 12/06/2024."""
|
103
|
+
model = 'gemini-exp-1206'
|
90
104
|
|
91
|
-
api_version = 'v1beta'
|
92
|
-
model = 'gemini-2.0-flash-thinking-exp-1219'
|
93
|
-
timeout = None
|
94
105
|
|
106
|
+
#
|
107
|
+
# Production models.
|
108
|
+
#
|
95
109
|
|
96
|
-
class GeminiFlash2_0Exp(GenAI): # pylint: disable=invalid-name
|
97
|
-
"""Gemini Flash 2.0 model launched on 12/11/2024."""
|
98
110
|
|
99
|
-
|
111
|
+
class Gemini2Flash(GenAI):
|
112
|
+
"""Gemini 2.0 Flash model (latest stable)."""
|
113
|
+
model = 'gemini-2.0-flash'
|
100
114
|
|
101
115
|
|
102
|
-
class
|
103
|
-
"""Gemini
|
116
|
+
class Gemini2Flash_001(GenAI):
|
117
|
+
"""Gemini 2.0 Flash model launched on 02/05/2025."""
|
118
|
+
model = 'gemini-2.0-flash-001'
|
104
119
|
|
105
|
-
|
120
|
+
|
121
|
+
class Gemini2FlashLitePreview_20250205(GenAI):
|
122
|
+
"""Gemini 2.0 Flash lite preview model launched on 02/05/2025."""
|
123
|
+
model = 'gemini-2.0-flash-lite-preview-02-05'
|
106
124
|
|
107
125
|
|
108
|
-
class
|
109
|
-
"""Gemini
|
126
|
+
class Gemini15Pro(GenAI):
|
127
|
+
"""Gemini 1.5 Pro latest stable model."""
|
128
|
+
model = 'gemini-1.5-pro'
|
110
129
|
|
111
|
-
model = 'gemini-exp-1114'
|
112
130
|
|
131
|
+
class Gemini15Pro_002(GenAI):
|
132
|
+
"""Gemini 1.5 Pro stable version 002."""
|
133
|
+
model = 'gemini-1.5-pro-002'
|
113
134
|
|
114
|
-
class GeminiPro1_5(GenAI): # pylint: disable=invalid-name
|
115
|
-
"""Gemini Pro latest model."""
|
116
135
|
|
117
|
-
|
136
|
+
class Gemini15Pro_001(GenAI):
|
137
|
+
"""Gemini 1.5 Pro stable version 001."""
|
138
|
+
model = 'gemini-1.5-pro-001'
|
118
139
|
|
119
140
|
|
120
|
-
class
|
121
|
-
"""Gemini
|
141
|
+
class Gemini15Flash(GenAI):
|
142
|
+
"""Gemini 1.5 Flash latest model."""
|
143
|
+
model = 'gemini-1.5-flash'
|
122
144
|
|
123
|
-
model = 'gemini-1.5-pro-002'
|
124
145
|
|
146
|
+
class Gemini15Flash_002(GenAI):
|
147
|
+
"""Gemini 1.5 Flash model stable version 002."""
|
148
|
+
model = 'gemini-1.5-flash-002'
|
125
149
|
|
126
|
-
class GeminiPro1_5_001(GenAI): # pylint: disable=invalid-name
|
127
|
-
"""Gemini Pro latest model."""
|
128
150
|
|
129
|
-
|
151
|
+
class Gemini15Flash_001(GenAI):
|
152
|
+
"""Gemini 1.5 Flash model stable version 001."""
|
153
|
+
model = 'gemini-1.5-flash-001'
|
130
154
|
|
131
155
|
|
132
|
-
class
|
133
|
-
"""Gemini Flash latest
|
156
|
+
class Gemini15Flash8B(GenAI):
|
157
|
+
"""Gemini 1.5 Flash 8B modle (latest stable)."""
|
158
|
+
model = 'gemini-1.5-flash-8b'
|
134
159
|
|
135
|
-
model = 'gemini-1.5-flash-latest'
|
136
160
|
|
161
|
+
class Gemini15Flash8B_001(GenAI):
|
162
|
+
"""Gemini 1.5 Flash 8B model (version 001)."""
|
163
|
+
model = 'gemini-1.5-flash-8b-001'
|
137
164
|
|
138
|
-
class GeminiFlash1_5_002(GenAI): # pylint: disable=invalid-name
|
139
|
-
"""Gemini Flash 1.5 model stable version 002."""
|
140
165
|
|
141
|
-
|
166
|
+
# For backward compatibility.
|
167
|
+
GeminiPro1_5 = Gemini15Pro
|
168
|
+
GeminiFlash1_5 = Gemini15Flash
|
142
169
|
|
170
|
+
# pylint: enable=invalid-name
|
143
171
|
|
144
|
-
class GeminiFlash1_5_001(GenAI): # pylint: disable=invalid-name
|
145
|
-
"""Gemini Flash 1.5 model stable version 001."""
|
146
172
|
|
147
|
-
|
173
|
+
def _genai_model(model: str, *args, **kwargs) -> GenAI:
|
174
|
+
model = model.removeprefix('google_genai://')
|
175
|
+
return GenAI(model=model, *args, **kwargs)
|
148
176
|
|
149
177
|
|
150
|
-
|
151
|
-
"""
|
178
|
+
def _register_genai_models():
|
179
|
+
"""Register GenAI models."""
|
180
|
+
for m in gemini.SUPPORTED_MODELS:
|
181
|
+
lf.LanguageModel.register('google_genai://' + m.model_id, _genai_model)
|
152
182
|
|
153
|
-
|
183
|
+
_register_genai_models()
|
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
import os
|
17
17
|
import unittest
|
18
|
+
import langfun.core as lf
|
18
19
|
from langfun.core.llms import google_genai
|
19
20
|
|
20
21
|
|
@@ -23,16 +24,22 @@ class GenAITest(unittest.TestCase):
|
|
23
24
|
|
24
25
|
def test_basics(self):
|
25
26
|
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
26
|
-
_ = google_genai.
|
27
|
+
_ = google_genai.Gemini15Pro().api_endpoint
|
27
28
|
|
28
|
-
self.assertIsNotNone(google_genai.
|
29
|
+
self.assertIsNotNone(google_genai.Gemini15Pro(api_key='abc').api_endpoint)
|
29
30
|
|
30
31
|
os.environ['GOOGLE_API_KEY'] = 'abc'
|
31
|
-
lm = google_genai.
|
32
|
+
lm = google_genai.Gemini15Pro_001()
|
32
33
|
self.assertIsNotNone(lm.api_endpoint)
|
33
|
-
self.
|
34
|
+
self.assertEqual(lm.model_id, 'gemini-1.5-pro-001')
|
35
|
+
self.assertEqual(lm.resource_id, 'google_genai://gemini-1.5-pro-001')
|
34
36
|
del os.environ['GOOGLE_API_KEY']
|
35
37
|
|
38
|
+
def test_lm_get(self):
|
39
|
+
self.assertIsInstance(
|
40
|
+
lf.LanguageModel.get('google_genai://gemini-1.5-pro'),
|
41
|
+
google_genai.GenAI,
|
42
|
+
)
|
36
43
|
|
37
44
|
if __name__ == '__main__':
|
38
45
|
unittest.main()
|