langfun 0.0.2.dev20240422__py3-none-any.whl → 0.0.2.dev20240425__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. langfun/__init__.py +1 -0
  2. langfun/core/component.py +6 -0
  3. langfun/core/component_test.py +1 -0
  4. langfun/core/eval/__init__.py +2 -0
  5. langfun/core/eval/base.py +175 -17
  6. langfun/core/eval/base_test.py +34 -6
  7. langfun/core/eval/matching.py +18 -1
  8. langfun/core/eval/matching_test.py +2 -1
  9. langfun/core/eval/scoring.py +11 -1
  10. langfun/core/eval/scoring_test.py +2 -1
  11. langfun/core/language_model.py +14 -0
  12. langfun/core/language_model_test.py +32 -0
  13. langfun/core/llms/anthropic.py +36 -22
  14. langfun/core/llms/anthropic_test.py +7 -7
  15. langfun/core/llms/groq.py +27 -18
  16. langfun/core/llms/groq_test.py +5 -5
  17. langfun/core/llms/openai.py +55 -50
  18. langfun/core/llms/openai_test.py +3 -3
  19. langfun/core/structured/__init__.py +1 -0
  20. langfun/core/structured/completion_test.py +1 -2
  21. langfun/core/structured/mapping.py +38 -1
  22. langfun/core/structured/mapping_test.py +17 -0
  23. langfun/core/structured/parsing_test.py +2 -4
  24. langfun/core/structured/prompting_test.py +2 -4
  25. langfun/core/structured/schema_generation_test.py +2 -2
  26. langfun/core/template.py +26 -8
  27. langfun/core/template_test.py +9 -0
  28. {langfun-0.0.2.dev20240422.dist-info → langfun-0.0.2.dev20240425.dist-info}/METADATA +3 -2
  29. {langfun-0.0.2.dev20240422.dist-info → langfun-0.0.2.dev20240425.dist-info}/RECORD +32 -32
  30. {langfun-0.0.2.dev20240422.dist-info → langfun-0.0.2.dev20240425.dist-info}/LICENSE +0 -0
  31. {langfun-0.0.2.dev20240422.dist-info → langfun-0.0.2.dev20240425.dist-info}/WHEEL +0 -0
  32. {langfun-0.0.2.dev20240422.dist-info → langfun-0.0.2.dev20240425.dist-info}/top_level.txt +0 -0
@@ -26,12 +26,15 @@ import requests
26
26
 
27
27
  SUPPORTED_MODELS_AND_SETTINGS = {
28
28
  # See https://docs.anthropic.com/claude/docs/models-overview
29
- 'claude-3-opus-20240229': pg.Dict(max_tokens=4096, max_concurrency=16),
30
- 'claude-3-sonnet-20240229': pg.Dict(max_tokens=4096, max_concurrency=16),
31
- 'claude-3-haiku-20240307': pg.Dict(max_tokens=4096, max_concurrency=16),
32
- 'claude-2.1': pg.Dict(max_tokens=4096, max_concurrency=16),
33
- 'claude-2.0': pg.Dict(max_tokens=4096, max_concurrency=16),
34
- 'claude-instant-1.2': pg.Dict(max_tokens=4096, max_concurrency=16),
29
+ # Rate limits from https://docs.anthropic.com/claude/reference/rate-limits
30
+ # RPM/TPM for Claude-2.1, Claude-2.0, and Claude-Instant-1.2 estimated
31
+ # as RPM/TPM of the largest-available model (Claude-3-Opus).
32
+ 'claude-3-opus-20240229': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
33
+ 'claude-3-sonnet-20240229': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
34
+ 'claude-3-haiku-20240307': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
35
+ 'claude-2.1': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
36
+ 'claude-2.0': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
37
+ 'claude-instant-1.2': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
35
38
  }
36
39
 
37
40
 
@@ -81,6 +84,7 @@ class Anthropic(lf.LanguageModel):
81
84
  super()._on_bound()
82
85
  self._api_key = None
83
86
  self.__dict__.pop('_api_initialized', None)
87
+ self.__dict__.pop('_session', None)
84
88
 
85
89
  @functools.cached_property
