langfun 0.0.2.dev20240531__tar.gz → 0.0.2.dev20240601__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/PKG-INFO +4 -3
  2. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/__init__.py +2 -0
  3. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/__init__.py +1 -0
  4. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/llms/google_genai.py +66 -13
  5. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/llms/google_genai_test.py +1 -1
  6. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/llms/vertexai.py +67 -14
  7. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/llms/vertexai_test.py +1 -1
  8. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/modalities/__init__.py +1 -1
  9. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/modalities/audio.py +1 -1
  10. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/modalities/image.py +1 -1
  11. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/modalities/image_test.py +23 -6
  12. langfun-0.0.2.dev20240601/langfun/core/modalities/mime.py +191 -0
  13. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/modalities/mime_test.py +18 -3
  14. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/modalities/ms_office.py +38 -10
  15. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/modalities/ms_office_test.py +93 -16
  16. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/modalities/pdf.py +1 -1
  17. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/modalities/video.py +1 -1
  18. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/modality.py +4 -0
  19. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun.egg-info/PKG-INFO +4 -3
  20. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun.egg-info/requires.txt +3 -2
  21. langfun-0.0.2.dev20240531/langfun/core/modalities/mime.py +0 -102
  22. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/LICENSE +0 -0
  23. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/README.md +0 -0
  24. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/coding/__init__.py +0 -0
  25. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/coding/python/__init__.py +0 -0
  26. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/coding/python/correction.py +0 -0
  27. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/coding/python/correction_test.py +0 -0
  28. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/coding/python/errors.py +0 -0
  29. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/coding/python/errors_test.py +0 -0
  30. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/coding/python/execution.py +0 -0
  31. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/coding/python/execution_test.py +0 -0
  32. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/coding/python/generation.py +0 -0
  33. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/coding/python/generation_test.py +0 -0
  34. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/coding/python/parsing.py +0 -0
  35. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/coding/python/parsing_test.py +0 -0
  36. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/coding/python/permissions.py +0 -0
  37. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/coding/python/permissions_test.py +0 -0
  38. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/component.py +0 -0
  39. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/component_test.py +0 -0
  40. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/concurrent.py +0 -0
  41. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/concurrent_test.py +0 -0
  42. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/console.py +0 -0
  43. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/console_test.py +0 -0
  44. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/eval/__init__.py +0 -0
  45. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/eval/base.py +0 -0
  46. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/eval/base_test.py +0 -0
  47. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/eval/matching.py +0 -0
  48. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/eval/matching_test.py +0 -0
  49. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/eval/patching.py +0 -0
  50. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/eval/patching_test.py +0 -0
  51. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/eval/scoring.py +0 -0
  52. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/eval/scoring_test.py +0 -0
  53. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/langfunc.py +0 -0
  54. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/langfunc_test.py +0 -0
  55. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/language_model.py +0 -0
  56. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/language_model_test.py +0 -0
  57. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/llms/__init__.py +0 -0
  58. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/llms/anthropic.py +0 -0
  59. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/llms/anthropic_test.py +0 -0
  60. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/llms/cache/__init__.py +0 -0
  61. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/llms/cache/base.py +0 -0
  62. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/llms/cache/in_memory.py +0 -0
  63. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/llms/cache/in_memory_test.py +0 -0
  64. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/llms/fake.py +0 -0
  65. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/llms/fake_test.py +0 -0
  66. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/llms/groq.py +0 -0
  67. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/llms/groq_test.py +0 -0
  68. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/llms/llama_cpp.py +0 -0
  69. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/llms/llama_cpp_test.py +0 -0
  70. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/llms/openai.py +0 -0
  71. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/llms/openai_test.py +0 -0
  72. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/memories/__init__.py +0 -0
  73. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/memories/conversation_history.py +0 -0
  74. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/memories/conversation_history_test.py +0 -0
  75. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/memory.py +0 -0
  76. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/message.py +0 -0
  77. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/message_test.py +0 -0
  78. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/modalities/audio_test.py +0 -0
  79. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/modalities/pdf_test.py +0 -0
  80. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/modalities/video_test.py +0 -0
  81. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/modality_test.py +0 -0
  82. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/natural_language.py +0 -0
  83. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/natural_language_test.py +0 -0
  84. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/sampling.py +0 -0
  85. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/sampling_test.py +0 -0
  86. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/structured/__init__.py +0 -0
  87. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/structured/completion.py +0 -0
  88. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/structured/completion_test.py +0 -0
  89. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/structured/description.py +0 -0
  90. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/structured/description_test.py +0 -0
  91. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/structured/function_generation.py +0 -0
  92. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/structured/function_generation_test.py +0 -0
  93. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/structured/mapping.py +0 -0
  94. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/structured/mapping_test.py +0 -0
  95. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/structured/parsing.py +0 -0
  96. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/structured/parsing_test.py +0 -0
  97. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/structured/prompting.py +0 -0
  98. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/structured/prompting_test.py +0 -0
  99. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/structured/schema.py +0 -0
  100. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/structured/schema_generation.py +0 -0
  101. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/structured/schema_generation_test.py +0 -0
  102. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/structured/schema_test.py +0 -0
  103. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/structured/scoring.py +0 -0
  104. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/structured/scoring_test.py +0 -0
  105. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/subscription.py +0 -0
  106. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/subscription_test.py +0 -0
  107. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/template.py +0 -0
  108. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/template_test.py +0 -0
  109. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/templates/__init__.py +0 -0
  110. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/templates/completion.py +0 -0
  111. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/templates/completion_test.py +0 -0
  112. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/templates/conversation.py +0 -0
  113. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/templates/conversation_test.py +0 -0
  114. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/templates/demonstration.py +0 -0
  115. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/templates/demonstration_test.py +0 -0
  116. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/templates/selfplay.py +0 -0
  117. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/templates/selfplay_test.py +0 -0
  118. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/text_formatting.py +0 -0
  119. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun/core/text_formatting_test.py +0 -0
  120. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun.egg-info/SOURCES.txt +0 -0
  121. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun.egg-info/dependency_links.txt +0 -0
  122. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/langfun.egg-info/top_level.txt +0 -0
  123. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/setup.cfg +0 -0
  124. {langfun-0.0.2.dev20240531 → langfun-0.0.2.dev20240601}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.0.2.dev20240531
