langfun 0.1.2.dev202503240804__py3-none-any.whl → 0.1.2.dev202503250804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,13 +13,13 @@
13
13
  # limitations under the License.
14
14
  """Gemini REST API (Shared by Google GenAI and Vertex AI)."""
15
15
 
16
- import base64
17
16
  import datetime
18
17
  import functools
19
18
  from typing import Annotated, Any
20
19
 
21
20
  import langfun.core as lf
22
21
  from langfun.core import modalities as lf_modalities
22
+ from langfun.core.data.conversion import gemini as gemini_conversion # pylint: disable=unused-import
23
23
  from langfun.core.llms import rest
24
24
  import pyglove as pg
25
25
 
@@ -559,7 +559,19 @@ class Gemini(rest.REST):
559
559
  request = dict(
560
560
  generationConfig=self._generation_config(prompt, sampling_options)
561
561
  )
562
- request['contents'] = [self._content_from_message(prompt)]
562
+ def modality_conversion(chunk: str | lf.Modality) -> Any:
563
+ if isinstance(chunk, lf_modalities.Mime):
564
+ try:
565
+ return chunk.make_compatible(
566
+ self.model_info.input_modalities + ['text/plain']
567
+ )
568
+ except lf.ModalityError as e:
569
+ raise lf.ModalityError(f'Unsupported modality: {chunk!r}') from e
570
+ return chunk
571
+
572
+ request['contents'] = [
573
+ prompt.as_format('gemini', chunk_preprocessor=modality_conversion)
574
+ ]
563
575
  return request
564
576
 
565
577
  def _generation_config(
@@ -593,49 +605,9 @@ class Gemini(rest.REST):
593
605
  )
594
606
  return config
595
607
 
596
- def _content_from_message(self, prompt: lf.Message) -> dict[str, Any]:
597
- """Gets generation content from langfun message."""
598
- parts = []
599
- for lf_chunk in prompt.chunk():
600
- if isinstance(lf_chunk, str):
601
- parts.append({'text': lf_chunk})
602
- elif isinstance(lf_chunk, lf_modalities.Mime):
603
- try:
604
- modalities = lf_chunk.make_compatible(
605
- self.model_info.input_modalities + ['text/plain']
606
- )
607
- if isinstance(modalities, lf_modalities.Mime):
608
- modalities = [modalities]
609
- for modality in modalities:
610
- if modality.is_text:
611
- # Add YouTube video into the context window.
612
- # https://ai.google.dev/gemini-api/docs/vision?lang=python#youtube
613
- if modality.mime_type == 'text/html' and modality.uri.startswith(
614
- 'https://www.youtube.com/watch?v='
615
- ):
616
- parts.append({
617
- 'fileData': {'mimeType': 'video/*', 'fileUri': modality.uri}
618
- })
619
- else:
620
- parts.append({'text': modality.to_text()})
621
- else:
622
- parts.append({
623
- 'inlineData': {
624
- 'data': base64.b64encode(modality.to_bytes()).decode(),
625
- 'mimeType': modality.mime_type,
626
- }
627
- })
628
- except lf.ModalityError as e:
629
- raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
630
- else:
631
- raise NotImplementedError(
632
- f'Input conversion not implemented: {lf_chunk!r}'
633
- )
634
- return dict(role='user', parts=parts)
635
-
636
608
  def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
637
609
  messages = [
638
- self._message_from_content_parts(candidate['content'].get('parts', []))
610
+ lf.Message.from_value(candidate['content'], format='gemini')
639
611
  for candidate in json['candidates']
640
612
  ]
641
613
  usage = json['usageMetadata']
@@ -652,22 +624,3 @@ class Gemini(rest.REST):
652
624
  total_tokens=input_tokens + output_tokens,
653
625
  ),
654
626
  )
