langfun 0.0.2.dev20240422__py3-none-any.whl → 0.0.2.dev20240423__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/component.py +6 -0
- langfun/core/component_test.py +1 -0
- langfun/core/language_model.py +14 -0
- langfun/core/language_model_test.py +32 -0
- langfun/core/llms/anthropic.py +36 -22
- langfun/core/llms/anthropic_test.py +7 -7
- langfun/core/llms/groq.py +27 -18
- langfun/core/llms/groq_test.py +5 -5
- langfun/core/llms/openai.py +55 -50
- langfun/core/llms/openai_test.py +3 -3
- langfun/core/template.py +26 -8
- langfun/core/template_test.py +9 -0
- {langfun-0.0.2.dev20240422.dist-info → langfun-0.0.2.dev20240423.dist-info}/METADATA +1 -1
- {langfun-0.0.2.dev20240422.dist-info → langfun-0.0.2.dev20240423.dist-info}/RECORD +17 -17
- {langfun-0.0.2.dev20240422.dist-info → langfun-0.0.2.dev20240423.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240422.dist-info → langfun-0.0.2.dev20240423.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240422.dist-info → langfun-0.0.2.dev20240423.dist-info}/top_level.txt +0 -0
langfun/core/component.py
CHANGED
@@ -210,6 +210,12 @@ def get_contextual_override(var_name: str) -> ContextualOverride | None:
|
|
210
210
|
return _get_scoped_value(_global_tls, _CONTEXT_OVERRIDES, var_name)
|
211
211
|
|
212
212
|
|
213
|
+
def all_contextual_values() -> dict[str, Any]:
|
214
|
+
"""Returns all contextual values provided from `lf.context` in scope."""
|
215
|
+
overrides = getattr(_global_tls, _CONTEXT_OVERRIDES, {})
|
216
|
+
return {k: v.value for k, v in overrides.items()}
|
217
|
+
|
218
|
+
|
213
219
|
@contextlib.contextmanager
|
214
220
|
def _contextual_scope(
|
215
221
|
tls: threading.local, tls_key, **variables
|
langfun/core/component_test.py
CHANGED
@@ -84,6 +84,7 @@ class ComponentContextTest(unittest.TestCase):
|
|
84
84
|
lf.get_contextual_override('y'),
|
85
85
|
lf.ContextualOverride(3, cascade=False, override_attrs=False),
|
86
86
|
)
|
87
|
+
self.assertEqual(lf.all_contextual_values(), dict(x=3, y=3, z=3))
|
87
88
|
|
88
89
|
# Member attributes take precedence over `lf.context`.
|
89
90
|
self.assertEqual(a1.x, 1)
|
langfun/core/language_model.py
CHANGED
@@ -24,6 +24,9 @@ from langfun.core import console
|
|
24
24
|
from langfun.core import message as message_lib
|
25
25
|
import pyglove as pg
|
26
26
|
|
27
|
+
TOKENS_PER_REQUEST = 250 # Estimated num tokens for a single request
|
28
|
+
DEFAULT_MAX_CONCURRENCY = 1 # Use this as max concurrency if no RPM or TPM data
|
29
|
+
|
27
30
|
|
28
31
|
class LMSample(pg.Object):
|
29
32
|
"""Response candidate."""
|
@@ -604,3 +607,14 @@ class LanguageModel(component.Component):
|
|
604
607
|
f'score: {r.score}',
|
605
608
|
color='blue',
|
606
609
|
)
|
610
|
+
|
611
|
+
def rate_to_max_concurrency(
|
612
|
+
self, requests_per_min: float = 0, tokens_per_min: float = 0
|
613
|
+
) -> int:
|
614
|
+
"""Converts a rate to a max concurrency."""
|
615
|
+
if tokens_per_min > 0:
|
616
|
+
return max(int(tokens_per_min / TOKENS_PER_REQUEST / 60), 1)
|
617
|
+
elif requests_per_min > 0:
|
618
|
+
return max(int(requests_per_min / 60), 1) # Max concurrency can't be zero
|
619
|
+
else:
|
620
|
+
return DEFAULT_MAX_CONCURRENCY # Default of 1
|
@@ -394,6 +394,38 @@ class LanguageModelTest(unittest.TestCase):
|
|
394
394
|
with self.assertRaises(NotImplementedError):
|
395
395
|
MockModel().score('hi', ['1', '2'])
|
396
396
|
|
397
|
+
def test_rate_to_max_concurrency_no_rpm_no_tpm(self) -> None:
|
398
|
+
lm = MockModel()
|
399
|
+
self.assertEqual(
|
400
|
+
lm_lib.DEFAULT_MAX_CONCURRENCY,
|
401
|
+
lm.rate_to_max_concurrency(requests_per_min=0, tokens_per_min=0),
|
402
|
+
)
|
403
|
+
self.assertEqual(
|
404
|
+
lm_lib.DEFAULT_MAX_CONCURRENCY,
|
405
|
+
lm.rate_to_max_concurrency(requests_per_min=-1, tokens_per_min=-1),
|
406
|
+
)
|
407
|
+
|
408
|
+
def test_rate_to_max_concurrency_only_rpm_specified_uses_rpm(self) -> None:
|
409
|
+
lm = MockModel()
|
410
|
+
test_rpm = 1e4
|
411
|
+
self.assertEqual(
|
412
|
+
lm.rate_to_max_concurrency(requests_per_min=test_rpm),
|
413
|
+
int(test_rpm / 60)
|
414
|
+
)
|
415
|
+
|
416
|
+
def test_rate_to_max_concurrency_tpm_specified_uses_tpm(self) -> None:
|
417
|
+
lm = MockModel()
|
418
|
+
test_tpm = 1e7
|
419
|
+
self.assertEqual(
|
420
|
+
lm.rate_to_max_concurrency(requests_per_min=1, tokens_per_min=test_tpm),
|
421
|
+
int(test_tpm / lm_lib.TOKENS_PER_REQUEST / 60),
|
422
|
+
)
|
423
|
+
|
424
|
+
def test_rate_to_max_concurrency_small_rate_returns_one(self) -> None:
|
425
|
+
lm = MockModel()
|
426
|
+
self.assertEqual(lm.rate_to_max_concurrency(requests_per_min=1), 1)
|
427
|
+
self.assertEqual(lm.rate_to_max_concurrency(tokens_per_min=1), 1)
|
428
|
+
|
397
429
|
|
398
430
|
if __name__ == '__main__':
|
399
431
|
unittest.main()
|
langfun/core/llms/anthropic.py
CHANGED
@@ -26,12 +26,15 @@ import requests
|
|
26
26
|
|
27
27
|
SUPPORTED_MODELS_AND_SETTINGS = {
|
28
28
|
# See https://docs.anthropic.com/claude/docs/models-overview
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
'claude-
|
33
|
-
'claude-
|
34
|
-
'claude-
|
29
|
+
# Rate limits from https://docs.anthropic.com/claude/reference/rate-limits
|
30
|
+
# RPM/TPM for Claude-2.1, Claude-2.0, and Claude-Instant-1.2 estimated
|
31
|
+
# as RPM/TPM of the largest-available model (Claude-3-Opus).
|
32
|
+
'claude-3-opus-20240229': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
33
|
+
'claude-3-sonnet-20240229': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
34
|
+
'claude-3-haiku-20240307': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
35
|
+
'claude-2.1': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
36
|
+
'claude-2.0': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
37
|
+
'claude-instant-1.2': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
35
38
|
}
|
36
39
|
|
37
40
|
|
@@ -81,6 +84,7 @@ class Anthropic(lf.LanguageModel):
|
|
81
84
|
super()._on_bound()
|
82
85
|
self._api_key = None
|
83
86
|
self.__dict__.pop('_api_initialized', None)
|
87
|
+
self.__dict__.pop('_session', None)
|
84
88
|
|
85
89
|
@functools.cached_property
|
86
90
|
def _api_initialized(self):
|
@@ -93,6 +97,17 @@ class Anthropic(lf.LanguageModel):
|
|
93
97
|
self._api_key = api_key
|
94
98
|
return True
|
95
99
|
|
100
|
+
@functools.cached_property
|
101
|
+
def _session(self) -> requests.Session:
|
102
|
+
assert self._api_initialized
|
103
|
+
s = requests.Session()
|
104
|
+
s.headers.update({
|
105
|
+
'x-api-key': self._api_key,
|
106
|
+
'anthropic-version': _ANTHROPIC_API_VERSION,
|
107
|
+
'content-type': 'application/json',
|
108
|
+
})
|
109
|
+
return s
|
110
|
+
|
96
111
|
@property
|
97
112
|
def model_id(self) -> str:
|
98
113
|
"""Returns a string to identify the model."""
|
@@ -100,7 +115,11 @@ class Anthropic(lf.LanguageModel):
|
|
100
115
|
|
101
116
|
@property
|
102
117
|
def max_concurrency(self) -> int:
|
103
|
-
|
118
|
+
rpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('rpm', 0)
|
119
|
+
tpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('tpm', 0)
|
120
|
+
return self.rate_to_max_concurrency(
|
121
|
+
requests_per_min=rpm, tokens_per_min=tpm
|
122
|
+
)
|
104
123
|
|
105
124
|
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
106
125
|
assert self._api_initialized
|
@@ -165,8 +184,8 @@ class Anthropic(lf.LanguageModel):
|
|
165
184
|
def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
|
166
185
|
"""Parses Anthropic's response."""
|
167
186
|
# NOTE(daiyip): Refer https://docs.anthropic.com/claude/reference/errors
|
168
|
-
output = response.json()
|
169
187
|
if response.status_code == 200:
|
188
|
+
output = response.json()
|
170
189
|
message = self._message_from_content(output['content'])
|
171
190
|
input_tokens = output['usage']['input_tokens']
|
172
191
|
output_tokens = output['usage']['output_tokens']
|
@@ -181,12 +200,11 @@ class Anthropic(lf.LanguageModel):
|
|
181
200
|
else:
|
182
201
|
if response.status_code == 429:
|
183
202
|
error_cls = RateLimitError
|
184
|
-
elif response.status_code
|
203
|
+
elif response.status_code in (502, 529):
|
185
204
|
error_cls = OverloadedError
|
186
205
|
else:
|
187
206
|
error_cls = AnthropicError
|
188
|
-
|
189
|
-
raise error_cls(f'{error["type"]}: {error["message"]}')
|
207
|
+
raise error_cls(f'{response.status_code}: {response.content}')
|
190
208
|
|
191
209
|
def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
192
210
|
request = dict()
|
@@ -198,17 +216,13 @@ class Anthropic(lf.LanguageModel):
|
|
198
216
|
]
|
199
217
|
)
|
200
218
|
)
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
},
|
209
|
-
timeout=self.timeout,
|
210
|
-
)
|
211
|
-
return self._parse_response(response)
|
219
|
+
try:
|
220
|
+
response = self._session.post(
|
221
|
+
_ANTHROPIC_MESSAGE_API_ENDPOINT, json=request, timeout=self.timeout,
|
222
|
+
)
|
223
|
+
return self._parse_response(response)
|
224
|
+
except ConnectionError as e:
|
225
|
+
raise OverloadedError(str(e)) from e
|
212
226
|
|
213
227
|
|
214
228
|
class Claude3(Anthropic):
|
@@ -98,20 +98,20 @@ def mock_requests_post_error(status_code, error_type, error_message):
|
|
98
98
|
return _mock_requests
|
99
99
|
|
100
100
|
|
101
|
-
class
|
101
|
+
class AnthropicTest(unittest.TestCase):
|
102
102
|
|
103
103
|
def test_basics(self):
|
104
104
|
self.assertEqual(
|
105
105
|
anthropic.Claude3Haiku().model_id, 'claude-3-haiku-20240307'
|
106
106
|
)
|
107
|
-
self.
|
107
|
+
self.assertGreater(anthropic.Claude3Haiku().max_concurrency, 0)
|
108
108
|
|
109
109
|
def test_api_key(self):
|
110
110
|
lm = anthropic.Claude3Haiku()
|
111
111
|
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
112
112
|
lm('hi')
|
113
113
|
|
114
|
-
with mock.patch('requests.post') as mock_request:
|
114
|
+
with mock.patch('requests.Session.post') as mock_request:
|
115
115
|
mock_request.side_effect = mock_requests_post
|
116
116
|
|
117
117
|
lm = anthropic.Claude3Haiku(api_key='fake key')
|
@@ -123,7 +123,7 @@ class AuthropicTest(unittest.TestCase):
|
|
123
123
|
del os.environ['ANTHROPIC_API_KEY']
|
124
124
|
|
125
125
|
def test_call(self):
|
126
|
-
with mock.patch('requests.post') as mock_request:
|
126
|
+
with mock.patch('requests.Session.post') as mock_request:
|
127
127
|
mock_request.side_effect = mock_requests_post
|
128
128
|
lm = anthropic.Claude3Haiku(api_key='fake_key')
|
129
129
|
response = lm('hello', temperature=0.0, top_k=0.1, top_p=0.2, stop=['\n'])
|
@@ -140,7 +140,7 @@ class AuthropicTest(unittest.TestCase):
|
|
140
140
|
self.assertIsNotNone(response.usage.total_tokens, 3)
|
141
141
|
|
142
142
|
def test_mm_call(self):
|
143
|
-
with mock.patch('requests.post') as mock_mm_request:
|
143
|
+
with mock.patch('requests.Session.post') as mock_mm_request:
|
144
144
|
mock_mm_request.side_effect = mock_mm_requests_post
|
145
145
|
lm = anthropic.Claude3Haiku(api_key='fake_key')
|
146
146
|
response = lm(lf_modalities.Image.from_bytes(image_content), lm=lm)
|
@@ -152,13 +152,13 @@ class AuthropicTest(unittest.TestCase):
|
|
152
152
|
(529, 'service_unavailable', 'Service unavailable.'),
|
153
153
|
(500, 'bad_request', 'Bad request.'),
|
154
154
|
]:
|
155
|
-
with mock.patch('requests.post') as mock_mm_request:
|
155
|
+
with mock.patch('requests.Session.post') as mock_mm_request:
|
156
156
|
mock_mm_request.side_effect = mock_requests_post_error(
|
157
157
|
status_code, error_type, error_message
|
158
158
|
)
|
159
159
|
lm = anthropic.Claude3Haiku(api_key='fake_key')
|
160
160
|
with self.assertRaisesRegex(
|
161
|
-
Exception, f'{
|
161
|
+
Exception, f'.*{status_code}: .*{error_message}'
|
162
162
|
):
|
163
163
|
lm('hello', lm=lm, max_attempts=1)
|
164
164
|
|
langfun/core/llms/groq.py
CHANGED
@@ -78,6 +78,7 @@ class Groq(lf.LanguageModel):
|
|
78
78
|
super()._on_bound()
|
79
79
|
self._api_key = None
|
80
80
|
self.__dict__.pop('_api_initialized', None)
|
81
|
+
self.__dict__.pop('_session', None)
|
81
82
|
|
82
83
|
@functools.cached_property
|
83
84
|
def _api_initialized(self):
|
@@ -85,11 +86,21 @@ class Groq(lf.LanguageModel):
|
|
85
86
|
if not api_key:
|
86
87
|
raise ValueError(
|
87
88
|
'Please specify `api_key` during `__init__` or set environment '
|
88
|
-
'variable `GROQ_API_KEY` with your
|
89
|
+
'variable `GROQ_API_KEY` with your Groq API key.'
|
89
90
|
)
|
90
91
|
self._api_key = api_key
|
91
92
|
return True
|
92
93
|
|
94
|
+
@functools.cached_property
|
95
|
+
def _session(self) -> requests.Session:
|
96
|
+
assert self._api_initialized
|
97
|
+
s = requests.Session()
|
98
|
+
s.headers.update({
|
99
|
+
'Authorization': f'Bearer {self._api_key}',
|
100
|
+
'Content-Type': 'application/json',
|
101
|
+
})
|
102
|
+
return s
|
103
|
+
|
93
104
|
@property
|
94
105
|
def model_id(self) -> str:
|
95
106
|
"""Returns a string to identify the model."""
|
@@ -119,7 +130,7 @@ class Groq(lf.LanguageModel):
|
|
119
130
|
return args
|
120
131
|
|
121
132
|
def _content_from_message(self, prompt: lf.Message) -> list[dict[str, Any]]:
|
122
|
-
"""Converts an message to
|
133
|
+
"""Converts an message to Groq's content protocol (list of dicts)."""
|
123
134
|
# Refer: https://platform.openai.com/docs/api-reference/chat/create
|
124
135
|
content = []
|
125
136
|
for chunk in prompt.chunk():
|
@@ -138,7 +149,7 @@ class Groq(lf.LanguageModel):
|
|
138
149
|
return content
|
139
150
|
|
140
151
|
def _message_from_choice(self, choice: dict[str, Any]) -> lf.Message:
|
141
|
-
"""Converts
|
152
|
+
"""Converts Groq's content protocol to message."""
|
142
153
|
# Refer: https://platform.openai.com/docs/api-reference/chat/create
|
143
154
|
content = choice['message']['content']
|
144
155
|
if isinstance(content, str):
|
@@ -148,10 +159,10 @@ class Groq(lf.LanguageModel):
|
|
148
159
|
)
|
149
160
|
|
150
161
|
def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
|
151
|
-
"""Parses
|
162
|
+
"""Parses Groq's response."""
|
152
163
|
# Refer: https://platform.openai.com/docs/api-reference/chat/object
|
153
|
-
output = response.json()
|
154
164
|
if response.status_code == 200:
|
165
|
+
output = response.json()
|
155
166
|
samples = [
|
156
167
|
lf.LMSample(self._message_from_choice(choice), score=0.0)
|
157
168
|
for choice in output['choices']
|
@@ -169,12 +180,11 @@ class Groq(lf.LanguageModel):
|
|
169
180
|
# https://platform.openai.com/docs/guides/error-codes/api-errors
|
170
181
|
if response.status_code == 429:
|
171
182
|
error_cls = RateLimitError
|
172
|
-
elif response.status_code in (500, 503):
|
183
|
+
elif response.status_code in (500, 502, 503):
|
173
184
|
error_cls = OverloadedError
|
174
185
|
else:
|
175
186
|
error_cls = GroqError
|
176
|
-
|
177
|
-
raise error_cls(f'{error["type"]}: {error["message"]}')
|
187
|
+
raise error_cls(f'{response.status_code}: {response.content}')
|
178
188
|
|
179
189
|
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
180
190
|
assert self._api_initialized
|
@@ -194,16 +204,15 @@ class Groq(lf.LanguageModel):
|
|
194
204
|
]
|
195
205
|
)
|
196
206
|
)
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
return self._parse_response(response)
|
207
|
+
try:
|
208
|
+
response = self._session.post(
|
209
|
+
_CHAT_COMPLETE_API_ENDPOINT,
|
210
|
+
json=request,
|
211
|
+
timeout=self.timeout,
|
212
|
+
)
|
213
|
+
return self._parse_response(response)
|
214
|
+
except ConnectionError as e:
|
215
|
+
raise OverloadedError(str(e)) from e
|
207
216
|
|
208
217
|
|
209
218
|
class GroqLlama3_8B(Groq): # pylint: disable=invalid-name
|
langfun/core/llms/groq_test.py
CHANGED
@@ -107,7 +107,7 @@ class AuthropicTest(unittest.TestCase):
|
|
107
107
|
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
108
108
|
lm('hi')
|
109
109
|
|
110
|
-
with mock.patch('requests.post') as mock_request:
|
110
|
+
with mock.patch('requests.Session.post') as mock_request:
|
111
111
|
mock_request.side_effect = mock_requests_post
|
112
112
|
|
113
113
|
lm = groq.GroqMistral_8x7B(api_key='fake key')
|
@@ -119,7 +119,7 @@ class AuthropicTest(unittest.TestCase):
|
|
119
119
|
del os.environ['GROQ_API_KEY']
|
120
120
|
|
121
121
|
def test_call(self):
|
122
|
-
with mock.patch('requests.post') as mock_request:
|
122
|
+
with mock.patch('requests.Session.post') as mock_request:
|
123
123
|
mock_request.side_effect = mock_requests_post
|
124
124
|
lm = groq.GroqLlama3_70B(api_key='fake_key')
|
125
125
|
response = lm(
|
@@ -143,7 +143,7 @@ class AuthropicTest(unittest.TestCase):
|
|
143
143
|
self.assertIsNotNone(response.usage.total_tokens, 3)
|
144
144
|
|
145
145
|
def test_mm_call(self):
|
146
|
-
with mock.patch('requests.post') as mock_mm_request:
|
146
|
+
with mock.patch('requests.Session.post') as mock_mm_request:
|
147
147
|
mock_mm_request.side_effect = mock_mm_requests_post
|
148
148
|
lm = groq.GroqLlama3_70B(multimodal=True, api_key='fake_key')
|
149
149
|
response = lm(lf_modalities.Image.from_uri('https://fake/image.jpg'))
|
@@ -155,13 +155,13 @@ class AuthropicTest(unittest.TestCase):
|
|
155
155
|
(503, 'service_unavailable', 'Service unavailable.'),
|
156
156
|
(500, 'bad_request', 'Bad request.'),
|
157
157
|
]:
|
158
|
-
with mock.patch('requests.post') as mock_mm_request:
|
158
|
+
with mock.patch('requests.Session.post') as mock_mm_request:
|
159
159
|
mock_mm_request.side_effect = mock_requests_post_error(
|
160
160
|
status_code, error_type, error_message
|
161
161
|
)
|
162
162
|
lm = groq.GroqLlama3_70B(api_key='fake_key')
|
163
163
|
with self.assertRaisesRegex(
|
164
|
-
Exception, f'{
|
164
|
+
Exception, f'{status_code}:.*{error_type}'
|
165
165
|
):
|
166
166
|
lm('hello', lm=lm, max_attempts=1)
|
167
167
|
|
langfun/core/llms/openai.py
CHANGED
@@ -26,54 +26,55 @@ from openai import openai_object
|
|
26
26
|
import pyglove as pg
|
27
27
|
|
28
28
|
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
29
|
+
# From https://platform.openai.com/settings/organization/limits
|
30
|
+
_DEFAULT_TPM = 250000
|
31
|
+
_DEFAULT_RPM = 3000
|
32
|
+
|
33
|
+
SUPPORTED_MODELS_AND_SETTINGS = {
|
34
|
+
# Models from https://platform.openai.com/docs/models
|
35
|
+
# RPM is from https://platform.openai.com/docs/guides/rate-limits
|
36
|
+
# GPT-4-Turbo models
|
37
|
+
'gpt-4-turbo': pg.Dict(rpm=10000, tpm=1500000),
|
38
|
+
'gpt-4-turbo-2024-04-09': pg.Dict(rpm=10000, tpm=1500000),
|
39
|
+
'gpt-4-turbo-preview': pg.Dict(rpm=10000, tpm=1500000),
|
40
|
+
'gpt-4-0125-preview': pg.Dict(rpm=10000, tpm=1500000),
|
41
|
+
'gpt-4-1106-preview': pg.Dict(rpm=10000, tpm=1500000),
|
42
|
+
'gpt-4-vision-preview': pg.Dict(rpm=10000, tpm=1500000),
|
43
|
+
'gpt-4-1106-vision-preview': pg.Dict(
|
44
|
+
rpm=10000, tpm=1500000
|
45
|
+
),
|
46
|
+
# GPT-4 models
|
47
|
+
'gpt-4': pg.Dict(rpm=10000, tpm=300000),
|
48
|
+
'gpt-4-0613': pg.Dict(rpm=10000, tpm=300000),
|
49
|
+
'gpt-4-0314': pg.Dict(rpm=10000, tpm=300000),
|
50
|
+
'gpt-4-32k': pg.Dict(rpm=10000, tpm=300000),
|
51
|
+
'gpt-4-32k-0613': pg.Dict(rpm=10000, tpm=300000),
|
52
|
+
'gpt-4-32k-0314': pg.Dict(rpm=10000, tpm=300000),
|
53
|
+
# GPT-3.5-Turbo models
|
54
|
+
'gpt-3.5-turbo': pg.Dict(rpm=10000, tpm=2000000),
|
55
|
+
'gpt-3.5-turbo-0125': pg.Dict(rpm=10000, tpm=2000000),
|
56
|
+
'gpt-3.5-turbo-1106': pg.Dict(rpm=10000, tpm=2000000),
|
57
|
+
'gpt-3.5-turbo-0613': pg.Dict(rpm=10000, tpm=2000000),
|
58
|
+
'gpt-3.5-turbo-0301': pg.Dict(rpm=10000, tpm=2000000),
|
59
|
+
'gpt-3.5-turbo-16k': pg.Dict(rpm=10000, tpm=2000000),
|
60
|
+
'gpt-3.5-turbo-16k-0613': pg.Dict(rpm=10000, tpm=2000000),
|
61
|
+
'gpt-3.5-turbo-16k-0301': pg.Dict(rpm=10000, tpm=2000000),
|
62
|
+
# GPT-3.5 models
|
63
|
+
'text-davinci-003': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
64
|
+
'text-davinci-002': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
65
|
+
'code-davinci-002': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
66
|
+
# GPT-3 instruction-tuned models
|
67
|
+
'text-curie-001': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
68
|
+
'text-babbage-001': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
69
|
+
'text-ada-001': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
70
|
+
'davinci': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
71
|
+
'curie': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
72
|
+
'babbage': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
73
|
+
'ada': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
74
|
+
# GPT-3 base models
|
75
|
+
'babbage-002': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
76
|
+
'davinci-002': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
77
|
+
}
|
77
78
|
|
78
79
|
|
79
80
|
@lf.use_init_args(['model'])
|
@@ -82,7 +83,7 @@ class OpenAI(lf.LanguageModel):
|
|
82
83
|
|
83
84
|
model: pg.typing.Annotated[
|
84
85
|
pg.typing.Enum(
|
85
|
-
pg.MISSING_VALUE,
|
86
|
+
pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
|
86
87
|
),
|
87
88
|
'The name of the model to use.',
|
88
89
|
] = 'gpt-3.5-turbo'
|
@@ -134,7 +135,11 @@ class OpenAI(lf.LanguageModel):
|
|
134
135
|
|
135
136
|
@property
|
136
137
|
def max_concurrency(self) -> int:
|
137
|
-
|
138
|
+
rpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('rpm', 0)
|
139
|
+
tpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('tpm', 0)
|
140
|
+
return self.rate_to_max_concurrency(
|
141
|
+
requests_per_min=rpm, tokens_per_min=tpm
|
142
|
+
)
|
138
143
|
|
139
144
|
@classmethod
|
140
145
|
def dir(cls):
|
langfun/core/llms/openai_test.py
CHANGED
@@ -11,7 +11,7 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""Tests for
|
14
|
+
"""Tests for OpenAI models."""
|
15
15
|
|
16
16
|
import unittest
|
17
17
|
from unittest import mock
|
@@ -85,7 +85,7 @@ def mock_chat_completion_query_vision(messages, *, n=1, **kwargs):
|
|
85
85
|
)
|
86
86
|
|
87
87
|
|
88
|
-
class
|
88
|
+
class OpenAITest(unittest.TestCase):
|
89
89
|
"""Tests for OpenAI language model."""
|
90
90
|
|
91
91
|
def test_model_id(self):
|
@@ -98,7 +98,7 @@ class OpenaiTest(unittest.TestCase):
|
|
98
98
|
)
|
99
99
|
|
100
100
|
def test_max_concurrency(self):
|
101
|
-
self.
|
101
|
+
self.assertGreater(openai.Gpt35(api_key='test_key').max_concurrency, 0)
|
102
102
|
|
103
103
|
def test_get_request_args(self):
|
104
104
|
self.assertEqual(
|
langfun/core/template.py
CHANGED
@@ -38,6 +38,10 @@ NO_TEMPLATE_DOCSTR_SIGN = 'THIS IS NOT A TEMPLATE'
|
|
38
38
|
_TLS_RENDER_STACK = '_template_render_stack'
|
39
39
|
_TLS_RENDER_RESULT_CACHE = '_template_render_result_cache'
|
40
40
|
|
41
|
+
# The prefix for fields or contextual attributes to be treated as additional
|
42
|
+
# metadata for rendered message.
|
43
|
+
_ADDITIONAL_METADATA_PREFIX = 'metadata_'
|
44
|
+
|
41
45
|
|
42
46
|
class Template(
|
43
47
|
natural_language.NaturalLanguageFormattable,
|
@@ -303,19 +307,19 @@ class Template(
|
|
303
307
|
with modality.format_modality_as_ref():
|
304
308
|
rendered_text = self._template.render(**inputs)
|
305
309
|
|
310
|
+
# Carry additional metadata.
|
311
|
+
metadata = self.additional_metadata()
|
312
|
+
|
306
313
|
if self.clean:
|
307
314
|
rendered_text = rendered_text.strip()
|
308
315
|
|
309
|
-
|
310
|
-
|
311
|
-
text=rendered_text,
|
312
|
-
metadata={
|
313
|
-
k: pg.Ref(v)
|
314
|
-
for k, v in inputs.items()
|
315
|
-
if not inspect.ismethod(v)
|
316
|
-
},
|
316
|
+
metadata.update(
|
317
|
+
{k: pg.Ref(v) for k, v in inputs.items() if not inspect.ismethod(v)}
|
317
318
|
)
|
318
319
|
|
320
|
+
# Fill the variables for rendering the template as metadata.
|
321
|
+
message = message_cls(text=rendered_text, metadata=metadata)
|
322
|
+
|
319
323
|
# Tag input as rendered message.
|
320
324
|
message.tag(message_lib.Message.TAG_RENDERED)
|
321
325
|
|
@@ -340,6 +344,20 @@ class Template(
|
|
340
344
|
top = pg.object_utils.thread_local_pop(_TLS_RENDER_STACK)
|
341
345
|
assert top is self, (top, self)
|
342
346
|
|
347
|
+
def additional_metadata(self) -> dict[str, Any]:
|
348
|
+
"""Returns additional metadta to be carried in the rendered message."""
|
349
|
+
metadata = {}
|
350
|
+
# Carry metadata from `lf.context`.
|
351
|
+
for k, v in component.all_contextual_values().items():
|
352
|
+
if k.startswith(_ADDITIONAL_METADATA_PREFIX):
|
353
|
+
metadata[k.removeprefix(_ADDITIONAL_METADATA_PREFIX)] = v
|
354
|
+
|
355
|
+
# Carry metadata from fields.
|
356
|
+
for k, v in self.sym_init_args.items():
|
357
|
+
if k.startswith(_ADDITIONAL_METADATA_PREFIX):
|
358
|
+
metadata[k.removeprefix(_ADDITIONAL_METADATA_PREFIX)] = v
|
359
|
+
return metadata
|
360
|
+
|
343
361
|
#
|
344
362
|
# Implements `pg.typing.CustomTyping`.
|
345
363
|
#
|
langfun/core/template_test.py
CHANGED
@@ -16,6 +16,7 @@ import inspect
|
|
16
16
|
import unittest
|
17
17
|
|
18
18
|
from langfun.core import component
|
19
|
+
from langfun.core import message as message_lib
|
19
20
|
from langfun.core import modality
|
20
21
|
from langfun.core import subscription
|
21
22
|
from langfun.core.template import Template
|
@@ -427,6 +428,14 @@ class RenderTest(unittest.TestCase):
|
|
427
428
|
# Test len.
|
428
429
|
self.assert_partial(Template('Hello {{len(x)}}'), 'Hello {{len(x)}}')
|
429
430
|
|
431
|
+
def test_additional_metadata(self):
|
432
|
+
t = Template('hi', metadata_weights=1.0, y=2)
|
433
|
+
self.assertEqual(t.render(), message_lib.UserMessage('hi', weights=1.0))
|
434
|
+
|
435
|
+
t = Template('hi')
|
436
|
+
with component.context(metadata_weights=1.0, y=2):
|
437
|
+
self.assertEqual(t.render(), message_lib.UserMessage('hi', weights=1.0))
|
438
|
+
|
430
439
|
|
431
440
|
class TemplateRenderEventTest(unittest.TestCase):
|
432
441
|
|
@@ -1,15 +1,15 @@
|
|
1
1
|
langfun/__init__.py,sha256=3iCC7F8XoRZ7Gvus11NT50e4KDOQJxIPn9a7TlLzuVI,1880
|
2
2
|
langfun/core/__init__.py,sha256=6QEuXOZ9BXxm6TjpaMXuLwUBTYO3pkFDqn9QVBXyyPQ,4248
|
3
|
-
langfun/core/component.py,sha256=
|
4
|
-
langfun/core/component_test.py,sha256=
|
3
|
+
langfun/core/component.py,sha256=oxesbC0BoE_TbtxwW5x-BAZWxZyyJbuPiX5S38RqCv0,9909
|
4
|
+
langfun/core/component_test.py,sha256=uR-_Sz_42Jxc5qzLIB-f5_pXmNwnC01Xlbv5NOQSeSU,8021
|
5
5
|
langfun/core/concurrent.py,sha256=TRc49pJ3HQro2kb5FtcWkHjhBm8UcgE8RJybU5cU3-0,24537
|
6
6
|
langfun/core/concurrent_test.py,sha256=mwFMZhDUdppnDr7vDSTwcbMHwrdsIoKJwRYNtl4ZWL4,15185
|
7
7
|
langfun/core/console.py,sha256=bk5rNPNm9rMGW5YT2HixxU04p2umnoabn5SDz6Dqe88,2317
|
8
8
|
langfun/core/console_test.py,sha256=5SYJdxpJGLgdSSQqqMPoA1X6jpsLD8rgcyk-EgI65oE,1077
|
9
9
|
langfun/core/langfunc.py,sha256=WXdTc3QsmGD_n80KD9dFRr5MHpGZ9E_y_Rhtk4t9-3w,11852
|
10
10
|
langfun/core/langfunc_test.py,sha256=sQaKuZpGGmG80GRifhbxkj7nfzQLJKj4Vuw5y1s1K3U,8378
|
11
|
-
langfun/core/language_model.py,sha256=
|
12
|
-
langfun/core/language_model_test.py,sha256=
|
11
|
+
langfun/core/language_model.py,sha256=mJfQ_Zqq9IyVyZUdYMQ1BPrpo4Gn8yxDJb_RghQFP_I,18911
|
12
|
+
langfun/core/language_model_test.py,sha256=oWQjnyiJugSpHJKda-qLaSvmbm1sx_v-ZXrHvw_kNk4,14172
|
13
13
|
langfun/core/memory.py,sha256=f-asN1F7Vehgdn_fK84v73GrEUOxRtaW934keutTKjk,2416
|
14
14
|
langfun/core/message.py,sha256=QhvV9t5qaryPcruyxxcXi3gm9QDInkSldwTtK6sVJ3c,15734
|
15
15
|
langfun/core/message_test.py,sha256=Z23pUM5vPnDrYkIIibe2KL73D5HKur_awI0ut_EQFQA,9501
|
@@ -21,8 +21,8 @@ langfun/core/sampling.py,sha256=vygWvgC8MFw0_AKNSmz-ywMXJYWf8cl0tI8QycvAmyI,5795
|
|
21
21
|
langfun/core/sampling_test.py,sha256=U7PANpMsl9E_pa4_Y4FzesSjcwg-u-LKHGCWSgv-8FY,3663
|
22
22
|
langfun/core/subscription.py,sha256=euawEuSZP-BHydaT-AQpfYFL0m5pWPGcW0upFhrojqc,10930
|
23
23
|
langfun/core/subscription_test.py,sha256=Y4ZdbZEwm83YNZBxHff0QR4QUa4rdaNXA3_jfIcArBo,8717
|
24
|
-
langfun/core/template.py,sha256=
|
25
|
-
langfun/core/template_test.py,sha256=
|
24
|
+
langfun/core/template.py,sha256=dr3tZCbXH2qWzigO_EFVHe0GDnnCu58Tru5Mvlzin4o,18447
|
25
|
+
langfun/core/template_test.py,sha256=xty7PgdNhGpw7ZRZ6QGwhKZWG6dyRgI16Lg3p7IMLJg,13944
|
26
26
|
langfun/core/text_formatting.py,sha256=ytjj7opnRJ6w-pkglL2CZUyfYDXLpNf65E42LBb31gc,5158
|
27
27
|
langfun/core/text_formatting_test.py,sha256=nyKC6tn2L4hPJiqQHgxcbQsJJi4A4Nbj8FiO8iT6B80,1514
|
28
28
|
langfun/core/coding/__init__.py,sha256=5utju_fwEsImaiftx4oXKl9FAM8p281k8-Esdh_-m1w,835
|
@@ -47,18 +47,18 @@ langfun/core/eval/matching_test.py,sha256=FFHYD7IDuKe5RMjkx74ksukiwUhO5a_SS340Ja
|
|
47
47
|
langfun/core/eval/scoring.py,sha256=aKeanBJf1yO3Q9JEtgPWoiZk_3M_GiqwXVXX7x_g22w,6172
|
48
48
|
langfun/core/eval/scoring_test.py,sha256=YH1cIxBWtfdKcAV9Fh10vLkV5J-gxk8b6nxW4Z2u5pk,4024
|
49
49
|
langfun/core/llms/__init__.py,sha256=1bPg1QI8duOZCYINm-jWi094x0JtLmsk4KX60qIC_gs,3245
|
50
|
-
langfun/core/llms/anthropic.py,sha256=
|
51
|
-
langfun/core/llms/anthropic_test.py,sha256=
|
50
|
+
langfun/core/llms/anthropic.py,sha256=7W9YdPN3SlAFhAIQlihMkrpo7tTY_4NvD0KIlCrqcsk,8505
|
51
|
+
langfun/core/llms/anthropic_test.py,sha256=TMM30myyEhwF99Le4RvJEXOn8RYl0q1FRkt9Q9nl1jk,5540
|
52
52
|
langfun/core/llms/fake.py,sha256=b-Xk5IPTbUt-elsyzd_i3n1tqzc_kgETXrEvgJruSMk,2824
|
53
53
|
langfun/core/llms/fake_test.py,sha256=ZlDQgL41EX3eYTfBQNp2nB2LciqCmtoHgCsGvW4XhwI,4184
|
54
54
|
langfun/core/llms/google_genai.py,sha256=n8zyJwh9UCTgb6-8LyvmjVNFGZQ4-zfzZ0ulkhHAnR8,8624
|
55
55
|
langfun/core/llms/google_genai_test.py,sha256=_UcGTfl16-aDUlEWFC2W2F8y9jPUs53RBYA6MOCpGXw,7525
|
56
|
-
langfun/core/llms/groq.py,sha256=
|
57
|
-
langfun/core/llms/groq_test.py,sha256=
|
56
|
+
langfun/core/llms/groq.py,sha256=NaGItVL_pkOpqPpI4bPGU27xLFRoaeizZ49v2s-4ERs,7844
|
57
|
+
langfun/core/llms/groq_test.py,sha256=M6GtlrsOvDun_j-sR8cPh4W_moHWZNSTiThu3kuwbbc,5281
|
58
58
|
langfun/core/llms/llama_cpp.py,sha256=Y_KkMUf3Xfac49koMUtUslKl3h-HWp3-ntq7Jaa3bdo,2385
|
59
59
|
langfun/core/llms/llama_cpp_test.py,sha256=ZxC6defGd_HX9SFRU9U4cJiQnBKundbOrchbXuC1Z2M,1683
|
60
|
-
langfun/core/llms/openai.py,sha256=
|
61
|
-
langfun/core/llms/openai_test.py,sha256=
|
60
|
+
langfun/core/llms/openai.py,sha256=06nPhmw0zIA5Zqv3eqsrZtYLHnKwW7N8yt3LlFUFVpI,13247
|
61
|
+
langfun/core/llms/openai_test.py,sha256=Yt_W6k8YXpT3bs0JroARofCGmn_Uq3u61LmZxqWS2DQ,8272
|
62
62
|
langfun/core/llms/cache/__init__.py,sha256=QAo3InUMDM_YpteNnVCSejI4zOsnjSMWKJKzkb3VY64,993
|
63
63
|
langfun/core/llms/cache/base.py,sha256=cFfYvOIUae842pncqCAsRvqXCk2AnAsRYVx0mcIoAeY,3338
|
64
64
|
langfun/core/llms/cache/in_memory.py,sha256=YfFyJEhLs73cUiB0ZfhMxYpdE8Iuxxw-dvMFwGHTSHw,4742
|
@@ -101,8 +101,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
|
|
101
101
|
langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
|
102
102
|
langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
|
103
103
|
langfun/core/templates/selfplay_test.py,sha256=DYVrkk7uNKCqJGEHH31HssU2BPuMItU1vJLzfcXIlYg,2156
|
104
|
-
langfun-0.0.2.
|
105
|
-
langfun-0.0.2.
|
106
|
-
langfun-0.0.2.
|
107
|
-
langfun-0.0.2.
|
108
|
-
langfun-0.0.2.
|
104
|
+
langfun-0.0.2.dev20240423.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
105
|
+
langfun-0.0.2.dev20240423.dist-info/METADATA,sha256=tK11XSi0smD_bL0tVdbT_YsNXu-xs1KKYbTX2powidg,3405
|
106
|
+
langfun-0.0.2.dev20240423.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
107
|
+
langfun-0.0.2.dev20240423.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
|
108
|
+
langfun-0.0.2.dev20240423.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|