86
90
  def _api_initialized(self):
@@ -93,6 +97,17 @@ class Anthropic(lf.LanguageModel):
93
97
  self._api_key = api_key
94
98
  return True
95
99
 
100
+ @functools.cached_property
101
+ def _session(self) -> requests.Session:
102
+ assert self._api_initialized
103
+ s = requests.Session()
104
+ s.headers.update({
105
+ 'x-api-key': self._api_key,
106
+ 'anthropic-version': _ANTHROPIC_API_VERSION,
107
+ 'content-type': 'application/json',
108
+ })
109
+ return s
110
+
96
111
  @property
97
112
  def model_id(self) -> str:
98
113
  """Returns a string to identify the model."""
@@ -100,7 +115,11 @@ class Anthropic(lf.LanguageModel):
100
115
 
101
116
  @property
102
117
  def max_concurrency(self) -> int:
103
- return SUPPORTED_MODELS_AND_SETTINGS[self.model].max_concurrency
118
+ rpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('rpm', 0)
119
+ tpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('tpm', 0)
120
+ return self.rate_to_max_concurrency(
121
+ requests_per_min=rpm, tokens_per_min=tpm
122
+ )
104
123
 
105
124
  def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
106
125
  assert self._api_initialized
@@ -165,8 +184,8 @@ class Anthropic(lf.LanguageModel):
165
184
  def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
166
185
  """Parses Anthropic's response."""
167
186
  # NOTE(daiyip): Refer https://docs.anthropic.com/claude/reference/errors
168
- output = response.json()
169
187
  if response.status_code == 200:
188
+ output = response.json()
170
189
  message = self._message_from_content(output['content'])
171
190
  input_tokens = output['usage']['input_tokens']
172
191
  output_tokens = output['usage']['output_tokens']
@@ -181,12 +200,11 @@ class Anthropic(lf.LanguageModel):
181
200
  else:
182
201
  if response.status_code == 429:
183
202
  error_cls = RateLimitError
184
- elif response.status_code == 529:
203
+ elif response.status_code in (502, 529):
185
204
  error_cls = OverloadedError
186
205
  else:
187
206
  error_cls = AnthropicError
188
- error = output['error']
189
- raise error_cls(f'{error["type"]}: {error["message"]}')
207
+ raise error_cls(f'{response.status_code}: {response.content}')
190
208
 
191
209
  def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
192
210
  request = dict()
@@ -198,17 +216,13 @@ class Anthropic(lf.LanguageModel):
198
216
  ]
199
217
  )
200
218
  )
201
- response = requests.post(
202
- _ANTHROPIC_MESSAGE_API_ENDPOINT,
203
- json=request,
204
- headers={
205
- 'x-api-key': self._api_key,
206
- 'anthropic-version': _ANTHROPIC_API_VERSION,
207
- 'content-type': 'application/json',
208
- },
209
- timeout=self.timeout,
210
- )
211
- return self._parse_response(response)
219
+ try:
220
+ response = self._session.post(
221
+ _ANTHROPIC_MESSAGE_API_ENDPOINT, json=request, timeout=self.timeout,
222
+ )
223
+ return self._parse_response(response)
224
+ except ConnectionError as e:
225
+ raise OverloadedError(str(e)) from e
212
226
 
213
227
 
214
228
  class Claude3(Anthropic):
@@ -98,20 +98,20 @@ def mock_requests_post_error(status_code, error_type, error_message):
98
98
  return _mock_requests
99
99
 
100
100
 
101
- class AuthropicTest(unittest.TestCase):
101
+ class AnthropicTest(unittest.TestCase):
102
102
 
103
103
  def test_basics(self):
104
104
  self.assertEqual(
105
105
  anthropic.Claude3Haiku().model_id, 'claude-3-haiku-20240307'
106
106
  )
107
- self.assertEqual(anthropic.Claude3Haiku().max_concurrency, 16)
107
+ self.assertGreater(anthropic.Claude3Haiku().max_concurrency, 0)
108
108
 
109
109
  def test_api_key(self):
110
110
  lm = anthropic.Claude3Haiku()
111
111
  with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
112
112
  lm('hi')
113
113
 