3
+ Version: 0.0.2.dev20240601
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -25,13 +25,14 @@ Requires-Dist: google-cloud-aiplatform>=1.5.0
25
25
  Requires-Dist: google-generativeai>=0.3.2
26
26
  Requires-Dist: jinja2>=3.1.2
27
27
  Requires-Dist: openai==0.27.2
28
+ Requires-Dist: openpyxl>=3.1.0
29
+ Requires-Dist: pandas>=2.1.4
28
30
  Requires-Dist: pyglove>=0.4.5.dev20240423
31
+ Requires-Dist: python-docx>=0.8.11
29
32
  Requires-Dist: python-magic>=0.4.27
30
33
  Requires-Dist: requests>=2.31.0
31
34
  Requires-Dist: termcolor==1.1.0
32
35
  Requires-Dist: tqdm>=4.64.1
33
- Requires-Dist: python-docx>=0.8.11
34
- Requires-Dist: pandas>=2.1.4
35
36
 
36
37
  <div align="center">
37
38
  <img src="https://raw.githubusercontent.com/google/langfun/main/docs/_static/logo.svg" width="520px" alt="logo"></img>
@@ -57,6 +57,8 @@ from langfun.core import memories
57
57
 
58
58
  from langfun.core import modalities
59
59
 
60
+ Mime = modalities.Mime
61
+ MimeType = Mime # For backwards compatibility.
60
62
  Image = modalities.Image
61
63
  Video = modalities.Video
