langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. langfun/__init__.py +22 -2
  2. langfun/core/__init__.py +17 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -28
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +69 -2
  18. langfun/core/component_test.py +54 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +17 -0
  24. langfun/core/eval/base.py +767 -140
  25. langfun/core/eval/base_test.py +238 -53
  26. langfun/core/eval/matching.py +80 -76
  27. langfun/core/eval/matching_test.py +19 -9
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +37 -28
  31. langfun/core/eval/scoring_test.py +21 -3
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +3 -21
  55. langfun/core/langfunc_test.py +26 -8
  56. langfun/core/language_model.py +686 -48
  57. langfun/core/language_model_test.py +681 -44
  58. langfun/core/llms/__init__.py +100 -12
  59. langfun/core/llms/anthropic.py +488 -0
  60. langfun/core/llms/anthropic_test.py +235 -0
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +88 -28
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +39 -26
  69. langfun/core/llms/fake_test.py +136 -11
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -197
  74. langfun/core/llms/groq.py +276 -0
  75. langfun/core/llms/groq_test.py +64 -0
  76. langfun/core/llms/llama_cpp.py +15 -40
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +436 -226
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +35 -174
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -23
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +15 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +9 -8
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +278 -0
  112. langfun/core/structured/function_generation_test.py +399 -0
  113. langfun/core/structured/mapping.py +150 -46
  114. langfun/core/structured/mapping_test.py +105 -0
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +71 -22
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
  119. langfun/core/structured/schema.py +208 -99
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_generation_test.py +2 -2
  122. langfun/core/structured/schema_test.py +133 -34
  123. langfun/core/structured/scoring.py +125 -19
  124. langfun/core/structured/scoring_test.py +30 -0
  125. langfun/core/structured/tokenization.py +64 -0
  126. langfun/core/structured/tokenization_test.py +48 -0
  127. langfun/core/template.py +240 -11
  128. langfun/core/template_test.py +146 -1
  129. langfun/core/templates/conversation.py +9 -0
  130. langfun/core/templates/conversation_test.py +4 -3
  131. langfun/core/templates/selfplay_test.py +14 -2
  132. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  133. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  134. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  135. langfun/core/coding/python/errors.py +0 -108
  136. langfun/core/coding/python/errors_test.py +0 -99
  137. langfun/core/coding/python/permissions.py +0 -90
  138. langfun/core/coding/python/permissions_test.py +0 -86
  139. langfun/core/structured/prompting.py +0 -217
  140. langfun/core/text_formatting.py +0 -162
  141. langfun/core/text_formatting_test.py +0 -47
  142. langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
  143. langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
  144. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  145. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -11,73 +11,22 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Tests for openai models."""
14
+ """Tests for OpenAI models."""
15
15
 
16
16
  import unittest
17
- from unittest import mock
18
-
19
17
  import langfun.core as lf
20
- from langfun.core import modalities as lf_modalities
21
18
  from langfun.core.llms import openai
22
- import pyglove as pg
23
-
24
-
25
- def mock_completion_query(prompt, *, n=1, **kwargs):
26
- del kwargs
27
- choices = []
28
- for i, _ in enumerate(prompt):
29
- for k in range(n):
30
- choices.append(pg.Dict(
31
- index=i,
32
- text=f'Sample {k} for prompt {i}.',
33
- logprobs=k / 10,
34
- ))
35
- return pg.Dict(choices=choices, usage=openai.Usage(
36
- prompt_tokens=100,
37
- completion_tokens=100,
38
- total_tokens=200,
39
- ))
40
-
41
-
42
- def mock_chat_completion_query(messages, *, n=1, **kwargs):
43
- del messages, kwargs
44
- choices = []
45
- for k in range(n):
46
- choices.append(pg.Dict(
47
- message=pg.Dict(
48
- content=f'Sample {k} for message.'
49
- ),
50
- logprobs=None,
51
- ))
52
- return pg.Dict(choices=choices, usage=openai.Usage(
53
- prompt_tokens=100,
54
- completion_tokens=100,
55
- total_tokens=200,
56
- ))
57
19
 
58
20
 
59
- def mock_chat_completion_query_vision(messages, *, n=1, **kwargs):
60
- del kwargs
61
- choices = []
62
- urls = [
63
- c['image_url'] for c in messages[0]['content'] if c['type'] == 'image_url'
64
- ]
65
- for k in range(n):
66
- choices.append(pg.Dict(
67
- message=pg.Dict(
68
- content=f'Sample {k} for message: {"".join(urls)}'
69
- ),
70
- logprobs=None,
71
- ))
72
- return pg.Dict(choices=choices, usage=openai.Usage(
73
- prompt_tokens=100,
74
- completion_tokens=100,
75
- total_tokens=200,
76
- ))
21
+ class OpenAITest(unittest.TestCase):
22
+ """Tests for OpenAI language model."""
77
23
 
24
+ def test_dir(self):
25
+ self.assertIn('gpt-4-turbo', openai.OpenAI.dir())
78
26
 
79
- class OpenaiTest(unittest.TestCase):
80
- """Tests for OpenAI language model."""
27
+ def test_key(self):
28
+ with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
29
+ openai.Gpt4()('hi')
81
30
 
82
31
  def test_model_id(self):
83
32
  self.assertEqual(
@@ -88,136 +37,48 @@ class OpenaiTest(unittest.TestCase):
88
37
  openai.Gpt35(api_key='test_key').resource_id, 'OpenAI(text-davinci-003)'
89
38
  )
90
39
 
40
+ def test_headers(self):
41
+ self.assertEqual(
42
+ openai.Gpt35(api_key='test_key').headers,
43
+ {
44
+ 'Content-Type': 'application/json',
45
+ 'Authorization': 'Bearer test_key',
46
+ },
47
+ )
48
+
91
49
  def test_max_concurrency(self):
92
- self.assertEqual(openai.Gpt35(api_key='test_key').max_concurrency, 8)
50
+ self.assertGreater(openai.Gpt35(api_key='test_key').max_concurrency, 0)
93
51
 
94
- def test_get_request_args(self):
52
+ def test_request_args(self):
95
53
  self.assertEqual(
96
- openai.Gpt35(api_key='test_key', timeout=90.0)._get_request_args(
54
+ openai.Gpt4(api_key='test_key')._request_args(
97
55
  lf.LMSamplingOptions(
98
- temperature=2.0,
99
- n=2,
100
- max_tokens=4096,
101
- top_p=1.0)),
102
- dict(
103
- engine='text-davinci-003',
104
- logprobs=False,
105
- top_logprobs=None,
106
- n=2,
107
- temperature=2.0,
108
- max_tokens=4096,
109
- stream=False,
110
- timeout=90.0,
111
- top_p=1.0,
112
- )
113
- )
114
- self.assertEqual(
115
- openai.Gpt4(api_key='test_key')._get_request_args(
116
- lf.LMSamplingOptions(temperature=1.0, stop=['\n'], n=1)
56
+ temperature=1.0, stop=['\n'], n=1, random_seed=123
57
+ )
117
58
  ),
118
59
  dict(
119
60
  model='gpt-4',
120
- logprobs=False,
121
61
  top_logprobs=None,
122
62
  n=1,
123
63
  temperature=1.0,
124
- max_tokens=1024,
125
- stream=False,
126
- timeout=120.0,
127
64
  stop=['\n'],
65
+ seed=123,
128
66
  ),
129
67
  )
130
-
131
- def test_call_completion(self):
132
- with mock.patch('openai.Completion.create') as mock_completion:
133
- mock_completion.side_effect = mock_completion_query
134
- lm = openai.OpenAI(api_key='test_key', model='text-davinci-003')
135
- self.assertEqual(
136
- lm('hello', sampling_options=lf.LMSamplingOptions(n=2)),
137
- 'Sample 0 for prompt 0.',
138
- )
139
-
140
- def test_call_chat_completion(self):
141
- with mock.patch('openai.ChatCompletion.create') as mock_chat_completion:
142
- mock_chat_completion.side_effect = mock_chat_completion_query
143
- lm = openai.OpenAI(api_key='test_key', model='gpt-4')
144
- self.assertEqual(
145
- lm('hello', sampling_options=lf.LMSamplingOptions(n=2)),
146
- 'Sample 0 for message.',
147
- )
148
-
149
- def test_call_chat_completion_vision(self):
150
- with mock.patch('openai.ChatCompletion.create') as mock_chat_completion:
151
- mock_chat_completion.side_effect = mock_chat_completion_query_vision
152
- lm = openai.Gpt4TurboVision(api_key='test_key')
153
- self.assertEqual(
154
- lm(
155
- lf.UserMessage(
156
- 'hello {{image}}',
157
- image=lf_modalities.Image.from_uri('https://fake/image')
158
- ),
159
- sampling_options=lf.LMSamplingOptions(n=2)
160
- ),
161
- 'Sample 0 for message: https://fake/image',
162
- )
163
-
164
- def test_sample_completion(self):
165
- with mock.patch('openai.Completion.create') as mock_completion:
166
- mock_completion.side_effect = mock_completion_query
167
- lm = openai.OpenAI(api_key='test_key', model='text-davinci-003')
168
- results = lm.sample(
169
- ['hello', 'bye'], sampling_options=lf.LMSamplingOptions(n=3)
170
- )
171
-
172
- self.assertEqual(len(results), 2)
173
- self.assertEqual(results[0], openai.LMSamplingResult([
174
- lf.LMSample('Sample 0 for prompt 0.', score=0.0),
175
- lf.LMSample('Sample 1 for prompt 0.', score=0.1),
176
- lf.LMSample('Sample 2 for prompt 0.', score=0.2),
177
- ], usage=openai.Usage(
178
- prompt_tokens=100, completion_tokens=100, total_tokens=200)))
179
-
180
- self.assertEqual(results[1], openai.LMSamplingResult([
181
- lf.LMSample('Sample 0 for prompt 1.', score=0.0),
182
- lf.LMSample('Sample 1 for prompt 1.', score=0.1),
183
- lf.LMSample('Sample 2 for prompt 1.', score=0.2),
184
- ]))
185
-
186
- def test_sample_chat_completion(self):
187
- with mock.patch('openai.ChatCompletion.create') as mock_chat_completion:
188
- mock_chat_completion.side_effect = mock_chat_completion_query
189
- lm = openai.OpenAI(api_key='test_key', model='gpt-4')
190
- results = lm.sample(
191
- ['hello', 'bye'], sampling_options=lf.LMSamplingOptions(n=3)
68
+ with self.assertRaisesRegex(RuntimeError, '`logprobs` is not supported.*'):
69
+ openai.GptO1Preview(api_key='test_key')._request_args(
70
+ lf.LMSamplingOptions(
71
+ temperature=1.0, logprobs=True
72
+ )
192
73
  )
193
74
 
194
- self.assertEqual(len(results), 2)
195
- self.assertEqual(results[0], openai.LMSamplingResult([
196
- lf.LMSample('Sample 0 for message.', score=0.0),
197
- lf.LMSample('Sample 1 for message.', score=0.0),
198
- lf.LMSample('Sample 2 for message.', score=0.0),
199
- ], usage=openai.Usage(
200
- prompt_tokens=100, completion_tokens=100, total_tokens=200)))
201
- self.assertEqual(results[1], openai.LMSamplingResult([
202
- lf.LMSample('Sample 0 for message.', score=0.0),
203
- lf.LMSample('Sample 1 for message.', score=0.0),
204
- lf.LMSample('Sample 2 for message.', score=0.0),
205
- ], usage=openai.Usage(
206
- prompt_tokens=100, completion_tokens=100, total_tokens=200)))
207
-
208
- def test_sample_with_contextual_options(self):
209
- with mock.patch('openai.Completion.create') as mock_completion:
210
- mock_completion.side_effect = mock_completion_query
211
- lm = openai.OpenAI(api_key='test_key', model='text-davinci-003')
212
- with lf.use_settings(sampling_options=lf.LMSamplingOptions(n=2)):
213
- results = lm.sample(['hello'])
214
-
215
- self.assertEqual(len(results), 1)
216
- self.assertEqual(results[0], openai.LMSamplingResult([
217
- lf.LMSample('Sample 0 for prompt 0.', score=0.0),
218
- lf.LMSample('Sample 1 for prompt 0.', score=0.1),
219
- ], usage=openai.Usage(
220
- prompt_tokens=100, completion_tokens=100, total_tokens=200)))
75
+ def test_estimate_cost(self):
76
+ self.assertEqual(
77
+ openai.Gpt4(api_key='test_key').estimate_cost(
78
+ num_input_tokens=100, num_output_tokens=100
79
+ ),
80
+ 0.009
81
+ )
221
82
 
222
83
 
223
84
  if __name__ == '__main__':
@@ -0,0 +1,113 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Base class for language models through REST APIs."""
15
+
16
+ import functools
17
+ from typing import Annotated, Any, Callable
18
+
19
+ import langfun.core as lf
20
+ import requests
21
+
22
+
23
+ class REST(lf.LanguageModel):
24
+ """REST-based language model."""
25
+
26
+ api_endpoint: Annotated[
27
+ str,
28
+ 'The endpoint of the REST API.'
29
+ ]
30
+
31
+ request: Annotated[
32
+ Callable[[lf.Message, lf.LMSamplingOptions], dict[str, Any]],
33
+ 'A function to convert a Langfun message to a JSON request.'
34
+ ]
35
+
36
+ result: Annotated[
37
+ Callable[[dict[str, Any]], lf.LMSamplingResult],
38
+ 'A function to convert a JSON response to an LMSamplingResult.'
39
+ ]
40
+
41
+ model: Annotated[
42
+ str | None,
43
+ 'Model ID.'
44
+ ] = None
45
+
46
+ headers: Annotated[
47
+ dict[str, Any] | None,
48
+ 'The headers for the REST API.'
49
+ ] = None
50
+
51
+ @property
52
+ def model_id(self) -> str:
53
+ """Returns a string to identify the model."""
54
+ return self.model or 'unknown'
55
+
56
+ @functools.cached_property
57
+ def _api_initialized(self) -> bool:
58
+ """Returns whether the API is initialized."""
59
+ self._initialize()
60
+ return True
61
+
62
+ def _initialize(self) -> None:
63
+ """Initializes the API. Subclasses can override."""
64
+
65
+ @functools.cached_property
66
+ def _session(self) -> requests.Session:
67
+ assert self._api_initialized
68
+ s = requests.Session()
69
+ s.headers.update(self.headers or {})
70
+ return s
71
+
72
+ def _on_bound(self):
73
+ super()._on_bound()
74
+ self.__dict__.pop('_session', None)
75
+ self.__dict__.pop('_api_initialized', None)
76
+
77
+ def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
78
+ assert self._api_initialized
79
+ return self._parallel_execute_with_currency_control(
80
+ self._sample_single, prompts
81
+ )
82
+
83
+ def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
84
+ try:
85
+ response = self._session.post(
86
+ self.api_endpoint,
87
+ json=self.request(prompt, self.sampling_options),
88
+ timeout=self.timeout,
89
+ )
90
+ return self._parse_response(response)
91
+ except ConnectionError as e:
92
+ raise lf.LMError(str(e)) from e
93
+
94
+ def _error(self, status_code: int, content: str) -> lf.LMError:
95
+ if status_code == 429:
96
+ error_cls = lf.RateLimitError
97
+ elif status_code in (
98
+ 500, # Server side issue (might be bug).
99
+ 502, # Bad gateway (upstream issue, might retry).
100
+ 503, # Servers currently under load, retry after a brief wait.
101
+ 529, # Overloaded, retry after a brief wait.
102
+ ):
103
+ error_cls = lf.TemporaryLMError
104
+ else:
105
+ error_cls = lf.LMError
106
+ return error_cls(f'{status_code}: {content}')
107
+
108
+ def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
109
+ """Parses Anthropic's response."""
110
+ if response.status_code == 200:
111
+ return self.result(response.json())
112
+ else:
113
+ raise self._error(response.status_code, response.content)
@@ -0,0 +1,111 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for REST models."""
15
+
16
+ from typing import Any
17
+ import unittest
18
+ from unittest import mock
19
+ import langfun.core as lf
20
+ from langfun.core.llms import rest
21
+ import pyglove as pg
22
+ import requests
23
+
24
+
25
+ def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
26
+ del url, kwargs
27
+ response = requests.Response()
28
+ response.status_code = 200
29
+ response._content = pg.to_json_str({
30
+ 'content': [(
31
+ f'hello with temperature={json.get("temperature")}, '
32
+ f'top_k={json.get("top_k")}, '
33
+ f'top_p={json.get("top_p")}, '
34
+ f'max_tokens={json.get("max_tokens")}, '
35
+ f'stop={json.get("stop_sequences")}.'
36
+ )],
37
+ }).encode()
38
+ return response
39
+
40
+
41
+ def mock_requests_post_error(status_code, error_type, error_message):
42
+ def _mock_requests(url: str, json: dict[str, Any], **kwargs):
43
+ del url, json, kwargs
44
+ response = requests.Response()
45
+ response.status_code = status_code
46
+ response._content = pg.to_json_str(
47
+ {
48
+ 'error': {
49
+ 'type': error_type,
50
+ 'message': error_message,
51
+ }
52
+ }
53
+ ).encode()
54
+ return response
55
+
56
+ return _mock_requests
57
+
58
+
59
+ class RestTest(unittest.TestCase):
60
+
61
+ def setUp(self):
62
+ super().setUp()
63
+ self._lm = rest.REST(
64
+ api_endpoint='https://fake-api.com',
65
+ request=lambda x, o: dict(
66
+ model='test-model',
67
+ prompt=x.text,
68
+ temperature=0.0,
69
+ top_k=0.1,
70
+ top_p=0.2,
71
+ stop_sequences=['\n'],
72
+ max_tokens=4096,
73
+ ),
74
+ result=lambda x: lf.LMSamplingResult(
75
+ [lf.LMSample(c) for c in x['content']]),
76
+ headers=dict(api_key='fake_key'),
77
+ )
78
+
79
+ def test_call(self):
80
+ with mock.patch('requests.Session.post') as mock_request:
81
+ mock_request.side_effect = mock_requests_post
82
+ self.assertEqual(self._lm.model_id, 'unknown')
83
+ response = self._lm(
84
+ 'hello', temperature=0.0, top_k=0.1, top_p=0.2, stop=['\n'])
85
+ self.assertEqual(
86
+ response.text,
87
+ (
88
+ 'hello with temperature=0.0, top_k=0.1, top_p=0.2, '
89
+ "max_tokens=4096, stop=['\\n']."
90
+ ),
91
+ )
92
+ self.assertIsInstance(response.usage, lf.UsageNotAvailable)
93
+
94
+ def test_call_errors(self):
95
+ for status_code, error_type, error_message in [
96
+ (429, 'rate_limit', 'Rate limit exceeded.'),
97
+ (529, 'service_unavailable', 'Service unavailable.'),
98
+ (500, 'bad_request', 'Bad request.'),
99
+ ]:
100
+ with mock.patch('requests.Session.post') as mock_mm_request:
101
+ mock_mm_request.side_effect = mock_requests_post_error(
102
+ status_code, error_type, error_message
103
+ )
104
+ with self.assertRaisesRegex(
105
+ Exception, f'.*{status_code}: .*{error_message}'
106
+ ):
107
+ self._lm('hello', max_attempts=1)
108
+
109
+
110
+ if __name__ == '__main__':
111
+ unittest.main()
@@ -0,0 +1,192 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Vertex AI generative models."""
15
+
16
+ import functools
17
+ import os
18
+ from typing import Annotated, Any
19
+
20
+ import langfun.core as lf
21
+ from langfun.core.llms import gemini
22
+ import pyglove as pg
23
+
24
+ try:
25
+ # pylint: disable=g-import-not-at-top
26
+ from google import auth as google_auth
27
+ from google.auth import credentials as credentials_lib
28
+ from google.auth.transport import requests as auth_requests
29
+ # pylint: enable=g-import-not-at-top
30
+
31
+ Credentials = credentials_lib.Credentials
32
+ except ImportError:
33
+ google_auth = None
34
+ credentials_lib = None
35
+ auth_requests = None
36
+ Credentials = Any
37
+
38
+
39
+ @lf.use_init_args(['model'])
40
+ @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
41
+ class VertexAI(gemini.Gemini):
42
+ """Language model served on VertexAI with REST API."""
43
+
44
+ project: Annotated[
45
+ str | None,
46
+ (
47
+ 'Vertex AI project ID. Or set from environment variable '
48
+ 'VERTEXAI_PROJECT.'
49
+ ),
50
+ ] = None
51
+
52
+ location: Annotated[
53
+ str | None,
54
+ (
55
+ 'Vertex AI service location. Or set from environment variable '
56
+ 'VERTEXAI_LOCATION.'
57
+ ),
58
+ ] = None
59
+
60
+ credentials: Annotated[
61
+ Credentials | None,
62
+ (
63
+ 'Credentials to use. If None, the default credentials to the '
64
+ 'environment will be used.'
65
+ ),
66
+ ] = None
67
+
68
+ def _on_bound(self):
69
+ super()._on_bound()
70
+ if google_auth is None:
71
+ raise ValueError(
72
+ 'Please install "langfun[llm-google-vertex]" to use Vertex AI models.'
73
+ )
74
+ self._project = None
75
+ self._credentials = None
76
+
77
+ def _initialize(self):
78
+ project = self.project or os.environ.get('VERTEXAI_PROJECT', None)
79
+ if not project:
80
+ raise ValueError(
81
+ 'Please specify `project` during `__init__` or set environment '
82
+ 'variable `VERTEXAI_PROJECT` with your Vertex AI project ID.'
83
+ )
84
+
85
+ location = self.location or os.environ.get('VERTEXAI_LOCATION', None)
86
+ if not location:
87
+ raise ValueError(
88
+ 'Please specify `location` during `__init__` or set environment '
89
+ 'variable `VERTEXAI_LOCATION` with your Vertex AI service location.'
90
+ )
91
+
92
+ self._project = project
93
+ self._location = location
94
+
95
+ credentials = self.credentials
96
+ if credentials is None:
97
+ # Use default credentials.
98
+ credentials = google_auth.default(
99
+ scopes=['https://www.googleapis.com/auth/cloud-platform']
100
+ )
101
+ self._credentials = credentials
102
+
103
+ @property
104
+ def model_id(self) -> str:
105
+ """Returns a string to identify the model."""
106
+ return f'VertexAI({self.model})'
107
+
108
+ @functools.cached_property
109
+ def _session(self):
110
+ assert self._api_initialized
111
+ assert self._credentials is not None
112
+ assert auth_requests is not None
113
+ s = auth_requests.AuthorizedSession(self._credentials)
114
+ s.headers.update(self.headers or {})
115
+ return s
116
+
117
+ @property
118
+ def api_endpoint(self) -> str:
119
+ assert self._api_initialized
120
+ return (
121
+ f'https://{self._location}-aiplatform.googleapis.com/v1/projects/'
122
+ f'{self._project}/locations/{self._location}/publishers/google/'
123
+ f'models/{self.model}:generateContent'
124
+ )
125
+
126
+
127
+ class VertexAIGeminiFlash2_0ThinkingExp_20241219(VertexAI): # pylint: disable=invalid-name
128
+ """Vertex AI Gemini Flash 2.0 Thinking model launched on 12/19/2024."""
129
+
130
+ api_version = 'v1alpha'
131
+ model = 'gemini-2.0-flash-thinking-exp-1219'
132
+ timeout = None
133
+
134
+
135
+ class VertexAIGeminiFlash2_0Exp(VertexAI): # pylint: disable=invalid-name
136
+ """Vertex AI Gemini 2.0 Flash model."""
137
+
138
+ model = 'gemini-2.0-flash-exp'
139
+
140
+
141
+ class VertexAIGeminiExp_20241206(VertexAI): # pylint: disable=invalid-name
142
+ """Vertex AI Gemini Experimental model launched on 12/06/2024."""
143
+
144
+ model = 'gemini-exp-1206'
145
+
146
+
147
+ class VertexAIGeminiExp_20241114(VertexAI): # pylint: disable=invalid-name
148
+ """Vertex AI Gemini Experimental model launched on 11/14/2024."""
149
+
150
+ model = 'gemini-exp-1114'
151
+
152
+
153
+ class VertexAIGeminiPro1_5(VertexAI): # pylint: disable=invalid-name
154
+ """Vertex AI Gemini 1.5 Pro model."""
155
+
156
+ model = 'gemini-1.5-pro-latest'
157
+
158
+
159
+ class VertexAIGeminiPro1_5_002(VertexAI): # pylint: disable=invalid-name
160
+ """Vertex AI Gemini 1.5 Pro model."""
161
+
162
+ model = 'gemini-1.5-pro-002'
163
+
164
+
165
+ class VertexAIGeminiPro1_5_001(VertexAI): # pylint: disable=invalid-name
166
+ """Vertex AI Gemini 1.5 Pro model."""
167
+
168
+ model = 'gemini-1.5-pro-001'
169
+
170
+
171
+ class VertexAIGeminiFlash1_5(VertexAI): # pylint: disable=invalid-name
172
+ """Vertex AI Gemini 1.5 Flash model."""
173
+
174
+ model = 'gemini-1.5-flash'
175
+
176
+
177
+ class VertexAIGeminiFlash1_5_002(VertexAI): # pylint: disable=invalid-name
178
+ """Vertex AI Gemini 1.5 Flash model."""
179
+
180
+ model = 'gemini-1.5-flash-002'
181
+
182
+
183
+ class VertexAIGeminiFlash1_5_001(VertexAI): # pylint: disable=invalid-name
184
+ """Vertex AI Gemini 1.5 Flash model."""
185
+
186
+ model = 'gemini-1.5-flash-001'
187
+
188
+
189
+ class VertexAIGeminiPro1(VertexAI): # pylint: disable=invalid-name
190
+ """Vertex AI Gemini 1.0 Pro model."""
191
+
192
+ model = 'gemini-1.0-pro'