langfun 0.0.2.dev20240429__py3-none-any.whl → 0.0.2.dev20240511__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (37) hide show
  1. langfun/__init__.py +5 -0
  2. langfun/core/eval/__init__.py +14 -1
  3. langfun/core/eval/base.py +503 -112
  4. langfun/core/eval/base_test.py +185 -53
  5. langfun/core/eval/matching.py +22 -21
  6. langfun/core/eval/matching_test.py +23 -2
  7. langfun/core/eval/patching.py +130 -0
  8. langfun/core/eval/patching_test.py +170 -0
  9. langfun/core/eval/scoring.py +4 -4
  10. langfun/core/eval/scoring_test.py +19 -2
  11. langfun/core/langfunc.py +1 -17
  12. langfun/core/langfunc_test.py +4 -0
  13. langfun/core/language_model.py +6 -0
  14. langfun/core/llms/__init__.py +8 -0
  15. langfun/core/llms/fake.py +6 -6
  16. langfun/core/llms/google_genai.py +8 -0
  17. langfun/core/llms/openai.py +3 -2
  18. langfun/core/llms/openai_test.py +2 -1
  19. langfun/core/llms/vertexai.py +291 -0
  20. langfun/core/llms/vertexai_test.py +233 -0
  21. langfun/core/modalities/image.py +1 -3
  22. langfun/core/modalities/mime.py +6 -0
  23. langfun/core/modalities/video.py +1 -3
  24. langfun/core/structured/__init__.py +2 -0
  25. langfun/core/structured/mapping.py +5 -1
  26. langfun/core/structured/prompting.py +39 -11
  27. langfun/core/structured/prompting_test.py +43 -0
  28. langfun/core/structured/schema.py +34 -4
  29. langfun/core/structured/schema_test.py +32 -1
  30. langfun/core/structured/scoring.py +4 -1
  31. langfun/core/structured/scoring_test.py +6 -0
  32. langfun/core/template.py +22 -1
  33. {langfun-0.0.2.dev20240429.dist-info → langfun-0.0.2.dev20240511.dist-info}/METADATA +2 -2
  34. {langfun-0.0.2.dev20240429.dist-info → langfun-0.0.2.dev20240511.dist-info}/RECORD +37 -33
  35. {langfun-0.0.2.dev20240429.dist-info → langfun-0.0.2.dev20240511.dist-info}/LICENSE +0 -0
  36. {langfun-0.0.2.dev20240429.dist-info → langfun-0.0.2.dev20240511.dist-info}/WHEEL +0 -0
  37. {langfun-0.0.2.dev20240429.dist-info → langfun-0.0.2.dev20240511.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,291 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Vertex AI generative models."""
15
+
16
+ import functools
17
+ import os
18
+ from typing import Annotated, Any
19
+
20
+ from google.auth import credentials as credentials_lib
21
+ import langfun.core as lf
22
+ from langfun.core import modalities as lf_modalities
23
+ import pyglove as pg
24
+
25
+
26
+ SUPPORTED_MODELS_AND_SETTINGS = {
27
+ 'gemini-1.5-pro-preview-0409': pg.Dict(api='gemini', rpm=5),
28
+ 'gemini-1.0-pro': pg.Dict(api='gemini', rpm=300),
29
+ 'gemini-1.0-pro-vision': pg.Dict(api='gemini', rpm=100),
30
+ # PaLM APIs.
31
+ 'text-bison': pg.Dict(api='palm', rpm=1600),
32
+ 'text-bison-32k': pg.Dict(api='palm', rpm=300),
33
+ 'text-unicorn': pg.Dict(api='palm', rpm=100),
34
+ }
35
+
36
+
37
+ @lf.use_init_args(['model'])
38
+ class VertexAI(lf.LanguageModel):
39
+ """Language model served on VertexAI."""
40
+
41
+ model: pg.typing.Annotated[
42
+ pg.typing.Enum(
43
+ pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
44
+ ),
45
+ (
46
+ 'Vertex AI model name. See '
47
+ 'https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models '
48
+ 'for details.'
49
+ ),
50
+ ]
51
+
52
+ project: Annotated[
53
+ str | None,
54
+ (
55
+ 'Vertex AI project ID. Or set from environment variable '
56
+ 'VERTEXAI_PROJECT.'
57
+ ),
58
+ ] = None
59
+
60
+ location: Annotated[
61
+ str | None,
62
+ (
63
+ 'Vertex AI service location. Or set from environment variable '
64
+ 'VERTEXAI_LOCATION.'
65
+ ),
66
+ ] = None
67
+
68
+ credentials: Annotated[
69
+ credentials_lib.Credentials | None,
70
+ (
71
+ 'Credentials to use. If None, the default credentials to the '
72
+ 'environment will be used.'
73
+ ),
74
+ ] = None
75
+
76
+ multimodal: Annotated[bool, 'Whether this model has multimodal support.'] = (
77
+ False
78
+ )
79
+
80
+ def _on_bound(self):
81
+ super()._on_bound()
82
+ self.__dict__.pop('_api_initialized', None)
83
+
84
+ @functools.cached_property
85
+ def _api_initialized(self):
86
+ project = self.project or os.environ.get('VERTEXAI_PROJECT', None)
87
+ if not project:
88
+ raise ValueError(
89
+ 'Please specify `project` during `__init__` or set environment '
90
+ 'variable `VERTEXAI_PROJECT` with your Vertex AI project ID.'
91
+ )
92
+
93
+ location = self.location or os.environ.get('VERTEXAI_LOCATION', None)
94
+ if not location:
95
+ raise ValueError(
96
+ 'Please specify `location` during `__init__` or set environment '
97
+ 'variable `VERTEXAI_LOCATION` with your Vertex AI service location.'
98
+ )
99
+
100
+ credentials = self.credentials
101
+ # Placeholder for Google-internal credentials.
102
+ from google.cloud.aiplatform import vertexai # pylint: disable=g-import-not-at-top
103
+ vertexai.init(project=project, location=location, credentials=credentials)
104
+ return True
105
+
106
+ @property
107
+ def model_id(self) -> str:
108
+ """Returns a string to identify the model."""
109
+ return f'VertexAI({self.model})'
110
+
111
+ @property
112
+ def resource_id(self) -> str:
113
+ """Returns a string to identify the resource for rate control."""
114
+ return self.model_id
115
+
116
+ @property
117
+ def max_concurrency(self) -> int:
118
+ """Returns the maximum number of concurrent requests."""
119
+ return self.rate_to_max_concurrency(
120
+ requests_per_min=SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm,
121
+ tokens_per_min=0,
122
+ )
123
+
124
+ def _generation_config(
125
+ self, options: lf.LMSamplingOptions
126
+ ) -> Any: # generative_models.GenerationConfig
127
+ """Creates generation config from langfun sampling options."""
128
+ from google.cloud.aiplatform.vertexai.preview import generative_models # pylint: disable=g-import-not-at-top
129
+ return generative_models.GenerationConfig(
130
+ temperature=options.temperature,
131
+ top_p=options.top_p,
132
+ top_k=options.top_k,
133
+ max_output_tokens=options.max_tokens,
134
+ stop_sequences=options.stop,
135
+ )
136
+
137
+ def _content_from_message(
138
+ self, prompt: lf.Message
139
+ ) -> list[str | Any]:
140
+ """Gets generation input from langfun message."""
141
+ from google.cloud.aiplatform.vertexai.preview import generative_models # pylint: disable=g-import-not-at-top
142
+ chunks = []
143
+ for lf_chunk in prompt.chunk():
144
+ if isinstance(lf_chunk, str):
145
+ chunk = lf_chunk
146
+ elif self.multimodal and isinstance(lf_chunk, lf_modalities.Image):
147
+ chunk = generative_models.Image.from_bytes(lf_chunk.to_bytes())
148
+ else:
149
+ raise ValueError(f'Unsupported modality: {lf_chunk!r}')
150
+ chunks.append(chunk)
151
+ return chunks
152
+
153
+ def _generation_response_to_message(
154
+ self,
155
+ response: Any, # generative_models.GenerationResponse
156
+ ) -> lf.Message:
157
+ """Parses generative response into message."""
158
+ return lf.AIMessage(response.text)
159
+
160
+ def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
161
+ assert self._api_initialized, 'Vertex AI API is not initialized.'
162
+ return lf.concurrent_execute(
163
+ self._sample_single,
164
+ prompts,
165
+ executor=self.resource_id,
166
+ max_workers=self.max_concurrency,
167
+ # NOTE(daiyip): Vertex has its own policy on handling
168
+ # with rate limit, so we do not retry on errors.
169
+ retry_on_errors=None,
170
+ )
171
+
172
+ def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
173
+ if self.sampling_options.n > 1:
174
+ raise ValueError(
175
+ f'`n` greater than 1 is not supported: {self.sampling_options.n}.'
176
+ )
177
+ api = SUPPORTED_MODELS_AND_SETTINGS[self.model].api
178
+ match api:
179
+ case 'gemini':
180
+ return self._sample_generative_model(prompt)
181
+ case 'palm':
182
+ return self._sample_text_generation_model(prompt)
183
+ case _:
184
+ raise ValueError(f'Unsupported API: {api}')
185
+
186
+ def _sample_generative_model(self, prompt: lf.Message) -> lf.LMSamplingResult:
187
+ """Samples a generative model."""
188
+ model = _VERTEXAI_MODEL_HUB.get_generative_model(self.model)
189
+ input_content = self._content_from_message(prompt)
190
+ response = model.generate_content(
191
+ input_content,
192
+ generation_config=self._generation_config(self.sampling_options),
193
+ )
194
+ usage_metadata = response.usage_metadata
195
+ usage = lf.LMSamplingUsage(
196
+ prompt_tokens=usage_metadata.prompt_token_count,
197
+ completion_tokens=usage_metadata.candidates_token_count,
198
+ total_tokens=usage_metadata.total_token_count,
199
+ )
200
+ return lf.LMSamplingResult(
201
+ [
202
+ # Scoring is not supported.
203
+ lf.LMSample(
204
+ self._generation_response_to_message(response), score=0.0
205
+ ),
206
+ ],
207
+ usage=usage,
208
+ )
209
+
210
+ def _sample_text_generation_model(
211
+ self, prompt: lf.Message
212
+ ) -> lf.LMSamplingResult:
213
+ """Samples a text generation model."""
214
+ model = _VERTEXAI_MODEL_HUB.get_text_generation_model(self.model)
215
+ predict_options = dict(
216
+ temperature=self.sampling_options.temperature,
217
+ top_k=self.sampling_options.top_k,
218
+ top_p=self.sampling_options.top_p,
219
+ max_output_tokens=self.sampling_options.max_tokens,
220
+ stop_sequences=self.sampling_options.stop,
221
+ )
222
+ response = model.predict(prompt.text, **predict_options)
223
+ return lf.LMSamplingResult([
224
+ # Scoring is not supported.
225
+ lf.LMSample(lf.AIMessage(response.text), score=0.0)
226
+ ])
227
+
228
+
229
+ class _ModelHub:
230
+ """Vertex AI model hub."""
231
+
232
+ def __init__(self):
233
+ self._generative_model_cache = {}
234
+ self._text_generation_model_cache = {}
235
+
236
+ def get_generative_model(
237
+ self, model_id: str
238
+ ) -> Any: # generative_models.GenerativeModel:
239
+ """Gets a generative model by model id."""
240
+ model = self._generative_model_cache.get(model_id, None)
241
+ if model is None:
242
+ from google.cloud.aiplatform.vertexai.preview import generative_models # pylint: disable=g-import-not-at-top
243
+ model = generative_models.GenerativeModel(model_id)
244
+ self._generative_model_cache[model_id] = model
245
+ return model
246
+
247
+ def get_text_generation_model(
248
+ self, model_id: str
249
+ ) -> Any: # language_models.TextGenerationModel
250
+ """Gets a text generation model by model id."""
251
+ model = self._text_generation_model_cache.get(model_id, None)
252
+ if model is None:
253
+ from google.cloud.aiplatform.vertexai import language_models # pylint: disable=g-import-not-at-top
254
+ model = language_models.TextGenerationModel.from_pretrained(model_id)
255
+ self._text_generation_model_cache[model_id] = model
256
+ return model
257
+
258
+
259
+ _VERTEXAI_MODEL_HUB = _ModelHub()
260
+
261
+
262
+ class VertexAIGeminiPro1_5(VertexAI): # pylint: disable=invalid-name
263
+ """Vertex AI Gemini 1.5 Pro model."""
264
+
265
+ model = 'gemini-1.5-pro-preview-0409'
266
+ multimodal = True
267
+
268
+
269
+ class VertexAIGeminiPro1(VertexAI): # pylint: disable=invalid-name
270
+ """Vertex AI Gemini 1.0 Pro model."""
271
+
272
+ model = 'gemini-1.0-pro'
273
+
274
+
275
+ class VertexAIGeminiPro1Vision(VertexAI): # pylint: disable=invalid-name
276
+ """Vertex AI Gemini 1.0 Pro model."""
277
+
278
+ model = 'gemini-1.0-pro-vision'
279
+ multimodal = True
280
+
281
+
282
+ class VertexAIPalm2(VertexAI): # pylint: disable=invalid-name
283
+ """Vertex AI PaLM2 text generation model."""
284
+
285
+ model = 'text-bison'
286
+
287
+
288
+ class VertexAIPalm2_32K(VertexAI): # pylint: disable=invalid-name
289
+ """Vertex AI PaLM2 text generation model (32K context length)."""
290
+
291
+ model = 'text-bison-32k'
@@ -0,0 +1,233 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for Gemini models."""
15
+
16
+ import os
17
+ import unittest
18
+ from unittest import mock
19
+
20
+ from google.cloud.aiplatform.vertexai.preview import generative_models
21
+ import langfun.core as lf
22
+ from langfun.core import modalities as lf_modalities
23
+ from langfun.core.llms import vertexai
24
+ import pyglove as pg
25
+
26
+
27
+ example_image = (
28
+ b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
29
+ b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
30
+ b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
31
+ b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
32
+ b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
33
+ b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
34
+ b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
35
+ b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
36
+ )
37
+
38
+
39
+ def mock_generate_content(content, generation_config, **kwargs):
40
+ del kwargs
41
+ c = pg.Dict(generation_config.to_dict())
42
+ print('zzz', c)
43
+ return generative_models.GenerationResponse.from_dict({
44
+ 'candidates': [
45
+ {
46
+ 'index': 0,
47
+ 'content': {
48
+ 'role': 'model',
49
+ 'parts': [
50
+ {
51
+ 'text': (
52
+ f'This is a response to {content[0]} with '
53
+ f'temperature={c.temperature}, '
54
+ f'top_p={c.top_p}, '
55
+ f'top_k={c.top_k}, '
56
+ f'max_tokens={c.max_output_tokens}, '
57
+ f'stop={"".join(c.stop_sequences)}.'
58
+ )
59
+ },
60
+ ],
61
+ },
62
+ },
63
+ ]
64
+ })
65
+
66
+
67
+ class VertexAITest(unittest.TestCase):
68
+ """Tests for Vertex model."""
69
+
70
+ def test_content_from_message_text_only(self):
71
+ text = 'This is a beautiful day'
72
+ model = vertexai.VertexAIGeminiPro1()
73
+ chunks = model._content_from_message(lf.UserMessage(text))
74
+ self.assertEqual(chunks, [text])
75
+
76
+ def test_content_from_message_mm(self):
77
+ message = lf.UserMessage(
78
+ 'This is an {{image}}, what is it?',
79
+ image=lf_modalities.Image.from_bytes(example_image),
80
+ )
81
+
82
+ # Non-multimodal model.
83
+ with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
84
+ vertexai.VertexAIGeminiPro1()._content_from_message(message)
85
+
86
+ model = vertexai.VertexAIGeminiPro1Vision()
87
+ chunks = model._content_from_message(message)
88
+ self.maxDiff = None
89
+ self.assertEqual([chunks[0], chunks[2]], ['This is an', ', what is it?'])
90
+ self.assertIsInstance(chunks[1], generative_models.Image)
91
+
92
+ def test_generation_response_to_message_text_only(self):
93
+ response = generative_models.GenerationResponse.from_dict({
94
+ 'candidates': [
95
+ {
96
+ 'index': 0,
97
+ 'content': {
98
+ 'role': 'model',
99
+ 'parts': [
100
+ {
101
+ 'text': 'hello world',
102
+ },
103
+ ],
104
+ },
105
+ },
106
+ ],
107
+ })
108
+ model = vertexai.VertexAIGeminiPro1()
109
+ message = model._generation_response_to_message(response)
110
+ self.assertEqual(message, lf.AIMessage('hello world'))
111
+
112
+ def test_model_hub(self):
113
+ with mock.patch(
114
+ 'google.cloud.aiplatform.vertexai.preview.generative_models.'
115
+ 'GenerativeModel.__init__'
116
+ ) as mock_model_init:
117
+ mock_model_init.side_effect = lambda *args, **kwargs: None
118
+ model = vertexai._VERTEXAI_MODEL_HUB.get_generative_model(
119
+ 'gemini-1.0-pro'
120
+ )
121
+ self.assertIsNotNone(model)
122
+ self.assertIs(
123
+ vertexai._VERTEXAI_MODEL_HUB.get_generative_model('gemini-1.0-pro'),
124
+ model,
125
+ )
126
+
127
+ with mock.patch(
128
+ 'google.cloud.aiplatform.vertexai.language_models.'
129
+ 'TextGenerationModel.from_pretrained'
130
+ ) as mock_model_init:
131
+
132
+ class TextGenerationModel:
133
+ pass
134
+
135
+ mock_model_init.side_effect = lambda *args, **kw: TextGenerationModel()
136
+ model = vertexai._VERTEXAI_MODEL_HUB.get_text_generation_model(
137
+ 'text-bison'
138
+ )
139
+ self.assertIsNotNone(model)
140
+ self.assertIs(
141
+ vertexai._VERTEXAI_MODEL_HUB.get_text_generation_model('text-bison'),
142
+ model,
143
+ )
144
+
145
+ def test_project_and_location_check(self):
146
+ with self.assertRaisesRegex(ValueError, 'Please specify `project`'):
147
+ _ = vertexai.VertexAIGeminiPro1()._api_initialized
148
+
149
+ with self.assertRaisesRegex(ValueError, 'Please specify `location`'):
150
+ _ = vertexai.VertexAIGeminiPro1(project='abc')._api_initialized
151
+
152
+ self.assertTrue(
153
+ vertexai.VertexAIGeminiPro1(
154
+ project='abc', location='us-central1'
155
+ )._api_initialized
156
+ )
157
+
158
+ os.environ['VERTEXAI_PROJECT'] = 'abc'
159
+ os.environ['VERTEXAI_LOCATION'] = 'us-central1'
160
+ self.assertTrue(vertexai.VertexAIGeminiPro1()._api_initialized)
161
+ del os.environ['VERTEXAI_PROJECT']
162
+ del os.environ['VERTEXAI_LOCATION']
163
+
164
+ def test_call_generative_model(self):
165
+ with mock.patch(
166
+ 'google.cloud.aiplatform.vertexai.preview.generative_models.'
167
+ 'GenerativeModel.__init__'
168
+ ) as mock_model_init:
169
+ mock_model_init.side_effect = lambda *args, **kwargs: None
170
+
171
+ with mock.patch(
172
+ 'google.cloud.aiplatform.vertexai.preview.generative_models.'
173
+ 'GenerativeModel.generate_content'
174
+ ) as mock_generate:
175
+ mock_generate.side_effect = mock_generate_content
176
+
177
+ lm = vertexai.VertexAIGeminiPro1(project='abc', location='us-central1')
178
+ self.assertEqual(
179
+ lm(
180
+ 'hello',
181
+ temperature=2.0,
182
+ top_p=1.0,
183
+ top_k=20,
184
+ max_tokens=1024,
185
+ stop='\n',
186
+ ).text,
187
+ (
188
+ 'This is a response to hello with temperature=2.0, '
189
+ 'top_p=1.0, top_k=20.0, max_tokens=1024, stop=\n.'
190
+ ),
191
+ )
192
+
193
+ def test_call_text_generation_model(self):
194
+ with mock.patch(
195
+ 'google.cloud.aiplatform.vertexai.language_models.'
196
+ 'TextGenerationModel.from_pretrained'
197
+ ) as mock_model_init:
198
+
199
+ class TextGenerationModel:
200
+
201
+ def predict(self, prompt, **kwargs):
202
+ c = pg.Dict(kwargs)
203
+ return pg.Dict(
204
+ text=(
205
+ f'This is a response to {prompt} with '
206
+ f'temperature={c.temperature}, '
207
+ f'top_p={c.top_p}, '
208
+ f'top_k={c.top_k}, '
209
+ f'max_tokens={c.max_output_tokens}, '
210
+ f'stop={"".join(c.stop_sequences)}.'
211
+ )
212
+ )
213
+
214
+ mock_model_init.side_effect = lambda *args, **kw: TextGenerationModel()
215
+ lm = vertexai.VertexAIPalm2(project='abc', location='us-central1')
216
+ self.assertEqual(
217
+ lm(
218
+ 'hello',
219
+ temperature=2.0,
220
+ top_p=1.0,
221
+ top_k=20,
222
+ max_tokens=1024,
223
+ stop='\n',
224
+ ).text,
225
+ (
226
+ 'This is a response to hello with temperature=2.0, '
227
+ 'top_p=1.0, top_k=20, max_tokens=1024, stop=\n.'
228
+ ),
229
+ )
230
+
231
+
232
+ if __name__ == '__main__':
233
+ unittest.main()
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
  """Image modality."""
15
15
 
16
- import base64
17
16
  import imghdr
18
17
  from typing import cast
19
18
  from langfun.core.modalities import mime
@@ -36,5 +35,4 @@ class Image(mime.MimeType):
36
35
  def _repr_html_(self) -> str:
37
36
  if self.uri and self.uri.lower().startswith(('http:', 'https:', 'ftp:')):
38
37
  return f'<img src="{self.uri}">'
39
- image_raw = base64.b64encode(self.to_bytes()).decode()
40
- return f'<img src="data:image/{self.image_format};base64,{image_raw}">'
38
+ return f'<img src="{self.content_uri}">'
@@ -14,6 +14,7 @@
14
14
  """MIME type data."""
15
15
 
16
16
  import abc
17
+ import base64
17
18
  from typing import Annotated, Union
18
19
  import langfun.core as lf
19
20
  import pyglove as pg
@@ -54,6 +55,11 @@ class MimeType(lf.Modality):
54
55
  self.rebind(content=content, skip_notification=True)
55
56
  return self.content
56
57
 
58
+ @property
59
+ def content_uri(self) -> str:
60
+ base64_content = base64.b64encode(self.to_bytes()).decode()
61
+ return f'data:{self.mime_type};base64,{base64_content}'
62
+
57
63
  @classmethod
58
64
  def from_uri(cls, uri: str, **kwargs) -> 'MimeType':
59
65
  return cls(uri=uri, content=None, **kwargs)
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
  """Video modality."""
15
15
 
16
- import base64
17
16
  from typing import cast
18
17
  from langfun.core.modalities import mime
19
18
 
@@ -40,8 +39,7 @@ class Video(mime.MimeType):
40
39
  def _repr_html_(self) -> str:
41
40
  if self.uri and self.uri.lower().startswith(('http:', 'https:', 'ftp:')):
42
41
  return f'<video controls> <source src="{self.uri}"> </video>'
43
- video_raw = base64.b64encode(self.to_bytes()).decode()
44
42
  return (
45
43
  '<video controls> <source'
46
- f' src="data:video/{self.video_format};base64,{video_raw}"> </video>'
44
+ f' src="data:video/{self.content_uri}"> </video>'
47
45
  )
@@ -64,6 +64,8 @@ from langfun.core.structured.prompting import QueryStructure
64
64
  from langfun.core.structured.prompting import QueryStructureJson
65
65
  from langfun.core.structured.prompting import QueryStructurePython
66
66
  from langfun.core.structured.prompting import query
67
+ from langfun.core.structured.prompting import query_prompt
68
+ from langfun.core.structured.prompting import query_output
67
69
 
68
70
  from langfun.core.structured.description import DescribeStructure
69
71
  from langfun.core.structured.description import describe
@@ -251,7 +251,7 @@ class Mapping(lf.LangFunc):
251
251
 
252
252
  {%- if example.schema -%}
253
253
  {{ schema_title }}:
254
- {{ example.schema_repr(protocol) | indent(2, True) }}
254
+ {{ example.schema_repr(protocol, include_methods=include_methods) | indent(2, True) }}
255
255
 
256
256
  {% endif -%}
257
257
 
@@ -279,6 +279,10 @@ class Mapping(lf.LangFunc):
279
279
  'The protocol for representing the schema and value.',
280
280
  ] = 'python'
281
281
 
282
+ include_methods: Annotated[
283
+ bool, 'If True, include method definitions in the schema.'
284
+ ] = False
285
+
282
286
  #
283
287
  # Other user-provided flags.
284
288
  #