langfun 0.0.2.dev20240420__py3-none-any.whl → 0.0.2.dev20240423__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/component.py +6 -0
- langfun/core/component_test.py +1 -0
- langfun/core/language_model.py +14 -0
- langfun/core/language_model_test.py +32 -0
- langfun/core/llms/__init__.py +7 -0
- langfun/core/llms/anthropic.py +36 -22
- langfun/core/llms/anthropic_test.py +7 -7
- langfun/core/llms/groq.py +260 -0
- langfun/core/llms/groq_test.py +170 -0
- langfun/core/llms/openai.py +55 -50
- langfun/core/llms/openai_test.py +3 -3
- langfun/core/template.py +26 -8
- langfun/core/template_test.py +9 -0
- {langfun-0.0.2.dev20240420.dist-info → langfun-0.0.2.dev20240423.dist-info}/METADATA +1 -1
- {langfun-0.0.2.dev20240420.dist-info → langfun-0.0.2.dev20240423.dist-info}/RECORD +18 -16
- {langfun-0.0.2.dev20240420.dist-info → langfun-0.0.2.dev20240423.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240420.dist-info → langfun-0.0.2.dev20240423.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240420.dist-info → langfun-0.0.2.dev20240423.dist-info}/top_level.txt +0 -0
langfun/core/component.py
CHANGED
@@ -210,6 +210,12 @@ def get_contextual_override(var_name: str) -> ContextualOverride | None:
|
|
210
210
|
return _get_scoped_value(_global_tls, _CONTEXT_OVERRIDES, var_name)
|
211
211
|
|
212
212
|
|
213
|
+
def all_contextual_values() -> dict[str, Any]:
|
214
|
+
"""Returns all contextual values provided from `lf.context` in scope."""
|
215
|
+
overrides = getattr(_global_tls, _CONTEXT_OVERRIDES, {})
|
216
|
+
return {k: v.value for k, v in overrides.items()}
|
217
|
+
|
218
|
+
|
213
219
|
@contextlib.contextmanager
|
214
220
|
def _contextual_scope(
|
215
221
|
tls: threading.local, tls_key, **variables
|
langfun/core/component_test.py
CHANGED
@@ -84,6 +84,7 @@ class ComponentContextTest(unittest.TestCase):
|
|
84
84
|
lf.get_contextual_override('y'),
|
85
85
|
lf.ContextualOverride(3, cascade=False, override_attrs=False),
|
86
86
|
)
|
87
|
+
self.assertEqual(lf.all_contextual_values(), dict(x=3, y=3, z=3))
|
87
88
|
|
88
89
|
# Member attributes take precedence over `lf.context`.
|
89
90
|
self.assertEqual(a1.x, 1)
|
langfun/core/language_model.py
CHANGED
@@ -24,6 +24,9 @@ from langfun.core import console
|
|
24
24
|
from langfun.core import message as message_lib
|
25
25
|
import pyglove as pg
|
26
26
|
|
27
|
+
TOKENS_PER_REQUEST = 250 # Estimated num tokens for a single request
|
28
|
+
DEFAULT_MAX_CONCURRENCY = 1 # Use this as max concurrency if no RPM or TPM data
|
29
|
+
|
27
30
|
|
28
31
|
class LMSample(pg.Object):
|
29
32
|
"""Response candidate."""
|
@@ -604,3 +607,14 @@ class LanguageModel(component.Component):
|
|
604
607
|
f'score: {r.score}',
|
605
608
|
color='blue',
|
606
609
|
)
|
610
|
+
|
611
|
+
def rate_to_max_concurrency(
|
612
|
+
self, requests_per_min: float = 0, tokens_per_min: float = 0
|
613
|
+
) -> int:
|
614
|
+
"""Converts a rate to a max concurrency."""
|
615
|
+
if tokens_per_min > 0:
|
616
|
+
return max(int(tokens_per_min / TOKENS_PER_REQUEST / 60), 1)
|
617
|
+
elif requests_per_min > 0:
|
618
|
+
return max(int(requests_per_min / 60), 1) # Max concurrency can't be zero
|
619
|
+
else:
|
620
|
+
return DEFAULT_MAX_CONCURRENCY # Default of 1
|
@@ -394,6 +394,38 @@ class LanguageModelTest(unittest.TestCase):
|
|
394
394
|
with self.assertRaises(NotImplementedError):
|
395
395
|
MockModel().score('hi', ['1', '2'])
|
396
396
|
|
397
|
+
def test_rate_to_max_concurrency_no_rpm_no_tpm(self) -> None:
|
398
|
+
lm = MockModel()
|
399
|
+
self.assertEqual(
|
400
|
+
lm_lib.DEFAULT_MAX_CONCURRENCY,
|
401
|
+
lm.rate_to_max_concurrency(requests_per_min=0, tokens_per_min=0),
|
402
|
+
)
|
403
|
+
self.assertEqual(
|
404
|
+
lm_lib.DEFAULT_MAX_CONCURRENCY,
|
405
|
+
lm.rate_to_max_concurrency(requests_per_min=-1, tokens_per_min=-1),
|
406
|
+
)
|
407
|
+
|
408
|
+
def test_rate_to_max_concurrency_only_rpm_specified_uses_rpm(self) -> None:
|
409
|
+
lm = MockModel()
|
410
|
+
test_rpm = 1e4
|
411
|
+
self.assertEqual(
|
412
|
+
lm.rate_to_max_concurrency(requests_per_min=test_rpm),
|
413
|
+
int(test_rpm / 60)
|
414
|
+
)
|
415
|
+
|
416
|
+
def test_rate_to_max_concurrency_tpm_specified_uses_tpm(self) -> None:
|
417
|
+
lm = MockModel()
|
418
|
+
test_tpm = 1e7
|
419
|
+
self.assertEqual(
|
420
|
+
lm.rate_to_max_concurrency(requests_per_min=1, tokens_per_min=test_tpm),
|
421
|
+
int(test_tpm / lm_lib.TOKENS_PER_REQUEST / 60),
|
422
|
+
)
|
423
|
+
|
424
|
+
def test_rate_to_max_concurrency_small_rate_returns_one(self) -> None:
|
425
|
+
lm = MockModel()
|
426
|
+
self.assertEqual(lm.rate_to_max_concurrency(requests_per_min=1), 1)
|
427
|
+
self.assertEqual(lm.rate_to_max_concurrency(tokens_per_min=1), 1)
|
428
|
+
|
397
429
|
|
398
430
|
if __name__ == '__main__':
|
399
431
|
unittest.main()
|
langfun/core/llms/__init__.py
CHANGED
@@ -66,6 +66,13 @@ from langfun.core.llms.anthropic import Claude3Opus
|
|
66
66
|
from langfun.core.llms.anthropic import Claude3Sonnet
|
67
67
|
from langfun.core.llms.anthropic import Claude3Haiku
|
68
68
|
|
69
|
+
from langfun.core.llms.groq import Groq
|
70
|
+
from langfun.core.llms.groq import GroqLlama3_70B
|
71
|
+
from langfun.core.llms.groq import GroqLlama3_8B
|
72
|
+
from langfun.core.llms.groq import GroqLlama2_70B
|
73
|
+
from langfun.core.llms.groq import GroqMistral_8x7B
|
74
|
+
from langfun.core.llms.groq import GroqGemma7B_IT
|
75
|
+
|
69
76
|
|
70
77
|
# LLaMA C++ models.
|
71
78
|
from langfun.core.llms.llama_cpp import LlamaCppRemote
|
langfun/core/llms/anthropic.py
CHANGED
@@ -26,12 +26,15 @@ import requests
|
|
26
26
|
|
27
27
|
SUPPORTED_MODELS_AND_SETTINGS = {
|
28
28
|
# See https://docs.anthropic.com/claude/docs/models-overview
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
'claude-
|
33
|
-
'claude-
|
34
|
-
'claude-
|
29
|
+
# Rate limits from https://docs.anthropic.com/claude/reference/rate-limits
|
30
|
+
# RPM/TPM for Claude-2.1, Claude-2.0, and Claude-Instant-1.2 estimated
|
31
|
+
# as RPM/TPM of the largest-available model (Claude-3-Opus).
|
32
|
+
'claude-3-opus-20240229': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
33
|
+
'claude-3-sonnet-20240229': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
34
|
+
'claude-3-haiku-20240307': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
35
|
+
'claude-2.1': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
36
|
+
'claude-2.0': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
37
|
+
'claude-instant-1.2': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
|
35
38
|
}
|
36
39
|
|
37
40
|
|
@@ -81,6 +84,7 @@ class Anthropic(lf.LanguageModel):
|
|
81
84
|
super()._on_bound()
|
82
85
|
self._api_key = None
|
83
86
|
self.__dict__.pop('_api_initialized', None)
|
87
|
+
self.__dict__.pop('_session', None)
|
84
88
|
|
85
89
|
@functools.cached_property
|
86
90
|
def _api_initialized(self):
|
@@ -93,6 +97,17 @@ class Anthropic(lf.LanguageModel):
|
|
93
97
|
self._api_key = api_key
|
94
98
|
return True
|
95
99
|
|
100
|
+
@functools.cached_property
|
101
|
+
def _session(self) -> requests.Session:
|
102
|
+
assert self._api_initialized
|
103
|
+
s = requests.Session()
|
104
|
+
s.headers.update({
|
105
|
+
'x-api-key': self._api_key,
|
106
|
+
'anthropic-version': _ANTHROPIC_API_VERSION,
|
107
|
+
'content-type': 'application/json',
|
108
|
+
})
|
109
|
+
return s
|
110
|
+
|
96
111
|
@property
|
97
112
|
def model_id(self) -> str:
|
98
113
|
"""Returns a string to identify the model."""
|
@@ -100,7 +115,11 @@ class Anthropic(lf.LanguageModel):
|
|
100
115
|
|
101
116
|
@property
|
102
117
|
def max_concurrency(self) -> int:
|
103
|
-
|
118
|
+
rpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('rpm', 0)
|
119
|
+
tpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('tpm', 0)
|
120
|
+
return self.rate_to_max_concurrency(
|
121
|
+
requests_per_min=rpm, tokens_per_min=tpm
|
122
|
+
)
|
104
123
|
|
105
124
|
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
106
125
|
assert self._api_initialized
|
@@ -165,8 +184,8 @@ class Anthropic(lf.LanguageModel):
|
|
165
184
|
def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
|
166
185
|
"""Parses Anthropic's response."""
|
167
186
|
# NOTE(daiyip): Refer https://docs.anthropic.com/claude/reference/errors
|
168
|
-
output = response.json()
|
169
187
|
if response.status_code == 200:
|
188
|
+
output = response.json()
|
170
189
|
message = self._message_from_content(output['content'])
|
171
190
|
input_tokens = output['usage']['input_tokens']
|
172
191
|
output_tokens = output['usage']['output_tokens']
|
@@ -181,12 +200,11 @@ class Anthropic(lf.LanguageModel):
|
|
181
200
|
else:
|
182
201
|
if response.status_code == 429:
|
183
202
|
error_cls = RateLimitError
|
184
|
-
elif response.status_code
|
203
|
+
elif response.status_code in (502, 529):
|
185
204
|
error_cls = OverloadedError
|
186
205
|
else:
|
187
206
|
error_cls = AnthropicError
|
188
|
-
|
189
|
-
raise error_cls(f'{error["type"]}: {error["message"]}')
|
207
|
+
raise error_cls(f'{response.status_code}: {response.content}')
|
190
208
|
|
191
209
|
def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
192
210
|
request = dict()
|
@@ -198,17 +216,13 @@ class Anthropic(lf.LanguageModel):
|
|
198
216
|
]
|
199
217
|
)
|
200
218
|
)
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
},
|
209
|
-
timeout=self.timeout,
|
210
|
-
)
|
211
|
-
return self._parse_response(response)
|
219
|
+
try:
|
220
|
+
response = self._session.post(
|
221
|
+
_ANTHROPIC_MESSAGE_API_ENDPOINT, json=request, timeout=self.timeout,
|
222
|
+
)
|
223
|
+
return self._parse_response(response)
|
224
|
+
except ConnectionError as e:
|
225
|
+
raise OverloadedError(str(e)) from e
|
212
226
|
|
213
227
|
|
214
228
|
class Claude3(Anthropic):
|
@@ -98,20 +98,20 @@ def mock_requests_post_error(status_code, error_type, error_message):
|
|
98
98
|
return _mock_requests
|
99
99
|
|
100
100
|
|
101
|
-
class
|
101
|
+
class AnthropicTest(unittest.TestCase):
|
102
102
|
|
103
103
|
def test_basics(self):
|
104
104
|
self.assertEqual(
|
105
105
|
anthropic.Claude3Haiku().model_id, 'claude-3-haiku-20240307'
|
106
106
|
)
|
107
|
-
self.
|
107
|
+
self.assertGreater(anthropic.Claude3Haiku().max_concurrency, 0)
|
108
108
|
|
109
109
|
def test_api_key(self):
|
110
110
|
lm = anthropic.Claude3Haiku()
|
111
111
|
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
112
112
|
lm('hi')
|
113
113
|
|
114
|
-
with mock.patch('requests.post') as mock_request:
|
114
|
+
with mock.patch('requests.Session.post') as mock_request:
|
115
115
|
mock_request.side_effect = mock_requests_post
|
116
116
|
|
117
117
|
lm = anthropic.Claude3Haiku(api_key='fake key')
|
@@ -123,7 +123,7 @@ class AuthropicTest(unittest.TestCase):
|
|
123
123
|
del os.environ['ANTHROPIC_API_KEY']
|
124
124
|
|
125
125
|
def test_call(self):
|
126
|
-
with mock.patch('requests.post') as mock_request:
|
126
|
+
with mock.patch('requests.Session.post') as mock_request:
|
127
127
|
mock_request.side_effect = mock_requests_post
|
128
128
|
lm = anthropic.Claude3Haiku(api_key='fake_key')
|
129
129
|
response = lm('hello', temperature=0.0, top_k=0.1, top_p=0.2, stop=['\n'])
|
@@ -140,7 +140,7 @@ class AuthropicTest(unittest.TestCase):
|
|
140
140
|
self.assertIsNotNone(response.usage.total_tokens, 3)
|
141
141
|
|
142
142
|
def test_mm_call(self):
|
143
|
-
with mock.patch('requests.post') as mock_mm_request:
|
143
|
+
with mock.patch('requests.Session.post') as mock_mm_request:
|
144
144
|
mock_mm_request.side_effect = mock_mm_requests_post
|
145
145
|
lm = anthropic.Claude3Haiku(api_key='fake_key')
|
146
146
|
response = lm(lf_modalities.Image.from_bytes(image_content), lm=lm)
|
@@ -152,13 +152,13 @@ class AuthropicTest(unittest.TestCase):
|
|
152
152
|
(529, 'service_unavailable', 'Service unavailable.'),
|
153
153
|
(500, 'bad_request', 'Bad request.'),
|
154
154
|
]:
|
155
|
-
with mock.patch('requests.post') as mock_mm_request:
|
155
|
+
with mock.patch('requests.Session.post') as mock_mm_request:
|
156
156
|
mock_mm_request.side_effect = mock_requests_post_error(
|
157
157
|
status_code, error_type, error_message
|
158
158
|
)
|
159
159
|
lm = anthropic.Claude3Haiku(api_key='fake_key')
|
160
160
|
with self.assertRaisesRegex(
|
161
|
-
Exception, f'{
|
161
|
+
Exception, f'.*{status_code}: .*{error_message}'
|
162
162
|
):
|
163
163
|
lm('hello', lm=lm, max_attempts=1)
|
164
164
|
|
@@ -0,0 +1,260 @@
|
|
1
|
+
# Copyright 2024 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Language models from Groq."""
|
15
|
+
|
16
|
+
import functools
|
17
|
+
import os
|
18
|
+
from typing import Annotated, Any
|
19
|
+
|
20
|
+
import langfun.core as lf
|
21
|
+
from langfun.core import modalities as lf_modalities
|
22
|
+
import pyglove as pg
|
23
|
+
import requests
|
24
|
+
|
25
|
+
|
26
|
+
SUPPORTED_MODELS_AND_SETTINGS = {
|
27
|
+
# Refer https://console.groq.com/docs/models
|
28
|
+
'llama3-8b-8192': pg.Dict(max_tokens=8192, max_concurrency=16),
|
29
|
+
'llama3-70b-8192': pg.Dict(max_tokens=8192, max_concurrency=16),
|
30
|
+
'llama2-70b-4096': pg.Dict(max_tokens=4096, max_concurrency=16),
|
31
|
+
'mixtral-8x7b-32768': pg.Dict(max_tokens=32768, max_concurrency=16),
|
32
|
+
'gemma-7b-it': pg.Dict(max_tokens=8192, max_concurrency=16),
|
33
|
+
}
|
34
|
+
|
35
|
+
|
36
|
+
class GroqError(Exception): # pylint: disable=g-bad-exception-name
|
37
|
+
"""Base class for Groq errors."""
|
38
|
+
|
39
|
+
|
40
|
+
class RateLimitError(GroqError):
|
41
|
+
"""Error for rate limit reached."""
|
42
|
+
|
43
|
+
|
44
|
+
class OverloadedError(GroqError):
|
45
|
+
"""Groq's server is temporarily overloaded."""
|
46
|
+
|
47
|
+
|
48
|
+
_CHAT_COMPLETE_API_ENDPOINT = 'https://api.groq.com/openai/v1/chat/completions'
|
49
|
+
|
50
|
+
|
51
|
+
@lf.use_init_args(['model'])
|
52
|
+
class Groq(lf.LanguageModel):
|
53
|
+
"""Groq LLMs through REST APIs (OpenAI compatible).
|
54
|
+
|
55
|
+
See https://platform.openai.com/docs/api-reference/chat
|
56
|
+
"""
|
57
|
+
|
58
|
+
model: pg.typing.Annotated[
|
59
|
+
pg.typing.Enum(
|
60
|
+
pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
|
61
|
+
),
|
62
|
+
'The name of the model to use.',
|
63
|
+
]
|
64
|
+
|
65
|
+
multimodal: Annotated[bool, 'Whether this model has multimodal support.'] = (
|
66
|
+
False
|
67
|
+
)
|
68
|
+
|
69
|
+
api_key: Annotated[
|
70
|
+
str | None,
|
71
|
+
(
|
72
|
+
'API key. If None, the key will be read from environment variable '
|
73
|
+
"'GROQ_API_KEY'."
|
74
|
+
),
|
75
|
+
] = None
|
76
|
+
|
77
|
+
def _on_bound(self):
|
78
|
+
super()._on_bound()
|
79
|
+
self._api_key = None
|
80
|
+
self.__dict__.pop('_api_initialized', None)
|
81
|
+
self.__dict__.pop('_session', None)
|
82
|
+
|
83
|
+
@functools.cached_property
|
84
|
+
def _api_initialized(self):
|
85
|
+
api_key = self.api_key or os.environ.get('GROQ_API_KEY', None)
|
86
|
+
if not api_key:
|
87
|
+
raise ValueError(
|
88
|
+
'Please specify `api_key` during `__init__` or set environment '
|
89
|
+
'variable `GROQ_API_KEY` with your Groq API key.'
|
90
|
+
)
|
91
|
+
self._api_key = api_key
|
92
|
+
return True
|
93
|
+
|
94
|
+
@functools.cached_property
|
95
|
+
def _session(self) -> requests.Session:
|
96
|
+
assert self._api_initialized
|
97
|
+
s = requests.Session()
|
98
|
+
s.headers.update({
|
99
|
+
'Authorization': f'Bearer {self._api_key}',
|
100
|
+
'Content-Type': 'application/json',
|
101
|
+
})
|
102
|
+
return s
|
103
|
+
|
104
|
+
@property
|
105
|
+
def model_id(self) -> str:
|
106
|
+
"""Returns a string to identify the model."""
|
107
|
+
return self.model
|
108
|
+
|
109
|
+
@property
|
110
|
+
def max_concurrency(self) -> int:
|
111
|
+
return SUPPORTED_MODELS_AND_SETTINGS[self.model].max_concurrency
|
112
|
+
|
113
|
+
def _get_request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
114
|
+
"""Returns a dict as request arguments."""
|
115
|
+
# `logprobs` and `top_logprobs` flags are not supported on Groq yet.
|
116
|
+
args = dict(
|
117
|
+
model=self.model,
|
118
|
+
n=options.n,
|
119
|
+
stream=False,
|
120
|
+
)
|
121
|
+
|
122
|
+
if options.temperature is not None:
|
123
|
+
args['temperature'] = options.temperature
|
124
|
+
if options.max_tokens is not None:
|
125
|
+
args['max_tokens'] = options.max_tokens
|
126
|
+
if options.top_p is not None:
|
127
|
+
args['top_p'] = options.top_p
|
128
|
+
if options.stop:
|
129
|
+
args['stop'] = options.stop
|
130
|
+
return args
|
131
|
+
|
132
|
+
def _content_from_message(self, prompt: lf.Message) -> list[dict[str, Any]]:
|
133
|
+
"""Converts an message to Groq's content protocol (list of dicts)."""
|
134
|
+
# Refer: https://platform.openai.com/docs/api-reference/chat/create
|
135
|
+
content = []
|
136
|
+
for chunk in prompt.chunk():
|
137
|
+
if isinstance(chunk, str):
|
138
|
+
item = dict(type='text', text=chunk)
|
139
|
+
elif (
|
140
|
+
self.multimodal
|
141
|
+
and isinstance(chunk, lf_modalities.Image)
|
142
|
+
and chunk.uri
|
143
|
+
):
|
144
|
+
# NOTE(daiyip): Groq only support image URL.
|
145
|
+
item = dict(type='image_url', image_url=chunk.uri)
|
146
|
+
else:
|
147
|
+
raise ValueError(f'Unsupported modality object: {chunk!r}.')
|
148
|
+
content.append(item)
|
149
|
+
return content
|
150
|
+
|
151
|
+
def _message_from_choice(self, choice: dict[str, Any]) -> lf.Message:
|
152
|
+
"""Converts Groq's content protocol to message."""
|
153
|
+
# Refer: https://platform.openai.com/docs/api-reference/chat/create
|
154
|
+
content = choice['message']['content']
|
155
|
+
if isinstance(content, str):
|
156
|
+
return lf.AIMessage(content)
|
157
|
+
return lf.AIMessage.from_chunks(
|
158
|
+
[x['text'] for x in content if x['type'] == 'text']
|
159
|
+
)
|
160
|
+
|
161
|
+
def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
|
162
|
+
"""Parses Groq's response."""
|
163
|
+
# Refer: https://platform.openai.com/docs/api-reference/chat/object
|
164
|
+
if response.status_code == 200:
|
165
|
+
output = response.json()
|
166
|
+
samples = [
|
167
|
+
lf.LMSample(self._message_from_choice(choice), score=0.0)
|
168
|
+
for choice in output['choices']
|
169
|
+
]
|
170
|
+
usage = output['usage']
|
171
|
+
return lf.LMSamplingResult(
|
172
|
+
samples,
|
173
|
+
usage=lf.LMSamplingUsage(
|
174
|
+
prompt_tokens=usage['prompt_tokens'],
|
175
|
+
completion_tokens=usage['completion_tokens'],
|
176
|
+
total_tokens=usage['total_tokens'],
|
177
|
+
),
|
178
|
+
)
|
179
|
+
else:
|
180
|
+
# https://platform.openai.com/docs/guides/error-codes/api-errors
|
181
|
+
if response.status_code == 429:
|
182
|
+
error_cls = RateLimitError
|
183
|
+
elif response.status_code in (500, 502, 503):
|
184
|
+
error_cls = OverloadedError
|
185
|
+
else:
|
186
|
+
error_cls = GroqError
|
187
|
+
raise error_cls(f'{response.status_code}: {response.content}')
|
188
|
+
|
189
|
+
def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
|
190
|
+
assert self._api_initialized
|
191
|
+
return self._parallel_execute_with_currency_control(
|
192
|
+
self._sample_single,
|
193
|
+
prompts,
|
194
|
+
retry_on_errors=(RateLimitError, OverloadedError),
|
195
|
+
)
|
196
|
+
|
197
|
+
def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
|
198
|
+
request = dict()
|
199
|
+
request.update(self._get_request_args(self.sampling_options))
|
200
|
+
request.update(
|
201
|
+
dict(
|
202
|
+
messages=[
|
203
|
+
dict(role='user', content=self._content_from_message(prompt))
|
204
|
+
]
|
205
|
+
)
|
206
|
+
)
|
207
|
+
try:
|
208
|
+
response = self._session.post(
|
209
|
+
_CHAT_COMPLETE_API_ENDPOINT,
|
210
|
+
json=request,
|
211
|
+
timeout=self.timeout,
|
212
|
+
)
|
213
|
+
return self._parse_response(response)
|
214
|
+
except ConnectionError as e:
|
215
|
+
raise OverloadedError(str(e)) from e
|
216
|
+
|
217
|
+
|
218
|
+
class GroqLlama3_8B(Groq): # pylint: disable=invalid-name
|
219
|
+
"""Llama3-8B with 8K context window.
|
220
|
+
|
221
|
+
See: https://huggingface.co/meta-llama/Meta-Llama-3-8B
|
222
|
+
"""
|
223
|
+
|
224
|
+
model = 'llama3-8b-8192'
|
225
|
+
|
226
|
+
|
227
|
+
class GroqLlama3_70B(Groq): # pylint: disable=invalid-name
|
228
|
+
"""Llama3-70B with 8K context window.
|
229
|
+
|
230
|
+
See: https://huggingface.co/meta-llama/Meta-Llama-3-70B
|
231
|
+
"""
|
232
|
+
|
233
|
+
model = 'llama3-70b-8192'
|
234
|
+
|
235
|
+
|
236
|
+
class GroqLlama2_70B(Groq): # pylint: disable=invalid-name
|
237
|
+
"""Llama2-70B with 4K context window.
|
238
|
+
|
239
|
+
See: https://huggingface.co/meta-llama/Llama-2-70b
|
240
|
+
"""
|
241
|
+
|
242
|
+
model = 'llama2-70b-4096'
|
243
|
+
|
244
|
+
|
245
|
+
class GroqMistral_8x7B(Groq): # pylint: disable=invalid-name
|
246
|
+
"""Mixtral 8x7B with 32K context window.
|
247
|
+
|
248
|
+
See: https://huggingface.co/meta-llama/Llama-2-70b
|
249
|
+
"""
|
250
|
+
|
251
|
+
model = 'mixtral-8x7b-32768'
|
252
|
+
|
253
|
+
|
254
|
+
class GroqGemma7B_IT(Groq): # pylint: disable=invalid-name
|
255
|
+
"""Gemma 7B with 8K context window.
|
256
|
+
|
257
|
+
See: https://huggingface.co/google/gemma-1.1-7b-it
|
258
|
+
"""
|
259
|
+
|
260
|
+
model = 'gemma-7b-it'
|
@@ -0,0 +1,170 @@
|
|
1
|
+
# Copyright 2023 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Tests for Groq models."""
|
15
|
+
|
16
|
+
import os
|
17
|
+
from typing import Any
|
18
|
+
import unittest
|
19
|
+
from unittest import mock
|
20
|
+
from langfun.core import modalities as lf_modalities
|
21
|
+
from langfun.core.llms import groq
|
22
|
+
import pyglove as pg
|
23
|
+
import requests
|
24
|
+
|
25
|
+
|
26
|
+
def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
|
27
|
+
del url, kwargs
|
28
|
+
|
29
|
+
response = requests.Response()
|
30
|
+
response.status_code = 200
|
31
|
+
response._content = pg.to_json_str({
|
32
|
+
'choices': [{
|
33
|
+
'message': {
|
34
|
+
'content': [{
|
35
|
+
'type': 'text',
|
36
|
+
'text': (
|
37
|
+
f'hello with temperature={json.get("temperature")}, '
|
38
|
+
f'top_p={json.get("top_p")}, '
|
39
|
+
f'max_tokens={json.get("max_tokens")}, '
|
40
|
+
f'stop={json.get("stop")}.'
|
41
|
+
),
|
42
|
+
}],
|
43
|
+
}
|
44
|
+
}],
|
45
|
+
'usage': {
|
46
|
+
'prompt_tokens': 2,
|
47
|
+
'completion_tokens': 1,
|
48
|
+
'total_tokens': 3,
|
49
|
+
},
|
50
|
+
}).encode()
|
51
|
+
return response
|
52
|
+
|
53
|
+
|
54
|
+
def mock_mm_requests_post(url: str, json: dict[str, Any], **kwargs):
|
55
|
+
del url, kwargs
|
56
|
+
v = json['messages'][0]['content'][0]
|
57
|
+
image = lf_modalities.Image.from_uri(v['image_url'])
|
58
|
+
|
59
|
+
response = requests.Response()
|
60
|
+
response.status_code = 200
|
61
|
+
response._content = pg.to_json_str({
|
62
|
+
'choices': [
|
63
|
+
{
|
64
|
+
'message': {
|
65
|
+
'content': [{
|
66
|
+
'type': 'text',
|
67
|
+
'text': image.uri,
|
68
|
+
}],
|
69
|
+
}
|
70
|
+
}
|
71
|
+
],
|
72
|
+
'usage': {
|
73
|
+
'prompt_tokens': 2,
|
74
|
+
'completion_tokens': 1,
|
75
|
+
'total_tokens': 3,
|
76
|
+
},
|
77
|
+
}).encode()
|
78
|
+
return response
|
79
|
+
|
80
|
+
|
81
|
+
def mock_requests_post_error(status_code, error_type, error_message):
|
82
|
+
def _mock_requests(url: str, json: dict[str, Any], **kwargs):
|
83
|
+
del url, json, kwargs
|
84
|
+
response = requests.Response()
|
85
|
+
response.status_code = status_code
|
86
|
+
response._content = pg.to_json_str(
|
87
|
+
{
|
88
|
+
'error': {
|
89
|
+
'type': error_type,
|
90
|
+
'message': error_message,
|
91
|
+
}
|
92
|
+
}
|
93
|
+
).encode()
|
94
|
+
return response
|
95
|
+
|
96
|
+
return _mock_requests
|
97
|
+
|
98
|
+
|
99
|
+
class AuthropicTest(unittest.TestCase):
|
100
|
+
|
101
|
+
def test_basics(self):
|
102
|
+
self.assertEqual(groq.GroqMistral_8x7B().model_id, 'mixtral-8x7b-32768')
|
103
|
+
self.assertEqual(groq.GroqMistral_8x7B().max_concurrency, 16)
|
104
|
+
|
105
|
+
def test_api_key(self):
|
106
|
+
lm = groq.GroqMistral_8x7B()
|
107
|
+
with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
|
108
|
+
lm('hi')
|
109
|
+
|
110
|
+
with mock.patch('requests.Session.post') as mock_request:
|
111
|
+
mock_request.side_effect = mock_requests_post
|
112
|
+
|
113
|
+
lm = groq.GroqMistral_8x7B(api_key='fake key')
|
114
|
+
self.assertRegex(lm('hi').text, 'hello.*')
|
115
|
+
|
116
|
+
os.environ['GROQ_API_KEY'] = 'abc'
|
117
|
+
lm = groq.GroqMistral_8x7B()
|
118
|
+
self.assertRegex(lm('hi').text, 'hello.*')
|
119
|
+
del os.environ['GROQ_API_KEY']
|
120
|
+
|
121
|
+
def test_call(self):
|
122
|
+
with mock.patch('requests.Session.post') as mock_request:
|
123
|
+
mock_request.side_effect = mock_requests_post
|
124
|
+
lm = groq.GroqLlama3_70B(api_key='fake_key')
|
125
|
+
response = lm(
|
126
|
+
'hello',
|
127
|
+
temperature=0.0,
|
128
|
+
max_tokens=1024,
|
129
|
+
top_k=0.1,
|
130
|
+
top_p=0.2,
|
131
|
+
stop=['\n'],
|
132
|
+
)
|
133
|
+
self.assertEqual(
|
134
|
+
response.text,
|
135
|
+
(
|
136
|
+
'hello with temperature=0.0, top_p=0.2, '
|
137
|
+
"max_tokens=1024, stop=['\\n']."
|
138
|
+
),
|
139
|
+
)
|
140
|
+
self.assertIsNotNone(response.usage)
|
141
|
+
self.assertIsNotNone(response.usage.prompt_tokens, 2)
|
142
|
+
self.assertIsNotNone(response.usage.completion_tokens, 1)
|
143
|
+
self.assertIsNotNone(response.usage.total_tokens, 3)
|
144
|
+
|
145
|
+
def test_mm_call(self):
|
146
|
+
with mock.patch('requests.Session.post') as mock_mm_request:
|
147
|
+
mock_mm_request.side_effect = mock_mm_requests_post
|
148
|
+
lm = groq.GroqLlama3_70B(multimodal=True, api_key='fake_key')
|
149
|
+
response = lm(lf_modalities.Image.from_uri('https://fake/image.jpg'))
|
150
|
+
self.assertEqual(response.text, 'https://fake/image.jpg')
|
151
|
+
|
152
|
+
def test_call_errors(self):
|
153
|
+
for status_code, error_type, error_message in [
|
154
|
+
(429, 'rate_limit', 'Rate limit exceeded.'),
|
155
|
+
(503, 'service_unavailable', 'Service unavailable.'),
|
156
|
+
(500, 'bad_request', 'Bad request.'),
|
157
|
+
]:
|
158
|
+
with mock.patch('requests.Session.post') as mock_mm_request:
|
159
|
+
mock_mm_request.side_effect = mock_requests_post_error(
|
160
|
+
status_code, error_type, error_message
|
161
|
+
)
|
162
|
+
lm = groq.GroqLlama3_70B(api_key='fake_key')
|
163
|
+
with self.assertRaisesRegex(
|
164
|
+
Exception, f'{status_code}:.*{error_type}'
|
165
|
+
):
|
166
|
+
lm('hello', lm=lm, max_attempts=1)
|
167
|
+
|
168
|
+
|
169
|
+
if __name__ == '__main__':
|
170
|
+
unittest.main()
|
langfun/core/llms/openai.py
CHANGED
@@ -26,54 +26,55 @@ from openai import openai_object
|
|
26
26
|
import pyglove as pg
|
27
27
|
|
28
28
|
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
29
|
+
# From https://platform.openai.com/settings/organization/limits
|
30
|
+
_DEFAULT_TPM = 250000
|
31
|
+
_DEFAULT_RPM = 3000
|
32
|
+
|
33
|
+
SUPPORTED_MODELS_AND_SETTINGS = {
|
34
|
+
# Models from https://platform.openai.com/docs/models
|
35
|
+
# RPM is from https://platform.openai.com/docs/guides/rate-limits
|
36
|
+
# GPT-4-Turbo models
|
37
|
+
'gpt-4-turbo': pg.Dict(rpm=10000, tpm=1500000),
|
38
|
+
'gpt-4-turbo-2024-04-09': pg.Dict(rpm=10000, tpm=1500000),
|
39
|
+
'gpt-4-turbo-preview': pg.Dict(rpm=10000, tpm=1500000),
|
40
|
+
'gpt-4-0125-preview': pg.Dict(rpm=10000, tpm=1500000),
|
41
|
+
'gpt-4-1106-preview': pg.Dict(rpm=10000, tpm=1500000),
|
42
|
+
'gpt-4-vision-preview': pg.Dict(rpm=10000, tpm=1500000),
|
43
|
+
'gpt-4-1106-vision-preview': pg.Dict(
|
44
|
+
rpm=10000, tpm=1500000
|
45
|
+
),
|
46
|
+
# GPT-4 models
|
47
|
+
'gpt-4': pg.Dict(rpm=10000, tpm=300000),
|
48
|
+
'gpt-4-0613': pg.Dict(rpm=10000, tpm=300000),
|
49
|
+
'gpt-4-0314': pg.Dict(rpm=10000, tpm=300000),
|
50
|
+
'gpt-4-32k': pg.Dict(rpm=10000, tpm=300000),
|
51
|
+
'gpt-4-32k-0613': pg.Dict(rpm=10000, tpm=300000),
|
52
|
+
'gpt-4-32k-0314': pg.Dict(rpm=10000, tpm=300000),
|
53
|
+
# GPT-3.5-Turbo models
|
54
|
+
'gpt-3.5-turbo': pg.Dict(rpm=10000, tpm=2000000),
|
55
|
+
'gpt-3.5-turbo-0125': pg.Dict(rpm=10000, tpm=2000000),
|
56
|
+
'gpt-3.5-turbo-1106': pg.Dict(rpm=10000, tpm=2000000),
|
57
|
+
'gpt-3.5-turbo-0613': pg.Dict(rpm=10000, tpm=2000000),
|
58
|
+
'gpt-3.5-turbo-0301': pg.Dict(rpm=10000, tpm=2000000),
|
59
|
+
'gpt-3.5-turbo-16k': pg.Dict(rpm=10000, tpm=2000000),
|
60
|
+
'gpt-3.5-turbo-16k-0613': pg.Dict(rpm=10000, tpm=2000000),
|
61
|
+
'gpt-3.5-turbo-16k-0301': pg.Dict(rpm=10000, tpm=2000000),
|
62
|
+
# GPT-3.5 models
|
63
|
+
'text-davinci-003': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
64
|
+
'text-davinci-002': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
65
|
+
'code-davinci-002': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
66
|
+
# GPT-3 instruction-tuned models
|
67
|
+
'text-curie-001': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
68
|
+
'text-babbage-001': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
69
|
+
'text-ada-001': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
70
|
+
'davinci': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
71
|
+
'curie': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
72
|
+
'babbage': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
73
|
+
'ada': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
74
|
+
# GPT-3 base models
|
75
|
+
'babbage-002': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
76
|
+
'davinci-002': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
|
77
|
+
}
|
77
78
|
|
78
79
|
|
79
80
|
@lf.use_init_args(['model'])
|
@@ -82,7 +83,7 @@ class OpenAI(lf.LanguageModel):
|
|
82
83
|
|
83
84
|
model: pg.typing.Annotated[
|
84
85
|
pg.typing.Enum(
|
85
|
-
pg.MISSING_VALUE,
|
86
|
+
pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
|
86
87
|
),
|
87
88
|
'The name of the model to use.',
|
88
89
|
] = 'gpt-3.5-turbo'
|
@@ -134,7 +135,11 @@ class OpenAI(lf.LanguageModel):
|
|
134
135
|
|
135
136
|
@property
|
136
137
|
def max_concurrency(self) -> int:
|
137
|
-
|
138
|
+
rpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('rpm', 0)
|
139
|
+
tpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('tpm', 0)
|
140
|
+
return self.rate_to_max_concurrency(
|
141
|
+
requests_per_min=rpm, tokens_per_min=tpm
|
142
|
+
)
|
138
143
|
|
139
144
|
@classmethod
|
140
145
|
def dir(cls):
|
langfun/core/llms/openai_test.py
CHANGED
@@ -11,7 +11,7 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""Tests for
|
14
|
+
"""Tests for OpenAI models."""
|
15
15
|
|
16
16
|
import unittest
|
17
17
|
from unittest import mock
|
@@ -85,7 +85,7 @@ def mock_chat_completion_query_vision(messages, *, n=1, **kwargs):
|
|
85
85
|
)
|
86
86
|
|
87
87
|
|
88
|
-
class
|
88
|
+
class OpenAITest(unittest.TestCase):
|
89
89
|
"""Tests for OpenAI language model."""
|
90
90
|
|
91
91
|
def test_model_id(self):
|
@@ -98,7 +98,7 @@ class OpenaiTest(unittest.TestCase):
|
|
98
98
|
)
|
99
99
|
|
100
100
|
def test_max_concurrency(self):
|
101
|
-
self.
|
101
|
+
self.assertGreater(openai.Gpt35(api_key='test_key').max_concurrency, 0)
|
102
102
|
|
103
103
|
def test_get_request_args(self):
|
104
104
|
self.assertEqual(
|
langfun/core/template.py
CHANGED
@@ -38,6 +38,10 @@ NO_TEMPLATE_DOCSTR_SIGN = 'THIS IS NOT A TEMPLATE'
|
|
38
38
|
_TLS_RENDER_STACK = '_template_render_stack'
|
39
39
|
_TLS_RENDER_RESULT_CACHE = '_template_render_result_cache'
|
40
40
|
|
41
|
+
# The prefix for fields or contextual attributes to be treated as additional
|
42
|
+
# metadata for rendered message.
|
43
|
+
_ADDITIONAL_METADATA_PREFIX = 'metadata_'
|
44
|
+
|
41
45
|
|
42
46
|
class Template(
|
43
47
|
natural_language.NaturalLanguageFormattable,
|
@@ -303,19 +307,19 @@ class Template(
|
|
303
307
|
with modality.format_modality_as_ref():
|
304
308
|
rendered_text = self._template.render(**inputs)
|
305
309
|
|
310
|
+
# Carry additional metadata.
|
311
|
+
metadata = self.additional_metadata()
|
312
|
+
|
306
313
|
if self.clean:
|
307
314
|
rendered_text = rendered_text.strip()
|
308
315
|
|
309
|
-
|
310
|
-
|
311
|
-
text=rendered_text,
|
312
|
-
metadata={
|
313
|
-
k: pg.Ref(v)
|
314
|
-
for k, v in inputs.items()
|
315
|
-
if not inspect.ismethod(v)
|
316
|
-
},
|
316
|
+
metadata.update(
|
317
|
+
{k: pg.Ref(v) for k, v in inputs.items() if not inspect.ismethod(v)}
|
317
318
|
)
|
318
319
|
|
320
|
+
# Fill the variables for rendering the template as metadata.
|
321
|
+
message = message_cls(text=rendered_text, metadata=metadata)
|
322
|
+
|
319
323
|
# Tag input as rendered message.
|
320
324
|
message.tag(message_lib.Message.TAG_RENDERED)
|
321
325
|
|
@@ -340,6 +344,20 @@ class Template(
|
|
340
344
|
top = pg.object_utils.thread_local_pop(_TLS_RENDER_STACK)
|
341
345
|
assert top is self, (top, self)
|
342
346
|
|
347
|
+
def additional_metadata(self) -> dict[str, Any]:
|
348
|
+
"""Returns additional metadta to be carried in the rendered message."""
|
349
|
+
metadata = {}
|
350
|
+
# Carry metadata from `lf.context`.
|
351
|
+
for k, v in component.all_contextual_values().items():
|
352
|
+
if k.startswith(_ADDITIONAL_METADATA_PREFIX):
|
353
|
+
metadata[k.removeprefix(_ADDITIONAL_METADATA_PREFIX)] = v
|
354
|
+
|
355
|
+
# Carry metadata from fields.
|
356
|
+
for k, v in self.sym_init_args.items():
|
357
|
+
if k.startswith(_ADDITIONAL_METADATA_PREFIX):
|
358
|
+
metadata[k.removeprefix(_ADDITIONAL_METADATA_PREFIX)] = v
|
359
|
+
return metadata
|
360
|
+
|
343
361
|
#
|
344
362
|
# Implements `pg.typing.CustomTyping`.
|
345
363
|
#
|
langfun/core/template_test.py
CHANGED
@@ -16,6 +16,7 @@ import inspect
|
|
16
16
|
import unittest
|
17
17
|
|
18
18
|
from langfun.core import component
|
19
|
+
from langfun.core import message as message_lib
|
19
20
|
from langfun.core import modality
|
20
21
|
from langfun.core import subscription
|
21
22
|
from langfun.core.template import Template
|
@@ -427,6 +428,14 @@ class RenderTest(unittest.TestCase):
|
|
427
428
|
# Test len.
|
428
429
|
self.assert_partial(Template('Hello {{len(x)}}'), 'Hello {{len(x)}}')
|
429
430
|
|
431
|
+
def test_additional_metadata(self):
|
432
|
+
t = Template('hi', metadata_weights=1.0, y=2)
|
433
|
+
self.assertEqual(t.render(), message_lib.UserMessage('hi', weights=1.0))
|
434
|
+
|
435
|
+
t = Template('hi')
|
436
|
+
with component.context(metadata_weights=1.0, y=2):
|
437
|
+
self.assertEqual(t.render(), message_lib.UserMessage('hi', weights=1.0))
|
438
|
+
|
430
439
|
|
431
440
|
class TemplateRenderEventTest(unittest.TestCase):
|
432
441
|
|
@@ -1,15 +1,15 @@
|
|
1
1
|
langfun/__init__.py,sha256=3iCC7F8XoRZ7Gvus11NT50e4KDOQJxIPn9a7TlLzuVI,1880
|
2
2
|
langfun/core/__init__.py,sha256=6QEuXOZ9BXxm6TjpaMXuLwUBTYO3pkFDqn9QVBXyyPQ,4248
|
3
|
-
langfun/core/component.py,sha256=
|
4
|
-
langfun/core/component_test.py,sha256=
|
3
|
+
langfun/core/component.py,sha256=oxesbC0BoE_TbtxwW5x-BAZWxZyyJbuPiX5S38RqCv0,9909
|
4
|
+
langfun/core/component_test.py,sha256=uR-_Sz_42Jxc5qzLIB-f5_pXmNwnC01Xlbv5NOQSeSU,8021
|
5
5
|
langfun/core/concurrent.py,sha256=TRc49pJ3HQro2kb5FtcWkHjhBm8UcgE8RJybU5cU3-0,24537
|
6
6
|
langfun/core/concurrent_test.py,sha256=mwFMZhDUdppnDr7vDSTwcbMHwrdsIoKJwRYNtl4ZWL4,15185
|
7
7
|
langfun/core/console.py,sha256=bk5rNPNm9rMGW5YT2HixxU04p2umnoabn5SDz6Dqe88,2317
|
8
8
|
langfun/core/console_test.py,sha256=5SYJdxpJGLgdSSQqqMPoA1X6jpsLD8rgcyk-EgI65oE,1077
|
9
9
|
langfun/core/langfunc.py,sha256=WXdTc3QsmGD_n80KD9dFRr5MHpGZ9E_y_Rhtk4t9-3w,11852
|
10
10
|
langfun/core/langfunc_test.py,sha256=sQaKuZpGGmG80GRifhbxkj7nfzQLJKj4Vuw5y1s1K3U,8378
|
11
|
-
langfun/core/language_model.py,sha256=
|
12
|
-
langfun/core/language_model_test.py,sha256=
|
11
|
+
langfun/core/language_model.py,sha256=mJfQ_Zqq9IyVyZUdYMQ1BPrpo4Gn8yxDJb_RghQFP_I,18911
|
12
|
+
langfun/core/language_model_test.py,sha256=oWQjnyiJugSpHJKda-qLaSvmbm1sx_v-ZXrHvw_kNk4,14172
|
13
13
|
langfun/core/memory.py,sha256=f-asN1F7Vehgdn_fK84v73GrEUOxRtaW934keutTKjk,2416
|
14
14
|
langfun/core/message.py,sha256=QhvV9t5qaryPcruyxxcXi3gm9QDInkSldwTtK6sVJ3c,15734
|
15
15
|
langfun/core/message_test.py,sha256=Z23pUM5vPnDrYkIIibe2KL73D5HKur_awI0ut_EQFQA,9501
|
@@ -21,8 +21,8 @@ langfun/core/sampling.py,sha256=vygWvgC8MFw0_AKNSmz-ywMXJYWf8cl0tI8QycvAmyI,5795
|
|
21
21
|
langfun/core/sampling_test.py,sha256=U7PANpMsl9E_pa4_Y4FzesSjcwg-u-LKHGCWSgv-8FY,3663
|
22
22
|
langfun/core/subscription.py,sha256=euawEuSZP-BHydaT-AQpfYFL0m5pWPGcW0upFhrojqc,10930
|
23
23
|
langfun/core/subscription_test.py,sha256=Y4ZdbZEwm83YNZBxHff0QR4QUa4rdaNXA3_jfIcArBo,8717
|
24
|
-
langfun/core/template.py,sha256=
|
25
|
-
langfun/core/template_test.py,sha256=
|
24
|
+
langfun/core/template.py,sha256=dr3tZCbXH2qWzigO_EFVHe0GDnnCu58Tru5Mvlzin4o,18447
|
25
|
+
langfun/core/template_test.py,sha256=xty7PgdNhGpw7ZRZ6QGwhKZWG6dyRgI16Lg3p7IMLJg,13944
|
26
26
|
langfun/core/text_formatting.py,sha256=ytjj7opnRJ6w-pkglL2CZUyfYDXLpNf65E42LBb31gc,5158
|
27
27
|
langfun/core/text_formatting_test.py,sha256=nyKC6tn2L4hPJiqQHgxcbQsJJi4A4Nbj8FiO8iT6B80,1514
|
28
28
|
langfun/core/coding/__init__.py,sha256=5utju_fwEsImaiftx4oXKl9FAM8p281k8-Esdh_-m1w,835
|
@@ -46,17 +46,19 @@ langfun/core/eval/matching.py,sha256=aqNlYrlav7YmsB7rUlsdfoi1RLA5CYqn2RGPxRlPc78
|
|
46
46
|
langfun/core/eval/matching_test.py,sha256=FFHYD7IDuKe5RMjkx74ksukiwUhO5a_SS340JaIPMws,4898
|
47
47
|
langfun/core/eval/scoring.py,sha256=aKeanBJf1yO3Q9JEtgPWoiZk_3M_GiqwXVXX7x_g22w,6172
|
48
48
|
langfun/core/eval/scoring_test.py,sha256=YH1cIxBWtfdKcAV9Fh10vLkV5J-gxk8b6nxW4Z2u5pk,4024
|
49
|
-
langfun/core/llms/__init__.py,sha256=
|
50
|
-
langfun/core/llms/anthropic.py,sha256=
|
51
|
-
langfun/core/llms/anthropic_test.py,sha256=
|
49
|
+
langfun/core/llms/__init__.py,sha256=1bPg1QI8duOZCYINm-jWi094x0JtLmsk4KX60qIC_gs,3245
|
50
|
+
langfun/core/llms/anthropic.py,sha256=7W9YdPN3SlAFhAIQlihMkrpo7tTY_4NvD0KIlCrqcsk,8505
|
51
|
+
langfun/core/llms/anthropic_test.py,sha256=TMM30myyEhwF99Le4RvJEXOn8RYl0q1FRkt9Q9nl1jk,5540
|
52
52
|
langfun/core/llms/fake.py,sha256=b-Xk5IPTbUt-elsyzd_i3n1tqzc_kgETXrEvgJruSMk,2824
|
53
53
|
langfun/core/llms/fake_test.py,sha256=ZlDQgL41EX3eYTfBQNp2nB2LciqCmtoHgCsGvW4XhwI,4184
|
54
54
|
langfun/core/llms/google_genai.py,sha256=n8zyJwh9UCTgb6-8LyvmjVNFGZQ4-zfzZ0ulkhHAnR8,8624
|
55
55
|
langfun/core/llms/google_genai_test.py,sha256=_UcGTfl16-aDUlEWFC2W2F8y9jPUs53RBYA6MOCpGXw,7525
|
56
|
+
langfun/core/llms/groq.py,sha256=NaGItVL_pkOpqPpI4bPGU27xLFRoaeizZ49v2s-4ERs,7844
|
57
|
+
langfun/core/llms/groq_test.py,sha256=M6GtlrsOvDun_j-sR8cPh4W_moHWZNSTiThu3kuwbbc,5281
|
56
58
|
langfun/core/llms/llama_cpp.py,sha256=Y_KkMUf3Xfac49koMUtUslKl3h-HWp3-ntq7Jaa3bdo,2385
|
57
59
|
langfun/core/llms/llama_cpp_test.py,sha256=ZxC6defGd_HX9SFRU9U4cJiQnBKundbOrchbXuC1Z2M,1683
|
58
|
-
langfun/core/llms/openai.py,sha256=
|
59
|
-
langfun/core/llms/openai_test.py,sha256=
|
60
|
+
langfun/core/llms/openai.py,sha256=06nPhmw0zIA5Zqv3eqsrZtYLHnKwW7N8yt3LlFUFVpI,13247
|
61
|
+
langfun/core/llms/openai_test.py,sha256=Yt_W6k8YXpT3bs0JroARofCGmn_Uq3u61LmZxqWS2DQ,8272
|
60
62
|
langfun/core/llms/cache/__init__.py,sha256=QAo3InUMDM_YpteNnVCSejI4zOsnjSMWKJKzkb3VY64,993
|
61
63
|
langfun/core/llms/cache/base.py,sha256=cFfYvOIUae842pncqCAsRvqXCk2AnAsRYVx0mcIoAeY,3338
|
62
64
|
langfun/core/llms/cache/in_memory.py,sha256=YfFyJEhLs73cUiB0ZfhMxYpdE8Iuxxw-dvMFwGHTSHw,4742
|
@@ -99,8 +101,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
|
|
99
101
|
langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
|
100
102
|
langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
|
101
103
|
langfun/core/templates/selfplay_test.py,sha256=DYVrkk7uNKCqJGEHH31HssU2BPuMItU1vJLzfcXIlYg,2156
|
102
|
-
langfun-0.0.2.
|
103
|
-
langfun-0.0.2.
|
104
|
-
langfun-0.0.2.
|
105
|
-
langfun-0.0.2.
|
106
|
-
langfun-0.0.2.
|
104
|
+
langfun-0.0.2.dev20240423.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
105
|
+
langfun-0.0.2.dev20240423.dist-info/METADATA,sha256=tK11XSi0smD_bL0tVdbT_YsNXu-xs1KKYbTX2powidg,3405
|
106
|
+
langfun-0.0.2.dev20240423.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
107
|
+
langfun-0.0.2.dev20240423.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
|
108
|
+
langfun-0.0.2.dev20240423.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|