62
64
  PDF = modalities.PDF
@@ -94,6 +94,7 @@ from langfun.core.message import MemoryRecord
94
94
  # Interface for modality.
95
95
  from langfun.core.modality import Modality
96
96
  from langfun.core.modality import ModalityRef
97
+ from langfun.core.modality import ModalityError
97
98
 
98
99
  # Interfaces for languge models.
99
100
  from langfun.core.language_model import LanguageModel
@@ -49,9 +49,10 @@ class GenAI(lf.LanguageModel):
49
49
  ),
50
50
  ] = None
51
51
 
52
- multimodal: Annotated[bool, 'Whether this model has multimodal support.'] = (
53
- False
54
- )
52
+ supported_modalities: Annotated[
53
+ list[str],
54
+ 'A list of MIME types for supported modalities'
55
+ ] = []
55
56
 
56
57
  # Set the default max concurrency to 8 workers.
57
58
  max_concurrency = 8
@@ -118,14 +119,27 @@ class GenAI(lf.LanguageModel):
118
119
  chunks = []
119
120
  for lf_chunk in formatted.chunk():
120
121
  if isinstance(lf_chunk, str):
121
- chunk = lf_chunk
122
- elif self.multimodal and isinstance(lf_chunk, lf_modalities.MimeType):
123
- chunk = genai.types.BlobDict(
124
- data=lf_chunk.to_bytes(), mime_type=lf_chunk.mime_type
125
- )
122
+ chunks.append(lf_chunk)
123
+ elif isinstance(lf_chunk, lf_modalities.Mime):
124
+ try:
125
+ modalities = lf_chunk.make_compatible(
126
+ self.supported_modalities + ['text/plain']
127
+ )
128
+ if isinstance(modalities, lf_modalities.Mime):
129
+ modalities = [modalities]
130
+ for modality in modalities:
131
+ if modality.is_text:
132
+ chunk = modality.to_text()
133
+ else:
134
+ chunk = genai.types.BlobDict(
135
+ data=modality.to_bytes(),
136
+ mime_type=modality.mime_type
137
+ )
138
+ chunks.append(chunk)
139
+ except lf.ModalityError as e:
140
+ raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
126
141
  else:
127
- raise ValueError(f'Unsupported modality: {lf_chunk!r}')
128
- chunks.append(chunk)
142
+ raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
129
143
  return chunks
130
144
 
131
145
  def _response_to_result(
@@ -264,18 +278,57 @@ _GOOGLE_GENAI_MODEL_HUB = _ModelHub()
264
278
  #
265
279
 
266
280
 
281
+ _IMAGE_TYPES = [
282
+ 'image/png',
283
+ 'image/jpeg',
284
+ 'image/webp',
285
+ 'image/heic',
286
+ 'image/heif',
287
+ ]
288
+
289
+ _AUDIO_TYPES = [
290
+ 'audio/aac',
291
+ 'audio/flac',
292
+ 'audio/mp3',
293
+ 'audio/m4a',
294
+ 'audio/mpeg',
295
+ 'audio/mpga',
296
+ 'audio/mp4',
297
+ 'audio/opus',
298
+ 'audio/pcm',
299
+ 'audio/wav',
300
+ 'audio/webm'
301
+ ]
302
+
303
+ _VIDEO_TYPES = [
304
+ 'video/mov',
305
+ 'video/mpeg',
306
+ 'video/mpegps',
307
+ 'video/mpg',
308
+ 'video/mp4',
309
+ 'video/webm',
310
+ 'video/wmv',
311
+ 'video/x-flv',
312
+ 'video/3gpp',
313
+ ]
314
+
315
+ _PDF = [
316
+ 'application/pdf',
317
+ ]
318
+
319
+
267
320
  class GeminiPro1_5(GenAI): # pylint: disable=invalid-name
268
321
  """Gemini Pro latest model."""
269
322
 
270
323
  model = 'gemini-1.5-pro-latest'
271
- multimodal = True
324
+ supported_modalities = _PDF + _IMAGE_TYPES + _AUDIO_TYPES + _VIDEO_TYPES
272
325
 
273
326
 
274
327
  class GeminiFlash1_5(GenAI): # pylint: disable=invalid-name
275
328
  """Gemini Flash latest model."""
276
329
 
277
330
  model = 'gemini-1.5-flash-latest'
278
- multimodal = True
331
+ supported_modalities = _PDF + _IMAGE_TYPES + _AUDIO_TYPES + _VIDEO_TYPES
279
332
 
280
333
 
281
334
  class GeminiPro(GenAI):
@@ -288,7 +341,7 @@ class GeminiProVision(GenAI):
288
341
  """Gemini Pro vision model."""
289
342
 
290
343
  model = 'gemini-pro-vision'
291
- multimodal = True
344
+ supported_modalities = _IMAGE_TYPES + _VIDEO_TYPES
292
345
 
293
346
 
294
347
  class Palm2(GenAI):
@@ -107,7 +107,7 @@ class GenAITest(unittest.TestCase):
107
107
  )