114
- with mock.patch('requests.post') as mock_request:
114
+ with mock.patch('requests.Session.post') as mock_request:
115
115
  mock_request.side_effect = mock_requests_post
116
116
 
117
117
  lm = anthropic.Claude3Haiku(api_key='fake key')
@@ -123,7 +123,7 @@ class AuthropicTest(unittest.TestCase):
123
123
  del os.environ['ANTHROPIC_API_KEY']
124
124
 
125
125
  def test_call(self):
126
- with mock.patch('requests.post') as mock_request:
126
+ with mock.patch('requests.Session.post') as mock_request:
127
127
  mock_request.side_effect = mock_requests_post
128
128
  lm = anthropic.Claude3Haiku(api_key='fake_key')
129
129
  response = lm('hello', temperature=0.0, top_k=0.1, top_p=0.2, stop=['\n'])
@@ -140,7 +140,7 @@ class AuthropicTest(unittest.TestCase):
140
140
  self.assertIsNotNone(response.usage.total_tokens, 3)
141
141
 
142
142
  def test_mm_call(self):
143
- with mock.patch('requests.post') as mock_mm_request:
143
+ with mock.patch('requests.Session.post') as mock_mm_request:
144
144
  mock_mm_request.side_effect = mock_mm_requests_post
145
145
  lm = anthropic.Claude3Haiku(api_key='fake_key')
146
146
  response = lm(lf_modalities.Image.from_bytes(image_content), lm=lm)
@@ -152,13 +152,13 @@ class AuthropicTest(unittest.TestCase):
152
152
  (529, 'service_unavailable', 'Service unavailable.'),
153
153
  (500, 'bad_request', 'Bad request.'),
154
154
  ]:
155
- with mock.patch('requests.post') as mock_mm_request:
155
+ with mock.patch('requests.Session.post') as mock_mm_request:
156
156
  mock_mm_request.side_effect = mock_requests_post_error(
157
157
  status_code, error_type, error_message
158
158
  )
159
159
  lm = anthropic.Claude3Haiku(api_key='fake_key')
160
160
  with self.assertRaisesRegex(
161
- Exception, f'{error_type}: {error_message}'
161
+ Exception, f'.*{status_code}: .*{error_message}'
162
162
  ):
163
163
  lm('hello', lm=lm, max_attempts=1)
164
164
 
langfun/core/llms/groq.py CHANGED
@@ -78,6 +78,7 @@ class Groq(lf.LanguageModel):
78
78
  super()._on_bound()
79
79
  self._api_key = None
80
80
  self.__dict__.pop('_api_initialized', None)
81
+ self.__dict__.pop('_session', None)
81
82
 
82
83
  @functools.cached_property
83
84
  def _api_initialized(self):
@@ -85,11 +86,21 @@ class Groq(lf.LanguageModel):
85
86
  if not api_key:
86
87
  raise ValueError(
87
88
  'Please specify `api_key` during `__init__` or set environment '
88
- 'variable `GROQ_API_KEY` with your Anthropic API key.'
89
+ 'variable `GROQ_API_KEY` with your Groq API key.'
89
90
  )
90
91
  self._api_key = api_key
91
92
  return True
92
93
 
94
+ @functools.cached_property
95
+ def _session(self) -> requests.Session:
96
+ assert self._api_initialized
97
+ s = requests.Session()
98
+ s.headers.update({
99
+ 'Authorization': f'Bearer {self._api_key}',
100
+ 'Content-Type': 'application/json',
101
+ })
102
+ return s
103
+
93
104
  @property
94
105
  def model_id(self) -> str:
95
106
  """Returns a string to identify the model."""
@@ -119,7 +130,7 @@ class Groq(lf.LanguageModel):
119
130
  return args
120
131
 
121
132
  def _content_from_message(self, prompt: lf.Message) -> list[dict[str, Any]]:
122
- """Converts an message to Anthropic's content protocol (list of dicts)."""
133
+ """Converts an message to Groq's content protocol (list of dicts)."""
123
134
  # Refer: https://platform.openai.com/docs/api-reference/chat/create
124
135
  content = []
125
136
  for chunk in prompt.chunk():
