langfun 0.0.2.dev20240530__py3-none-any.whl → 0.0.2.dev20240601__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

@@ -12,11 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  """Video tests."""
15
+ import base64
15
16
  import io
16
17
  import unittest
17
18
  from unittest import mock
18
19
 
19
20
  from langfun.core.modalities import ms_office as ms_office_lib
21
+ from langfun.core.modalities import pdf as pdf_lib
20
22
  import pyglove as pg
21
23
 
22
24
 
@@ -243,23 +245,72 @@ def pptx_mock_request(*args, **kwargs):
243
245
  return pg.Dict(content=pptx_bytes)
244
246
 
245
247
 
248
+ pdf_bytes = (
249
+ b'%PDF-1.1\n%\xc2\xa5\xc2\xb1\xc3\xab\n\n1 0 obj\n'
250
+ b'<< /Type /Catalog\n /Pages 2 0 R\n >>\nendobj\n\n2 0 obj\n '
251
+ b'<< /Type /Pages\n /Kids [3 0 R]\n '
252
+ b'/Count 1\n /MediaBox [0 0 300 144]\n '
253
+ b'>>\nendobj\n\n3 0 obj\n '
254
+ b'<< /Type /Page\n /Parent 2 0 R\n /Resources\n '
255
+ b'<< /Font\n'
256
+ b'<< /F1\n'
257
+ b'<< /Type /Font\n'
258
+ b'/Subtype /Type1\n'
259
+ b'/BaseFont /Times-Roman\n'
260
+ b'>>\n>>\n>>\n '
261
+ b'/Contents 4 0 R\n >>\nendobj\n\n4 0 obj\n '
262
+ b'<< /Length 55 >>\nstream\n BT\n /F1 18 Tf\n 0 0 Td\n '
263
+ b'(Hello World) Tj\n ET\nendstream\nendobj\n\nxref\n0 5\n0000000000 '
264
+ b'65535 f \n0000000018 00000 n \n0000000077 00000 n \n0000000178 00000 n '
265
+ b'\n0000000457 00000 n \ntrailer\n << /Root 1 0 R\n /Size 5\n '
266
+ b'>>\nstartxref\n565\n%%EOF\n'
267
+ )
268
+
269
+
270
+ def convert_mock_request(*args, **kwargs):
271
+ del args, kwargs
272
+
273
+ class Result:
274
+ def json(self):
275
+ return {
276
+ 'Files': [
277
+ {
278
+ 'FileData': base64.b64encode(pdf_bytes).decode()
279
+ }
280
+ ]
281
+ }
282
+ return Result()
283
+
284
+
246
285
  class DocxTest(unittest.TestCase):
247
286
 
248
- def test_content(self):
287
+ def test_from_bytes(self):
249
288
  content = ms_office_lib.Docx.from_bytes(docx_bytes)