108
108
 
109
109
  # Non-multimodal model.
110
- with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
110
+ with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
111
111
  google_genai.GeminiPro()._content_from_message(message)
112
112
 
113
113
  model = google_genai.GeminiProVision()
@@ -75,9 +75,10 @@ class VertexAI(lf.LanguageModel):
75
75
  ),
76
76
  ] = None
77
77
 
78
- multimodal: Annotated[bool, 'Whether this model has multimodal support.'] = (
79
- False
80
- )
78
+ supported_modalities: Annotated[
79
+ list[str],
80
+ 'A list of MIME types for supported modalities'
81
+ ] = []
81
82
 
82
83
  def _on_bound(self):
83
84
  super()._on_bound()
@@ -142,16 +143,29 @@ class VertexAI(lf.LanguageModel):
142
143
  """Gets generation input from langfun message."""
143
144
  from vertexai import generative_models
144
145
  chunks = []
146
+
145
147
  for lf_chunk in prompt.chunk():
146
148
  if isinstance(lf_chunk, str):
147
- chunk = lf_chunk
148
- elif self.multimodal and isinstance(lf_chunk, lf_modalities.MimeType):
149
- chunk = generative_models.Part.from_data(
150
- lf_chunk.to_bytes(), lf_chunk.mime_type
151
- )
149
+ chunks.append(lf_chunk)
150
+ elif isinstance(lf_chunk, lf_modalities.Mime):
151
+ try:
152
+ modalities = lf_chunk.make_compatible(
153
+ self.supported_modalities + ['text/plain']
154
+ )
155
+ if isinstance(modalities, lf_modalities.Mime):
156
+ modalities = [modalities]
157
+ for modality in modalities:
158
+ if modality.is_text:
159
+ chunk = modality.to_text()
160
+ else:
161
+ chunk = generative_models.Part.from_data(
162
+ modality.to_bytes(), modality.mime_type
163
+ )
164
+ chunks.append(chunk)
165
+ except lf.ModalityError as e:
166
+ raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
152
167
  else:
153
- raise ValueError(f'Unsupported modality: {lf_chunk!r}')
154
- chunks.append(chunk)
168
+ raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
155
169
  return chunks
156
170
 
