langfun 0.1.1.dev20240826__py3-none-any.whl → 0.1.1.dev202408282153__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,12 +18,29 @@ import functools
18
18
  import os
19
19
  from typing import Annotated, Any, Literal
20
20
 
21
- import google.generativeai as genai
22
21
  import langfun.core as lf
23
22
  from langfun.core import modalities as lf_modalities
24
23
  import pyglove as pg
25
24
 
26
25
 
26
+ try:
27
+ import google.generativeai as genai # pylint: disable=g-import-not-at-top
28
+ BlobDict = genai.types.BlobDict
29
+ GenerativeModel = genai.GenerativeModel
30
+ Completion = genai.types.Completion
31
+ GenerationConfig = genai.GenerationConfig
32
+ GenerateContentResponse = genai.types.GenerateContentResponse
33
+ ChatResponse = genai.types.ChatResponse
34
+ except ImportError:
35
+ genai = None
36
+ BlobDict = Any
37
+ GenerativeModel = Any
38
+ Completion = Any
39
+ GenerationConfig = Any
40
+ GenerateContentResponse = Any
41
+ ChatResponse = Any
42
+
43
+
27
44
  @lf.use_init_args(['model'])
28
45
  class GenAI(lf.LanguageModel):
29
46
  """Language models provided by Google GenAI."""
@@ -59,10 +76,16 @@ class GenAI(lf.LanguageModel):
59
76
 
60
77
  def _on_bound(self):
61
78
  super()._on_bound()
79
+ if genai is None:
80
+ raise RuntimeError(
81
+ 'Please install "langfun[llm-google-genai]" to use '
82
+ 'Google Generative AI models.'
83
+ )
62
84
  self.__dict__.pop('_api_initialized', None)
63
85
 
64
86
  @functools.cached_property
65
87
  def _api_initialized(self):
88
+ assert genai is not None
66
89
  api_key = self.api_key or os.environ.get('GOOGLE_API_KEY', None)
67
90
  if not api_key:
68
91
  raise ValueError(
@@ -78,6 +101,7 @@ class GenAI(lf.LanguageModel):
78
101
  @classmethod
79
102
  def dir(cls) -> list[str]:
80
103
  """Lists generative models."""
104
+ assert genai is not None
81
105
  return [
82
106
  m.name.lstrip('models/')
83
107
  for m in genai.list_models()
@@ -100,7 +124,7 @@ class GenAI(lf.LanguageModel):
100
124
 
101
125
  def _generation_config(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
102
126
  """Creates generation config from langfun sampling options."""
103
- return genai.GenerationConfig(
127
+ return GenerationConfig(
104
128
  candidate_count=options.n,
105
129
  temperature=options.temperature,
106
130
  top_p=options.top_p,
@@ -111,7 +135,7 @@ class GenAI(lf.LanguageModel):
111
135
 
112
136
  def _content_from_message(
113
137
  self, prompt: lf.Message
114
- ) -> list[str | genai.types.BlobDict]:
138
+ ) -> list[str | BlobDict]:
115
139
  """Gets Evergreen formatted content from langfun message."""
116
140
  formatted = lf.UserMessage(prompt.text)
117
141
  formatted.source = prompt
@@ -131,7 +155,7 @@ class GenAI(lf.LanguageModel):
131
155
  if modality.is_text:
132
156
  chunk = modality.to_text()
133
157
  else:
134
- chunk = genai.types.BlobDict(
158
+ chunk = BlobDict(
135
159
  data=modality.to_bytes(),
136
160
  mime_type=modality.mime_type
137
161
  )
@@ -143,7 +167,7 @@ class GenAI(lf.LanguageModel):
143
167
  return chunks
144
168
 
145
169
  def _response_to_result(
146
- self, response: genai.types.GenerateContentResponse | pg.Dict
170
+ self, response: GenerateContentResponse | pg.Dict
147
171
  ) -> lf.LMSamplingResult:
148
172
  """Parses generative response into message."""
149
173
  samples = []
@@ -182,8 +206,8 @@ class _LegacyGenerativeModel(pg.Object):
182
206
 
183
207
  def generate_content(
184
208
  self,
185
- input_content: list[str | genai.types.BlobDict],
186
- generation_config: genai.GenerationConfig,
209
+ input_content: list[str | BlobDict],
210
+ generation_config: GenerationConfig,
187
211
  ) -> pg.Dict:
188
212
  """Generate content."""
189
213
  segments = []
@@ -195,7 +219,7 @@ class _LegacyGenerativeModel(pg.Object):
195
219
 
196
220
  @abc.abstractmethod
197
221
  def generate(
198
- self, prompt: str, generation_config: genai.GenerationConfig) -> pg.Dict:
222
+ self, prompt: str, generation_config: GenerationConfig) -> pg.Dict:
199
223
  """Generate response based on prompt."""
200
224
 
201
225
 
@@ -203,9 +227,10 @@ class _LegacyCompletionModel(_LegacyGenerativeModel):
203
227
  """Legacy GenAI completion model."""
204
228
 
205
229
  def generate(
206
- self, prompt: str, generation_config: genai.GenerationConfig
230
+ self, prompt: str, generation_config: GenerationConfig
207
231
  ) -> pg.Dict:
208
- completion: genai.types.Completion = genai.generate_text(
232
+ assert genai is not None
233
+ completion: Completion = genai.generate_text(
209
234
  model=f'models/{self.model}',
210
235
  prompt=prompt,
211
236
  temperature=generation_config.temperature,
@@ -227,9 +252,10 @@ class _LegacyChatModel(_LegacyGenerativeModel):
227
252
  """Legacy GenAI chat model."""
228
253
 
229
254
  def generate(
230
- self, prompt: str, generation_config: genai.GenerationConfig
255
+ self, prompt: str, generation_config: GenerationConfig
231
256
  ) -> pg.Dict:
232
- response: genai.types.ChatResponse = genai.chat(
257
+ assert genai is not None
258
+ response: ChatResponse = genai.chat(
233
259
  model=f'models/{self.model}',
234
260
  messages=prompt,
235
261
  temperature=generation_config.temperature,
@@ -253,8 +279,9 @@ class _ModelHub:
253
279
 
254
280
  def get(
255
281
  self, model_name: str
256
- ) -> genai.GenerativeModel | _LegacyGenerativeModel:
282
+ ) -> GenerativeModel | _LegacyGenerativeModel:
257
283
  """Gets a generative model by model id."""
284
+ assert genai is not None
258
285
  model = self._model_cache.get(model_name, None)
259
286
  if model is None:
260
287
  model_info = genai.get_model(f'models/{model_name}')
@@ -16,15 +16,31 @@
16
16
  import collections
17
17
  import functools
18
18
  import os
19
- from typing import Annotated, Any, cast
19
+ from typing import Annotated, Any
20
20
 
21
21
  import langfun.core as lf
22
22
  from langfun.core import modalities as lf_modalities
23
- import openai
24
- from openai import error as openai_error
25
- from openai import openai_object
26
23
  import pyglove as pg
27
24
 
25
+ try:
26
+ import openai # pylint: disable=g-import-not-at-top
27
+
28
+ if hasattr(openai, 'error'):
29
+ # For lower versions.
30
+ ServiceUnavailableError = openai.error.ServiceUnavailableError
31
+ RateLimitError = openai.error.RateLimitError
32
+ APITimeoutError = (
33
+ openai.error.APIError,
34
+ '.*The server had an error processing your request'
35
+ )
36
+ else:
37
+ # For higher versions.
38
+ ServiceUnavailableError = getattr(openai, 'InternalServerError')
39
+ RateLimitError = getattr(openai, 'RateLimitError')
40
+ APITimeoutError = getattr(openai, 'APITimeoutError')
41
+ except ImportError:
42
+ openai = None
43
+
28
44
 
29
45
  # From https://platform.openai.com/settings/organization/limits
30
46
  _DEFAULT_TPM = 250000
@@ -119,6 +135,10 @@ class OpenAI(lf.LanguageModel):
119
135
  def _on_bound(self):
120
136
  super()._on_bound()
121
137
  self.__dict__.pop('_api_initialized', None)
138
+ if openai is None:
139
+ raise RuntimeError(
140
+ 'Please install "langfun[llm-openai]" to use OpenAI models.'
141
+ )
122
142
 
123
143
  @functools.cached_property
124
144
  def _api_initialized(self):
@@ -149,6 +169,7 @@ class OpenAI(lf.LanguageModel):
149
169
 
150
170
  @classmethod
151
171
  def dir(cls):
172
+ assert openai is not None
152
173
  return openai.Model.list()
153
174
 
154
175
  @property
@@ -195,11 +216,11 @@ class OpenAI(lf.LanguageModel):
195
216
  ) -> list[lf.LMSamplingResult]:
196
217
 
197
218
  def _open_ai_completion(prompts):
219
+ assert openai is not None
198
220
  response = openai.Completion.create(
199
221
  prompt=[p.text for p in prompts],
200
222
  **self._get_request_args(self.sampling_options),
201
223
  )
202
- response = cast(openai_object.OpenAIObject, response)
203
224
  # Parse response.
204
225
  samples_by_index = collections.defaultdict(list)
205
226
  for choice in response.choices:
@@ -222,12 +243,9 @@ class OpenAI(lf.LanguageModel):
222
243
  _open_ai_completion,
223
244
  [prompts],
224
245
  retry_on_errors=(
225
- openai_error.ServiceUnavailableError,
226
- openai_error.RateLimitError,
227
- # Handling transient OpenAI server error (code 500). Check out
228
- # https://platform.openai.com/docs/guides/error-codes/error-codes
229
- (openai_error.APIError,
230
- '.*The server had an error processing your request'),
246
+ ServiceUnavailableError,
247
+ RateLimitError,
248
+ APITimeoutError,
231
249
  ),
232
250
  )[0]
233
251
 
@@ -292,10 +310,8 @@ class OpenAI(lf.LanguageModel):
292
310
  )
293
311
  messages.append(dict(role='user', content=_content_from_message(prompt)))
294
312
 
295
- response = cast(
296
- openai_object.OpenAIObject,
297
- openai.ChatCompletion.create(messages=messages, **request_args)
298
- )
313
+ assert openai is not None
314
+ response = openai.ChatCompletion.create(messages=messages, **request_args)
299
315
 
300
316
  samples = []
301
317
  for choice in response.choices:
@@ -330,8 +346,9 @@ class OpenAI(lf.LanguageModel):
330
346
  _open_ai_chat_completion,
331
347
  prompts,
332
348
  retry_on_errors=(
333
- openai_error.ServiceUnavailableError,
334
- openai_error.RateLimitError,
349
+ ServiceUnavailableError,
350
+ RateLimitError,
351
+ APITimeoutError
335
352
  ),
336
353
  )
337
354
 
@@ -17,11 +17,28 @@ import functools
17
17
  import os
18
18
  from typing import Annotated, Any
19
19
 
20
- from google.auth import credentials as credentials_lib
21
20
  import langfun.core as lf
22
21
  from langfun.core import modalities as lf_modalities
23
22
  import pyglove as pg
24
23
 
24
+ try:
25
+ # pylint: disable=g-import-not-at-top
26
+ from google.auth import credentials as credentials_lib
27
+ import vertexai
28
+ from google.cloud.aiplatform import models as aiplatform_models
29
+ from vertexai import generative_models
30
+ from vertexai import language_models
31
+ # pylint: enable=g-import-not-at-top
32
+
33
+ Credentials = credentials_lib.Credentials
34
+ except ImportError:
35
+ credentials_lib = None # pylint: disable=invalid-name
36
+ vertexai = None
37
+ generative_models = None
38
+ language_models = None
39
+ aiplatform_models = None
40
+ Credentials = Any
41
+
25
42
 
26
43
  SUPPORTED_MODELS_AND_SETTINGS = {
27
44
  'gemini-1.5-pro-001': pg.Dict(api='gemini', rpm=50),
@@ -78,7 +95,7 @@ class VertexAI(lf.LanguageModel):
78
95
  ] = None
79
96
 
80
97
  credentials: Annotated[
81
- credentials_lib.Credentials | None,
98
+ Credentials | None,
82
99
  (
83
100
  'Credentials to use. If None, the default credentials to the '
84
101
  'environment will be used.'
@@ -93,6 +110,10 @@ class VertexAI(lf.LanguageModel):
93
110
  def _on_bound(self):
94
111
  super()._on_bound()
95
112
  self.__dict__.pop('_api_initialized', None)
113
+ if generative_models is None:
114
+ raise RuntimeError(
115
+ 'Please install "langfun[llm-google-vertex]" to use Vertex AI models.'
116
+ )
96
117
 
97
118
  @functools.cached_property
98
119
  def _api_initialized(self):
@@ -112,7 +133,7 @@ class VertexAI(lf.LanguageModel):
112
133
 
113
134
  credentials = self.credentials
114
135
  # Placeholder for Google-internal credentials.
115
- import vertexai
136
+ assert vertexai is not None
116
137
  vertexai.init(project=project, location=location, credentials=credentials)
117
138
  return True
118
139
 
@@ -138,7 +159,7 @@ class VertexAI(lf.LanguageModel):
138
159
  self, prompt: lf.Message, options: lf.LMSamplingOptions
139
160
  ) -> Any: # generative_models.GenerationConfig
140
161
  """Creates generation config from langfun sampling options."""
141
- from vertexai import generative_models
162
+ assert generative_models is not None
142
163
  # Users could use `metadata_json_schema` to pass additional
143
164
  # request arguments.
144
165
  json_schema = prompt.metadata.get('json_schema')
@@ -169,7 +190,7 @@ class VertexAI(lf.LanguageModel):
169
190
  self, prompt: lf.Message
170
191
  ) -> list[str | Any]:
171
192
  """Gets generation input from langfun message."""
172
- from vertexai import generative_models
193
+ assert generative_models is not None
173
194
  chunks = []
174
195
 
175
196
  for lf_chunk in prompt.chunk():
@@ -296,8 +317,8 @@ class VertexAI(lf.LanguageModel):
296
317
 
297
318
  def _sample_endpoint_model(self, prompt: lf.Message) -> lf.LMSamplingResult:
298
319
  """Samples a text generation model."""
299
- from google.cloud.aiplatform import models
300
- model = models.Endpoint(self.endpoint_name)
320
+ assert aiplatform_models is not None
321
+ model = aiplatform_models.Endpoint(self.endpoint_name)
301
322
  # TODO(chengrun): Add support for stop_sequences.
302
323
  predict_options = dict(
303
324
  temperature=self.sampling_options.temperature
@@ -337,7 +358,7 @@ class _ModelHub:
337
358
  """Gets a generative model by model id."""
338
359
  model = self._generative_model_cache.get(model_id, None)
339
360
  if model is None:
340
- from vertexai import generative_models
361
+ assert generative_models is not None
341
362
  model = generative_models.GenerativeModel(model_id)
342
363
  self._generative_model_cache[model_id] = model
343
364
  return model
@@ -348,7 +369,7 @@ class _ModelHub:
348
369
  """Gets a text generation model by model id."""
349
370
  model = self._text_generation_model_cache.get(model_id, None)
350
371
  if model is None:
351
- from vertexai import language_models
372
+ assert language_models is not None
352
373
  model = language_models.TextGenerationModel.from_pretrained(model_id)
353
374
  self._text_generation_model_cache[model_id] = model
354
375
  return model
langfun/core/message.py CHANGED
@@ -278,6 +278,16 @@ class Message(natural_language.NaturalLanguageFormattable, pg.Object):
278
278
  # API for supporting modalities.
279
279
  #
280
280
 
281
+ @property
282
+ def text_with_modality_hash(self) -> str:
283
+ """Returns text with modality object placeheld by their 8-byte MD5 hash."""
284
+ parts = [self.text]
285
+ for name, modality_obj in self.referred_modalities().items():
286
+ parts.append(
287
+ f'<{name}>{modality_obj.hash}</{name}>'
288
+ )
289
+ return ''.join(parts)
290
+
281
291
  def get_modality(
282
292
  self, var_name: str, default: Any = None, from_message_chain: bool = True
283
293
  ) -> modality.Modality | None:
@@ -282,6 +282,20 @@ class MessageTest(unittest.TestCase):
282
282
  },
283
283
  )
284
284
 
285
+ def test_text_with_modality_hash(self):
286
+ m = message.UserMessage(
287
+ 'hi, this is a <<[[img1]]>> and <<[[x.img2]]>>',
288
+ img1=CustomModality('foo'),
289
+ x=dict(img2=CustomModality('bar')),
290
+ )
291
+ self.assertEqual(
292
+ m.text_with_modality_hash,
293
+ (
294
+ 'hi, this is a <<[[img1]]>> and <<[[x.img2]]>>'
295
+ '<img1>acbd18db</img1><x.img2>37b51d19</x.img2>'
296
+ )
297
+ )
298
+
285
299
  def test_chunking(self):
286
300
  m = message.UserMessage(
287
301
  inspect.cleandoc("""
@@ -15,9 +15,21 @@
15
15
 
16
16
  import functools
17
17
  import io
18
+ from typing import Any
18
19
 
19
20
  from langfun.core.modalities import mime
20
- from PIL import Image as pil_image
21
+
22
+ try:
23
+ from PIL import Image as pil_image # pylint: disable=g-import-not-at-top
24
+ PILImage = pil_image.Image
25
+ pil_open = pil_image.open
26
+ except ImportError:
27
+ PILImage = Any
28
+
29
+ def pil_open(*unused_args, **unused_kwargs):
30
+ raise RuntimeError(
31
+ 'Please install "langfun[mime-pil]" to enable PIL image support.'
32
+ )
21
33
 
22
34
 
23
35
  class Image(mime.Mime):
@@ -34,14 +46,14 @@ class Image(mime.Mime):
34
46
 
35
47
  @functools.cached_property
36
48
  def size(self) -> tuple[int, int]:
37
- img = pil_image.open(io.BytesIO(self.to_bytes()))
49
+ img = pil_open(io.BytesIO(self.to_bytes()))
38
50
  return img.size
39
51
 
40
- def to_pil_image(self) -> pil_image.Image:
41
- return pil_image.open(io.BytesIO(self.to_bytes()))
52
+ def to_pil_image(self) -> PILImage: # pytype: disable=invalid-annotation
53
+ return pil_open(io.BytesIO(self.to_bytes()))
42
54
 
43
55
  @classmethod
44
- def from_pil_image(cls, img: pil_image.Image) -> 'Image':
56
+ def from_pil_image(cls, img: PILImage) -> 'Image': # pytype: disable=invalid-annotation
45
57
  buf = io.BytesIO()
46
58
  img.save(buf, format='PNG')
47
59
  return cls.from_bytes(buf.getvalue())
@@ -17,11 +17,20 @@ import base64
17
17
  import functools
18
18
  from typing import Annotated, Iterable, Type, Union
19
19
  import langfun.core as lf
20
- import magic
21
20
  import pyglove as pg
22
21
  import requests
23
22
 
24
23
 
24
+ try:
25
+ import magic # pylint: disable=g-import-not-at-top
26
+ from_buffer = magic.from_buffer
27
+ except ImportError:
28
+ def from_buffer(*unused_args, **unused_kwargs):
29
+ raise RuntimeError(
30
+ 'Please install "langfun[mime-auto]" to enable automatic MIME support.'
31
+ )
32
+
33
+
25
34
  class Mime(lf.Modality):
26
35
  """Base for MIME data."""
27
36
 
@@ -38,7 +47,7 @@ class Mime(lf.Modality):
38
47
  @functools.cached_property
39
48
  def mime_type(self) -> str:
40
49
  """Returns the MIME type."""
41
- mime = magic.from_buffer((self.to_bytes()), mime=True)
50
+ mime = from_buffer((self.to_bytes()), mime=True)
42
51
  if (
43
52
  self.MIME_PREFIX
44
53
  and not mime.lower().startswith(self.MIME_PREFIX)
@@ -136,14 +145,14 @@ class Mime(lf.Modality):
136
145
  def from_uri(cls, uri: str, **kwargs) -> 'Mime':
137
146
  if cls is Mime:
138
147
  content = cls.download(uri)
139
- mime = magic.from_buffer(content, mime=True).lower()
148
+ mime = from_buffer(content, mime=True).lower()
140
149
  return cls.class_from_mime_type(mime)(uri=uri, content=content, **kwargs)
141
150
  return cls(uri=uri, content=None, **kwargs)
142
151
 
143
152
  @classmethod
144
153
  def from_bytes(cls, content: bytes | str, **kwargs) -> 'Mime':
145
154
  if cls is Mime:
146
- mime = magic.from_buffer(content, mime=True).lower()
155
+ mime = from_buffer(content, mime=True).lower()
147
156
  return cls.class_from_mime_type(mime)(content=content, **kwargs)
148
157
  return cls(content=content, **kwargs)
149
158
 
@@ -30,10 +30,15 @@ class Xlsx(mime.Mime):
30
30
  )
31
31
 
32
32
  def to_html(self) -> str:
33
- import pandas as pd # pylint: disable=g-import-not-at-top
34
-
35
- df = pd.read_excel(io.BytesIO(self.to_bytes()))
36
- return df.to_html()
33
+ try:
34
+ import pandas as pd # pylint: disable=g-import-not-at-top
35
+ import openpyxl # pylint: disable=g-import-not-at-top, unused-import
36
+ df = pd.read_excel(io.BytesIO(self.to_bytes()))
37
+ return df.to_html()
38
+ except ImportError as e:
39
+ raise RuntimeError(
40
+ 'Please install "langfun[mime-xlsx]" to enable XLSX support.'
41
+ ) from e
37
42
 
38
43
  def _repr_html_(self) -> str:
39
44
  return self.to_html()
@@ -58,10 +63,14 @@ class Docx(mime.Mime):
58
63
  )
59
64
 
60
65
  def to_xml(self) -> str:
61
- import docx # pylint: disable=g-import-not-at-top
62
-
63
- doc = docx.Document(io.BytesIO(self.to_bytes()))
64
- return str(doc.element.xml)
66
+ try:
67
+ import docx # pylint: disable=g-import-not-at-top
68
+ doc = docx.Document(io.BytesIO(self.to_bytes()))
69
+ return str(doc.element.xml)
70
+ except ImportError as e:
71
+ raise RuntimeError(
72
+ 'Please install "langfun[mime-docx]" to enable Docx support.'
73
+ ) from e
65
74
 
66
75
  def _repr_html_(self) -> str:
67
76
  return self.to_xml()
langfun/core/modality.py CHANGED
@@ -14,6 +14,8 @@
14
14
  """Interface for modality (e.g. Image, Video, etc.)."""
15
15
 
16
16
  import abc
17
+ import functools
18
+ import hashlib
17
19
  from typing import Any, ContextManager
18
20
  from langfun.core import component
19
21
  import pyglove as pg
@@ -35,6 +37,11 @@ class Modality(component.Component):
35
37
  REF_START = '<<[['
36
38
  REF_END = ']]>>'
37
39
 
40
+ def _on_bound(self):
41
+ super()._on_bound()
42
+ # Invalidate cached hash if modality member is changed.
43
+ self.__dict__.pop('hash', None)
44
+
38
45
  def format(self, *args, **kwargs) -> str:
39
46
  if self.referred_name is None or not pg.object_utils.thread_local_get(
40
47
  _TLS_MODALITY_AS_REF, False
@@ -46,6 +53,11 @@ class Modality(component.Component):
46
53
  def to_bytes(self) -> bytes:
47
54
  """Returns content in bytes."""
48
55
 
56
+ @functools.cached_property
57
+ def hash(self) -> str:
58
+ """Returns a 8-byte MD5 hash as the identifier for this modality object."""
59
+ return hashlib.md5(self.to_bytes()).hexdigest()[:8]
60
+
49
61
  @classmethod
50
62
  def text_marker(cls, var_name: str) -> str:
51
63
  """Returns a marker in the text for this object."""
@@ -32,6 +32,7 @@ class ModalityTest(unittest.TestCase):
32
32
  v = CustomModality('a')
33
33
  self.assertIsNone(v.referred_name)
34
34
  self.assertEqual(str(v), "CustomModality(\n content = 'a'\n)")
35
+ self.assertEqual(v.hash, '0cc175b9')
35
36
 
36
37
  _ = pg.Dict(metadata=pg.Dict(x=pg.Dict(metadata=pg.Dict(y=v))))
37
38
  self.assertEqual(v.referred_name, 'x.metadata.y')
@@ -352,6 +352,12 @@ class Mapping(lf.LangFunc):
352
352
  lm_output = self.postprocess_response(lm_output)
353
353
  lm_output.result = self.postprocess_result(self.parse_result(lm_output))
354
354
  except Exception as e: # pylint: disable=broad-exception-caught
355
+ if (self.lm.cache is not None
356
+ and lm_output.lm_input.cache_seed is not None):
357
+ success = self.lm.cache.delete(
358
+ self.lm, lm_output.lm_input, lm_output.lm_input.cache_seed
359
+ )
360
+ assert success
355
361
  if self.default == lf.RAISE_IF_HAS_ERROR:
356
362
  raise MappingError(lm_output, e) from e
357
363
  lm_output.result = self.default
@@ -19,6 +19,7 @@ import unittest
19
19
  import langfun.core as lf
20
20
  from langfun.core import modalities
21
21
  from langfun.core.llms import fake
22
+ from langfun.core.llms.cache import in_memory
22
23
  from langfun.core.structured import mapping
23
24
  from langfun.core.structured import prompting
24
25
  import pyglove as pg
@@ -799,15 +800,18 @@ class QueryStructureJsonTest(unittest.TestCase):
799
800
  self.assertIsNone(r.result[0].hotel)
800
801
 
801
802
  def test_bad_transform(self):
802
- with lf.context(
803
- lm=fake.StaticSequence(['3']),
804
- override_attrs=True,
805
- ):
806
- with self.assertRaisesRegex(
807
- mapping.MappingError,
808
- 'No JSON dict in the output',
803
+ with in_memory.lm_cache() as cache:
804
+ with lf.context(
805
+ lm=fake.StaticSequence(['3']),
806
+ override_attrs=True,
809
807
  ):
810
- prompting.query('Compute 1 + 2', int, protocol='json')
808
+ with self.assertRaisesRegex(
809
+ mapping.MappingError,
810
+ 'No JSON dict in the output',
811
+ ):
812
+ prompting.query('Compute 1 + 2', int, protocol='json', cache_seed=1)
813
+ # Make sure bad mapping does not impact cache.
814
+ self.assertEqual(len(cache), 0)
811
815
 
812
816
  def test_query(self):
813
817
  lm = fake.StaticSequence(['{"result": 1}'])
@@ -16,7 +16,11 @@
16
16
  import io
17
17
  import re
18
18
  from typing import Any
19
- import termcolor
19
+
20
+ try:
21
+ import termcolor # pylint: disable=g-import-not-at-top
22
+ except ImportError:
23
+ termcolor = None
20
24
 
21
25
 
22
26
  # Regular expression for ANSI color characters.
@@ -49,6 +53,8 @@ def colored(
49
53
  Returns:
50
54
  A string with ANSI color characters embracing the entire text.
51
55
  """
56
+ if not termcolor:
57
+ return text
52
58
  return termcolor.colored(
53
59
  text,
54
60
  color=color,
@@ -42,6 +42,24 @@ class TextFormattingTest(unittest.TestCase):
42
42
  )
43
43
  self.assertEqual(text_formatting.decolored(colored_text), original_text)
44
44
 
45
+ def test_colored_without_termcolor(self):
46
+ termcolor = text_formatting.termcolor
47
+ text_formatting.termcolor = None
48
+ original_text = inspect.cleandoc("""
49
+ Hi {{ foo }}
50
+ {# print x if x is present #}
51
+ {% if x %}
52
+ {{ x }}
53
+ {% endif %}
54
+ """)
55
+
56
+ colored_text = text_formatting.colored_template(
57
+ text_formatting.colored(original_text, color='blue')
58
+ )
59
+ self.assertEqual(colored_text, original_text)
60
+ self.assertEqual(text_formatting.decolored(colored_text), original_text)
61
+ text_formatting.termcolor = termcolor
62
+
45
63
 
46
64
  if __name__ == '__main__':
47
65
  unittest.main()