langfun 0.1.2.dev202503240804__py3-none-any.whl → 0.1.2.dev202503260804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +1 -0
- langfun/core/__init__.py +2 -0
- langfun/core/data/__init__.py +19 -0
- langfun/core/data/conversion/__init__.py +23 -0
- langfun/core/data/conversion/anthropic.py +131 -0
- langfun/core/data/conversion/anthropic_test.py +267 -0
- langfun/core/data/conversion/gemini.py +168 -0
- langfun/core/data/conversion/gemini_test.py +256 -0
- langfun/core/data/conversion/openai.py +131 -0
- langfun/core/data/conversion/openai_test.py +176 -0
- langfun/core/llms/anthropic.py +10 -52
- langfun/core/llms/gemini.py +15 -62
- langfun/core/llms/gemini_test.py +0 -32
- langfun/core/llms/openai_compatible.py +15 -19
- langfun/core/message.py +232 -27
- langfun/core/message_test.py +130 -16
- langfun/core/modalities/image.py +1 -0
- {langfun-0.1.2.dev202503240804.dist-info → langfun-0.1.2.dev202503260804.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202503240804.dist-info → langfun-0.1.2.dev202503260804.dist-info}/RECORD +22 -14
- {langfun-0.1.2.dev202503240804.dist-info → langfun-0.1.2.dev202503260804.dist-info}/WHEEL +1 -1
- {langfun-0.1.2.dev202503240804.dist-info → langfun-0.1.2.dev202503260804.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202503240804.dist-info → langfun-0.1.2.dev202503260804.dist-info}/top_level.txt +0 -0
langfun/core/llms/gemini.py
CHANGED
@@ -13,13 +13,13 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Gemini REST API (Shared by Google GenAI and Vertex AI)."""
|
15
15
|
|
16
|
-
import base64
|
17
16
|
import datetime
|
18
17
|
import functools
|
19
18
|
from typing import Annotated, Any
|
20
19
|
|
21
20
|
import langfun.core as lf
|
22
21
|
from langfun.core import modalities as lf_modalities
|
22
|
+
from langfun.core.data.conversion import gemini as gemini_conversion # pylint: disable=unused-import
|
23
23
|
from langfun.core.llms import rest
|
24
24
|
import pyglove as pg
|
25
25
|
|
@@ -559,7 +559,19 @@ class Gemini(rest.REST):
|
|
559
559
|
request = dict(
|
560
560
|
generationConfig=self._generation_config(prompt, sampling_options)
|
561
561
|
)
|
562
|
-
|
562
|
+
def modality_conversion(chunk: str | lf.Modality) -> Any:
|
563
|
+
if isinstance(chunk, lf_modalities.Mime):
|
564
|
+
try:
|
565
|
+
return chunk.make_compatible(
|
566
|
+
self.model_info.input_modalities + ['text/plain']
|
567
|
+
)
|
568
|
+
except lf.ModalityError as e:
|
569
|
+
raise lf.ModalityError(f'Unsupported modality: {chunk!r}') from e
|
570
|
+
return chunk
|
571
|
+
|
572
|
+
request['contents'] = [
|
573
|
+
prompt.as_format('gemini', chunk_preprocessor=modality_conversion)
|
574
|
+
]
|
563
575
|
return request
|
564
576
|
|
565
577
|
def _generation_config(
|
@@ -593,49 +605,9 @@ class Gemini(rest.REST):
|
|
593
605
|
)
|
594
606
|
return config
|
595
607
|
|
596
|
-
def _content_from_message(self, prompt: lf.Message) -> dict[str, Any]:
|
597
|
-
"""Gets generation content from langfun message."""
|
598
|
-
parts = []
|
599
|
-
for lf_chunk in prompt.chunk():
|
600
|
-
if isinstance(lf_chunk, str):
|
601
|
-
parts.append({'text': lf_chunk})
|
602
|
-
elif isinstance(lf_chunk, lf_modalities.Mime):
|
603
|
-
try:
|
604
|
-
modalities = lf_chunk.make_compatible(
|
605
|
-
self.model_info.input_modalities + ['text/plain']
|
606
|
-
)
|
607
|
-
if isinstance(modalities, lf_modalities.Mime):
|
608
|
-
modalities = [modalities]
|
609
|
-
for modality in modalities:
|
610
|
-
if modality.is_text:
|
611
|
-
# Add YouTube video into the context window.
|
612
|
-
# https://ai.google.dev/gemini-api/docs/vision?lang=python#youtube
|
613
|
-
if modality.mime_type == 'text/html' and modality.uri.startswith(
|
614
|
-
'https://www.youtube.com/watch?v='
|
615
|
-
):
|
616
|
-
parts.append({
|
617
|
-
'fileData': {'mimeType': 'video/*', 'fileUri': modality.uri}
|
618
|
-
})
|
619
|
-
else:
|
620
|
-
parts.append({'text': modality.to_text()})
|
621
|
-
else:
|
622
|
-
parts.append({
|
623
|
-
'inlineData': {
|
624
|
-
'data': base64.b64encode(modality.to_bytes()).decode(),
|
625
|
-
'mimeType': modality.mime_type,
|
626
|
-
}
|
627
|
-
})
|
628
|
-
except lf.ModalityError as e:
|
629
|
-
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
|
630
|
-
else:
|
631
|
-
raise NotImplementedError(
|
632
|
-
f'Input conversion not implemented: {lf_chunk!r}'
|
633
|
-
)
|
634
|
-
return dict(role='user', parts=parts)
|
635
|
-
|
636
608
|
def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
|
637
609
|
messages = [
|
638
|
-
|
610
|
+
lf.Message.from_value(candidate['content'], format='gemini')
|
639
611
|
for candidate in json['candidates']
|
640
612
|
]
|
641
613
|
usage = json['usageMetadata']
|
@@ -652,22 +624,3 @@ class Gemini(rest.REST):
|
|
652
624
|
total_tokens=input_tokens + output_tokens,
|
653
625
|
),
|
654
626
|
)
|
655
|
-
|
656
|
-
def _message_from_content_parts(
|
657
|
-
self, parts: list[dict[str, Any]]
|
658
|
-
) -> lf.Message:
|
659
|
-
"""Converts Vertex AI's content parts protocol to message."""
|
660
|
-
chunks = []
|
661
|
-
thought_chunks = []
|
662
|
-
for part in parts:
|
663
|
-
if text_part := part.get('text'):
|
664
|
-
if part.get('thought'):
|
665
|
-
thought_chunks.append(text_part)
|
666
|
-
else:
|
667
|
-
chunks.append(text_part)
|
668
|
-
else:
|
669
|
-
raise ValueError(f'Unsupported part: {part}')
|
670
|
-
message = lf.AIMessage.from_chunks(chunks)
|
671
|
-
if thought_chunks:
|
672
|
-
message.set('thought', lf.AIMessage.from_chunks(thought_chunks))
|
673
|
-
return message
|
langfun/core/llms/gemini_test.py
CHANGED
@@ -13,13 +13,11 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Tests for Gemini API."""
|
15
15
|
|
16
|
-
import base64
|
17
16
|
from typing import Any
|
18
17
|
import unittest
|
19
18
|
from unittest import mock
|
20
19
|
|
21
20
|
import langfun.core as lf
|
22
|
-
from langfun.core import modalities as lf_modalities
|
23
21
|
from langfun.core.llms import gemini
|
24
22
|
import pyglove as pg
|
25
23
|
import requests
|
@@ -105,36 +103,6 @@ class GeminiTest(unittest.TestCase):
|
|
105
103
|
0.51
|
106
104
|
)
|
107
105
|
|
108
|
-
def test_content_from_message_text_only(self):
|
109
|
-
text = 'This is a beautiful day'
|
110
|
-
model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
|
111
|
-
chunks = model._content_from_message(lf.UserMessage(text))
|
112
|
-
self.assertEqual(chunks, {'role': 'user', 'parts': [{'text': text}]})
|
113
|
-
|
114
|
-
def test_content_from_message_mm(self):
|
115
|
-
image = lf_modalities.Image.from_bytes(example_image)
|
116
|
-
message = lf.UserMessage(
|
117
|
-
'This is an <<[[image]]>>, what is it?', image=image
|
118
|
-
)
|
119
|
-
model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
|
120
|
-
content = model._content_from_message(message)
|
121
|
-
self.assertEqual(
|
122
|
-
content,
|
123
|
-
{
|
124
|
-
'role': 'user',
|
125
|
-
'parts': [
|
126
|
-
{'text': 'This is an'},
|
127
|
-
{
|
128
|
-
'inlineData': {
|
129
|
-
'data': base64.b64encode(example_image).decode(),
|
130
|
-
'mimeType': 'image/png',
|
131
|
-
}
|
132
|
-
},
|
133
|
-
{'text': ', what is it?'},
|
134
|
-
],
|
135
|
-
},
|
136
|
-
)
|
137
|
-
|
138
106
|
def test_generation_config(self):
|
139
107
|
model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
|
140
108
|
json_schema = {
|
@@ -17,6 +17,7 @@ from typing import Annotated, Any
|
|
17
17
|
|
18
18
|
import langfun.core as lf
|
19
19
|
from langfun.core import modalities as lf_modalities
|
20
|
+
from langfun.core.data.conversion import openai as openai_conversion # pylint: disable=unused-import
|
20
21
|
from langfun.core.llms import rest
|
21
22
|
import pyglove as pg
|
22
23
|
|
@@ -61,20 +62,6 @@ class OpenAICompatible(rest.REST):
|
|
61
62
|
args['seed'] = options.random_seed
|
62
63
|
return args
|
63
64
|
|
64
|
-
def _content_from_message(self, message: lf.Message) -> list[dict[str, Any]]:
|
65
|
-
"""Returns a OpenAI content object from a Langfun message."""
|
66
|
-
content = []
|
67
|
-
for chunk in message.chunk():
|
68
|
-
if isinstance(chunk, str):
|
69
|
-
item = dict(type='text', text=chunk)
|
70
|
-
elif (isinstance(chunk, lf_modalities.Image)
|
71
|
-
and self.supports_input(chunk.mime_type)):
|
72
|
-
item = dict(type='image_url', image_url=dict(url=chunk.embeddable_uri))
|
73
|
-
else:
|
74
|
-
raise ValueError(f'Unsupported modality: {chunk!r}.')
|
75
|
-
content.append(item)
|
76
|
-
return content
|
77
|
-
|
78
65
|
def request(
|
79
66
|
self,
|
80
67
|
prompt: lf.Message,
|
@@ -114,16 +101,25 @@ class OpenAICompatible(rest.REST):
|
|
114
101
|
|
115
102
|
# Prepare messages.
|
116
103
|
messages = []
|
104
|
+
|
105
|
+
def modality_check(chunk: str | lf.Modality) -> Any:
|
106
|
+
if (isinstance(chunk, lf_modalities.Mime)
|
107
|
+
and not self.supports_input(chunk.mime_type)):
|
108
|
+
raise ValueError(
|
109
|
+
f'Unsupported modality: {chunk!r}.'
|
110
|
+
)
|
111
|
+
return chunk
|
112
|
+
|
117
113
|
# Users could use `metadata_system_message` to pass system message.
|
118
114
|
system_message = prompt.metadata.get('system_message')
|
119
115
|
if system_message:
|
120
|
-
system_message = lf.SystemMessage.from_value(system_message)
|
121
116
|
messages.append(
|
122
|
-
|
123
|
-
|
117
|
+
lf.SystemMessage.from_value(system_message).as_format(
|
118
|
+
'openai', chunk_preprocessor=modality_check
|
119
|
+
)
|
124
120
|
)
|
125
121
|
messages.append(
|
126
|
-
|
122
|
+
prompt.as_format('openai', chunk_preprocessor=modality_check)
|
127
123
|
)
|
128
124
|
request = dict()
|
129
125
|
request.update(request_args)
|
@@ -145,7 +141,7 @@ class OpenAICompatible(rest.REST):
|
|
145
141
|
for t in choice_logprobs['content']
|
146
142
|
]
|
147
143
|
return lf.LMSample(
|
148
|
-
choice['message']
|
144
|
+
lf.Message.from_value(choice['message'], format='openai'),
|
149
145
|
score=0.0,
|
150
146
|
logprobs=logprobs,
|
151
147
|
)
|
langfun/core/message.py
CHANGED
@@ -13,10 +13,14 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Messages that are exchanged between users and agents."""
|
15
15
|
|
16
|
+
import abc
|
17
|
+
import collections
|
18
|
+
from collections.abc import Mapping
|
16
19
|
import contextlib
|
17
20
|
import functools
|
21
|
+
import inspect
|
18
22
|
import io
|
19
|
-
from typing import Annotated, Any, Optional, Union
|
23
|
+
from typing import Annotated, Any, ClassVar, Optional, Type, Union
|
20
24
|
|
21
25
|
from langfun.core import modality
|
22
26
|
from langfun.core import natural_language
|
@@ -146,13 +150,91 @@ class Message(
|
|
146
150
|
self._source = source
|
147
151
|
|
148
152
|
@classmethod
|
149
|
-
def from_value(
|
150
|
-
|
153
|
+
def from_value(
|
154
|
+
cls,
|
155
|
+
value: Union[str, 'Message', Any],
|
156
|
+
*,
|
157
|
+
format: str | None = None, # pylint: disable=redefined-builtin
|
158
|
+
**kwargs
|
159
|
+
) -> 'Message':
|
160
|
+
"""Creates a message from a str, message or object of registered format.
|
161
|
+
|
162
|
+
Example:
|
163
|
+
# Create a user message from a str.
|
164
|
+
lf.UserMessage.from_value('hi')
|
165
|
+
|
166
|
+
# Create a user message from a multi-modal object.
|
167
|
+
lf.UserMessage.from_value(lf.Image.from_uri('...'))
|
168
|
+
|
169
|
+
# Create a user message from OpenAI API format.
|
170
|
+
lf.Message.from_value(
|
171
|
+
{
|
172
|
+
'role': 'user',
|
173
|
+
'content': [{'type': 'text', 'text': 'hi'}],
|
174
|
+
},
|
175
|
+
format='openai.api',
|
176
|
+
)
|
177
|
+
|
178
|
+
Args:
|
179
|
+
value: The value to create a message from.
|
180
|
+
format: The format ID to convert to. If None, the conversion will be
|
181
|
+
performed according to the type of `value`. Otherwise, the converter
|
182
|
+
registered for the format will be used.
|
183
|
+
**kwargs: The keyword arguments passed to the __init__ of the converter.
|
184
|
+
|
185
|
+
Returns:
|
186
|
+
A message created from the value.
|
187
|
+
"""
|
151
188
|
if isinstance(value, modality.Modality):
|
152
189
|
return cls('<<[[object]]>>', object=value)
|
153
190
|
if isinstance(value, Message):
|
154
191
|
return value
|
155
|
-
|
192
|
+
if isinstance(value, str):
|
193
|
+
return cls(value)
|
194
|
+
value_type = type(value)
|
195
|
+
if format is None:
|
196
|
+
converter = MessageConverter.get_by_type(value_type, **kwargs)
|
197
|
+
else:
|
198
|
+
converter = MessageConverter.get_by_format(format, **kwargs)
|
199
|
+
if (converter.OUTPUT_TYPE is not None
|
200
|
+
and not isinstance(value, converter.OUTPUT_TYPE)):
|
201
|
+
raise ValueError(f'{format!r} is not applicable to {value!r}.')
|
202
|
+
return converter.from_value(value)
|
203
|
+
|
204
|
+
#
|
205
|
+
# Conversion to other formats or types.
|
206
|
+
#
|
207
|
+
|
208
|
+
def as_format(self, format_or_type: str | type[Any], **kwargs) -> Any:
|
209
|
+
"""Converts the message to a registered format or type.
|
210
|
+
|
211
|
+
Example:
|
212
|
+
|
213
|
+
m = lf.Template('What is this {{image}}?').render()
|
214
|
+
m.as_format('openai') # Convert to OpenAI message format.
|
215
|
+
m.as_format('gemini') # Convert to Gemini message format.
|
216
|
+
m.as_format('anthropic') # Convert to Anthropic message format.
|
217
|
+
|
218
|
+
Args:
|
219
|
+
format_or_type: The format ID or type to convert to.
|
220
|
+
**kwargs: The conversion arguments.
|
221
|
+
|
222
|
+
Returns:
|
223
|
+
The converted object according to the format or type.
|
224
|
+
"""
|
225
|
+
return MessageConverter.get(format_or_type, **kwargs).to_value(self)
|
226
|
+
|
227
|
+
@classmethod
|
228
|
+
@property
|
229
|
+
def convertible_formats(cls) -> list[str]:
|
230
|
+
"""Returns supported format for message conversion."""
|
231
|
+
return MessageConverter.convertible_formats()
|
232
|
+
|
233
|
+
@classmethod
|
234
|
+
@property
|
235
|
+
def convertible_types(cls) -> list[str]:
|
236
|
+
"""Returns supported types for message conversion."""
|
237
|
+
return MessageConverter.convertible_types()
|
156
238
|
|
157
239
|
#
|
158
240
|
# Unified interface for accessing text, result and metadata.
|
@@ -377,29 +459,6 @@ class Message(
|
|
377
459
|
ref_index += 1
|
378
460
|
return cls(fused_text.getvalue().strip(), metadata=metadata)
|
379
461
|
|
380
|
-
#
|
381
|
-
# API for testing the message types.
|
382
|
-
#
|
383
|
-
|
384
|
-
@property
|
385
|
-
def from_user(self) -> bool:
|
386
|
-
"""Returns True if it's user message."""
|
387
|
-
return isinstance(self, UserMessage)
|
388
|
-
|
389
|
-
@property
|
390
|
-
def from_agent(self) -> bool:
|
391
|
-
"""Returns True if it's agent message."""
|
392
|
-
return isinstance(self, AIMessage)
|
393
|
-
|
394
|
-
@property
|
395
|
-
def from_system(self) -> bool:
|
396
|
-
"""Returns True if it's agent message."""
|
397
|
-
return isinstance(self, SystemMessage)
|
398
|
-
|
399
|
-
@property
|
400
|
-
def from_memory(self) -> bool:
|
401
|
-
return isinstance(self, MemoryRecord)
|
402
|
-
|
403
462
|
#
|
404
463
|
# Tagging
|
405
464
|
#
|
@@ -799,6 +858,152 @@ class Message(
|
|
799
858
|
]
|
800
859
|
|
801
860
|
|
861
|
+
class _MessageConverterRegistry:
|
862
|
+
"""Message converter registry."""
|
863
|
+
|
864
|
+
def __init__(self):
|
865
|
+
self._name_to_converter: dict[str, Type[MessageConverter]] = {}
|
866
|
+
self._type_to_converters: dict[Type[Any], list[Type[MessageConverter]]] = (
|
867
|
+
collections.defaultdict(list)
|
868
|
+
)
|
869
|
+
|
870
|
+
def register(self, converter: Type['MessageConverter']) -> None:
|
871
|
+
"""Registers a message converter."""
|
872
|
+
self._name_to_converter[converter.FORMAT_ID] = converter
|
873
|
+
if converter.OUTPUT_TYPE is not None:
|
874
|
+
self._type_to_converters[converter.OUTPUT_TYPE].append(converter)
|
875
|
+
|
876
|
+
def get_by_type(self, t: Type[Any], **kwargs) -> 'MessageConverter':
|
877
|
+
"""Returns a message converter for the given type."""
|
878
|
+
t = self._type_to_converters[t]
|
879
|
+
if not t:
|
880
|
+
raise TypeError(
|
881
|
+
f'Cannot convert Message to {t!r}.'
|
882
|
+
)
|
883
|
+
if len(t) > 1:
|
884
|
+
raise TypeError(
|
885
|
+
f'More than one converters found for output type {t!r}. '
|
886
|
+
f'Please specify one for this conversion: {[x.FORMAT_ID for x in t]}.'
|
887
|
+
)
|
888
|
+
return t[0](**kwargs)
|
889
|
+
|
890
|
+
def get_by_format(self, format: str, **kwargs) -> 'MessageConverter': # pylint: disable=redefined-builtin
|
891
|
+
"""Returns a message converter for the given format."""
|
892
|
+
if format not in self._name_to_converter:
|
893
|
+
raise ValueError(f'Unsupported format: {format!r}.')
|
894
|
+
return self._name_to_converter[format](**kwargs)
|
895
|
+
|
896
|
+
def get(
|
897
|
+
self,
|
898
|
+
format_or_type: str | Type[Any], **kwargs
|
899
|
+
) -> 'MessageConverter':
|
900
|
+
"""Returns a message converter for the given format or type."""
|
901
|
+
if isinstance(format_or_type, str):
|
902
|
+
return self.get_by_format(format_or_type, **kwargs)
|
903
|
+
assert isinstance(format_or_type, type), format_or_type
|
904
|
+
return self.get_by_type(format_or_type, **kwargs)
|
905
|
+
|
906
|
+
def convertible_formats(self) -> list[str]:
|
907
|
+
"""Returns a list of converter names."""
|
908
|
+
return sorted(list(self._name_to_converter.keys()))
|
909
|
+
|
910
|
+
def convertible_types(self) -> list[Type[Any]]:
|
911
|
+
"""Returns a list of converter types."""
|
912
|
+
return list(self._type_to_converters.keys())
|
913
|
+
|
914
|
+
|
915
|
+
class MessageConverter(pg.Object):
|
916
|
+
"""Interface for converting a Langfun message to other formats."""
|
917
|
+
|
918
|
+
# A global unique identifier for the converter.
|
919
|
+
FORMAT_ID: ClassVar[str]
|
920
|
+
|
921
|
+
# The output type of the converter.
|
922
|
+
# If None, the converter will not be registered to handle
|
923
|
+
# `lf.Message.to_value(output_type)`.
|
924
|
+
OUTPUT_TYPE: ClassVar[Type[Any] | None] = None
|
925
|
+
|
926
|
+
_REGISTRY = _MessageConverterRegistry()
|
927
|
+
|
928
|
+
def __init_subclass__(cls, *args, **kwargs):
|
929
|
+
super().__init_subclass__(*args, **kwargs)
|
930
|
+
if not inspect.isabstract(cls):
|
931
|
+
cls._REGISTRY.register(cls)
|
932
|
+
|
933
|
+
@abc.abstractmethod
|
934
|
+
def to_value(self, message: Message) -> Any:
|
935
|
+
"""Converts a Langfun message to other formats."""
|
936
|
+
|
937
|
+
@abc.abstractmethod
|
938
|
+
def from_value(self, value: Message) -> Message:
|
939
|
+
"""Returns a MessageConverter from a Langfun message."""
|
940
|
+
|
941
|
+
@classmethod
|
942
|
+
def _safe_read(
|
943
|
+
cls,
|
944
|
+
data: Mapping[str, Any],
|
945
|
+
key: str,
|
946
|
+
default: Any = pg.MISSING_VALUE
|
947
|
+
) -> Any:
|
948
|
+
"""Safe reads a key from a mapping."""
|
949
|
+
if not isinstance(data, Mapping):
|
950
|
+
raise ValueError(f'Invalid data type: {data!r}.')
|
951
|
+
if key not in data:
|
952
|
+
if pg.MISSING_VALUE == default:
|
953
|
+
raise ValueError(f'Missing key {key!r} in {data!r}')
|
954
|
+
return default
|
955
|
+
return data[key]
|
956
|
+
|
957
|
+
@classmethod
|
958
|
+
def get_role(cls, message: Message) -> str:
|
959
|
+
"""Returns the role of the message."""
|
960
|
+
if isinstance(message, SystemMessage):
|
961
|
+
return 'system'
|
962
|
+
elif isinstance(message, UserMessage):
|
963
|
+
return 'user'
|
964
|
+
elif isinstance(message, AIMessage):
|
965
|
+
return 'assistant'
|
966
|
+
else:
|
967
|
+
raise ValueError(f'Unsupported message type: {message!r}.')
|
968
|
+
|
969
|
+
@classmethod
|
970
|
+
def get_message_cls(cls, role: str) -> type[Message]:
|
971
|
+
"""Returns the message class of the message."""
|
972
|
+
match role:
|
973
|
+
case 'system':
|
974
|
+
return SystemMessage
|
975
|
+
case 'user':
|
976
|
+
return UserMessage
|
977
|
+
case 'assistant':
|
978
|
+
return AIMessage
|
979
|
+
case _:
|
980
|
+
raise ValueError(f'Unsupported role: {role!r}.')
|
981
|
+
|
982
|
+
@classmethod
|
983
|
+
def get(cls, format_or_type: str | Type[Any], **kwargs) -> 'MessageConverter':
|
984
|
+
"""Returns a message converter."""
|
985
|
+
return cls._REGISTRY.get(format_or_type, **kwargs)
|
986
|
+
|
987
|
+
@classmethod
|
988
|
+
def get_by_format(cls, format: str, **kwargs) -> 'MessageConverter': # pylint: disable=redefined-builtin
|
989
|
+
"""Returns a message converter for the given format."""
|
990
|
+
return cls._REGISTRY.get_by_format(format, **kwargs)
|
991
|
+
|
992
|
+
@classmethod
|
993
|
+
def get_by_type(cls, t: Type[Any], **kwargs) -> 'MessageConverter':
|
994
|
+
"""Returns a message converter for the given type."""
|
995
|
+
return cls._REGISTRY.get_by_type(t, **kwargs)
|
996
|
+
|
997
|
+
@classmethod
|
998
|
+
def convertible_formats(cls) -> list[str]:
|
999
|
+
"""Returns a list of converter names."""
|
1000
|
+
return cls._REGISTRY.convertible_formats()
|
1001
|
+
|
1002
|
+
@classmethod
|
1003
|
+
def convertible_types(cls) -> list[Type[Any]]:
|
1004
|
+
"""Returns a list of converter types."""
|
1005
|
+
return cls._REGISTRY.convertible_types()
|
1006
|
+
|
802
1007
|
#
|
803
1008
|
# Messages of different roles.
|
804
1009
|
#
|
langfun/core/message_test.py
CHANGED
@@ -187,10 +187,6 @@ class MessageTest(unittest.TestCase):
|
|
187
187
|
m = message.UserMessage('hi')
|
188
188
|
self.assertEqual(m.text, 'hi')
|
189
189
|
self.assertEqual(m.sender, 'User')
|
190
|
-
self.assertTrue(m.from_user)
|
191
|
-
self.assertFalse(m.from_agent)
|
192
|
-
self.assertFalse(m.from_system)
|
193
|
-
self.assertFalse(m.from_memory)
|
194
190
|
self.assertEqual(str(m), m.text)
|
195
191
|
|
196
192
|
m = message.UserMessage('hi', sender='Tom')
|
@@ -201,10 +197,6 @@ class MessageTest(unittest.TestCase):
|
|
201
197
|
m = message.AIMessage('hi')
|
202
198
|
self.assertEqual(m.text, 'hi')
|
203
199
|
self.assertEqual(m.sender, 'AI')
|
204
|
-
self.assertFalse(m.from_user)
|
205
|
-
self.assertTrue(m.from_agent)
|
206
|
-
self.assertFalse(m.from_system)
|
207
|
-
self.assertFalse(m.from_memory)
|
208
200
|
self.assertEqual(str(m), m.text)
|
209
201
|
|
210
202
|
m = message.AIMessage('hi', sender='Model')
|
@@ -215,10 +207,6 @@ class MessageTest(unittest.TestCase):
|
|
215
207
|
m = message.SystemMessage('hi')
|
216
208
|
self.assertEqual(m.text, 'hi')
|
217
209
|
self.assertEqual(m.sender, 'System')
|
218
|
-
self.assertFalse(m.from_user)
|
219
|
-
self.assertFalse(m.from_agent)
|
220
|
-
self.assertTrue(m.from_system)
|
221
|
-
self.assertFalse(m.from_memory)
|
222
210
|
self.assertEqual(str(m), m.text)
|
223
211
|
|
224
212
|
m = message.SystemMessage('hi', sender='Environment1')
|
@@ -229,10 +217,6 @@ class MessageTest(unittest.TestCase):
|
|
229
217
|
m = message.MemoryRecord('hi')
|
230
218
|
self.assertEqual(m.text, 'hi')
|
231
219
|
self.assertEqual(m.sender, 'Memory')
|
232
|
-
self.assertFalse(m.from_user)
|
233
|
-
self.assertFalse(m.from_agent)
|
234
|
-
self.assertFalse(m.from_system)
|
235
|
-
self.assertTrue(m.from_memory)
|
236
220
|
self.assertEqual(str(m), m.text)
|
237
221
|
|
238
222
|
m = message.MemoryRecord('hi', sender="Someone's Memory")
|
@@ -493,5 +477,135 @@ class MessageTest(unittest.TestCase):
|
|
493
477
|
)
|
494
478
|
|
495
479
|
|
480
|
+
class MessageConverterTest(unittest.TestCase):
|
481
|
+
|
482
|
+
def test_basics(self):
|
483
|
+
|
484
|
+
class IntConverter(message.MessageConverter):
|
485
|
+
OUTPUT_TYPE = int
|
486
|
+
|
487
|
+
class TestConverter(IntConverter): # pylint: disable=unused-variable
|
488
|
+
FORMAT_ID = 'test_format1'
|
489
|
+
|
490
|
+
def to_value(self, m: message.Message) -> int:
|
491
|
+
return int(m.text)
|
492
|
+
|
493
|
+
def from_value(self, value: int) -> message.Message:
|
494
|
+
return message.UserMessage(str(value))
|
495
|
+
|
496
|
+
class TestConverter2(IntConverter): # pylint: disable=unused-variable
|
497
|
+
FORMAT_ID = 'test_format2'
|
498
|
+
|
499
|
+
def to_value(self, m: message.Message) -> int:
|
500
|
+
return int(m.text) + 1
|
501
|
+
|
502
|
+
def from_value(self, value: int) -> message.Message:
|
503
|
+
return message.UserMessage(str(value - 1))
|
504
|
+
|
505
|
+
class TestConverter3(message.MessageConverter): # pylint: disable=unused-variable
|
506
|
+
FORMAT_ID = 'test_format3'
|
507
|
+
OUTPUT_TYPE = tuple
|
508
|
+
|
509
|
+
def to_value(self, m: message.Message) -> tuple[int, ...]:
|
510
|
+
return tuple(int(x) for x in m.text.split(','))
|
511
|
+
|
512
|
+
def from_value(self, value: tuple[int, ...]) -> message.Message:
|
513
|
+
return message.UserMessage(','.join(str(x) for x in value))
|
514
|
+
|
515
|
+
self.assertIn('test_format1', message.Message.convertible_formats)
|
516
|
+
self.assertIn('test_format2', message.Message.convertible_formats)
|
517
|
+
self.assertIn('test_format3', message.Message.convertible_formats)
|
518
|
+
|
519
|
+
self.assertIn(int, message.Message.convertible_types)
|
520
|
+
self.assertIn(tuple, message.Message.convertible_types)
|
521
|
+
self.assertEqual(
|
522
|
+
message.Message.from_value(1, format='test_format1'),
|
523
|
+
message.UserMessage('1')
|
524
|
+
)
|
525
|
+
self.assertEqual(
|
526
|
+
message.UserMessage('1').as_format('test_format1'),
|
527
|
+
1
|
528
|
+
)
|
529
|
+
self.assertEqual(
|
530
|
+
message.Message.from_value(1, format='test_format2'),
|
531
|
+
message.UserMessage('0')
|
532
|
+
)
|
533
|
+
self.assertEqual(
|
534
|
+
message.UserMessage('1').as_format('test_format2'),
|
535
|
+
2
|
536
|
+
)
|
537
|
+
with self.assertRaisesRegex(ValueError, 'Unsupported format: .*'):
|
538
|
+
message.UserMessage('1').as_format('test4')
|
539
|
+
|
540
|
+
with self.assertRaisesRegex(TypeError, 'Cannot convert Message to .*'):
|
541
|
+
message.UserMessage('1').as_format(float)
|
542
|
+
|
543
|
+
with self.assertRaisesRegex(
|
544
|
+
TypeError, 'More than one converters found for output type .*'
|
545
|
+
):
|
546
|
+
message.UserMessage('1').as_format(int)
|
547
|
+
self.assertEqual(
|
548
|
+
message.UserMessage('1,2,3').as_format('test_format3'),
|
549
|
+
(1, 2, 3)
|
550
|
+
)
|
551
|
+
self.assertEqual(
|
552
|
+
message.UserMessage('1,2,3').as_format(tuple),
|
553
|
+
(1, 2, 3)
|
554
|
+
)
|
555
|
+
self.assertEqual(
|
556
|
+
message.Message.from_value((1, 2, 3)),
|
557
|
+
message.UserMessage('1,2,3')
|
558
|
+
)
|
559
|
+
|
560
|
+
def test_get_role(self):
|
561
|
+
self.assertEqual(
|
562
|
+
message.MessageConverter.get_role(message.SystemMessage('hi')),
|
563
|
+
'system',
|
564
|
+
)
|
565
|
+
self.assertEqual(
|
566
|
+
message.MessageConverter.get_role(message.UserMessage('hi')),
|
567
|
+
'user',
|
568
|
+
)
|
569
|
+
self.assertEqual(
|
570
|
+
message.MessageConverter.get_role(message.AIMessage('hi')),
|
571
|
+
'assistant',
|
572
|
+
)
|
573
|
+
with self.assertRaisesRegex(ValueError, 'Unsupported message type: .*'):
|
574
|
+
message.MessageConverter.get_role(message.MemoryRecord('hi'))
|
575
|
+
|
576
|
+
def test_get_message_cls(self):
|
577
|
+
self.assertEqual(
|
578
|
+
message.MessageConverter.get_message_cls('system'),
|
579
|
+
message.SystemMessage,
|
580
|
+
)
|
581
|
+
self.assertEqual(
|
582
|
+
message.MessageConverter.get_message_cls('user'),
|
583
|
+
message.UserMessage,
|
584
|
+
)
|
585
|
+
self.assertEqual(
|
586
|
+
message.MessageConverter.get_message_cls('assistant'),
|
587
|
+
message.AIMessage,
|
588
|
+
)
|
589
|
+
with self.assertRaisesRegex(ValueError, 'Unsupported role: .*'):
|
590
|
+
message.MessageConverter.get_message_cls('foo')
|
591
|
+
|
592
|
+
def test_safe_read(self):
|
593
|
+
self.assertEqual(
|
594
|
+
message.MessageConverter._safe_read({'a': 1}, 'a'),
|
595
|
+
1,
|
596
|
+
)
|
597
|
+
self.assertEqual(
|
598
|
+
message.MessageConverter._safe_read({'a': 1}, 'a', default=2),
|
599
|
+
1,
|
600
|
+
)
|
601
|
+
self.assertEqual(
|
602
|
+
message.MessageConverter._safe_read({'a': 1}, 'b', default=2),
|
603
|
+
2,
|
604
|
+
)
|
605
|
+
with self.assertRaisesRegex(ValueError, 'Invalid data type: .*'):
|
606
|
+
message.MessageConverter._safe_read(1, 'a')
|
607
|
+
with self.assertRaisesRegex(ValueError, 'Missing key .*'):
|
608
|
+
message.MessageConverter._safe_read({'a': 1}, 'b')
|
609
|
+
|
496
610
|
if __name__ == '__main__':
|
497
611
|
unittest.main()
|