157
171
  def _generation_response_to_message(
@@ -265,25 +279,64 @@ class _ModelHub:
265
279
  _VERTEXAI_MODEL_HUB = _ModelHub()
266
280
 
267
281
 
282
+ _IMAGE_TYPES = [
283
+ 'image/png',
284
+ 'image/jpeg',
285
+ 'image/webp',
286
+ 'image/heic',
287
+ 'image/heif',
288
+ ]
289
+
290
+ _AUDIO_TYPES = [
291
+ 'audio/aac',
292
+ 'audio/flac',
293
+ 'audio/mp3',
294
+ 'audio/m4a',
295
+ 'audio/mpeg',
296
+ 'audio/mpga',
297
+ 'audio/mp4',
298
+ 'audio/opus',
299
+ 'audio/pcm',
300
+ 'audio/wav',
301
+ 'audio/webm'
302
+ ]
303
+
304
+ _VIDEO_TYPES = [
305
+ 'video/mov',
306
+ 'video/mpeg',
307
+ 'video/mpegps',
308
+ 'video/mpg',
309
+ 'video/mp4',
310
+ 'video/webm',
311
+ 'video/wmv',
312
+ 'video/x-flv',
313
+ 'video/3gpp',
314
+ ]
315
+
316
+ _PDF = [
317
+ 'application/pdf',
318
+ ]
319
+
320
+
268
321
  class VertexAIGeminiPro1_5(VertexAI): # pylint: disable=invalid-name
269
322
  """Vertex AI Gemini 1.5 Pro model."""
270
323
 
271
324
  model = 'gemini-1.5-pro-preview-0514'
272
- multimodal = True
325
+ supported_modalities = _PDF + _IMAGE_TYPES + _AUDIO_TYPES + _VIDEO_TYPES
273
326
 
274
327
 
275
328
  class VertexAIGeminiPro1_5_0409(VertexAI): # pylint: disable=invalid-name
276
329
  """Vertex AI Gemini 1.5 Pro model."""
277
330
 
278
331
  model = 'gemini-1.5-pro-preview-0409'
279
- multimodal = True
332
+ supported_modalities = _PDF + _IMAGE_TYPES + _AUDIO_TYPES + _VIDEO_TYPES
280
333
 
281
334
 
282
335
  class VertexAIGeminiFlash1_5(VertexAI): # pylint: disable=invalid-name
283
336
  """Vertex AI Gemini 1.5 Flash model."""
284
337
 
285
338
  model = 'gemini-1.5-flash-preview-0514'
286
- multimodal = True
339
+ supported_modalities = _PDF + _IMAGE_TYPES + _AUDIO_TYPES + _VIDEO_TYPES
287
340
 
288
341
 
289
342
  class VertexAIGeminiPro1(VertexAI): # pylint: disable=invalid-name
@@ -296,7 +349,7 @@ class VertexAIGeminiPro1Vision(VertexAI): # pylint: disable=invalid-name
296
349
  """Vertex AI Gemini 1.0 Pro model."""
297
350
 
298
351
  model = 'gemini-1.0-pro-vision'
299
- multimodal = True
352
+ supported_modalities = _IMAGE_TYPES + _VIDEO_TYPES
300
353
 
301
354
 
302
355
  class VertexAIPalm2(VertexAI): # pylint: disable=invalid-name
@@ -79,7 +79,7 @@ class VertexAITest(unittest.TestCase):
79
79
  )
80
80
 
81
81
  # Non-multimodal model.
82
- with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
82
+ with self.assertRaisesRegex(lf.ModalityError, 'Unsupported modality'):
83
83
  vertexai.VertexAIGeminiPro1()._content_from_message(message)
84
84
 
85
85
  model = vertexai.VertexAIGeminiPro1Vision()
@@ -18,7 +18,7 @@
18
18
  # pylint: disable=g-import-not-at-top
19
19
 
20
20
  from langfun.core.modalities.audio import Audio
21
- from langfun.core.modalities.mime import MimeType
21
+ from langfun.core.modalities.mime import Mime
22
22
  from langfun.core.modalities.mime import Custom
23
23
  from langfun.core.modalities.ms_office import Docx
24
24
  from langfun.core.modalities.ms_office import Pptx
@@ -17,7 +17,7 @@ import functools
17
17
  from langfun.core.modalities import mime
18
18
 
19
19
 
20
- class Audio(mime.MimeType):
20
+ class Audio(mime.Mime):
21
21
  """Audio."""