250
- self.assertEqual(
289
+ self.assertIn(
251
290
  content.mime_type,
252
- 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
291
+ (
292
+ 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
293
+ 'application/octet-stream',
294
+ ),
253
295
  )
254
296
  self.assertEqual(content.to_bytes(), docx_bytes)
297
+ self.assertTrue(content.is_compatible('text/plain'))
298
+ self.assertFalse(content.is_compatible('application/pdf'))
299
+ self.assertEqual(
300
+ content.make_compatible(['image/png', 'text/plain']).mime_type,
301
+ 'text/plain'
302
+ )
255
303
 
256
- def test_file(self):
304
+ def test_from_uri(self):
257
305
  content = ms_office_lib.Docx.from_uri('http://mock/web/a.docx')
258
306
  with mock.patch('requests.get') as mock_requests_get:
259
307
  mock_requests_get.side_effect = docx_mock_request
260
- self.assertEqual(
308
+ self.assertIn(
261
309
  content.mime_type,
262
- 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
310
+ (
311
+ 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
312
+ 'application/octet-stream',
313
+ ),
263
314
  )
264
315
  self.assertEqual(content.to_bytes(), docx_bytes)
265
316
  self.assertEqual(content.to_xml(), expected_docx_xml)
@@ -267,21 +318,33 @@ class DocxTest(unittest.TestCase):
267
318
 
268
319
  class XlsxTest(unittest.TestCase):
269
320
 
270
- def test_content(self):
321
+ def test_from_bytes(self):
271
322
  content = ms_office_lib.Xlsx.from_bytes(xlsx_bytes)
272
- self.assertEqual(
323
+ self.assertIn(
273
324
  content.mime_type,
274
- 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
325
+ (
326
+ 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
327
+ 'application/octet-stream',
328
+ ),
275
329
  )
276
330
  self.assertEqual(content.to_bytes(), xlsx_bytes)
331
+ self.assertTrue(content.is_compatible('text/plain'))
332
+ self.assertFalse(content.is_compatible('application/pdf'))
333
+ self.assertEqual(
334
+ content.make_compatible('text/plain').mime_type,
335
+ 'text/html'
336
+ )
277
337
 
278
- def test_file(self):
338
+ def test_from_uri(self):
279
339
  content = ms_office_lib.Xlsx.from_uri('http://mock/web/a.xlsx')
280
340
  with mock.patch('requests.get') as mock_requests_get:
281
341
  mock_requests_get.side_effect = xlsx_mock_request
282
- self.assertEqual(
342
+ self.assertIn(
283
343
  content.mime_type,
284
- 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
344
+ (
345
+ 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
346
+ 'application/octet-stream',
347
+ ),
285
348
  )
286
349
  self.assertEqual(content.to_bytes(), xlsx_bytes)
287
350
  self.assertEqual(content.to_html(), expected_xlsx_html)
@@ -291,22 +354,36 @@ class PptxTest(unittest.TestCase):
291
354
 
292
355
  def test_content(self):
293
356
  content = ms_office_lib.Pptx.from_bytes(pptx_bytes)
294
- self.assertEqual(
357
+ self.assertIn(
295
358
  content.mime_type,
296
- 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
359
+ (
360
+ 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
361
+ 'application/octet-stream',
362
+ ),
297
363
  )
298
364
  self.assertEqual(content.to_bytes(), pptx_bytes)
299
365
 
300
366
  def test_file(self):
301
367
  content = ms_office_lib.Pptx.from_uri('http://mock/web/a.pptx')
368
+ self.assertFalse(content.is_compatible('text/plain'))
369
+ self.assertTrue(content.is_compatible('application/pdf'))
302
370
  with mock.patch('requests.get') as mock_requests_get:
303
371
  mock_requests_get.side_effect = pptx_mock_request
304
- self.assertEqual(
372
+ self.assertIn(
305
373
  content.mime_type,
306
- 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
374
+ (
375
+ 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
376
+ 'application/octet-stream',
377
+ ),
307
378
  )
308
379
  self.assertEqual(content.to_bytes(), pptx_bytes)
309
380
 
381
+ with mock.patch('requests.post') as mock_requests_post:
382
+ mock_requests_post.side_effect = convert_mock_request
383
+ self.assertIsInstance(
384
+ content.make_compatible('application/pdf'), pdf_lib.PDF
385
+ )
386
+
310
387
 
311
388
  if __name__ == '__main__':
312
389
  unittest.main()
@@ -16,7 +16,7 @@
16
16
  from langfun.core.modalities import mime
17
17
 
18
18
 
19
- class PDF(mime.MimeType):
19
+ class PDF(mime.Mime):
20
20
  """PDF document."""
21
21
 
22
22
  MIME_PREFIX = 'application/pdf'
@@ -17,7 +17,7 @@ import functools
17
17
  from langfun.core.modalities import mime
18
18
 
19
19
 
20
- class Video(mime.MimeType):
20
+ class Video(mime.Mime):
21
21
  """Video."""
22
22
 
23
23
  MIME_PREFIX = 'video'
langfun/core/modality.py CHANGED
@@ -108,3 +108,7 @@ class ModalityRef(pg.Object, pg.typing.CustomTyping):
108
108
  return ModalityRef(name=value.sym_path + k)
109
109
  return v
110
110
  return value.clone().rebind(_placehold, raise_on_no_change=False)
111
+
112
+
113
+ class ModalityError(RuntimeError): # pylint: disable=g-bad-exception-name
114
+ """Exception raised when modality is not supported."""
@@ -16,6 +16,8 @@
16
16
  # pylint: disable=g-bad-import-order
17
17
  # pylint: disable=g-importing-member
18
18
 
19
+ from langfun.core.structured.schema import include_method_in_prompt
20
+
19
21
  from langfun.core.structured.schema import Missing
20
22
  from langfun.core.structured.schema import MISSING
21
23
  from langfun.core.structured.schema import Unknown
@@ -107,7 +107,9 @@ class CompleteStructure(mapping.Mapping):
107
107
 
108
108
  def class_defs_repr(self, value: Any) -> str | None:
109
109
  return schema_lib.class_definitions(
110
- self.missing_type_dependencies(value), markdown=True
110
+ self.missing_type_dependencies(value),
111
+ markdown=True,
112
+ allowed_dependencies=set()
111
113
  )
112
114
 
113
115
  def postprocess_result(self, result: Any) -> Any:
@@ -251,7 +251,7 @@ class Mapping(lf.LangFunc):
251
251
 
252
252
  {%- if example.schema -%}
253
253
  {{ schema_title }}:
254
- {{ example.schema_repr(protocol, include_methods=include_methods) | indent(2, True) }}
254
+ {{ example.schema_repr(protocol) | indent(2, True) }}
255
255
 
256
256
  {% endif -%}
257
257
 
@@ -279,10 +279,6 @@ class Mapping(lf.LangFunc):
279
279
  'The protocol for representing the schema and value.',
280
280
  ] = 'python'
281
281
 
282
- include_methods: Annotated[
283
- bool, 'If True, include method definitions in the schema.'
284
- ] = False
285
-
286
282
  #
287
283
  # Other user-provided flags.
288
284
  #
@@ -114,7 +114,6 @@ def query(
114
114
  autofix: int = 0,
115
115
  autofix_lm: lf.LanguageModel | None = None,
116
116
  protocol: schema_lib.SchemaProtocol = 'python',
117
- include_methods: bool = False,
118
117
  returns_message: bool = False,
119
118
  skip_lm: bool = False,
120
119
  **kwargs,
@@ -174,8 +173,6 @@ def query(
174
173
  will use `lm`.
175
174
  protocol: The protocol for schema/value representation. Applicable values
176
175
  are 'json' and 'python'. By default `python` will be used.
177
- include_methods: If True, include method definitions in the output type
178
- during prompting.
179
176
  returns_message: If True, returns `lf.Message` as the output, instead of
180
177
  returning the structured `message.result`.
181
178
  skip_lm: If True, returns the rendered prompt as a UserMessage object.
@@ -225,7 +222,6 @@ def query(
225
222
  schema=schema,
226
223
  default=default,
227
224
  examples=examples,
228
- include_methods=include_methods,
229
225
  response_postprocess=response_postprocess,
230
226
  autofix=autofix if protocol == 'python' else 0,
231
227
  **kwargs,
@@ -16,6 +16,7 @@
16
16
  import abc
17
17
  import inspect
18
18
  import io
19
+ import re
19
20
  import textwrap
20
21
  import typing
21
22
  from typing import Any, Literal, Sequence, Type, Union
@@ -24,6 +25,17 @@ from langfun.core.coding.python import correction
24
25
  import pyglove as pg
25
26
 
26
27
 
28
+ def include_method_in_prompt(method):
29
+ """Decorator to include a method in the class definition of the prompt."""
30
+ setattr(method, '__show_in_prompt__', True)
31
+ return method
32
+
33
+
34
+ def should_include_method_in_prompt(method):
35
+ """Returns true if the method should be shown in the prompt."""
36
+ return getattr(method, '__show_in_prompt__', False)
37
+
38
+
27
39
  def parse_value_spec(value) -> pg.typing.ValueSpec:
28
40
  """Parses a PyGlove ValueSpec equivalence into a ValueSpec."""
29
41
  if isinstance(value, pg.typing.ValueSpec):
@@ -163,9 +175,12 @@ class Schema(lf.NaturalLanguageFormattable, pg.Object):
163
175
 
164
176
  def class_dependencies(
165
177
  self,
178
+ include_base_classes: bool = True,
166
179
  include_subclasses: bool = True) -> list[Type[Any]]:
167
180
  """Returns a list of class dependencies for current schema."""
168
- return class_dependencies(self.spec, include_subclasses)
181
+ return class_dependencies(
182
+ self.spec, include_base_classes, include_subclasses
183
+ )
169
184
 
170
185
  @classmethod
171
186
  def from_value(cls, value) -> 'Schema':
@@ -198,11 +213,12 @@ def class_dependencies(
198
213
  Type[pg.Object],
199
214
  tuple[Union[pg.typing.ValueSpec, Type[pg.Object]], ...],
200
215
  ],
216
+ include_base_classes: bool = True,
201
217
  include_subclasses: bool = True,
202
218
  ) -> list[Type[Any]]:
203
219
  """Returns a list of class dependencies from a value or specs."""
204
220
  if isinstance(value_or_spec, Schema):
205
- return value_or_spec.class_dependencies(include_subclasses)
221
+ value_or_spec = value_or_spec.spec
206
222
 
207
223
  if inspect.isclass(value_or_spec) or isinstance(
208
224
  value_or_spec, pg.typing.ValueSpec
@@ -236,13 +252,14 @@ def class_dependencies(
236
252
  if vs.cls not in seen:
237
253
  seen.add(vs.cls)
238
254
 
239
- # Add base classes as dependencies.
240
- for base_cls in vs.cls.__bases__:
241
- # We only keep track of user-defined symbolic classes.
242
- if base_cls is not object and base_cls is not pg.Object:
243
- _fill_dependencies(
244
- pg.typing.Object(base_cls), include_subclasses=False
245
- )
255
+ if include_base_classes:
256
+ # Add base classes as dependencies.
257
+ for base_cls in vs.cls.__bases__:
258
+ # We only keep track of user-defined symbolic classes.
259
+ if base_cls is not object and base_cls is not pg.Object:
260
+ _fill_dependencies(
261
+ pg.typing.Object(base_cls), include_subclasses=False
262
+ )
246
263
 
247
264
  # Add members as dependencies.
248
265
  for field in _pg_schema(vs.cls).values():
@@ -301,7 +318,6 @@ class SchemaPythonRepr(SchemaRepr):
301
318
  schema: Schema,
302
319
  *,
303
320
  include_result_definition: bool = True,
304
- include_methods: bool = False,
305
321
  markdown: bool = True,
306
322
  **kwargs,
307
323
  ) -> str:
@@ -309,15 +325,27 @@ class SchemaPythonRepr(SchemaRepr):
309
325
  if include_result_definition:
310
326
  ret += self.result_definition(schema)
311
327
  class_definition_str = self.class_definitions(
312
- schema, markdown=markdown, include_methods=include_methods, **kwargs
328
+ schema, markdown=markdown, **kwargs
313
329
  )
314
330
  if class_definition_str:
315
331
  ret += f'\n\n{class_definition_str}'
316
332
  return ret.strip()
317
333
 
318
- def class_definitions(self, schema: Schema, **kwargs) -> str | None:
319
- deps = schema.class_dependencies(include_subclasses=True)
320
- return class_definitions(deps, **kwargs)
334
+ def class_definitions(
335
+ self,
336
+ schema: Schema,
337
+ additional_dependencies: list[Type[Any]] | None = None,
338
+ **kwargs
339
+ ) -> str | None:
340
+ """Returns a string containing of class definitions from a schema."""
341
+ deps = schema.class_dependencies(
342
+ include_base_classes=False, include_subclasses=True
343
+ )
344
+ allowed_dependencies = set(deps)
345
+ if additional_dependencies:
346
+ allowed_dependencies.update(additional_dependencies)
347
+ return class_definitions(
348
+ deps, allowed_dependencies=allowed_dependencies, **kwargs)
321
349
 
322
350
  def result_definition(self, schema: Schema) -> str:
323
351
  return annotation(schema.spec)
@@ -331,8 +359,7 @@ def source_form(value, markdown: bool = False) -> str:
331
359
  def class_definitions(
332
360
  classes: Sequence[Type[Any]],
333
361
  *,
334
- include_pg_object_as_base: bool = False,
335
- include_methods: bool = False,
362
+ allowed_dependencies: set[Type[Any]] | None = None,
336
363
  strict: bool = False,
337
364
  markdown: bool = False,
338
365
  ) -> str | None:
@@ -347,8 +374,7 @@ def class_definitions(
347
374
  class_definition(
348
375
  cls,
349
376
  strict=strict,
350
- include_pg_object_as_base=include_pg_object_as_base,
351
- include_methods=include_methods,
377
+ allowed_dependencies=allowed_dependencies,
352
378
  )
353
379
  )
354
380
  ret = def_str.getvalue()
@@ -360,8 +386,7 @@ def class_definitions(
360
386
  def class_definition(
361
387
  cls,
362
388
  strict: bool = False,
363
- include_pg_object_as_base: bool = False,
364
- include_methods: bool = False,
389
+ allowed_dependencies: set[Type[Any]] | None = None,
365
390
  ) -> str:
366
391
  """Returns the Python class definition."""
367
392
  out = io.StringIO()
@@ -369,7 +394,7 @@ def class_definition(
369
394
  eligible_bases = []
370
395
  for base_cls in cls.__bases__:
371
396
  if base_cls is not object:
372
- if include_pg_object_as_base or base_cls is not pg.Object:
397
+ if allowed_dependencies is None or base_cls in allowed_dependencies:
373
398
  eligible_bases.append(base_cls.__name__)
374
399
 
375
400
  if eligible_bases:
@@ -406,32 +431,41 @@ def class_definition(
406
431
  out.write(' # ')
407
432
  out.write(line)
408
433
  out.write('\n')
409
- out.write(f' {field.key}: {annotation(field.value, strict=strict)}')
410
- out.write('\n')
411
- empty_class = False
412
434
 
413
- if include_methods:
414
- for method in _iter_newly_defined_methods(cls):
415
- out.write('\n')
416
- out.write(
417
- textwrap.indent(
418
- inspect.cleandoc('\n' + inspect.getsource(method)), ' ' * 2)
435
+ annotation_str = annotation(
436
+ field.value, strict=strict, allowed_dependencies=allowed_dependencies
419
437
  )
438
+ out.write(f' {field.key}: {annotation_str}')
420
439
  out.write('\n')
421
440
  empty_class = False
422
441
 
442
+ for method in _iter_newly_defined_methods(cls, allowed_dependencies):
443
+ source = inspect.getsource(method)
444
+ # Remove decorators from the method definition.
445
+ source = re.sub(r'\s*@.*\.include_method_in_prompt.*\n', '', source)
446
+ out.write('\n')
447
+ out.write(
448
+ textwrap.indent(
449
+ inspect.cleandoc('\n' + source), ' ' * 2)
450
+ )
451
+ out.write('\n')
452
+ empty_class = False
453
+
423
454
  if empty_class:
424
455
  out.write(' pass\n')
425
456
  return out.getvalue()
426
457
 
427
458
 
428
- def _iter_newly_defined_methods(cls):
429
- names = set(dir(cls))
459
+ def _iter_newly_defined_methods(
460
+ cls, allowed_dependencies: set[Type[Any]] | None):
461
+ names = {attr_name: True for attr_name in dir(cls)}
430
462
  for base in cls.__bases__:
431
- names -= set(dir(base))
432
- for name in names:
463
+ if allowed_dependencies is None or base in allowed_dependencies:
464
+ for name in dir(base):
465
+ names.pop(name, None)
466
+ for name in names.keys():
433
467
  attr = getattr(cls, name)
434
- if callable(attr):
468
+ if callable(attr) and should_include_method_in_prompt(attr):
435
469
  yield attr
436
470
 
437
471
 
@@ -439,8 +473,12 @@ def annotation(
439
473
  vs: pg.typing.ValueSpec,
440
474
  annotate_optional: bool = True,
441
475
  strict: bool = False,
476
+ allowed_dependencies: set[Type[Any]] | None = None,
442
477
  ) -> str:
443
478
  """Returns the annotation string for a value spec."""
479
+ child_annotation_kwargs = dict(
480
+ strict=strict, allowed_dependencies=allowed_dependencies
481
+ )
444
482
  if isinstance(vs, pg.typing.Any):
445
483
  return 'Any'
446
484
  elif isinstance(vs, pg.typing.Enum):
@@ -449,7 +487,7 @@ def annotation(
449
487
  elif isinstance(vs, pg.typing.Union):
450
488
  candidate_str = ', '.join(
451
489
  [
452
- annotation(c, annotate_optional=False, strict=strict)
490
+ annotation(c, annotate_optional=False, **child_annotation_kwargs)
453
491
  for c in vs.candidates
454
492
  ]
455
493
  )
@@ -485,20 +523,23 @@ def annotation(
485
523
  )
486
524
  x += '(' + ', '.join(constraints) + ')'
487
525
  elif isinstance(vs, pg.typing.Object):
488
- x = vs.cls.__name__
526
+ if allowed_dependencies is None or vs.cls in allowed_dependencies:
527
+ x = vs.cls.__name__
528
+ else:
529
+ x = 'Any'
489
530
  elif isinstance(vs, pg.typing.List):
490
- item_str = annotation(vs.element.value, strict=strict)
531
+ item_str = annotation(vs.element.value, **child_annotation_kwargs)
491
532
  x = f'list[{item_str}]'
492
533
  elif isinstance(vs, pg.typing.Tuple):
493
534
  elem_str = ', '.join(
494
- [annotation(el.value, strict=strict) for el in vs.elements]
535
+ [annotation(el.value, **child_annotation_kwargs) for el in vs.elements]
495
536
  )
496
537
  x = f'tuple[{elem_str}]'
497
538
  elif isinstance(vs, pg.typing.Dict):
498
539
  kv_pairs = None
499
540
  if vs.schema is not None:
500
541
  kv_pairs = [
501
- (k, annotation(f.value, strict=strict))
542
+ (k, annotation(f.value, **child_annotation_kwargs))
502
543
  for k, f in vs.schema.items()
503
544
  if isinstance(k, pg.typing.ConstStrKey)
504
545
  ]
@@ -509,7 +550,7 @@ def annotation(
509
550
  if strict:
510
551
  x = f'pg.typing.Dict({x})'
511
552
  elif vs.schema and vs.schema.dynamic_field:
512
- v = annotation(vs.schema.dynamic_field.value, strict=strict)
553
+ v = annotation(vs.schema.dynamic_field.value, **child_annotation_kwargs)
513
554
  x = f'dict[str, {v}]'
514
555
  else:
515
556
  x = 'dict[str, Any]'
@@ -604,7 +645,12 @@ class ValuePythonRepr(ValueRepr):
604
645
  cls_schema = Schema.from_value(value)
605
646
  if isinstance(cls_schema.spec, pg.typing.Object):
606
647
  object_code = SchemaPythonRepr().class_definitions(
607
- cls_schema, markdown=markdown, include_pg_object_as_base=True
648
+ cls_schema,
649
+ markdown=markdown,
650
+ # We add `pg.Object` as additional dependencies to the class
651
+ # definition so exemplars for class generation could show
652
+ # pg.Object as their bases.
653
+ additional_dependencies=[pg.Object]
608
654
  )
609
655
  assert object_code is not None
610
656
  return object_code