@@ -138,7 +149,7 @@ class Groq(lf.LanguageModel):
138
149
  return content
139
150
 
140
151
  def _message_from_choice(self, choice: dict[str, Any]) -> lf.Message:
141
- """Converts Anthropic's content protocol to message."""
152
+ """Converts Groq's content protocol to message."""
142
153
  # Refer: https://platform.openai.com/docs/api-reference/chat/create
143
154
  content = choice['message']['content']
144
155
  if isinstance(content, str):
@@ -148,10 +159,10 @@ class Groq(lf.LanguageModel):
148
159
  )
149
160
 
150
161
  def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
151
- """Parses Anthropic's response."""
162
+ """Parses Groq's response."""
152
163
  # Refer: https://platform.openai.com/docs/api-reference/chat/object
153
- output = response.json()
154
164
  if response.status_code == 200:
165
+ output = response.json()
155
166
  samples = [
156
167
  lf.LMSample(self._message_from_choice(choice), score=0.0)
157
168
  for choice in output['choices']
@@ -169,12 +180,11 @@ class Groq(lf.LanguageModel):
169
180
  # https://platform.openai.com/docs/guides/error-codes/api-errors
170
181
  if response.status_code == 429:
171
182
  error_cls = RateLimitError
172
- elif response.status_code in (500, 503):
183
+ elif response.status_code in (500, 502, 503):
173
184
  error_cls = OverloadedError
174
185
  else:
175
186
  error_cls = GroqError
176
- error = output['error']
177
- raise error_cls(f'{error["type"]}: {error["message"]}')
187
+ raise error_cls(f'{response.status_code}: {response.content}')
178
188
 
179
189
  def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
180
190
  assert self._api_initialized
@@ -194,16 +204,15 @@ class Groq(lf.LanguageModel):
194
204
  ]
195
205
  )
196
206
  )
197
- response = requests.post(
198
- _CHAT_COMPLETE_API_ENDPOINT,
199
- json=request,
200
- headers={
201
- 'Authorization': f'Bearer {self._api_key}',
202
- 'Content-Type': 'application/json',
203
- },
204
- timeout=self.timeout,
205
- )
206
- return self._parse_response(response)
207
+ try:
208
+ response = self._session.post(
209
+ _CHAT_COMPLETE_API_ENDPOINT,
210
+ json=request,
211
+ timeout=self.timeout,
212
+ )
213
+ return self._parse_response(response)
214
+ except ConnectionError as e:
215
+ raise OverloadedError(str(e)) from e
207
216
 
208
217
 
209
218
  class GroqLlama3_8B(Groq): # pylint: disable=invalid-name
@@ -107,7 +107,7 @@ class AuthropicTest(unittest.TestCase):
107
107
  with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
108
108
  lm('hi')
109
109
 
110
- with mock.patch('requests.post') as mock_request:
110
+ with mock.patch('requests.Session.post') as mock_request:
111
111
  mock_request.side_effect = mock_requests_post
112
112
 
113
113
  lm = groq.GroqMistral_8x7B(api_key='fake key')
@@ -119,7 +119,7 @@ class AuthropicTest(unittest.TestCase):
119
119
  del os.environ['GROQ_API_KEY']
120
120
 
121
121
  def test_call(self):
122
- with mock.patch('requests.post') as mock_request:
122
+ with mock.patch('requests.Session.post') as mock_request:
123
123
  mock_request.side_effect = mock_requests_post
124
124
  lm = groq.GroqLlama3_70B(api_key='fake_key')
125
125
  response = lm(
@@ -143,7 +143,7 @@ class AuthropicTest(unittest.TestCase):
143
143
  self.assertIsNotNone(response.usage.total_tokens, 3)
144
144
 
145
145
  def test_mm_call(self):
146
- with mock.patch('requests.post') as mock_mm_request:
146
+ with mock.patch('requests.Session.post') as mock_mm_request:
147
147
  mock_mm_request.side_effect = mock_mm_requests_post
148
148
  lm = groq.GroqLlama3_70B(multimodal=True, api_key='fake_key')
149
149
  response = lm(lf_modalities.Image.from_uri('https://fake/image.jpg'))
@@ -155,13 +155,13 @@ class AuthropicTest(unittest.TestCase):
155
155
  (503, 'service_unavailable', 'Service unavailable.'),
156
156
  (500, 'bad_request', 'Bad request.'),
157
157
  ]:
158
- with mock.patch('requests.post') as mock_mm_request:
158
+ with mock.patch('requests.Session.post') as mock_mm_request:
159
159
  mock_mm_request.side_effect = mock_requests_post_error(
160
160
  status_code, error_type, error_message
161
161
  )
162
162
  lm = groq.GroqLlama3_70B(api_key='fake_key')
163
163
  with self.assertRaisesRegex(
164
- Exception, f'{error_type}: {error_message}'
164
+ Exception, f'{status_code}:.*{error_type}'
165
165
  ):
166
166
  lm('hello', lm=lm, max_attempts=1)
167
167
 
@@ -26,54 +26,55 @@ from openai import openai_object
26
26
  import pyglove as pg
27
27
 
28
28
 
29
- SUPPORTED_MODELS_AND_SETTINGS = [
30
- # Model name, max concurrent requests.
31
- # The concurrent requests is estimated by TPM/RPM from
32
- # https://platform.openai.com/account/limits
33
- # GPT-4 Turbo models.
34
- ('gpt-4-turbo', 8), # GPT-4 Turbo with Vision
35
- ('gpt-4-turbo-2024-04-09', 8), # GPT-4-Turbo with Vision, 04/09/2024
36
- ('gpt-4-turbo-preview', 8), # GPT-4 Turbo Preview
37
- ('gpt-4-0125-preview', 8), # GPT-4 Turbo Preview, 01/25/2024
38
- ('gpt-4-1106-preview', 8), # GPT-4 Turbo Preview, 11/06/2023
39
- ('gpt-4-vision-preview', 8), # GPT-4 Turbo Vision Preview.
40
- ('gpt-4-1106-vision-preview', 8), # GPT-4 Turbo Vision Preview, 11/06/2023
41
- # GPT-4 models.
42
- ('gpt-4', 4),
43
- ('gpt-4-0613', 4),
44
- ('gpt-4-0314', 4),
45
- ('gpt-4-32k', 4),
46
- ('gpt-4-32k-0613', 4),
47
- ('gpt-4-32k-0314', 4),
48
- # GPT-3.5 Turbo models.
49
- ('gpt-3.5-turbo', 16),
50
- ('gpt-3.5-turbo-0125', 16),
51
- ('gpt-3.5-turbo-1106', 16),
52
- ('gpt-3.5-turbo-0613', 16),
53
- ('gpt-3.5-turbo-0301', 16),
54
- ('gpt-3.5-turbo-16k', 16),
55
- ('gpt-3.5-turbo-16k-0613', 16),
56
- ('gpt-3.5-turbo-16k-0301', 16),
57
- # GPT-3.5 models.
58
- ('text-davinci-003', 8), # GPT-3.5, trained with RHLF.
59
- ('text-davinci-002', 4), # Trained with SFT but no RHLF.
60
- ('code-davinci-002', 4),
61
- # GPT-3 instruction-tuned models.
62
- ('text-curie-001', 4),
63
- ('text-babbage-001', 4),
64
- ('text-ada-001', 4),
65
- ('davinci', 4),
66
- ('curie', 4),
67
- ('babbage', 4),
68
- ('ada', 4),
69
- # GPT-3 base models without instruction tuning.
70
- ('babbage-002', 4),
71
- ('davinci-002', 4),
72
- ]
73
-
74
-
75
- # Model concurreny setting.
76
- _MODEL_CONCURRENCY = {m[0]: m[1] for m in SUPPORTED_MODELS_AND_SETTINGS}
29
+ # From https://platform.openai.com/settings/organization/limits
30
+ _DEFAULT_TPM = 250000
31
+ _DEFAULT_RPM = 3000
32
+
33
+ SUPPORTED_MODELS_AND_SETTINGS = {
34
+ # Models from https://platform.openai.com/docs/models
35
+ # RPM is from https://platform.openai.com/docs/guides/rate-limits
36
+ # GPT-4-Turbo models
37
+ 'gpt-4-turbo': pg.Dict(rpm=10000, tpm=1500000),
38
+ 'gpt-4-turbo-2024-04-09': pg.Dict(rpm=10000, tpm=1500000),
39
+ 'gpt-4-turbo-preview': pg.Dict(rpm=10000, tpm=1500000),
40
+ 'gpt-4-0125-preview': pg.Dict(rpm=10000, tpm=1500000),
41
+ 'gpt-4-1106-preview': pg.Dict(rpm=10000, tpm=1500000),
42
+ 'gpt-4-vision-preview': pg.Dict(rpm=10000, tpm=1500000),
43
+ 'gpt-4-1106-vision-preview': pg.Dict(
44
+ rpm=10000, tpm=1500000
45
+ ),
46
+ # GPT-4 models
47
+ 'gpt-4': pg.Dict(rpm=10000, tpm=300000),
48
+ 'gpt-4-0613': pg.Dict(rpm=10000, tpm=300000),
49
+ 'gpt-4-0314': pg.Dict(rpm=10000, tpm=300000),
50
+ 'gpt-4-32k': pg.Dict(rpm=10000, tpm=300000),
51
+ 'gpt-4-32k-0613': pg.Dict(rpm=10000, tpm=300000),
52
+ 'gpt-4-32k-0314': pg.Dict(rpm=10000, tpm=300000),
53
+ # GPT-3.5-Turbo models
54
+ 'gpt-3.5-turbo': pg.Dict(rpm=10000, tpm=2000000),
55
+ 'gpt-3.5-turbo-0125': pg.Dict(rpm=10000, tpm=2000000),
56
+ 'gpt-3.5-turbo-1106': pg.Dict(rpm=10000, tpm=2000000),
57
+ 'gpt-3.5-turbo-0613': pg.Dict(rpm=10000, tpm=2000000),
58
+ 'gpt-3.5-turbo-0301': pg.Dict(rpm=10000, tpm=2000000),
59
+ 'gpt-3.5-turbo-16k': pg.Dict(rpm=10000, tpm=2000000),
60
+ 'gpt-3.5-turbo-16k-0613': pg.Dict(rpm=10000, tpm=2000000),
61
+ 'gpt-3.5-turbo-16k-0301': pg.Dict(rpm=10000, tpm=2000000),
62
+ # GPT-3.5 models
63
+ 'text-davinci-003': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
64
+ 'text-davinci-002': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
65
+ 'code-davinci-002': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
66
+ # GPT-3 instruction-tuned models
67
+ 'text-curie-001': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
68
+ 'text-babbage-001': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
69
+ 'text-ada-001': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
70
+ 'davinci': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
71
+ 'curie': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
72
+ 'babbage': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
73
+ 'ada': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
74
+ # GPT-3 base models
75
+ 'babbage-002': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
76
+ 'davinci-002': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
77
+ }
77
78
 