22
22
 
23
23
  MIME_PREFIX = 'audio'
@@ -17,7 +17,7 @@ import functools
17
17
  from langfun.core.modalities import mime
18
18
 
19
19
 
20
- class Image(mime.MimeType):
20
+ class Image(mime.Mime):
21
21
  """Image."""
22
22
 
23
23
  MIME_PREFIX = 'image'
@@ -15,7 +15,9 @@
15
15
  import unittest
16
16
  from unittest import mock
17
17
 
18
+ import langfun.core as lf
18
19
  from langfun.core.modalities import image as image_lib
20
+ from langfun.core.modalities import mime as mime_lib
19
21
  import pyglove as pg
20
22
 
21
23
 
@@ -36,23 +38,29 @@ def mock_request(*args, **kwargs):
36
38
  return pg.Dict(content=image_content)
37
39
 
38
40
 
39
- class ImageContentTest(unittest.TestCase):
41
+ class ImageTest(unittest.TestCase):
40
42
 
41
- def test_image_content(self):
43
+ def test_from_bytes(self):
42
44
  image = image_lib.Image.from_bytes(image_content)
43
45
  self.assertEqual(image.image_format, 'png')
44
46
  self.assertIn('data:image/png;base64,', image._repr_html_())
45
47
  self.assertEqual(image.to_bytes(), image_content)
48
+ with self.assertRaisesRegex(
49
+ lf.ModalityError, '.* cannot be converted to text'
50
+ ):
51
+ image.to_text()
46
52
 
47
- def test_bad_image(self):
53
+ def test_from_bytes_invalid(self):
48
54
  image = image_lib.Image.from_bytes(b'bad')
49
55
  with self.assertRaisesRegex(ValueError, 'Expected MIME type'):
50
56
  _ = image.image_format
51
57
 
58
+ def test_from_bytes_base_cls(self):
59
+ self.assertIsInstance(
60
+ mime_lib.Mime.from_bytes(image_content), image_lib.Image
61
+ )
52
62
 
53
- class ImageFileTest(unittest.TestCase):
54
-
55
- def test_image_file(self):
63
+ def test_from_uri(self):
56
64
  image = image_lib.Image.from_uri('http://mock/web/a.png')
57
65
  with mock.patch('requests.get') as mock_requests_get:
58
66
  mock_requests_get.side_effect = mock_request
@@ -60,6 +68,15 @@ class ImageFileTest(unittest.TestCase):
60
68
  self.assertEqual(image._repr_html_(), '<img src="http://mock/web/a.png">')
61
69
  self.assertEqual(image.to_bytes(), image_content)
62
70
 
71
+ def test_from_uri_base_cls(self):
72
+ with mock.patch('requests.get') as mock_requests_get:
73
+ mock_requests_get.side_effect = mock_request
74
+ image = mime_lib.Mime.from_uri('http://mock/web/a.png')
75
+ self.assertIsInstance(image, image_lib.Image)
76
+ self.assertEqual(image.image_format, 'png')
77
+ self.assertEqual(image._repr_html_(), '<img src="http://mock/web/a.png">')
78
+ self.assertEqual(image.to_bytes(), image_content)
79
+
63
80
 
64
81
  if __name__ == '__main__':
65
82
  unittest.main()
