langfun 0.1.2.dev202503240804__py3-none-any.whl → 0.1.2.dev202503250804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +1 -0
- langfun/core/__init__.py +2 -0
- langfun/core/data/__init__.py +19 -0
- langfun/core/data/conversion/__init__.py +21 -0
- langfun/core/data/conversion/anthropic.py +131 -0
- langfun/core/data/conversion/anthropic_test.py +267 -0
- langfun/core/data/conversion/gemini.py +168 -0
- langfun/core/data/conversion/gemini_test.py +256 -0
- langfun/core/data/conversion/openai.py +131 -0
- langfun/core/data/conversion/openai_test.py +176 -0
- langfun/core/llms/anthropic.py +10 -52
- langfun/core/llms/gemini.py +15 -62
- langfun/core/llms/gemini_test.py +0 -32
- langfun/core/llms/openai_compatible.py +15 -19
- langfun/core/message.py +232 -27
- langfun/core/message_test.py +132 -16
- {langfun-0.1.2.dev202503240804.dist-info → langfun-0.1.2.dev202503250804.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202503240804.dist-info → langfun-0.1.2.dev202503250804.dist-info}/RECORD +21 -13
- {langfun-0.1.2.dev202503240804.dist-info → langfun-0.1.2.dev202503250804.dist-info}/WHEEL +1 -1
- {langfun-0.1.2.dev202503240804.dist-info → langfun-0.1.2.dev202503250804.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202503240804.dist-info → langfun-0.1.2.dev202503250804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,256 @@
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import base64
|
15
|
+
import unittest
|
16
|
+
import langfun.core as lf
|
17
|
+
from langfun.core import modalities as lf_modalities
|
18
|
+
from langfun.core.data.conversion import gemini # pylint: disable=unused-import
|
19
|
+
|
20
|
+
|
21
|
+
image_content = (
|
22
|
+
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
|
23
|
+
b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
|
24
|
+
b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
|
25
|
+
b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
|
26
|
+
b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
|
27
|
+
b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
|
28
|
+
b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
|
29
|
+
b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
class GeminiConversionTest(unittest.TestCase):
|
34
|
+
|
35
|
+
def test_as_format_with_role(self):
|
36
|
+
self.assertEqual(
|
37
|
+
lf.UserMessage('hi').as_format('gemini'),
|
38
|
+
{
|
39
|
+
'role': 'user',
|
40
|
+
'parts': [{'text': 'hi'}],
|
41
|
+
},
|
42
|
+
)
|
43
|
+
self.assertEqual(
|
44
|
+
lf.AIMessage('hi').as_format('gemini'),
|
45
|
+
{
|
46
|
+
'role': 'model',
|
47
|
+
'parts': [{'text': 'hi'}],
|
48
|
+
},
|
49
|
+
)
|
50
|
+
self.assertEqual(
|
51
|
+
lf.SystemMessage('hi').as_format('gemini'),
|
52
|
+
{
|
53
|
+
'role': 'system',
|
54
|
+
'parts': [{'text': 'hi'}],
|
55
|
+
},
|
56
|
+
)
|
57
|
+
|
58
|
+
def test_as_format_with_image(self):
|
59
|
+
self.assertEqual(
|
60
|
+
lf.Template(
|
61
|
+
'What are the common words from {{image}}, {{pdf}} and {{video}}?',
|
62
|
+
image=lf_modalities.Image.from_bytes(image_content),
|
63
|
+
pdf=lf_modalities.Custom.from_uri(
|
64
|
+
'https://my.pdf', mime='application/pdf'
|
65
|
+
),
|
66
|
+
video=lf_modalities.Custom.from_uri(
|
67
|
+
'https://www.youtube.com/watch?v=abcd', mime='text/html'
|
68
|
+
),
|
69
|
+
).render().as_gemini_format(),
|
70
|
+
{
|
71
|
+
'role': 'user',
|
72
|
+
'parts': [
|
73
|
+
{
|
74
|
+
'text': 'What are the common words from'
|
75
|
+
},
|
76
|
+
{
|
77
|
+
'inlineData': {
|
78
|
+
'data': base64.b64encode(image_content).decode('utf-8'),
|
79
|
+
'mimeType': 'image/png',
|
80
|
+
}
|
81
|
+
},
|
82
|
+
{
|
83
|
+
'text': ','
|
84
|
+
},
|
85
|
+
{
|
86
|
+
'fileData': {
|
87
|
+
'fileUri': 'https://my.pdf',
|
88
|
+
'mimeType': 'application/pdf',
|
89
|
+
}
|
90
|
+
},
|
91
|
+
{
|
92
|
+
'text': 'and'
|
93
|
+
},
|
94
|
+
{
|
95
|
+
'fileData': {
|
96
|
+
'fileUri': 'https://www.youtube.com/watch?v=abcd',
|
97
|
+
'mimeType': 'video/*',
|
98
|
+
}
|
99
|
+
},
|
100
|
+
{
|
101
|
+
'text': '?'
|
102
|
+
}
|
103
|
+
],
|
104
|
+
},
|
105
|
+
)
|
106
|
+
|
107
|
+
def test_as_format_with_chunk_preprocessor(self):
|
108
|
+
self.assertEqual(
|
109
|
+
lf.Template(
|
110
|
+
'What is this {{image}}?',
|
111
|
+
image=lf_modalities.Image.from_bytes(image_content)
|
112
|
+
).render().as_format(
|
113
|
+
'gemini',
|
114
|
+
chunk_preprocessor=lambda x: x if isinstance(x, str) else None
|
115
|
+
),
|
116
|
+
{
|
117
|
+
'role': 'user',
|
118
|
+
'parts': [
|
119
|
+
{
|
120
|
+
'text': 'What is this'
|
121
|
+
},
|
122
|
+
{
|
123
|
+
'text': '?'
|
124
|
+
}
|
125
|
+
],
|
126
|
+
},
|
127
|
+
)
|
128
|
+
|
129
|
+
def test_from_value_with_simple_text(self):
|
130
|
+
self.assertEqual(
|
131
|
+
lf.Message.from_value(
|
132
|
+
{
|
133
|
+
'parts': [{'text': 'this is a text'}],
|
134
|
+
},
|
135
|
+
format='gemini',
|
136
|
+
),
|
137
|
+
lf.AIMessage('this is a text'),
|
138
|
+
)
|
139
|
+
|
140
|
+
def test_from_value_with_role(self):
|
141
|
+
self.assertEqual(
|
142
|
+
lf.Message.from_value(
|
143
|
+
{
|
144
|
+
'role': 'user',
|
145
|
+
'parts': [{'text': 'this is a text'}],
|
146
|
+
},
|
147
|
+
format='gemini',
|
148
|
+
),
|
149
|
+
lf.UserMessage('this is a text'),
|
150
|
+
)
|
151
|
+
self.assertEqual(
|
152
|
+
lf.Message.from_value(
|
153
|
+
{
|
154
|
+
'role': 'model',
|
155
|
+
'parts': [{'text': 'this is a text'}],
|
156
|
+
},
|
157
|
+
format='gemini',
|
158
|
+
),
|
159
|
+
lf.AIMessage('this is a text'),
|
160
|
+
)
|
161
|
+
self.assertEqual(
|
162
|
+
lf.Message.from_value(
|
163
|
+
{
|
164
|
+
'role': 'system',
|
165
|
+
'parts': [{'text': 'this is a text'}],
|
166
|
+
},
|
167
|
+
format='gemini',
|
168
|
+
),
|
169
|
+
lf.SystemMessage('this is a text'),
|
170
|
+
)
|
171
|
+
with self.assertRaisesRegex(ValueError, 'Unsupported role: .*'):
|
172
|
+
lf.Message.from_value(
|
173
|
+
{
|
174
|
+
'role': 'function',
|
175
|
+
'parts': [{'text': 'this is a text'}],
|
176
|
+
},
|
177
|
+
format='gemini',
|
178
|
+
)
|
179
|
+
|
180
|
+
def test_from_value_with_thoughts(self):
|
181
|
+
message = lf.Message.from_value(
|
182
|
+
{
|
183
|
+
'role': 'user',
|
184
|
+
'parts': [
|
185
|
+
{
|
186
|
+
'text': 'this is a red round object',
|
187
|
+
'thought': True
|
188
|
+
},
|
189
|
+
{
|
190
|
+
'text': 'this is a apple',
|
191
|
+
},
|
192
|
+
],
|
193
|
+
},
|
194
|
+
format='gemini',
|
195
|
+
)
|
196
|
+
self.assertEqual(message.text, 'this is a apple')
|
197
|
+
self.assertEqual(message.thought, 'this is a red round object')
|
198
|
+
|
199
|
+
def test_from_value_with_modalities(self):
|
200
|
+
m = lf.Message.from_gemini_format(
|
201
|
+
{
|
202
|
+
'role': 'user',
|
203
|
+
'parts': [
|
204
|
+
{
|
205
|
+
'text': 'What are the common words from'
|
206
|
+
},
|
207
|
+
{
|
208
|
+
'inlineData': {
|
209
|
+
'data': base64.b64encode(image_content).decode('utf-8'),
|
210
|
+
'mimeType': 'image/png',
|
211
|
+
}
|
212
|
+
},
|
213
|
+
{
|
214
|
+
'text': ','
|
215
|
+
},
|
216
|
+
{
|
217
|
+
'fileData': {
|
218
|
+
'fileUri': 'https://my.pdf',
|
219
|
+
'mimeType': 'application/pdf',
|
220
|
+
}
|
221
|
+
},
|
222
|
+
{
|
223
|
+
'text': 'and'
|
224
|
+
},
|
225
|
+
{
|
226
|
+
'fileData': {
|
227
|
+
'fileUri': 'https://www.youtube.com/watch?v=abcd',
|
228
|
+
'mimeType': 'video/*',
|
229
|
+
}
|
230
|
+
},
|
231
|
+
{
|
232
|
+
'text': '?'
|
233
|
+
}
|
234
|
+
],
|
235
|
+
},
|
236
|
+
)
|
237
|
+
self.assertEqual(
|
238
|
+
m.text,
|
239
|
+
(
|
240
|
+
'What are the common words from <<[[obj0]]>> , <<[[obj1]]>> '
|
241
|
+
'and <<[[obj2]]>> ?'
|
242
|
+
)
|
243
|
+
)
|
244
|
+
self.assertIsInstance(m.obj0, lf_modalities.Image)
|
245
|
+
self.assertEqual(m.obj0.mime_type, 'image/png')
|
246
|
+
self.assertEqual(m.obj0.to_bytes(), image_content)
|
247
|
+
|
248
|
+
self.assertIsInstance(m.obj1, lf_modalities.PDF)
|
249
|
+
self.assertEqual(m.obj1.uri, 'https://my.pdf')
|
250
|
+
|
251
|
+
self.assertIsInstance(m.obj2, lf_modalities.Video)
|
252
|
+
self.assertEqual(m.obj2.uri, 'https://www.youtube.com/watch?v=abcd')
|
253
|
+
|
254
|
+
|
255
|
+
if __name__ == '__main__':
|
256
|
+
unittest.main()
|
@@ -0,0 +1,131 @@
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""OpenAI API message conversion."""
|
15
|
+
|
16
|
+
from typing import Annotated, Any, Callable
|
17
|
+
|
18
|
+
import langfun.core as lf
|
19
|
+
from langfun.core import modalities as lf_modalities
|
20
|
+
|
21
|
+
|
22
|
+
class OpenAIMessageConverter(lf.MessageConverter):
|
23
|
+
"""Converter to OpenAI API."""
|
24
|
+
|
25
|
+
FORMAT_ID = 'openai'
|
26
|
+
|
27
|
+
chunk_preprocessor: Annotated[
|
28
|
+
Callable[[str | lf.Modality], Any] | None,
|
29
|
+
(
|
30
|
+
'Chunk preprocessor for Langfun chunk to OpenAI chunk conversion. '
|
31
|
+
'It will be applied before each Langfun chunk is converted. '
|
32
|
+
'If returns None, the chunk will be skipped.'
|
33
|
+
)
|
34
|
+
] = None
|
35
|
+
|
36
|
+
def to_value(self, message: lf.Message) -> dict[str, Any]:
|
37
|
+
"""Converts a Langfun message to OpenAI API."""
|
38
|
+
parts = []
|
39
|
+
for chunk in message.chunk():
|
40
|
+
if self.chunk_preprocessor is not None:
|
41
|
+
chunk = self.chunk_preprocessor(chunk)
|
42
|
+
if chunk is None:
|
43
|
+
continue
|
44
|
+
|
45
|
+
if isinstance(chunk, str):
|
46
|
+
item = dict(type='text', text=chunk)
|
47
|
+
elif isinstance(chunk, lf_modalities.Image):
|
48
|
+
item = dict(
|
49
|
+
type='image_url', image_url=dict(url=chunk.embeddable_uri)
|
50
|
+
)
|
51
|
+
# TODO(daiyip): Support audio_input.
|
52
|
+
else:
|
53
|
+
raise ValueError(f'Unsupported content type: {chunk!r}.')
|
54
|
+
parts.append(item)
|
55
|
+
return dict(
|
56
|
+
role=self.get_role(message),
|
57
|
+
content=parts,
|
58
|
+
)
|
59
|
+
|
60
|
+
def get_role(self, message: lf.Message) -> str:
|
61
|
+
"""Returns the role of the message."""
|
62
|
+
if isinstance(message, lf.SystemMessage):
|
63
|
+
return 'system'
|
64
|
+
elif isinstance(message, lf.UserMessage):
|
65
|
+
return 'user'
|
66
|
+
elif isinstance(message, lf.AIMessage):
|
67
|
+
return 'assistant'
|
68
|
+
else:
|
69
|
+
raise ValueError(f'Unsupported message type: {message!r}.')
|
70
|
+
|
71
|
+
def get_message_cls(self, role: str) -> type[lf.Message]:
|
72
|
+
"""Returns the message class of the message."""
|
73
|
+
match role:
|
74
|
+
case 'system':
|
75
|
+
return lf.SystemMessage
|
76
|
+
case 'user':
|
77
|
+
return lf.UserMessage
|
78
|
+
case 'assistant':
|
79
|
+
return lf.AIMessage
|
80
|
+
case _:
|
81
|
+
raise ValueError(f'Unsupported role: {role!r}.')
|
82
|
+
|
83
|
+
def from_value(self, value: dict[str, Any]) -> lf.Message:
|
84
|
+
"""Returns a Langfun message from OpenAI message."""
|
85
|
+
message_cls = self.get_message_cls(
|
86
|
+
self._safe_read(value, 'role', default='assistant')
|
87
|
+
)
|
88
|
+
content = self._safe_read(value, 'content')
|
89
|
+
if isinstance(content, str):
|
90
|
+
return message_cls(content)
|
91
|
+
|
92
|
+
assert isinstance(content, list)
|
93
|
+
chunks = []
|
94
|
+
for item in content:
|
95
|
+
t = self._safe_read(item, 'type')
|
96
|
+
if t == 'text':
|
97
|
+
chunk = self._safe_read(item, 'text')
|
98
|
+
elif t == 'image_url':
|
99
|
+
chunk = lf_modalities.Image.from_uri(
|
100
|
+
self._safe_read(self._safe_read(item, 'image_url'), 'url')
|
101
|
+
)
|
102
|
+
else:
|
103
|
+
raise ValueError(f'Unsupported content type: {item!r}.')
|
104
|
+
chunks.append(chunk)
|
105
|
+
return message_cls.from_chunks(chunks)
|
106
|
+
|
107
|
+
|
108
|
+
def _as_openai_format(
|
109
|
+
self,
|
110
|
+
chunk_preprocessor: Callable[[str | lf.Modality], Any] | None = None,
|
111
|
+
**kwargs
|
112
|
+
) -> dict[str, Any]:
|
113
|
+
"""Returns the OpenAI format of the chunk."""
|
114
|
+
return OpenAIMessageConverter(
|
115
|
+
chunk_preprocessor=chunk_preprocessor, **kwargs
|
116
|
+
).to_value(self)
|
117
|
+
|
118
|
+
|
119
|
+
@classmethod
|
120
|
+
def _from_openai_format(
|
121
|
+
cls,
|
122
|
+
openai_message: dict[str, Any],
|
123
|
+
**kwargs
|
124
|
+
) -> dict[str, Any]:
|
125
|
+
"""Returns the OpenAI format of the chunk."""
|
126
|
+
del cls
|
127
|
+
return OpenAIMessageConverter(**kwargs).from_value(openai_message)
|
128
|
+
|
129
|
+
# Set shortcut methods in lf.Message.
|
130
|
+
lf.Message.as_openai_format = _as_openai_format
|
131
|
+
lf.Message.from_openai_format = _from_openai_format
|
@@ -0,0 +1,176 @@
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import base64
|
15
|
+
import unittest
|
16
|
+
import langfun.core as lf
|
17
|
+
from langfun.core import modalities as lf_modalities
|
18
|
+
from langfun.core.data.conversion import openai # pylint: disable=unused-import
|
19
|
+
|
20
|
+
|
21
|
+
image_content = (
|
22
|
+
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
|
23
|
+
b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
|
24
|
+
b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
|
25
|
+
b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
|
26
|
+
b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
|
27
|
+
b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
|
28
|
+
b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
|
29
|
+
b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
class OpenAIConversionTest(unittest.TestCase):
|
34
|
+
|
35
|
+
def test_as_format_with_role(self):
|
36
|
+
self.assertEqual(
|
37
|
+
lf.UserMessage('hi').as_format('openai'),
|
38
|
+
{
|
39
|
+
'role': 'user',
|
40
|
+
'content': [{'type': 'text', 'text': 'hi'}],
|
41
|
+
},
|
42
|
+
)
|
43
|
+
self.assertEqual(
|
44
|
+
lf.AIMessage('hi').as_format('openai'),
|
45
|
+
{
|
46
|
+
'role': 'assistant',
|
47
|
+
'content': [{'type': 'text', 'text': 'hi'}],
|
48
|
+
},
|
49
|
+
)
|
50
|
+
self.assertEqual(
|
51
|
+
lf.SystemMessage('hi').as_format('openai'),
|
52
|
+
{
|
53
|
+
'role': 'system',
|
54
|
+
'content': [{'type': 'text', 'text': 'hi'}],
|
55
|
+
},
|
56
|
+
)
|
57
|
+
|
58
|
+
def test_as_format_with_image(self):
|
59
|
+
self.assertEqual(
|
60
|
+
lf.Template(
|
61
|
+
'What is this {{image}}?',
|
62
|
+
image=lf_modalities.Image.from_bytes(image_content)
|
63
|
+
).render().as_format('openai'),
|
64
|
+
{
|
65
|
+
'role': 'user',
|
66
|
+
'content': [
|
67
|
+
{
|
68
|
+
'type': 'text',
|
69
|
+
'text': 'What is this'
|
70
|
+
},
|
71
|
+
{
|
72
|
+
'type': 'image_url',
|
73
|
+
'image_url': {
|
74
|
+
'url': (
|
75
|
+
'data:image/png;base64,'
|
76
|
+
+ base64.b64encode(image_content).decode('utf-8')
|
77
|
+
)
|
78
|
+
}
|
79
|
+
},
|
80
|
+
{
|
81
|
+
'type': 'text',
|
82
|
+
'text': '?'
|
83
|
+
}
|
84
|
+
],
|
85
|
+
},
|
86
|
+
)
|
87
|
+
|
88
|
+
def test_as_format_with_chunk_preprocessor(self):
|
89
|
+
self.assertEqual(
|
90
|
+
lf.Template(
|
91
|
+
'What is this {{image}}?',
|
92
|
+
image=lf_modalities.Image.from_bytes(image_content)
|
93
|
+
).render().as_openai_format(
|
94
|
+
chunk_preprocessor=lambda x: x if isinstance(x, str) else None
|
95
|
+
),
|
96
|
+
{
|
97
|
+
'role': 'user',
|
98
|
+
'content': [
|
99
|
+
{
|
100
|
+
'type': 'text',
|
101
|
+
'text': 'What is this'
|
102
|
+
},
|
103
|
+
{
|
104
|
+
'type': 'text',
|
105
|
+
'text': '?'
|
106
|
+
}
|
107
|
+
],
|
108
|
+
},
|
109
|
+
)
|
110
|
+
|
111
|
+
def test_from_value_with_simple_text(self):
|
112
|
+
self.assertEqual(
|
113
|
+
lf.Message.from_value(
|
114
|
+
{
|
115
|
+
'content': 'this is a text',
|
116
|
+
},
|
117
|
+
format='openai',
|
118
|
+
),
|
119
|
+
lf.AIMessage('this is a text'),
|
120
|
+
)
|
121
|
+
|
122
|
+
def test_from_value_with_role(self):
|
123
|
+
self.assertEqual(
|
124
|
+
lf.Message.from_value(
|
125
|
+
{
|
126
|
+
'role': 'user',
|
127
|
+
'content': [{'type': 'text', 'text': 'hi'}],
|
128
|
+
},
|
129
|
+
format='openai',
|
130
|
+
),
|
131
|
+
lf.UserMessage('hi'),
|
132
|
+
)
|
133
|
+
self.assertEqual(
|
134
|
+
lf.Message.from_value(
|
135
|
+
{
|
136
|
+
'role': 'assistant',
|
137
|
+
'content': [{'type': 'text', 'text': 'hi'}],
|
138
|
+
},
|
139
|
+
format='openai',
|
140
|
+
),
|
141
|
+
lf.AIMessage('hi'),
|
142
|
+
)
|
143
|
+
self.assertEqual(
|
144
|
+
lf.Message.from_value(
|
145
|
+
{
|
146
|
+
'role': 'system',
|
147
|
+
'content': [{'type': 'text', 'text': 'hi'}],
|
148
|
+
},
|
149
|
+
format='openai',
|
150
|
+
),
|
151
|
+
lf.SystemMessage('hi'),
|
152
|
+
)
|
153
|
+
with self.assertRaisesRegex(ValueError, 'Unsupported role: .*'):
|
154
|
+
lf.Message.from_value(
|
155
|
+
{
|
156
|
+
'role': 'function',
|
157
|
+
'content': [{'type': 'text', 'text': 'hi'}],
|
158
|
+
},
|
159
|
+
format='openai',
|
160
|
+
)
|
161
|
+
|
162
|
+
def test_from_value_with_image(self):
|
163
|
+
m = lf.Message.from_openai_format(
|
164
|
+
lf.Template(
|
165
|
+
'What is this {{image}}?',
|
166
|
+
image=lf_modalities.Image.from_bytes(image_content)
|
167
|
+
).render().as_format('openai'),
|
168
|
+
)
|
169
|
+
self.assertEqual(m.text, 'What is this <<[[obj0]]>> ?')
|
170
|
+
self.assertIsInstance(m.obj0, lf_modalities.Image)
|
171
|
+
self.assertEqual(m.obj0.mime_type, 'image/png')
|
172
|
+
self.assertEqual(m.obj0.to_bytes(), image_content)
|
173
|
+
|
174
|
+
|
175
|
+
if __name__ == '__main__':
|
176
|
+
unittest.main()
|
langfun/core/llms/anthropic.py
CHANGED
@@ -13,7 +13,6 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Language models from Anthropic."""
|
15
15
|
|
16
|
-
import base64
|
17
16
|
import datetime
|
18
17
|
import functools
|
19
18
|
import os
|
@@ -21,6 +20,7 @@ from typing import Annotated, Any
|
|
21
20
|
|
22
21
|
import langfun.core as lf
|
23
22
|
from langfun.core import modalities as lf_modalities
|
23
|
+
from langfun.core.data.conversion import anthropic as anthropic_conversion # pylint: disable=unused-import
|
24
24
|
from langfun.core.llms import rest
|
25
25
|
import pyglove as pg
|
26
26
|
|
@@ -502,10 +502,17 @@ class Anthropic(rest.REST):
|
|
502
502
|
"""Returns the JSON input for a message."""
|
503
503
|
request = dict()
|
504
504
|
request.update(self._request_args(sampling_options))
|
505
|
+
|
506
|
+
def modality_check(chunk: Any) -> Any:
|
507
|
+
if isinstance(chunk, lf_modalities.Mime):
|
508
|
+
if not self.supports_input(chunk.mime_type):
|
509
|
+
raise ValueError(f'Unsupported modality: {chunk!r}.')
|
510
|
+
return chunk
|
511
|
+
|
505
512
|
request.update(
|
506
513
|
dict(
|
507
514
|
messages=[
|
508
|
-
|
515
|
+
prompt.as_format('anthropic', chunk_preprocessor=modality_check)
|
509
516
|
]
|
510
517
|
)
|
511
518
|
)
|
@@ -548,43 +555,8 @@ class Anthropic(rest.REST):
|
|
548
555
|
args.pop('top_p', None)
|
549
556
|
return args
|
550
557
|
|
551
|
-
def _content_from_message(self, prompt: lf.Message) -> list[dict[str, Any]]:
|
552
|
-
"""Converts an message to Anthropic's content protocol (list of dicts)."""
|
553
|
-
# Refer: https://docs.anthropic.com/claude/reference/messages-examples
|
554
|
-
content = []
|
555
|
-
for chunk in prompt.chunk():
|
556
|
-
if isinstance(chunk, str):
|
557
|
-
content.append(dict(type='text', text=chunk))
|
558
|
-
elif isinstance(chunk, lf_modalities.Mime):
|
559
|
-
if not self.supports_input(chunk.mime_type):
|
560
|
-
raise ValueError(f'Unsupported modality: {chunk!r}.')
|
561
|
-
if isinstance(chunk, lf_modalities.Image):
|
562
|
-
item = dict(
|
563
|
-
type='image',
|
564
|
-
source=dict(
|
565
|
-
type='base64',
|
566
|
-
media_type=chunk.mime_type,
|
567
|
-
data=base64.b64encode(chunk.to_bytes()).decode(),
|
568
|
-
),
|
569
|
-
)
|
570
|
-
elif isinstance(chunk, lf_modalities.PDF):
|
571
|
-
item = dict(
|
572
|
-
type='document',
|
573
|
-
source=dict(
|
574
|
-
type='base64',
|
575
|
-
media_type=chunk.mime_type,
|
576
|
-
data=base64.b64encode(chunk.to_bytes()).decode(),
|
577
|
-
),
|
578
|
-
)
|
579
|
-
else:
|
580
|
-
raise NotImplementedError(
|
581
|
-
f'Modality conversion not implemented: {chunk!r}'
|
582
|
-
)
|
583
|
-
content.append(item)
|
584
|
-
return content
|
585
|
-
|
586
558
|
def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
|
587
|
-
message =
|
559
|
+
message = lf.Message.from_value(json, format='anthropic')
|
588
560
|
input_tokens = json['usage']['input_tokens']
|
589
561
|
output_tokens = json['usage']['output_tokens']
|
590
562
|
return lf.LMSamplingResult(
|
@@ -596,20 +568,6 @@ class Anthropic(rest.REST):
|
|
596
568
|
),
|
597
569
|
)
|
598
570
|
|
599
|
-
def _message_from_content(self, content: list[dict[str, Any]]) -> lf.Message:
|
600
|
-
"""Converts Anthropic's content protocol to message."""
|
601
|
-
# Refer: https://docs.anthropic.com/claude/reference/messages-examples
|
602
|
-
# Thinking: https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#implementing-extended-thinking # pylint: disable=line-too-long
|
603
|
-
response = lf.AIMessage.from_chunks(
|
604
|
-
[x['text'] for x in content if x['type'] == 'text']
|
605
|
-
)
|
606
|
-
thinking = lf.AIMessage.from_chunks(
|
607
|
-
[x['thinking'] for x in content if x['type'] == 'thinking']
|
608
|
-
)
|
609
|
-
# thinking is added into the metadata.thinking field.
|
610
|
-
response.set('thinking', thinking)
|
611
|
-
return response
|
612
|
-
|
613
571
|
|
614
572
|
class Claude37(Anthropic):
|
615
573
|
"""Base class for Claude 3.7 models."""
|