78
79
 
79
80
  @lf.use_init_args(['model'])
@@ -82,7 +83,7 @@ class OpenAI(lf.LanguageModel):
82
83
 
83
84
  model: pg.typing.Annotated[
84
85
  pg.typing.Enum(
85
- pg.MISSING_VALUE, [m[0] for m in SUPPORTED_MODELS_AND_SETTINGS]
86
+ pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
86
87
  ),
87
88
  'The name of the model to use.',
88
89
  ] = 'gpt-3.5-turbo'
@@ -134,7 +135,11 @@ class OpenAI(lf.LanguageModel):
134
135
 
135
136
  @property
136
137
  def max_concurrency(self) -> int:
137
- return _MODEL_CONCURRENCY[self.model]
138
+ rpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('rpm', 0)
139
+ tpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('tpm', 0)
140
+ return self.rate_to_max_concurrency(
141
+ requests_per_min=rpm, tokens_per_min=tpm
142
+ )
138
143
 
139
144
  @classmethod
140
145
  def dir(cls):
@@ -11,7 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Tests for openai models."""
14
+ """Tests for OpenAI models."""
15
15
 
16
16
  import unittest
17
17
  from unittest import mock
@@ -85,7 +85,7 @@ def mock_chat_completion_query_vision(messages, *, n=1, **kwargs):
85
85
  )
86
86
 
87
87
 
88
- class OpenaiTest(unittest.TestCase):
88
+ class OpenAITest(unittest.TestCase):
89
89
  """Tests for OpenAI language model."""
90
90
 
91
91
  def test_model_id(self):
@@ -98,7 +98,7 @@ class OpenaiTest(unittest.TestCase):
98
98
  )
99
99
 
100
100
  def test_max_concurrency(self):
101
- self.assertEqual(openai.Gpt35(api_key='test_key').max_concurrency, 8)
101
+ self.assertGreater(openai.Gpt35(api_key='test_key').max_concurrency, 0)
102
102
 
103
103
  def test_get_request_args(self):
104
104
  self.assertEqual(
@@ -51,6 +51,7 @@ from langfun.core.structured.schema_generation import default_classgen_examples
51
51
  from langfun.core.structured.function_generation import function_gen
52
52
 
53
53
  from langfun.core.structured.mapping import Mapping
54
+ from langfun.core.structured.mapping import MappingError
54
55
  from langfun.core.structured.mapping import MappingExample
55
56
 
56
57
  from langfun.core.structured.parsing import ParseStructure
@@ -17,7 +17,6 @@ import inspect
17
17
  import unittest
18
18
 
19
19
  import langfun.core as lf
20
- from langfun.core import coding
21
20
  from langfun.core import modalities
22
21
  from langfun.core.llms import fake
23
22
  from langfun.core.structured import completion
@@ -608,7 +607,7 @@ class CompleteStructureTest(unittest.TestCase):
608
607
  override_attrs=True,
609
608
  ):
610
609
  with self.assertRaisesRegex(
611
- coding.CodeError,
610
+ mapping.MappingError,
612
611
  'Expect .* but encountered .*',
613
612
  ):
614
613
  completion.complete(Activity.partial(), autofix=0)
@@ -20,6 +20,43 @@ from langfun.core.structured import schema as schema_lib
20
20
  import pyglove as pg
21
21
 
22
22
 
23
+ class MappingError(Exception): # pylint: disable=g-bad-exception-name
24
+ """Mapping error."""
25
+
26
+ def __init__(self, lm_response: lf.Message, cause: Exception):
27
+ self._lm_response = lm_response
28
+ self._cause = cause
29
+
30
+ @property
31
+ def lm_response(self) -> lf.Message:
32
+ """Returns the LM response that failed to be mapped."""
33
+ return self._lm_response
34
+
35
+ @property
36
+ def cause(self) -> Exception:
37
+ """Returns the cause of the error."""
38
+ return self._cause
39
+
40
+ def __str__(self) -> str:
41
+ return self.format(include_lm_response=True)
42
+
43
+ def format(self, include_lm_response: bool = True) -> str:
44
+ """Formats the mapping error."""
45
+ r = io.StringIO()
46
+ error_message = str(self.cause).rstrip()
47
+ r.write(
48
+ lf.colored(
49
+ f'{self.cause.__class__.__name__}: {error_message}', 'magenta'
50
+ )
51
+ )
52
+ if include_lm_response:
53
+ r.write('\n\n')
54
+ r.write(lf.colored('[LM Response]', 'blue', styles=['bold']))
55
+ r.write('\n')
56
+ r.write(lf.colored(self.lm_response.text, 'blue'))
57
+ return r.getvalue()
58
+
59
+
23
60
  @pg.use_init_args(['input', 'output', 'schema', 'context'])
24
61
  class MappingExample(lf.NaturalLanguageFormattable, lf.Component):
25
62
  """Mapping example between text, schema and structured value."""
@@ -308,7 +345,7 @@ class Mapping(lf.LangFunc):
308
345
  lm_output.result = self.postprocess_result(self.parse_result(lm_output))
309
346
  except Exception as e: # pylint: disable=broad-exception-caught
310
347
  if self.default == lf.RAISE_IF_HAS_ERROR:
311
- raise e
348
+ raise MappingError(lm_output, e) from e
312
349
  lm_output.result = self.default
313
350
  return lm_output
314
351
 
@@ -16,10 +16,27 @@
16
16
  import inspect
17
17
  import unittest
18
18
 
19
+ import langfun.core as lf
19
20
  from langfun.core.structured import mapping
20
21
  import pyglove as pg
21
22
 
22
23
 
24
+ class MappingErrorTest(unittest.TestCase):
25
+
26
+ def test_format(self):
27
+ error = mapping.MappingError(
28
+ lf.AIMessage('hi'), ValueError('Cannot parse message.')
29
+ )
30
+ self.assertEqual(
31
+ lf.text_formatting.decolored(str(error)),
32
+ 'ValueError: Cannot parse message.\n\n[LM Response]\nhi',
33
+ )
34
+ self.assertEqual(
35
+ lf.text_formatting.decolored(error.format(include_lm_response=False)),
36
+ 'ValueError: Cannot parse message.',
37
+ )
38
+
39
+
23
40
  class MappingExampleTest(unittest.TestCase):
24
41
 
25
42
  def test_basics(self):
@@ -17,11 +17,9 @@ import inspect
17
17
  import unittest
18
18
 
19
19
  import langfun.core as lf
20
- from langfun.core import coding
21
20
  from langfun.core.llms import fake
22
21
  from langfun.core.structured import mapping
23
22
  from langfun.core.structured import parsing
24
- from langfun.core.structured import schema as schema_lib
25
23
  import pyglove as pg
26
24
 
27
25
 
@@ -255,7 +253,7 @@ class ParseStructurePythonTest(unittest.TestCase):
255
253
  override_attrs=True,
256
254
  ):
257
255
  with self.assertRaisesRegex(
258
- coding.CodeError,
256
+ mapping.MappingError,
259
257
  'name .* is not defined',
260
258
  ):
261
259
  parsing.parse('three', int)
@@ -546,7 +544,7 @@ class ParseStructureJsonTest(unittest.TestCase):
546
544
  override_attrs=True,
547
545
  ):
548
546
  with self.assertRaisesRegex(
549
- schema_lib.JsonError,
547
+ mapping.MappingError,
550
548
  'No JSON dict in the output',
551
549
  ):
552
550
  parsing.parse('three', int, protocol='json')
@@ -17,12 +17,10 @@ import inspect
17
17
  import unittest
18
18
 
19
19
  import langfun.core as lf
20
- from langfun.core import coding
21
20
  from langfun.core import modalities
22
21
  from langfun.core.llms import fake
23
22
  from langfun.core.structured import mapping
24
23
  from langfun.core.structured import prompting
25
- from langfun.core.structured import schema as schema_lib
26
24
  import pyglove as pg
27
25
 
28
26
 
@@ -439,7 +437,7 @@ class QueryStructurePythonTest(unittest.TestCase):
439
437
  override_attrs=True,
440
438
  ):
441
439
  with self.assertRaisesRegex(
442
- coding.CodeError,
440
+ mapping.MappingError,
443
441
  'name .* is not defined',
444
442
  ):
445
443
  prompting.query('Compute 1 + 2', int)
@@ -677,7 +675,7 @@ class QueryStructureJsonTest(unittest.TestCase):
677
675
  override_attrs=True,
678
676
  ):
679
677
  with self.assertRaisesRegex(
680
- schema_lib.JsonError,
678
+ mapping.MappingError,
681
679
  'No JSON dict in the output',
682
680
  ):
683
681
  prompting.query('Compute 1 + 2', int, protocol='json')
@@ -14,8 +14,8 @@
14
14
  import inspect
15
15
  import unittest
16
16
 
17
- import langfun.core.coding as lf_coding
18
17
  from langfun.core.llms import fake
18
+ from langfun.core.structured import mapping
19
19
  from langfun.core.structured import schema_generation
20
20
 
21
21
 
@@ -92,7 +92,7 @@ class GenerateClassTest(unittest.TestCase):
92
92
  )
93
93
  self.assertIs(cls.__name__, 'B')
94
94
 
95
- with self.assertRaises(lf_coding.CodeError):
95
+ with self.assertRaises(mapping.MappingError):
96
96
  schema_generation.generate_class(
97
97
  'Foo',
98
98
  'Generate a Foo class with a field pointing to another class A',