655
-
656
- def _message_from_content_parts(
657
- self, parts: list[dict[str, Any]]
658
- ) -> lf.Message:
659
- """Converts Vertex AI's content parts protocol to message."""
660
- chunks = []
661
- thought_chunks = []
662
- for part in parts:
663
- if text_part := part.get('text'):
664
- if part.get('thought'):
665
- thought_chunks.append(text_part)
666
- else:
667
- chunks.append(text_part)
668
- else:
669
- raise ValueError(f'Unsupported part: {part}')
670
- message = lf.AIMessage.from_chunks(chunks)
671
- if thought_chunks:
672
- message.set('thought', lf.AIMessage.from_chunks(thought_chunks))
673
- return message
@@ -13,13 +13,11 @@
13
13
  # limitations under the License.
14
14
  """Tests for Gemini API."""
15
15
 
16
- import base64
17
16
  from typing import Any
18
17
  import unittest
19
18
  from unittest import mock
20
19
 
21
20
  import langfun.core as lf
22
- from langfun.core import modalities as lf_modalities
23
21
  from langfun.core.llms import gemini
24
22
  import pyglove as pg
25
23
  import requests
@@ -105,36 +103,6 @@ class GeminiTest(unittest.TestCase):
105
103
  0.51
106
104
  )
107
105
 
108
- def test_content_from_message_text_only(self):
109
- text = 'This is a beautiful day'
110
- model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
111
- chunks = model._content_from_message(lf.UserMessage(text))
112
- self.assertEqual(chunks, {'role': 'user', 'parts': [{'text': text}]})
113
-
114
- def test_content_from_message_mm(self):
115
- image = lf_modalities.Image.from_bytes(example_image)
116
- message = lf.UserMessage(
117
- 'This is an <<[[image]]>>, what is it?', image=image
118
- )
119
- model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
120
- content = model._content_from_message(message)
121
- self.assertEqual(
122
- content,
123
- {
124
- 'role': 'user',
125
- 'parts': [
126
- {'text': 'This is an'},
127
- {
128
- 'inlineData': {
129
- 'data': base64.b64encode(example_image).decode(),
130
- 'mimeType': 'image/png',
131
- }
132
- },
133
- {'text': ', what is it?'},
134
- ],
135
- },
136
- )
137
-
138
106
  def test_generation_config(self):
139
107
  model = gemini.Gemini('gemini-1.5-pro', api_endpoint='')