@@ -0,0 +1,191 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """MIME type data."""
15
+
16
+ import base64
17
+ import functools
18
+ from typing import Annotated, Iterable, Type, Union
19
+ import langfun.core as lf
20
+ import magic
21
+ import pyglove as pg
22
+ import requests
23
+
24
+
25
+ class Mime(lf.Modality):
26
+ """Base for MIME data."""
27
+
28
+ # The regular expression that describes the MIME type str.
29
+ # If None, the MIME type is dynamic. Subclass could override.
30
+ MIME_PREFIX = None
31
+
32
+ uri: Annotated[str | None, 'The URI for locating the MIME data. '] = None
33
+
34
+ content: Annotated[
35
+ Union[str, bytes, None], 'The raw content of the MIME type.'
36
+ ] = None
37
+
38
+ @functools.cached_property
39
+ def mime_type(self) -> str:
40
+ """Returns the MIME type."""
41
+ mime = magic.from_buffer((self.to_bytes()), mime=True)
42
+ if (
43
+ self.MIME_PREFIX
44
+ and not mime.lower().startswith(self.MIME_PREFIX)
45
+ # NOTE(daiyip): libmagic fails to detect the MIME type of some binary
46
+ # files.
47
+ and mime != 'application/octet-stream'
48
+ ):
49
+ raise ValueError(
50
+ f'Expected MIME type: {self.MIME_PREFIX}, Encountered: {mime}'
51
+ )
52
+ return mime
53
+
54
+ @functools.cached_property
55
+ def is_text(self) -> bool:
56
+ return self.mime_type.startswith(
57
+ (
58
+ 'text/',
59
+ 'application/javascript',
60
+ 'application/json',
61
+ 'application/ld+json',
62
+ 'application/plain',
63
+ 'application/xhtml+xml',
64
+ 'application/xml',
65
+ 'application/x-tex',
66
+ 'application/x-yaml',
67
+ )
68
+ )
69
+
70
+ @property
71
+ def is_binary(self) -> bool:
72
+ """Returns True if the MIME type is a binary type."""
73
+ return not self.is_text
74
+
75
+ def to_text(self) -> str:
76
+ """Returns the text content of the MIME type."""
77
+ if not self.is_text:
78
+ raise lf.ModalityError(
79
+ f'MIME type {self.mime_type!r} cannot be converted to text.'
80
+ )
81
+ return self.to_bytes().decode()
82
+
83
+ def is_compatible(
84
+ self, mime_types: str | Iterable[str]
85
+ ) -> bool:
86
+ """Returns True if this object is compatible to any of the MIME types."""
87
+ if isinstance(mime_types, str):
88
+ mime_types = {mime_types}
89
+ return self._is_compatible(mime_types)
90
+
91
+ def _is_compatible(self, mime_types: Iterable[str]):
92
+ return self.mime_type in mime_types
93
+
94
+ def make_compatible(
95
+ self,
96
+ mime_types: str | Iterable[str]
97
+ ) -> Union['Mime', list['Mime']]:
98
+ """Makes compatible MIME objects from this object."""
99
+ if isinstance(mime_types, str):
100
+ mime_types = {mime_types}
101
+ if not self._is_compatible(mime_types):
102
+ raise lf.ModalityError(
103
+ f'MIME type {self.mime_type!r} cannot be converted to supported '
104
+ f'types: {mime_types!r}.'
105
+ )
106
+ return self._make_compatible(mime_types)
107
+
108
+ def _make_compatible(
109
+ self,
110
+ mime_types: Iterable[str]
111
+ ) -> Union['Mime', list['Mime']]:
112
+ """Makes compatbile MIME objects from this object."""
113
+ del mime_types
114
+ return self
115
+
116
+ def _on_bound(self):
117
+ super()._on_bound()
118
+ if self.uri is None and self.content is None:
119
+ raise ValueError('Either uri or content must be provided.')
120
+
121
+ def to_bytes(self) -> bytes:
122
+ if self.content is not None:
123
+ return self.content
124
+
125
+ self.rebind(content=self.download(self.uri), skip_notification=True)
126
+ return self.content
127
+
128
+ @property
129
+ def content_uri(self) -> str:
130
+ base64_content = base64.b64encode(self.to_bytes()).decode()
131
+ return f'data:{self.mime_type};base64,{base64_content}'
132
+
133
+ @classmethod
134
+ def from_uri(cls, uri: str, **kwargs) -> 'Mime':
135
+ if cls is Mime:
136
+ content = cls.download(uri)
137
+ mime = magic.from_buffer(content, mime=True).lower()
138
+ return cls.class_from_mime_type(mime)(uri=uri, content=content, **kwargs)
139
+ return cls(uri=uri, content=None, **kwargs)
140
+
141
+ @classmethod
142
+ def from_bytes(cls, content: bytes | str, **kwargs) -> 'Mime':
143
+ if cls is Mime:
144
+ mime = magic.from_buffer(content, mime=True).lower()
145
+ return cls.class_from_mime_type(mime)(content=content, **kwargs)
146
+ return cls(content=content, **kwargs)
147
+
148
+ @classmethod
149
+ def class_from_mime_type(cls, mime_type: str) -> Type['Mime']:
150
+ """Subclass from the given MIME type."""
151
+ for subcls in cls.__subclasses__():
152
+ if subcls.MIME_PREFIX is not None and mime_type.startswith(
153
+ subcls.MIME_PREFIX):
154
+ return subcls
155
+ return cls
156
+
157
+ @classmethod
158
+ def download(cls, uri: str) -> bytes | str:
159
+ """Downloads the content of the given URI."""
160
+ if uri.lower().startswith(('http:', 'https:', 'ftp:')):
161
+ return requests.get(
162
+ uri,
163
+ headers={'User-Agent': 'Mozilla/5.0'},
164
+ ).content
165
+ else:
166
+ content = pg.io.readfile(uri, mode='rb')
167
+ assert content is not None
168
+ return content
169
+
170
+ def _repr_html_(self) -> str:
171
+ if self.uri and self.uri.lower().startswith(('http:', 'https:', 'ftp:')):
172
+ uri = self.uri
173
+ else:
174
+ uri = self.content_uri
175
+ return self._html(uri)
176
+
177
+ def _html(self, uri) -> str:
178
+ return f'<embed type="{self.mime_type}" src="{uri}"/>'
179
+
180
+
181
+ @pg.use_init_args(['mime', 'content', 'uri'])
182
+ class Custom(Mime):
183
+ """Custom MIME data."""
184
+
185
+ mime: Annotated[
186
+ str, 'The MIME type of the data. E.g. text/plain, or image/png. '
187
+ ]
188
+
189
+ @property
190
+ def mime_type(self) -> str:
191
+ return self.mime
@@ -15,6 +15,7 @@
15
15
  import unittest
