langfun 0.0.2.dev20240422__py3-none-any.whl → 0.0.2.dev20240425__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +1 -0
- langfun/core/component.py +6 -0
- langfun/core/component_test.py +1 -0
- langfun/core/eval/__init__.py +2 -0
- langfun/core/eval/base.py +175 -17
- langfun/core/eval/base_test.py +34 -6
- langfun/core/eval/matching.py +18 -1
- langfun/core/eval/matching_test.py +2 -1
- langfun/core/eval/scoring.py +11 -1
- langfun/core/eval/scoring_test.py +2 -1
- langfun/core/language_model.py +14 -0
- langfun/core/language_model_test.py +32 -0
- langfun/core/llms/anthropic.py +36 -22
- langfun/core/llms/anthropic_test.py +7 -7
- langfun/core/llms/groq.py +27 -18
- langfun/core/llms/groq_test.py +5 -5
- langfun/core/llms/openai.py +55 -50
- langfun/core/llms/openai_test.py +3 -3
- langfun/core/structured/__init__.py +1 -0
- langfun/core/structured/completion_test.py +1 -2
- langfun/core/structured/mapping.py +38 -1
- langfun/core/structured/mapping_test.py +17 -0
- langfun/core/structured/parsing_test.py +2 -4
- langfun/core/structured/prompting_test.py +2 -4
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/template.py +26 -8
- langfun/core/template_test.py +9 -0
- {langfun-0.0.2.dev20240422.dist-info → langfun-0.0.2.dev20240425.dist-info}/METADATA +3 -2
- {langfun-0.0.2.dev20240422.dist-info → langfun-0.0.2.dev20240425.dist-info}/RECORD +32 -32
- {langfun-0.0.2.dev20240422.dist-info → langfun-0.0.2.dev20240425.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240422.dist-info → langfun-0.0.2.dev20240425.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240422.dist-info → langfun-0.0.2.dev20240425.dist-info}/top_level.txt +0 -0
langfun/core/llms/anthropic.py
CHANGED
@@ -26,12 +26,15 @@ import requests
|
|
26
26
|
|
27
27
|
SUPPORTED_MODELS_AND_SETTINGS = {
|
28
28
|
# See https://docs.anthropic.com/claude/docs/models-overview
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
'claude-
|
33
|
-
'claude-
|
34
|
-
'claude-
|
29
|
+
# Rate limits from https://docs.anthropic.com/claude/reference/rate-limits
|
30
|
+
# RPM/TPM for Claude-2.1, Claude-2.0, and Claude-Instant-1.2 estimated
|
31
|
+
# as RPM/TPM of the largest-available model (Claude-3-Opus).
|
32
|
+
'claude-3-opus-20240229': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
33
|
+
'claude-3-sonnet-20240229': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
34
|
+
'claude-3-haiku-20240307': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
35
|
+
'claude-2.1': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
36
|
+
'claude-2.0': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
37
|
+
'claude-instant-1.2': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
35
38
|
}
|
36
39
|
|
37
40
|
|
@@ -81,6 +84,7 @@ class Anthropic(lf.LanguageModel):
|
|
81
84
|
super()._on_bound()
|
82
85
|
self._api_key = None
|
83
86
|
self.__dict__.pop('_api_initialized', None)
|
87
|
+
self.__dict__.pop('_session', None)
|
84
88
|
|
85
89
|
@functools.cached_property
|
86
90
|
def _api_initialized(self):
|
@@ -93,6 +97,17 @@ class Anthropic(lf.LanguageModel):
|
|
93
97
|
self._api_key = api_key
|
94
98
|
return True
|
95
99
|
|
100
|
+
@functools.cached_property
|
101
|
+
def _session(self) -> requests.Session:
|
102
|
+
assert self._api_initialized
|
103
|
+
s = requests.Session()
|
104
|
+
s.headers.update({
|
105
|
+
'x-api-key': self._api_key,
|
106
|
+
'anthropic-version': _ANTHROPIC_API_VERSION,
|
107
|
+
'content-type': 'application/json',
|
108
|
+
})
|
109
|
+
return s
|
110
|
+
|
96
111
|
@property
|
97
112
|
def model_id(self) -> str:
|
98
113
|
"""Returns a string to identify the model."""
|
@@ -100,7 +115,11 @@ class Anthropic(lf.LanguageModel):
|
|
100
115
|
|
101
116
|
@property
|
102
117
|
def max_concurrency(self) -> int:
|
103
|
-
|
118
|
+
rpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('rpm', 0)
|
119
|
+
tpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('tpm', 0)
|
120
|
+
return self.rate_to_max_concurrency(
|
121
|
+
requests_per_min=rpm, tokens_per_min=tpm
|
122
|
+
)
|
104
123
|
|
105
124
|
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
106
125
|
assert self._api_initialized
|
@@ -165,8 +184,8 @@ class Anthropic(lf.LanguageModel):
|
|
165
184
|
def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
|
166
185
|
"""Parses Anthropic's response."""
|
167
186
|
# NOTE(daiyip): Refer https://docs.anthropic.com/claude/reference/errors
|
168
|
-
output = response.json()
|
169
187
|
if response.status_code == 200:
|
188
|
+
output = response.json()
|
170
189
|
message = self._message_from_content(output['content'])
|
171
190
|
input_tokens = output['usage']['input_tokens']
|
172
191
|
output_tokens = output['usage']['output_tokens']
|
@@ -181,12 +200,11 @@ class Anthropic(lf.LanguageModel):
|
|
181
200
|
else:
|
182
201
|
if response.status_code == 429:
|
183
202
|
error_cls = RateLimitError
|
184
|
-
elif response.status_code
|
203
|
+
elif response.status_code in (502, 529):
|
185
204
|
error_cls = OverloadedError
|
186
205
|
else:
|
187
206
|
error_cls = AnthropicError
|
188
|
-
|
189
|
-
raise error_cls(f'{error["type"]}: {error["message"]}')
|
207
|
+
raise error_cls(f'{response.status_code}: {response.content}')
|
190
208
|
|
191
209
|
def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
192
210
|
request = dict()
|
@@ -198,17 +216,13 @@ class Anthropic(lf.LanguageModel):
|
|
198
216
|
]
|
199
217
|
)
|
200
218
|
)
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
},
|
209
|
-
timeout=self.timeout,
|
210
|
-
)
|
211
|
-
return self._parse_response(response)
|
219
|
+
try:
|
220
|
+
response = self._session.post(
|
221
|
+
_ANTHROPIC_MESSAGE_API_ENDPOINT, json=request, timeout=self.timeout,
|
222
|
+
)
|
223
|
+
return self._parse_response(response)
|
224
|
+
except ConnectionError as e:
|
225
|
+
raise OverloadedError(str(e)) from e
|
212
226
|
|
213
227
|
|
214
228
|
class Claude3(Anthropic):
|
@@ -98,20 +98,20 @@ def mock_requests_post_error(status_code, error_type, error_message):
|
|
98
98
|
return _mock_requests
|
99
99
|
|
100
100
|
|
101
|
-
class
|
101
|
+
class AnthropicTest(unittest.TestCase):
|
102
102
|
|
103
103
|
def test_basics(self):
|
104
104
|
self.assertEqual(
|
105
105
|
anthropic.Claude3Haiku().model_id, 'claude-3-haiku-20240307'
|
106
106
|
)
|
107
|
-
self.
|
107
|
+
self.assertGreater(anthropic.Claude3Haiku().max_concurrency, 0)
|
108
108
|
|
109
109
|
def test_api_key(self):
|
110
110
|
lm = anthropic.Claude3Haiku()
|
111
111
|
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
112
112
|
lm('hi')
|
113
113
|
|
114
|
-
with mock.patch('requests.post') as mock_request:
|
114
|
+
with mock.patch('requests.Session.post') as mock_request:
|
115
115
|
mock_request.side_effect = mock_requests_post
|
116
116
|
|
117
117
|
lm = anthropic.Claude3Haiku(api_key='fake key')
|
@@ -123,7 +123,7 @@ class AuthropicTest(unittest.TestCase):
|
|
123
123
|
del os.environ['ANTHROPIC_API_KEY']
|
124
124
|
|
125
125
|
def test_call(self):
|
126
|
-
with mock.patch('requests.post') as mock_request:
|
126
|
+
with mock.patch('requests.Session.post') as mock_request:
|
127
127
|
mock_request.side_effect = mock_requests_post
|
128
128
|
lm = anthropic.Claude3Haiku(api_key='fake_key')
|
129
129
|
response = lm('hello', temperature=0.0, top_k=0.1, top_p=0.2, stop=['\n'])
|
@@ -140,7 +140,7 @@ class AuthropicTest(unittest.TestCase):
|
|
140
140
|
self.assertIsNotNone(response.usage.total_tokens, 3)
|
141
141
|
|
142
142
|
def test_mm_call(self):
|
143
|
-
with mock.patch('requests.post') as mock_mm_request:
|
143
|
+
with mock.patch('requests.Session.post') as mock_mm_request:
|
144
144
|
mock_mm_request.side_effect = mock_mm_requests_post
|
145
145
|
lm = anthropic.Claude3Haiku(api_key='fake_key')
|
146
146
|
response = lm(lf_modalities.Image.from_bytes(image_content), lm=lm)
|
@@ -152,13 +152,13 @@ class AuthropicTest(unittest.TestCase):
|
|
152
152
|
(529, 'service_unavailable', 'Service unavailable.'),
|
153
153
|
(500, 'bad_request', 'Bad request.'),
|
154
154
|
]:
|
155
|
-
with mock.patch('requests.post') as mock_mm_request:
|
155
|
+
with mock.patch('requests.Session.post') as mock_mm_request:
|
156
156
|
mock_mm_request.side_effect = mock_requests_post_error(
|
157
157
|
status_code, error_type, error_message
|
158
158
|
)
|
159
159
|
lm = anthropic.Claude3Haiku(api_key='fake_key')
|
160
160
|
with self.assertRaisesRegex(
|
161
|
-
Exception, f'{
|
161
|
+
Exception, f'.*{status_code}: .*{error_message}'
|
162
162
|
):
|
163
163
|
lm('hello', lm=lm, max_attempts=1)
|
164
164
|
|
langfun/core/llms/groq.py
CHANGED
@@ -78,6 +78,7 @@ class Groq(lf.LanguageModel):
|
|
78
78
|
super()._on_bound()
|
79
79
|
self._api_key = None
|
80
80
|
self.__dict__.pop('_api_initialized', None)
|
81
|
+
self.__dict__.pop('_session', None)
|
81
82
|
|
82
83
|
@functools.cached_property
|
83
84
|
def _api_initialized(self):
|
@@ -85,11 +86,21 @@ class Groq(lf.LanguageModel):
|
|
85
86
|
if not api_key:
|
86
87
|
raise ValueError(
|
87
88
|
'Please specify `api_key` during `__init__` or set environment '
|
88
|
-
'variable `GROQ_API_KEY` with your
|
89
|
+
'variable `GROQ_API_KEY` with your Groq API key.'
|
89
90
|
)
|
90
91
|
self._api_key = api_key
|
91
92
|
return True
|
92
93
|
|
94
|
+
@functools.cached_property
|
95
|
+
def _session(self) -> requests.Session:
|
96
|
+
assert self._api_initialized
|
97
|
+
s = requests.Session()
|
98
|
+
s.headers.update({
|
99
|
+
'Authorization': f'Bearer {self._api_key}',
|
100
|
+
'Content-Type': 'application/json',
|
101
|
+
})
|
102
|
+
return s
|
103
|
+
|
93
104
|
@property
|
94
105
|
def model_id(self) -> str:
|
95
106
|
"""Returns a string to identify the model."""
|
@@ -119,7 +130,7 @@ class Groq(lf.LanguageModel):
|
|
119
130
|
return args
|
120
131
|
|
121
132
|
def _content_from_message(self, prompt: lf.Message) -> list[dict[str, Any]]:
|
122
|
-
"""Converts an message to
|
133
|
+
"""Converts an message to Groq's content protocol (list of dicts)."""
|
123
134
|
# Refer: https://platform.openai.com/docs/api-reference/chat/create
|
124
135
|
content = []
|
125
136
|
for chunk in prompt.chunk():
|
@@ -138,7 +149,7 @@ class Groq(lf.LanguageModel):
|
|
138
149
|
return content
|
139
150
|
|
140
151
|
def _message_from_choice(self, choice: dict[str, Any]) -> lf.Message:
|
141
|
-
"""Converts
|
152
|
+
"""Converts Groq's content protocol to message."""
|
142
153
|
# Refer: https://platform.openai.com/docs/api-reference/chat/create
|
143
154
|
content = choice['message']['content']
|
144
155
|
if isinstance(content, str):
|
@@ -148,10 +159,10 @@ class Groq(lf.LanguageModel):
|
|
148
159
|
)
|
149
160
|
|
150
161
|
def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
|
151
|
-
"""Parses
|
162
|
+
"""Parses Groq's response."""
|
152
163
|
# Refer: https://platform.openai.com/docs/api-reference/chat/object
|
153
|
-
output = response.json()
|
154
164
|
if response.status_code == 200:
|
165
|
+
output = response.json()
|
155
166
|
samples = [
|
156
167
|
lf.LMSample(self._message_from_choice(choice), score=0.0)
|
157
168
|
for choice in output['choices']
|
@@ -169,12 +180,11 @@ class Groq(lf.LanguageModel):
|
|
169
180
|
# https://platform.openai.com/docs/guides/error-codes/api-errors
|
170
181
|
if response.status_code == 429:
|
171
182
|
error_cls = RateLimitError
|
172
|
-
elif response.status_code in (500, 503):
|
183
|
+
elif response.status_code in (500, 502, 503):
|
173
184
|
error_cls = OverloadedError
|
174
185
|
else:
|
175
186
|
error_cls = GroqError
|
176
|
-
|
177
|
-
raise error_cls(f'{error["type"]}: {error["message"]}')
|
187
|
+
raise error_cls(f'{response.status_code}: {response.content}')
|
178
188
|
|
179
189
|
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
180
190
|
assert self._api_initialized
|
@@ -194,16 +204,15 @@ class Groq(lf.LanguageModel):
|
|
194
204
|
]
|
195
205
|
)
|
196
206
|
)
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
return self._parse_response(response)
|
207
|
+
try:
|
208
|
+
response = self._session.post(
|
209
|
+
_CHAT_COMPLETE_API_ENDPOINT,
|
210
|
+
json=request,
|
211
|
+
timeout=self.timeout,
|
212
|
+
)
|
213
|
+
return self._parse_response(response)
|
214
|
+
except ConnectionError as e:
|
215
|
+
raise OverloadedError(str(e)) from e
|
207
216
|
|
208
217
|
|
209
218
|
class GroqLlama3_8B(Groq): # pylint: disable=invalid-name
|
langfun/core/llms/groq_test.py
CHANGED
@@ -107,7 +107,7 @@ class AuthropicTest(unittest.TestCase):
|
|
107
107
|
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
108
108
|
lm('hi')
|
109
109
|
|
110
|
-
with mock.patch('requests.post') as mock_request:
|
110
|
+
with mock.patch('requests.Session.post') as mock_request:
|
111
111
|
mock_request.side_effect = mock_requests_post
|
112
112
|
|
113
113
|
lm = groq.GroqMistral_8x7B(api_key='fake key')
|
@@ -119,7 +119,7 @@ class AuthropicTest(unittest.TestCase):
|
|
119
119
|
del os.environ['GROQ_API_KEY']
|
120
120
|
|
121
121
|
def test_call(self):
|
122
|
-
with mock.patch('requests.post') as mock_request:
|
122
|
+
with mock.patch('requests.Session.post') as mock_request:
|
123
123
|
mock_request.side_effect = mock_requests_post
|
124
124
|
lm = groq.GroqLlama3_70B(api_key='fake_key')
|
125
125
|
response = lm(
|
@@ -143,7 +143,7 @@ class AuthropicTest(unittest.TestCase):
|
|
143
143
|
self.assertIsNotNone(response.usage.total_tokens, 3)
|
144
144
|
|
145
145
|
def test_mm_call(self):
|
146
|
-
with mock.patch('requests.post') as mock_mm_request:
|
146
|
+
with mock.patch('requests.Session.post') as mock_mm_request:
|
147
147
|
mock_mm_request.side_effect = mock_mm_requests_post
|
148
148
|
lm = groq.GroqLlama3_70B(multimodal=True, api_key='fake_key')
|
149
149
|
response = lm(lf_modalities.Image.from_uri('https://fake/image.jpg'))
|
@@ -155,13 +155,13 @@ class AuthropicTest(unittest.TestCase):
|
|
155
155
|
(503, 'service_unavailable', 'Service unavailable.'),
|
156
156
|
(500, 'bad_request', 'Bad request.'),
|
157
157
|
]:
|
158
|
-
with mock.patch('requests.post') as mock_mm_request:
|
158
|
+
with mock.patch('requests.Session.post') as mock_mm_request:
|
159
159
|
mock_mm_request.side_effect = mock_requests_post_error(
|
160
160
|
status_code, error_type, error_message
|
161
161
|
)
|
162
162
|
lm = groq.GroqLlama3_70B(api_key='fake_key')
|
163
163
|
with self.assertRaisesRegex(
|
164
|
-
Exception, f'{
|
164
|
+
Exception, f'{status_code}:.*{error_type}'
|
165
165
|
):
|
166
166
|
lm('hello', lm=lm, max_attempts=1)
|
167
167
|
|
langfun/core/llms/openai.py
CHANGED
@@ -26,54 +26,55 @@ from openai import openai_object
|
|
26
26
|
import pyglove as pg
|
27
27
|
|
28
28
|
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
29
|
+
# From https://platform.openai.com/settings/organization/limits
|
30
|
+
_DEFAULT_TPM = 250000
|
31
|
+
_DEFAULT_RPM = 3000
|
32
|
+
|
33
|
+
SUPPORTED_MODELS_AND_SETTINGS = {
|
34
|
+
# Models from https://platform.openai.com/docs/models
|
35
|
+
# RPM is from https://platform.openai.com/docs/guides/rate-limits
|
36
|
+
# GPT-4-Turbo models
|
37
|
+
'gpt-4-turbo': pg.Dict(rpm=10000, tpm=1500000),
|
38
|
+
'gpt-4-turbo-2024-04-09': pg.Dict(rpm=10000, tpm=1500000),
|
39
|
+
'gpt-4-turbo-preview': pg.Dict(rpm=10000, tpm=1500000),
|
40
|
+
'gpt-4-0125-preview': pg.Dict(rpm=10000, tpm=1500000),
|
41
|
+
'gpt-4-1106-preview': pg.Dict(rpm=10000, tpm=1500000),
|
42
|
+
'gpt-4-vision-preview': pg.Dict(rpm=10000, tpm=1500000),
|
43
|
+
'gpt-4-1106-vision-preview': pg.Dict(
|
44
|
+
rpm=10000, tpm=1500000
|
45
|
+
),
|
46
|
+
# GPT-4 models
|
47
|
+
'gpt-4': pg.Dict(rpm=10000, tpm=300000),
|
48
|
+
'gpt-4-0613': pg.Dict(rpm=10000, tpm=300000),
|
49
|
+
'gpt-4-0314': pg.Dict(rpm=10000, tpm=300000),
|
50
|
+
'gpt-4-32k': pg.Dict(rpm=10000, tpm=300000),
|
51
|
+
'gpt-4-32k-0613': pg.Dict(rpm=10000, tpm=300000),
|
52
|
+
'gpt-4-32k-0314': pg.Dict(rpm=10000, tpm=300000),
|
53
|
+
# GPT-3.5-Turbo models
|
54
|
+
'gpt-3.5-turbo': pg.Dict(rpm=10000, tpm=2000000),
|
55
|
+
'gpt-3.5-turbo-0125': pg.Dict(rpm=10000, tpm=2000000),
|
56
|
+
'gpt-3.5-turbo-1106': pg.Dict(rpm=10000, tpm=2000000),
|
57
|
+
'gpt-3.5-turbo-0613': pg.Dict(rpm=10000, tpm=2000000),
|
58
|
+
'gpt-3.5-turbo-0301': pg.Dict(rpm=10000, tpm=2000000),
|
59
|
+
'gpt-3.5-turbo-16k': pg.Dict(rpm=10000, tpm=2000000),
|
60
|
+
'gpt-3.5-turbo-16k-0613': pg.Dict(rpm=10000, tpm=2000000),
|
61
|
+
'gpt-3.5-turbo-16k-0301': pg.Dict(rpm=10000, tpm=2000000),
|
62
|
+
# GPT-3.5 models
|
63
|
+
'text-davinci-003': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
64
|
+
'text-davinci-002': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
65
|
+
'code-davinci-002': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
66
|
+
# GPT-3 instruction-tuned models
|
67
|
+
'text-curie-001': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
68
|
+
'text-babbage-001': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
69
|
+
'text-ada-001': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
70
|
+
'davinci': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
71
|
+
'curie': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
72
|
+
'babbage': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
73
|
+
'ada': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
74
|
+
# GPT-3 base models
|
75
|
+
'babbage-002': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
76
|
+
'davinci-002': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
77
|
+
}
|
77
78
|
|
78
79
|
|
79
80
|
@lf.use_init_args(['model'])
|
@@ -82,7 +83,7 @@ class OpenAI(lf.LanguageModel):
|
|
82
83
|
|
83
84
|
model: pg.typing.Annotated[
|
84
85
|
pg.typing.Enum(
|
85
|
-
pg.MISSING_VALUE,
|
86
|
+
pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
|
86
87
|
),
|
87
88
|
'The name of the model to use.',
|
88
89
|
] = 'gpt-3.5-turbo'
|
@@ -134,7 +135,11 @@ class OpenAI(lf.LanguageModel):
|
|
134
135
|
|
135
136
|
@property
|
136
137
|
def max_concurrency(self) -> int:
|
137
|
-
|
138
|
+
rpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('rpm', 0)
|
139
|
+
tpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('tpm', 0)
|
140
|
+
return self.rate_to_max_concurrency(
|
141
|
+
requests_per_min=rpm, tokens_per_min=tpm
|
142
|
+
)
|
138
143
|
|
139
144
|
@classmethod
|
140
145
|
def dir(cls):
|
langfun/core/llms/openai_test.py
CHANGED
@@ -11,7 +11,7 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""Tests for
|
14
|
+
"""Tests for OpenAI models."""
|
15
15
|
|
16
16
|
import unittest
|
17
17
|
from unittest import mock
|
@@ -85,7 +85,7 @@ def mock_chat_completion_query_vision(messages, *, n=1, **kwargs):
|
|
85
85
|
)
|
86
86
|
|
87
87
|
|
88
|
-
class
|
88
|
+
class OpenAITest(unittest.TestCase):
|
89
89
|
"""Tests for OpenAI language model."""
|
90
90
|
|
91
91
|
def test_model_id(self):
|
@@ -98,7 +98,7 @@ class OpenaiTest(unittest.TestCase):
|
|
98
98
|
)
|
99
99
|
|
100
100
|
def test_max_concurrency(self):
|
101
|
-
self.
|
101
|
+
self.assertGreater(openai.Gpt35(api_key='test_key').max_concurrency, 0)
|
102
102
|
|
103
103
|
def test_get_request_args(self):
|
104
104
|
self.assertEqual(
|
@@ -51,6 +51,7 @@ from langfun.core.structured.schema_generation import default_classgen_examples
|
|
51
51
|
from langfun.core.structured.function_generation import function_gen
|
52
52
|
|
53
53
|
from langfun.core.structured.mapping import Mapping
|
54
|
+
from langfun.core.structured.mapping import MappingError
|
54
55
|
from langfun.core.structured.mapping import MappingExample
|
55
56
|
|
56
57
|
from langfun.core.structured.parsing import ParseStructure
|
@@ -17,7 +17,6 @@ import inspect
|
|
17
17
|
import unittest
|
18
18
|
|
19
19
|
import langfun.core as lf
|
20
|
-
from langfun.core import coding
|
21
20
|
from langfun.core import modalities
|
22
21
|
from langfun.core.llms import fake
|
23
22
|
from langfun.core.structured import completion
|
@@ -608,7 +607,7 @@ class CompleteStructureTest(unittest.TestCase):
|
|
608
607
|
override_attrs=True,
|
609
608
|
):
|
610
609
|
with self.assertRaisesRegex(
|
611
|
-
|
610
|
+
mapping.MappingError,
|
612
611
|
'Expect .* but encountered .*',
|
613
612
|
):
|
614
613
|
completion.complete(Activity.partial(), autofix=0)
|
@@ -20,6 +20,43 @@ from langfun.core.structured import schema as schema_lib
|
|
20
20
|
import pyglove as pg
|
21
21
|
|
22
22
|
|
23
|
+
class MappingError(Exception): # pylint: disable=g-bad-exception-name
|
24
|
+
"""Mapping error."""
|
25
|
+
|
26
|
+
def __init__(self, lm_response: lf.Message, cause: Exception):
|
27
|
+
self._lm_response = lm_response
|
28
|
+
self._cause = cause
|
29
|
+
|
30
|
+
@property
|
31
|
+
def lm_response(self) -> lf.Message:
|
32
|
+
"""Returns the LM response that failed to be mapped."""
|
33
|
+
return self._lm_response
|
34
|
+
|
35
|
+
@property
|
36
|
+
def cause(self) -> Exception:
|
37
|
+
"""Returns the cause of the error."""
|
38
|
+
return self._cause
|
39
|
+
|
40
|
+
def __str__(self) -> str:
|
41
|
+
return self.format(include_lm_response=True)
|
42
|
+
|
43
|
+
def format(self, include_lm_response: bool = True) -> str:
|
44
|
+
"""Formats the mapping error."""
|
45
|
+
r = io.StringIO()
|
46
|
+
error_message = str(self.cause).rstrip()
|
47
|
+
r.write(
|
48
|
+
lf.colored(
|
49
|
+
f'{self.cause.__class__.__name__}: {error_message}', 'magenta'
|
50
|
+
)
|
51
|
+
)
|
52
|
+
if include_lm_response:
|
53
|
+
r.write('\n\n')
|
54
|
+
r.write(lf.colored('[LM Response]', 'blue', styles=['bold']))
|
55
|
+
r.write('\n')
|
56
|
+
r.write(lf.colored(self.lm_response.text, 'blue'))
|
57
|
+
return r.getvalue()
|
58
|
+
|
59
|
+
|
23
60
|
@pg.use_init_args(['input', 'output', 'schema', 'context'])
|
24
61
|
class MappingExample(lf.NaturalLanguageFormattable, lf.Component):
|
25
62
|
"""Mapping example between text, schema and structured value."""
|
@@ -308,7 +345,7 @@ class Mapping(lf.LangFunc):
|
|
308
345
|
lm_output.result = self.postprocess_result(self.parse_result(lm_output))
|
309
346
|
except Exception as e: # pylint: disable=broad-exception-caught
|
310
347
|
if self.default == lf.RAISE_IF_HAS_ERROR:
|
311
|
-
raise e
|
348
|
+
raise MappingError(lm_output, e) from e
|
312
349
|
lm_output.result = self.default
|
313
350
|
return lm_output
|
314
351
|
|
@@ -16,10 +16,27 @@
|
|
16
16
|
import inspect
|
17
17
|
import unittest
|
18
18
|
|
19
|
+
import langfun.core as lf
|
19
20
|
from langfun.core.structured import mapping
|
20
21
|
import pyglove as pg
|
21
22
|
|
22
23
|
|
24
|
+
class MappingErrorTest(unittest.TestCase):
|
25
|
+
|
26
|
+
def test_format(self):
|
27
|
+
error = mapping.MappingError(
|
28
|
+
lf.AIMessage('hi'), ValueError('Cannot parse message.')
|
29
|
+
)
|
30
|
+
self.assertEqual(
|
31
|
+
lf.text_formatting.decolored(str(error)),
|
32
|
+
'ValueError: Cannot parse message.\n\n[LM Response]\nhi',
|
33
|
+
)
|
34
|
+
self.assertEqual(
|
35
|
+
lf.text_formatting.decolored(error.format(include_lm_response=False)),
|
36
|
+
'ValueError: Cannot parse message.',
|
37
|
+
)
|
38
|
+
|
39
|
+
|
23
40
|
class MappingExampleTest(unittest.TestCase):
|
24
41
|
|
25
42
|
def test_basics(self):
|
@@ -17,11 +17,9 @@ import inspect
|
|
17
17
|
import unittest
|
18
18
|
|
19
19
|
import langfun.core as lf
|
20
|
-
from langfun.core import coding
|
21
20
|
from langfun.core.llms import fake
|
22
21
|
from langfun.core.structured import mapping
|
23
22
|
from langfun.core.structured import parsing
|
24
|
-
from langfun.core.structured import schema as schema_lib
|
25
23
|
import pyglove as pg
|
26
24
|
|
27
25
|
|
@@ -255,7 +253,7 @@ class ParseStructurePythonTest(unittest.TestCase):
|
|
255
253
|
override_attrs=True,
|
256
254
|
):
|
257
255
|
with self.assertRaisesRegex(
|
258
|
-
|
256
|
+
mapping.MappingError,
|
259
257
|
'name .* is not defined',
|
260
258
|
):
|
261
259
|
parsing.parse('three', int)
|
@@ -546,7 +544,7 @@ class ParseStructureJsonTest(unittest.TestCase):
|
|
546
544
|
override_attrs=True,
|
547
545
|
):
|
548
546
|
with self.assertRaisesRegex(
|
549
|
-
|
547
|
+
mapping.MappingError,
|
550
548
|
'No JSON dict in the output',
|
551
549
|
):
|
552
550
|
parsing.parse('three', int, protocol='json')
|
@@ -17,12 +17,10 @@ import inspect
|
|
17
17
|
import unittest
|
18
18
|
|
19
19
|
import langfun.core as lf
|
20
|
-
from langfun.core import coding
|
21
20
|
from langfun.core import modalities
|
22
21
|
from langfun.core.llms import fake
|
23
22
|
from langfun.core.structured import mapping
|
24
23
|
from langfun.core.structured import prompting
|
25
|
-
from langfun.core.structured import schema as schema_lib
|
26
24
|
import pyglove as pg
|
27
25
|
|
28
26
|
|
@@ -439,7 +437,7 @@ class QueryStructurePythonTest(unittest.TestCase):
|
|
439
437
|
override_attrs=True,
|
440
438
|
):
|
441
439
|
with self.assertRaisesRegex(
|
442
|
-
|
440
|
+
mapping.MappingError,
|
443
441
|
'name .* is not defined',
|
444
442
|
):
|
445
443
|
prompting.query('Compute 1 + 2', int)
|
@@ -677,7 +675,7 @@ class QueryStructureJsonTest(unittest.TestCase):
|
|
677
675
|
override_attrs=True,
|
678
676
|
):
|
679
677
|
with self.assertRaisesRegex(
|
680
|
-
|
678
|
+
mapping.MappingError,
|
681
679
|
'No JSON dict in the output',
|
682
680
|
):
|
683
681
|
prompting.query('Compute 1 + 2', int, protocol='json')
|
@@ -14,8 +14,8 @@
|
|
14
14
|
import inspect
|
15
15
|
import unittest
|
16
16
|
|
17
|
-
import langfun.core.coding as lf_coding
|
18
17
|
from langfun.core.llms import fake
|
18
|
+
from langfun.core.structured import mapping
|
19
19
|
from langfun.core.structured import schema_generation
|
20
20
|
|
21
21
|
|
@@ -92,7 +92,7 @@ class GenerateClassTest(unittest.TestCase):
|
|
92
92
|
)
|
93
93
|
self.assertIs(cls.__name__, 'B')
|
94
94
|
|
95
|
-
with self.assertRaises(
|
95
|
+
with self.assertRaises(mapping.MappingError):
|
96
96
|
schema_generation.generate_class(
|
97
97
|
'Foo',
|
98
98
|
'Generate a Foo class with a field pointing to another class A',
|