140
108
  json_schema = {
@@ -17,6 +17,7 @@ from typing import Annotated, Any
17
17
 
18
18
  import langfun.core as lf
19
19
  from langfun.core import modalities as lf_modalities
20
+ from langfun.core.data.conversion import openai as openai_conversion # pylint: disable=unused-import
20
21
  from langfun.core.llms import rest
21
22
  import pyglove as pg
22
23
 
@@ -61,20 +62,6 @@ class OpenAICompatible(rest.REST):
61
62
  args['seed'] = options.random_seed
62
63
  return args
63
64
 
64
- def _content_from_message(self, message: lf.Message) -> list[dict[str, Any]]:
65
- """Returns a OpenAI content object from a Langfun message."""
66
- content = []
67
- for chunk in message.chunk():
68
- if isinstance(chunk, str):
69
- item = dict(type='text', text=chunk)
70
- elif (isinstance(chunk, lf_modalities.Image)
71
- and self.supports_input(chunk.mime_type)):
72
- item = dict(type='image_url', image_url=dict(url=chunk.embeddable_uri))
73
- else:
74
- raise ValueError(f'Unsupported modality: {chunk!r}.')
75
- content.append(item)
76
- return content
77
-
78
65
  def request(
79
66
  self,
80
67
  prompt: lf.Message,
@@ -114,16 +101,25 @@ class OpenAICompatible(rest.REST):
114
101
 
115
102
  # Prepare messages.
116
103
  messages = []
104
+
105
+ def modality_check(chunk: str | lf.Modality) -> Any:
106
+ if (isinstance(chunk, lf_modalities.Mime)
107
+ and not self.supports_input(chunk.mime_type)):
108
+ raise ValueError(
109
+ f'Unsupported modality: {chunk!r}.'
110
+ )
111
+ return chunk
112
+
117
113
  # Users could use `metadata_system_message` to pass system message.
118
114
  system_message = prompt.metadata.get('system_message')
119
115
  if system_message:
120
- system_message = lf.SystemMessage.from_value(system_message)
121
116
  messages.append(
122
- dict(role='system',
123
- content=self._content_from_message(system_message))
117
+ lf.SystemMessage.from_value(system_message).as_format(
118
+ 'openai', chunk_preprocessor=modality_check
119
+ )
124
120
  )
125
121
  messages.append(
126
- dict(role='user', content=self._content_from_message(prompt))
122
+ prompt.as_format('openai', chunk_preprocessor=modality_check)
127
123
  )
128
124
  request = dict()
129
125
  request.update(request_args)
@@ -145,7 +141,7 @@ class OpenAICompatible(rest.REST):
145
141
  for t in choice_logprobs['content']
146
142
  ]
147
143
  return lf.LMSample(
148
- choice['message']['content'],
144
+ lf.Message.from_value(choice['message'], format='openai'),
149
145
  score=0.0,
150
146
  logprobs=logprobs,
151
147
  )
langfun/core/message.py CHANGED
@@ -13,10 +13,14 @@
13
13
  # limitations under the License.
14
14
  """Messages that are exchanged between users and agents."""
15
15
 
16
+ import abc
17
+ import collections
18
+ from collections.abc import Mapping
16
19
  import contextlib
17
20
  import functools
21
+ import inspect
18
22
  import io
19
- from typing import Annotated, Any, Optional, Union
23
+ from typing import Annotated, Any, ClassVar, Optional, Type, Union
20
24
 
21
25
  from langfun.core import modality
22
26
  from langfun.core import natural_language
@@ -146,13 +150,91 @@ class Message(
146
150
  self._source = source
147
151
 
148
152
  @classmethod
149
- def from_value(cls, value: Union[str, 'Message']) -> 'Message':
150
- """Creates a message from a value or return value itself if a Message."""
153
+ def from_value(
154
+ cls,
155
+ value: Union[str, 'Message', Any],
156
+ *,
157
+ format: str | None = None, # pylint: disable=redefined-builtin
158
+ **kwargs
159
+ ) -> 'Message':
160
+ """Creates a message from a str, message or object of registered format.
161
+
162
+ Example:
163
+ # Create a user message from a str.
164
+ lf.UserMessage.from_value('hi')
165
+
166
+ # Create a user message from a multi-modal object.
167
+ lf.UserMessage.from_value(lf.Image.from_uri('...'))
168
+
169
+ # Create a user message from OpenAI API format.
170
+ lf.Message.from_value(
171
+ {
172
+ 'role': 'user',
173
+ 'content': [{'type': 'text', 'text': 'hi'}],
174
+ },
175
+ format='openai.api',
176
+ )
177
+
178
+ Args:
179
+ value: The value to create a message from.
180
+ format: The format ID to convert to. If None, the conversion will be
181
+ performed according to the type of `value`. Otherwise, the converter
182
+ registered for the format will be used.
183
+ **kwargs: The keyword arguments passed to the __init__ of the converter.
184
+
185
+ Returns:
186
+ A message created from the value.
187
+ """
151
188
  if isinstance(value, modality.Modality):
152
189
  return cls('<<[[object]]>>', object=value)
153
190
  if isinstance(value, Message):
154
191
  return value
155
- return cls(value)
192
+ if isinstance(value, str):
193
+ return cls(value)
194
+ value_type = type(value)
195
+ if format is None:
196
+ converter = MessageConverter.get_by_type(value_type, **kwargs)
197
+ else:
198
+ converter = MessageConverter.get_by_format(format, **kwargs)
199
+ if (converter.OUTPUT_TYPE is not None
200
+ and not isinstance(value, converter.OUTPUT_TYPE)):
201
+ raise ValueError(f'{format!r} is not applicable to {value!r}.')
202
+ return converter.from_value(value)
203
+
204
+ #
205
+ # Conversion to other formats or types.
206
+ #
207
+
208
+ def as_format(self, format_or_type: str | type[Any], **kwargs) -> Any:
209
+ """Converts the message to a registered format or type.
210
+
211
+ Example:
212
+
213
+ m = lf.Template('What is this {{image}}?').render()
214
+ m.as_format('openai') # Convert to OpenAI message format.
215
+ m.as_format('gemini') # Convert to Gemini message format.
216
+ m.as_format('anthropic') # Convert to Anthropic message format.
217
+
218
+ Args:
219
+ format_or_type: The format ID or type to convert to.
220
+ **kwargs: The conversion arguments.
221
+
222
+ Returns:
223
+ The converted object according to the format or type.
224
+ """
225
+ return MessageConverter.get(format_or_type, **kwargs).to_value(self)
226
+
227
+ @classmethod
228
+ @property
229
+ def convertible_formats(cls) -> list[str]:
230
+ """Returns supported format for message conversion."""
231
+ return MessageConverter.convertible_formats()
232
+
233
+ @classmethod
234
+ @property
235
+ def convertible_types(cls) -> list[str]:
236
+ """Returns supported types for message conversion."""
237
+ return MessageConverter.convertible_types()
156
238
 
157
239
  #
158
240
  # Unified interface for accessing text, result and metadata.
@@ -377,29 +459,6 @@ class Message(
377
459
  ref_index += 1
378
460
  return cls(fused_text.getvalue().strip(), metadata=metadata)
379
461
 
380
- #
381
- # API for testing the message types.
382
- #
383
-
384
- @property
385
- def from_user(self) -> bool:
386
- """Returns True if it's user message."""
387
- return isinstance(self, UserMessage)
388
-
389
- @property
390
- def from_agent(self) -> bool:
391
- """Returns True if it's agent message."""
392
- return isinstance(self, AIMessage)
393
-
394
- @property
395
- def from_system(self) -> bool:
396
- """Returns True if it's agent message."""
397
- return isinstance(self, SystemMessage)
398
-
399
- @property
400
- def from_memory(self) -> bool:
401
- return isinstance(self, MemoryRecord)
402
-
403
462
  #
404
463
  # Tagging
405
464
  #
@@ -799,6 +858,152 @@ class Message(
799
858
  ]
800
859
 
801
860
 
861
+ class _MessageConverterRegistry:
862
+ """Message converter registry."""
863
+
864
+ def __init__(self):
865
+ self._name_to_converter: dict[str, Type[MessageConverter]] = {}
866
+ self._type_to_converters: dict[Type[Any], list[Type[MessageConverter]]] = (
867
+ collections.defaultdict(list)
868
+ )
869
+
870
+ def register(self, converter: Type['MessageConverter']) -> None:
871
+ """Registers a message converter."""
872
+ self._name_to_converter[converter.FORMAT_ID] = converter
873
+ if converter.OUTPUT_TYPE is not None:
874
+ self._type_to_converters[converter.OUTPUT_TYPE].append(converter)
875
+
876
+ def get_by_type(self, t: Type[Any], **kwargs) -> 'MessageConverter':
877
+ """Returns a message converter for the given type."""
878
+ t = self._type_to_converters[t]
879
+ if not t:
880
+ raise TypeError(
881
+ f'Cannot convert Message to {t!r}.'
882
+ )
883
+ if len(t) > 1:
884
+ raise TypeError(
885
+ f'More than one converters found for output type {t!r}. '
886
+ f'Please specify one for this conversion: {[x.FORMAT_ID for x in t]}.'
887
+ )
888
+ return t[0](**kwargs)
889
+
890
+ def get_by_format(self, format: str, **kwargs) -> 'MessageConverter': # pylint: disable=redefined-builtin
891
+ """Returns a message converter for the given format."""
892
+ if format not in self._name_to_converter:
893
+ raise ValueError(f'Unsupported format: {format!r}.')
894
+ return self._name_to_converter[format](**kwargs)
895
+
896
+ def get(
897
+ self,
898
+ format_or_type: str | Type[Any], **kwargs
899
+ ) -> 'MessageConverter':
900
+ """Returns a message converter for the given format or type."""
901
+ if isinstance(format_or_type, str):
902
+ return self.get_by_format(format_or_type, **kwargs)
903
+ assert isinstance(format_or_type, type), format_or_type
904
+ return self.get_by_type(format_or_type, **kwargs)
905
+
906
+ def convertible_formats(self) -> list[str]:
907
+ """Returns a list of converter names."""
908
+ return sorted(list(self._name_to_converter.keys()))
909
+
910
+ def convertible_types(self) -> list[Type[Any]]:
911
+ """Returns a list of converter types."""
912
+ return list(self._type_to_converters.keys())
913
+
914
+
915
+ class MessageConverter(pg.Object):
916
+ """Interface for converting a Langfun message to other formats."""
917
+
918
+ # A global unique identifier for the converter.
919
+ FORMAT_ID: ClassVar[str]
920
+
921
+ # The output type of the converter.
922
+ # If None, the converter will not be registered to handle
923
+ # `lf.Message.to_value(output_type)`.
924
+ OUTPUT_TYPE: ClassVar[Type[Any] | None] = None
925
+
926
+ _REGISTRY = _MessageConverterRegistry()
927
+
928
+ def __init_subclass__(cls, *args, **kwargs):
929
+ super().__init_subclass__(*args, **kwargs)
930
+ if not inspect.isabstract(cls):
931
+ cls._REGISTRY.register(cls)
932
+
933
+ @abc.abstractmethod
934
+ def to_value(self, message: Message) -> Any:
935
+ """Converts a Langfun message to other formats."""
936
+
937
+ @abc.abstractmethod
938
+ def from_value(self, value: Message) -> Message:
939
+ """Returns a MessageConverter from a Langfun message."""
940
+
941
+ @classmethod
942
+ def _safe_read(
943
+ cls,
944
+ data: Mapping[str, Any],
945
+ key: str,
946
+ default: Any = pg.MISSING_VALUE
947
+ ) -> Any:
948
+ """Safe reads a key from a mapping."""
949
+ if not isinstance(data, Mapping):
950
+ raise ValueError(f'Invalid data type: {data!r}.')
951
+ if key not in data:
952
+ if pg.MISSING_VALUE == default:
953
+ raise ValueError(f'Missing key {key!r} in {data!r}')
954
+ return default
955
+ return data[key]
956
+
957
+ @classmethod
958
+ def get_role(cls, message: Message) -> str:
959
+ """Returns the role of the message."""
960
+ if isinstance(message, SystemMessage):
961
+ return 'system'
962
+ elif isinstance(message, UserMessage):
963
+ return 'user'
964
+ elif isinstance(message, AIMessage):
965
+ return 'assistant'
966
+ else:
967
+ raise ValueError(f'Unsupported message type: {message!r}.')
968
+
969
+ @classmethod
970
+ def get_message_cls(cls, role: str) -> type[Message]:
971
+ """Returns the message class of the message."""
972
+ match role:
973
+ case 'system':
974
+ return SystemMessage
975
+ case 'user':
976
+ return UserMessage
977
+ case 'assistant':
978
+ return AIMessage
979
+ case _:
980
+ raise ValueError(f'Unsupported role: {role!r}.')
981
+
982
+ @classmethod
983
+ def get(cls, format_or_type: str | Type[Any], **kwargs) -> 'MessageConverter':
984
+ """Returns a message converter."""
985
+ return cls._REGISTRY.get(format_or_type, **kwargs)
986
+
987
+ @classmethod
988
+ def get_by_format(cls, format: str, **kwargs) -> 'MessageConverter': # pylint: disable=redefined-builtin
989
+ """Returns a message converter for the given format."""
990
+ return cls._REGISTRY.get_by_format(format, **kwargs)
991
+
992
+ @classmethod
993
+ def get_by_type(cls, t: Type[Any], **kwargs) -> 'MessageConverter':
994
+ """Returns a message converter for the given type."""
995
+ return cls._REGISTRY.get_by_type(t, **kwargs)
996
+
997
+ @classmethod
998
+ def convertible_formats(cls) -> list[str]:
999
+ """Returns a list of converter names."""
1000
+ return cls._REGISTRY.convertible_formats()
1001
+
1002
+ @classmethod
1003
+ def convertible_types(cls) -> list[Type[Any]]:
1004
+ """Returns a list of converter types."""
1005
+ return cls._REGISTRY.convertible_types()
1006
+
802
1007
  #
803
1008
  # Messages of different roles.
804
1009
  #
@@ -187,10 +187,6 @@ class MessageTest(unittest.TestCase):
187
187
  m = message.UserMessage('hi')
188
188
  self.assertEqual(m.text, 'hi')
189
189
  self.assertEqual(m.sender, 'User')
190
- self.assertTrue(m.from_user)
191
- self.assertFalse(m.from_agent)
192
- self.assertFalse(m.from_system)
193
- self.assertFalse(m.from_memory)
194
190
  self.assertEqual(str(m), m.text)
195
191
 
196
192
  m = message.UserMessage('hi', sender='Tom')
@@ -201,10 +197,6 @@ class MessageTest(unittest.TestCase):
201
197
  m = message.AIMessage('hi')
202
198
  self.assertEqual(m.text, 'hi')
203
199
  self.assertEqual(m.sender, 'AI')
204
- self.assertFalse(m.from_user)
205
- self.assertTrue(m.from_agent)
206
- self.assertFalse(m.from_system)
207
- self.assertFalse(m.from_memory)
208
200
  self.assertEqual(str(m), m.text)
209
201
 
210
202
  m = message.AIMessage('hi', sender='Model')
@@ -215,10 +207,6 @@ class MessageTest(unittest.TestCase):
215
207
  m = message.SystemMessage('hi')
216
208
  self.assertEqual(m.text, 'hi')
217
209
  self.assertEqual(m.sender, 'System')
218
- self.assertFalse(m.from_user)
219
- self.assertFalse(m.from_agent)
220
- self.assertTrue(m.from_system)
221
- self.assertFalse(m.from_memory)
222
210
  self.assertEqual(str(m), m.text)
223
211
 
224
212
  m = message.SystemMessage('hi', sender='Environment1')
@@ -229,10 +217,6 @@ class MessageTest(unittest.TestCase):
229
217
  m = message.MemoryRecord('hi')
230
218
  self.assertEqual(m.text, 'hi')
231
219
  self.assertEqual(m.sender, 'Memory')
232
- self.assertFalse(m.from_user)
233
- self.assertFalse(m.from_agent)
234
- self.assertFalse(m.from_system)
235
- self.assertTrue(m.from_memory)
236
220
  self.assertEqual(str(m), m.text)
237
221
 
238
222
  m = message.MemoryRecord('hi', sender="Someone's Memory")
@@ -493,5 +477,137 @@ class MessageTest(unittest.TestCase):
493
477
  )
494
478
 
495
479
 
480
+ class MessageConverterTest(unittest.TestCase):
481
+
482
+ def test_basics(self):
483
+
484
+ class IntConverter(message.MessageConverter):
485
+ OUTPUT_TYPE = int
486
+
487
+ class TestConverter(IntConverter): # pylint: disable=unused-variable
488
+ FORMAT_ID = 'test_format1'
489
+
490
+ def to_value(self, m: message.Message) -> int:
491
+ return int(m.text)
492
+
493
+ def from_value(self, value: int) -> message.Message:
494
+ return message.UserMessage(str(value))
495
+
496
+ class TestConverter2(IntConverter): # pylint: disable=unused-variable
497
+ FORMAT_ID = 'test_format2'
498
+
499
+ def to_value(self, m: message.Message) -> int:
500
+ return int(m.text) + 1
501
+
502
+ def from_value(self, value: int) -> message.Message:
503
+ return message.UserMessage(str(value - 1))
504
+
505
+ class TestConverter3(message.MessageConverter): # pylint: disable=unused-variable
506
+ FORMAT_ID = 'test_format3'
507
+ OUTPUT_TYPE = tuple
508
+
509
+ def to_value(self, m: message.Message) -> tuple[int, ...]:
510
+ return tuple(int(x) for x in m.text.split(','))
511
+
512
+ def from_value(self, value: tuple[int, ...]) -> message.Message:
513
+ return message.UserMessage(','.join(str(x) for x in value))
514
+
515
+ self.assertEqual(
516
+ message.Message.convertible_formats,
517
+ ['test_format1', 'test_format2', 'test_format3']
518
+ )
519
+ self.assertEqual(
520
+ message.Message.convertible_types,
521
+ [int, tuple]
522
+ )
523
+ self.assertEqual(
524
+ message.Message.from_value(1, format='test_format1'),
525
+ message.UserMessage('1')
526
+ )
527
+ self.assertEqual(
528
+ message.UserMessage('1').as_format('test_format1'),
529
+ 1
530
+ )
531
+ self.assertEqual(
532
+ message.Message.from_value(1, format='test_format2'),
533
+ message.UserMessage('0')
534
+ )
535
+ self.assertEqual(
536
+ message.UserMessage('1').as_format('test_format2'),
537
+ 2
538
+ )
539
+ with self.assertRaisesRegex(ValueError, 'Unsupported format: .*'):
540
+ message.UserMessage('1').as_format('test4')
541
+
542
+ with self.assertRaisesRegex(TypeError, 'Cannot convert Message to .*'):
543
+ message.UserMessage('1').as_format(float)
544
+
545
+ with self.assertRaisesRegex(
546
+ TypeError, 'More than one converters found for output type .*'
547
+ ):
548
+ message.UserMessage('1').as_format(int)
549
+ self.assertEqual(
550
+ message.UserMessage('1,2,3').as_format('test_format3'),
551
+ (1, 2, 3)
552
+ )
553
+ self.assertEqual(
554
+ message.UserMessage('1,2,3').as_format(tuple),
555
+ (1, 2, 3)
556
+ )
557
+ self.assertEqual(
558
+ message.Message.from_value((1, 2, 3)),
559
+ message.UserMessage('1,2,3')
560
+ )
561
+
562
+ def test_get_role(self):
563
+ self.assertEqual(
564
+ message.MessageConverter.get_role(message.SystemMessage('hi')),
565
+ 'system',
566
+ )
567
+ self.assertEqual(
568
+ message.MessageConverter.get_role(message.UserMessage('hi')),
569
+ 'user',
570
+ )
571
+ self.assertEqual(
572
+ message.MessageConverter.get_role(message.AIMessage('hi')),
573
+ 'assistant',
574
+ )
575
+ with self.assertRaisesRegex(ValueError, 'Unsupported message type: .*'):
576
+ message.MessageConverter.get_role(message.MemoryRecord('hi'))
577
+
578
+ def test_get_message_cls(self):
579
+ self.assertEqual(
580
+ message.MessageConverter.get_message_cls('system'),
581
+ message.SystemMessage,
582
+ )
583
+ self.assertEqual(
584
+ message.MessageConverter.get_message_cls('user'),
585
+ message.UserMessage,
586
+ )
587
+ self.assertEqual(
588
+ message.MessageConverter.get_message_cls('assistant'),
589
+ message.AIMessage,
590
+ )
591
+ with self.assertRaisesRegex(ValueError, 'Unsupported role: .*'):
592
+ message.MessageConverter.get_message_cls('foo')
593
+
594
+ def test_safe_read(self):
595
+ self.assertEqual(
596
+ message.MessageConverter._safe_read({'a': 1}, 'a'),
597
+ 1,
598
+ )
599
+ self.assertEqual(
600
+ message.MessageConverter._safe_read({'a': 1}, 'a', default=2),
601
+ 1,
602
+ )
603
+ self.assertEqual(
604
+ message.MessageConverter._safe_read({'a': 1}, 'b', default=2),
605
+ 2,
606
+ )
607
+ with self.assertRaisesRegex(ValueError, 'Invalid data type: .*'):
608
+ message.MessageConverter._safe_read(1, 'a')
609
+ with self.assertRaisesRegex(ValueError, 'Missing key .*'):
610
+ message.MessageConverter._safe_read({'a': 1}, 'b')
611
+
496
612
  if __name__ == '__main__':
497
613
  unittest.main()