16
16
  from unittest import mock
17
17
 
18
+ import langfun.core as lf
18
19
  from langfun.core.modalities import mime
19
20
  import pyglove as pg
20
21
 
@@ -31,10 +32,24 @@ def mock_readfile(*args, **kwargs):
31
32
 
32
33
  class CustomMimeTest(unittest.TestCase):
33
34
 
34
- def test_content(self):
35
- content = mime.Custom('text/plain', 'foo')
36
- self.assertEqual(content.to_bytes(), 'foo')
35
+ def test_from_byes(self):
36
+ content = mime.Mime.from_bytes(b'hello')
37
+ self.assertIs(content.__class__, mime.Mime)
38
+
39
+ content = mime.Custom('text/plain', b'foo')
40
+ self.assertEqual(content.to_bytes(), b'foo')
37
41
  self.assertEqual(content.mime_type, 'text/plain')
42
+ self.assertTrue(content.is_text)
43
+ self.assertFalse(content.is_binary)
44
+ self.assertEqual(content.to_text(), 'foo')
45
+ self.assertTrue(content.is_compatible('text/plain'))
46
+ self.assertFalse(content.is_compatible('text/xml'))
47
+ self.assertIs(content.make_compatible('text/plain'), content)
48
+
49
+ with self.assertRaisesRegex(
50
+ lf.ModalityError, '.* cannot be converted to supported types'
51
+ ):
52
+ content.make_compatible('application/pdf')
38
53
 
39
54
  with self.assertRaisesRegex(
40
55
  ValueError, 'Either uri or content must be provided.'