langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. langfun/__init__.py +22 -2
  2. langfun/core/__init__.py +17 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -28
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +69 -2
  18. langfun/core/component_test.py +54 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +17 -0
  24. langfun/core/eval/base.py +767 -140
  25. langfun/core/eval/base_test.py +238 -53
  26. langfun/core/eval/matching.py +80 -76
  27. langfun/core/eval/matching_test.py +19 -9
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +37 -28
  31. langfun/core/eval/scoring_test.py +21 -3
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +3 -21
  55. langfun/core/langfunc_test.py +26 -8
  56. langfun/core/language_model.py +686 -48
  57. langfun/core/language_model_test.py +681 -44
  58. langfun/core/llms/__init__.py +100 -12
  59. langfun/core/llms/anthropic.py +488 -0
  60. langfun/core/llms/anthropic_test.py +235 -0
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +88 -28
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +39 -26
  69. langfun/core/llms/fake_test.py +136 -11
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -197
  74. langfun/core/llms/groq.py +276 -0
  75. langfun/core/llms/groq_test.py +64 -0
  76. langfun/core/llms/llama_cpp.py +15 -40
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +436 -226
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +35 -174
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -23
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +15 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +9 -8
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +278 -0
  112. langfun/core/structured/function_generation_test.py +399 -0
  113. langfun/core/structured/mapping.py +150 -46
  114. langfun/core/structured/mapping_test.py +105 -0
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +71 -22
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
  119. langfun/core/structured/schema.py +208 -99
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_generation_test.py +2 -2
  122. langfun/core/structured/schema_test.py +133 -34
  123. langfun/core/structured/scoring.py +125 -19
  124. langfun/core/structured/scoring_test.py +30 -0
  125. langfun/core/structured/tokenization.py +64 -0
  126. langfun/core/structured/tokenization_test.py +48 -0
  127. langfun/core/template.py +240 -11
  128. langfun/core/template_test.py +146 -1
  129. langfun/core/templates/conversation.py +9 -0
  130. langfun/core/templates/conversation_test.py +4 -3
  131. langfun/core/templates/selfplay_test.py +14 -2
  132. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  133. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  134. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  135. langfun/core/coding/python/errors.py +0 -108
  136. langfun/core/coding/python/errors_test.py +0 -99
  137. langfun/core/coding/python/permissions.py +0 -90
  138. langfun/core/coding/python/permissions_test.py +0 -86
  139. langfun/core/structured/prompting.py +0 -217
  140. langfun/core/text_formatting.py +0 -162
  141. langfun/core/text_formatting_test.py +0 -47
  142. langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
  143. langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
  144. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  145. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@ import inspect
18
18
  import typing
19
19
  import unittest
20
20
 
21
+ import langfun.core as lf
21
22
  from langfun.core.llms import fake
22
23
  from langfun.core.structured import schema as schema_lib
23
24
  import pyglove as pg
@@ -192,9 +193,9 @@ class SchemaTest(unittest.TestCase):
192
193
  self.assertEqual(schema.parse('{"result": 1}'), 1)
193
194
  schema = schema_lib.Schema(dict[str, int])
194
195
  self.assertEqual(
195
- schema.parse(
196
- '{"result": {"_type": "Unknown", "x": 1}}}', force_dict=True),
197
- dict(x=1))
196
+ schema.parse('{"result": {"x": 1}}}'),
197
+ dict(x=1)
198
+ )
198
199
  with self.assertRaisesRegex(
199
200
  schema_lib.SchemaError, 'Expect .* but encountered .*'):
200
201
  schema.parse('{"result": "def"}')
@@ -260,15 +261,18 @@ class ClassDependenciesTest(unittest.TestCase):
260
261
  class A(pg.Object):
261
262
  foo: tuple[Foo, int]
262
263
 
264
+ class B(pg.Object):
265
+ pass
266
+
263
267
  class X(pg.Object):
264
- k: int
268
+ k: dict[str, B]
265
269
 
266
- class B(A):
270
+ class C(A):
267
271
  bar: Bar
268
272
  foo2: Foo | X
269
273
 
270
274
  a = A(foo=(Foo(1), 0))
271
- self.assertEqual(schema_lib.class_dependencies(a), [Foo, A, Bar, X, B])
275
+ self.assertEqual(schema_lib.class_dependencies(a), [Foo, A, Bar, B, X, C])
272
276
 
273
277
  self.assertEqual(schema_lib.class_dependencies(1), [])
274
278
 
@@ -280,9 +284,10 @@ class SchemaPythonReprTest(unittest.TestCase):
280
284
  value_spec: pg.typing.ValueSpec,
281
285
  expected_annotation: str,
282
286
  strict: bool = False,
287
+ **kwargs,
283
288
  ) -> None:
284
289
  self.assertEqual(
285
- schema_lib.annotation(value_spec, strict=strict),
290
+ schema_lib.annotation(value_spec, strict=strict, **kwargs),
286
291
  expected_annotation,
287
292
  )
288
293
 
@@ -358,11 +363,27 @@ class SchemaPythonReprTest(unittest.TestCase):
358
363
  self.assert_annotation(
359
364
  pg.typing.Object(Activity).noneable(), 'Activity | None'
360
365
  )
366
+ self.assert_annotation(
367
+ pg.typing.Object(Activity).noneable(), 'Activity | None',
368
+ allowed_dependencies=set([Activity]),
369
+ )
370
+ self.assert_annotation(
371
+ pg.typing.Object(Activity).noneable(), 'Any | None',
372
+ allowed_dependencies=set(),
373
+ )
361
374
 
362
375
  # List.
363
376
  self.assert_annotation(
364
377
  pg.typing.List(pg.typing.Object(Activity)), 'list[Activity]'
365
378
  )
379
+ self.assert_annotation(
380
+ pg.typing.List(pg.typing.Object(Activity)), 'list[Activity]',
381
+ allowed_dependencies=set([Activity]),
382
+ )
383
+ self.assert_annotation(
384
+ pg.typing.List(pg.typing.Object(Activity)), 'list[Any]',
385
+ allowed_dependencies=set(),
386
+ )
366
387
  self.assert_annotation(
367
388
  pg.typing.List(pg.typing.Object(Activity)).noneable(),
368
389
  'list[Activity] | None',
@@ -374,16 +395,35 @@ class SchemaPythonReprTest(unittest.TestCase):
374
395
 
375
396
  # Tuple.
376
397
  self.assert_annotation(
377
- pg.typing.Tuple([pg.typing.Int(), pg.typing.Str()]), 'tuple[int, str]'
398
+ pg.typing.Tuple([Activity, pg.typing.Str()]), 'tuple[Activity, str]'
399
+ )
400
+ self.assert_annotation(
401
+ pg.typing.Tuple([Activity, pg.typing.Str()]), 'tuple[Activity, str]',
402
+ allowed_dependencies=set([Activity]),
403
+ )
404
+ self.assert_annotation(
405
+ pg.typing.Tuple([Activity, pg.typing.Str()]), 'tuple[Any, str]',
406
+ allowed_dependencies=set(),
378
407
  )
379
408
  self.assert_annotation(
380
- pg.typing.Tuple([pg.typing.Int(), pg.typing.Str()]).noneable(),
381
- 'tuple[int, str] | None',
409
+ pg.typing.Tuple([Activity, pg.typing.Str()]).noneable(),
410
+ 'tuple[Activity, str] | None',
382
411
  )
383
412
 
384
413
  # Dict.
385
414
  self.assert_annotation(
386
- pg.typing.Dict({'x': int, 'y': str}), '{\'x\': int, \'y\': str}'
415
+ pg.typing.Dict({'x': Activity, 'y': str}),
416
+ '{\'x\': Activity, \'y\': str}'
417
+ )
418
+ self.assert_annotation(
419
+ pg.typing.Dict({'x': Activity, 'y': str}),
420
+ '{\'x\': Activity, \'y\': str}',
421
+ allowed_dependencies=set([Activity]),
422
+ )
423
+ self.assert_annotation(
424
+ pg.typing.Dict({'x': Activity, 'y': str}),
425
+ '{\'x\': Any, \'y\': str}',
426
+ allowed_dependencies=set(),
387
427
  )
388
428
  self.assert_annotation(
389
429
  pg.typing.Dict({'x': int, 'y': str}),
@@ -395,6 +435,15 @@ class SchemaPythonReprTest(unittest.TestCase):
395
435
  'dict[str, Any]',
396
436
  strict=False,
397
437
  )
438
+
439
+ class DictValue(pg.Object):
440
+ pass
441
+
442
+ self.assert_annotation(
443
+ pg.typing.Dict([(pg.typing.StrKey(), DictValue)]),
444
+ 'dict[str, DictValue]',
445
+ strict=False,
446
+ )
398
447
  self.assert_annotation(
399
448
  pg.typing.Dict(),
400
449
  'dict[str, Any]',
@@ -408,6 +457,13 @@ class SchemaPythonReprTest(unittest.TestCase):
408
457
  ).noneable(),
409
458
  'Union[Activity, Itinerary, None]',
410
459
  )
460
+ self.assert_annotation(
461
+ pg.typing.Union(
462
+ [pg.typing.Object(Activity), pg.typing.Object(Itinerary)]
463
+ ).noneable(),
464
+ 'Union[Activity, Any, None]',
465
+ allowed_dependencies=set([Activity]),
466
+ )
411
467
 
412
468
  # Any.
413
469
  self.assert_annotation(pg.typing.Any(), 'Any')
@@ -415,13 +471,13 @@ class SchemaPythonReprTest(unittest.TestCase):
415
471
 
416
472
  def test_class_definition(self):
417
473
  self.assertEqual(
418
- schema_lib.class_definition(Activity),
474
+ schema_lib.class_definition(Activity, allowed_dependencies=set()),
419
475
  'class Activity:\n description: str\n',
420
476
  )
421
477
  self.assertEqual(
422
478
  schema_lib.class_definition(Itinerary),
423
479
  inspect.cleandoc("""
424
- class Itinerary:
480
+ class Itinerary(Object):
425
481
  \"\"\"A travel itinerary for a day.\"\"\"
426
482
  day: int(min=1)
427
483
  type: Literal['daytime', 'nighttime']
@@ -431,7 +487,9 @@ class SchemaPythonReprTest(unittest.TestCase):
431
487
  """) + '\n',
432
488
  )
433
489
  self.assertEqual(
434
- schema_lib.class_definition(PlaceOfInterest),
490
+ schema_lib.class_definition(
491
+ PlaceOfInterest, allowed_dependencies=set()
492
+ ),
435
493
  inspect.cleandoc("""
436
494
  class PlaceOfInterest:
437
495
  \"\"\"The name of a place of interest.
@@ -447,11 +505,11 @@ class SchemaPythonReprTest(unittest.TestCase):
447
505
  pass
448
506
 
449
507
  self.assertEqual(
450
- schema_lib.class_definition(A),
508
+ schema_lib.class_definition(A, allowed_dependencies=set()),
451
509
  'class A:\n pass\n',
452
510
  )
453
511
  self.assertEqual(
454
- schema_lib.class_definition(A, include_pg_object_as_base=True),
512
+ schema_lib.class_definition(A),
455
513
  'class A(Object):\n pass\n',
456
514
  )
457
515
 
@@ -459,9 +517,27 @@ class SchemaPythonReprTest(unittest.TestCase):
459
517
  x: str
460
518
  __kwargs__: typing.Any
461
519
 
462
- with self.assertRaisesRegex(
463
- TypeError, 'Variable-length keyword arguments is not supported'):
464
- schema_lib.class_definition(C)
520
+ self.assertEqual(
521
+ schema_lib.class_definition(C), 'class C(Object):\n x: str\n'
522
+ )
523
+
524
+ class D(pg.Object):
525
+ x: str
526
+ @schema_lib.include_method_in_prompt
527
+ def __call__(self, y: int) -> int:
528
+ return len(self.x) + y
529
+
530
+ self.assertEqual(
531
+ schema_lib.class_definition(D),
532
+ inspect.cleandoc(
533
+ """
534
+ class D(Object):
535
+ x: str
536
+
537
+ def __call__(self, y: int) -> int:
538
+ return len(self.x) + y
539
+ """) + '\n'
540
+ )
465
541
 
466
542
  def test_repr(self):
467
543
  class Foo(pg.Object):
@@ -479,10 +555,21 @@ class SchemaPythonReprTest(unittest.TestCase):
479
555
  class A(pg.Object):
480
556
  foo: Foo
481
557
 
558
+ @schema_lib.include_method_in_prompt
559
+ def foo_value(self) -> int:
560
+ return self.foo.x
561
+
562
+ def baz_value(self) -> str:
563
+ return 'baz'
564
+
482
565
  class B(A):
483
566
  bar: Bar
484
567
  foo2: Foo
485
568
 
569
+ @schema_lib.include_method_in_prompt
570
+ def bar_value(self) -> str:
571
+ return self.bar.y
572
+
486
573
  schema = schema_lib.Schema([B])
487
574
  self.assertEqual(
488
575
  schema_lib.SchemaPythonRepr().class_definitions(schema),
@@ -490,9 +577,6 @@ class SchemaPythonReprTest(unittest.TestCase):
490
577
  class Foo:
491
578
  x: int
492
579
 
493
- class A:
494
- foo: Foo
495
-
496
580
  class Bar:
497
581
  """Class Bar."""
498
582
  y: str
@@ -501,10 +585,16 @@ class SchemaPythonReprTest(unittest.TestCase):
501
585
  """Baz(y: str)"""
502
586
  y: str
503
587
 
504
- class B(A):
588
+ class B:
505
589
  foo: Foo
506
590
  bar: Bar
507
591
  foo2: Foo
592
+
593
+ def bar_value(self) -> str:
594
+ return self.bar.y
595
+
596
+ def foo_value(self) -> int:
597
+ return self.foo.x
508
598
  ''') + '\n',
509
599
  )
510
600
 
@@ -521,9 +611,6 @@ class SchemaPythonReprTest(unittest.TestCase):
521
611
  class Foo:
522
612
  x: int
523
613
 
524
- class A:
525
- foo: Foo
526
-
527
614
  class Bar:
528
615
  """Class Bar."""
529
616
  y: str
@@ -532,10 +619,16 @@ class SchemaPythonReprTest(unittest.TestCase):
532
619
  """Baz(y: str)"""
533
620
  y: str
534
621
 
535
- class B(A):
622
+ class B:
536
623
  foo: Foo
537
624
  bar: Bar
538
625
  foo2: Foo
626
+
627
+ def bar_value(self) -> str:
628
+ return self.bar.y
629
+
630
+ def foo_value(self) -> int:
631
+ return self.foo.x
539
632
  ```
540
633
  '''),
541
634
  )
@@ -543,16 +636,12 @@ class SchemaPythonReprTest(unittest.TestCase):
543
636
  schema_lib.SchemaPythonRepr().repr(
544
637
  schema,
545
638
  include_result_definition=False,
546
- include_pg_object_as_base=True,
547
639
  markdown=False,
548
640
  ),
549
641
  inspect.cleandoc('''
550
- class Foo(Object):
642
+ class Foo:
551
643
  x: int
552
644
 
553
- class A(Object):
554
- foo: Foo
555
-
556
645
  class Bar:
557
646
  """Class Bar."""
558
647
  y: str
@@ -561,10 +650,16 @@ class SchemaPythonReprTest(unittest.TestCase):
561
650
  """Baz(y: str)"""
562
651
  y: str
563
652
 
564
- class B(A):
653
+ class B:
565
654
  foo: Foo
566
655
  bar: Bar
567
656
  foo2: Foo
657
+
658
+ def bar_value(self) -> str:
659
+ return self.bar.y
660
+
661
+ def foo_value(self) -> int:
662
+ return self.foo.x
568
663
  '''),
569
664
  )
570
665
 
@@ -600,6 +695,10 @@ class ValuePythonReprTest(unittest.TestCase):
600
695
  schema_lib.ValuePythonRepr().repr(1, schema_lib.Schema(int)),
601
696
  '```python\n1\n```'
602
697
  )
698
+ self.assertEqual(
699
+ schema_lib.ValuePythonRepr().repr(lf.Template('hi, {{a}}', a='foo')),
700
+ 'hi, foo'
701
+ )
603
702
  self.assertEqual(
604
703
  schema_lib.ValuePythonRepr().repr(
605
704
  A([Foo(1), Foo(2)], 'bar'), schema_lib.Schema(A), markdown=False,
@@ -612,7 +711,7 @@ class ValuePythonReprTest(unittest.TestCase):
612
711
  ```python
613
712
  class Foo(Object):
614
713
  x: int
615
-
714
+
616
715
  class A(Object):
617
716
  foo: list[Foo]
618
717
  y: str | None
@@ -17,13 +17,13 @@ from typing import Any, Type, Union
17
17
 
18
18
  import langfun.core as lf
19
19
  from langfun.core.structured import mapping
20
- from langfun.core.structured import prompting
20
+ from langfun.core.structured import querying
21
21
  from langfun.core.structured import schema as schema_lib
22
22
  import pyglove as pg
23
23
 
24
24
 
25
25
  def score(
26
- prompt: Union[str, pg.Symbolic],
26
+ prompt: Union[str, pg.Symbolic] | list[str | pg.Symbolic],
27
27
  completions: list[str | pg.Symbolic],
28
28
  schema: Union[
29
29
  schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
@@ -32,9 +32,58 @@ def score(
32
32
  lm: lf.LanguageModel | None = None,
33
33
  examples: list[mapping.MappingExample] | None = None,
34
34
  protocol: schema_lib.SchemaProtocol = 'python',
35
+ return_scoring_results: bool = False,
35
36
  **kwargs,
36
- ) -> list[float]:
37
- """Scores the outputs based on the prompt."""
37
+ ) -> list[float] | list[lf.LMScoringResult]:
38
+ """Scores the outputs based on the prompt.
39
+
40
+ Examples:
41
+ ```
42
+ # Example 1: Scoring text output based on the user prompt.
43
+ scores = lf.score('{{x}} + {{y}} =', ['1', '2', '3'], lm=lm, x=1, y=2)
44
+ assert len(scores) == 3
45
+
46
+ # Example 2: Scoring int output based on the formulated OOP prompt.
47
+ scores = lf.score('1 + 1 =', [1, 2, 3], lm=lm)
48
+ assert len(scores) == 3
49
+
50
+ class Answer(pg.Object):
51
+ result: int
52
+
53
+ # Example 3: Scoring object output based on the formulated OOP prompt.
54
+ scores = lf.score('1 + 1 =', [Answer(1), Answer(2), Answer(3)], lm=lm)
55
+ assert len(scores) == 3
56
+
57
+ # Example 4: Scoring object field value based on the formulated OOP prompt
58
+ # and the generated tokens before the first `pg.oneof`.
59
+ scores = lf.score('1 + 1 =', [Answer(pg.oneof([1, 2, 3]))], lm=lm)
60
+ assert len(scores) == 3
61
+
62
+ # Example 5: Scoring multiple prompt/completion pairs.
63
+ scores = lf.score(
64
+ ['1 + 1=', '2 + 3='],
65
+ ['2', '4'],
66
+ lm=lm
67
+ )
68
+ assert len(scores) == 2
69
+ ```
70
+
71
+ Args:
72
+ prompt: The prompt(s) based on which each completion will be scored.
73
+ completions: A list of strings or symbolic objects as the output.
74
+ schema: The schema as the output type. If None, it will be inferred from
75
+ the completions.
76
+ lm: The language model used for scoring.
77
+ examples: Fewshot exemplars used together with the prompt in getting the
78
+ completions.
79
+ protocol: The protocol for formulating the prompt based on objects.
80
+ return_scoring_results: If True, returns a list of `lf.LMScoringResult`,
81
+ otherwise returns a list of floats as the scores of each completion.
82
+ **kwargs: Keyword arguments that are referred by the prompt.
83
+
84
+ Returns:
85
+ A list of floats or `lf.LMScoringResult` as the score of each completion.
86
+ """
38
87
  if not completions:
39
88
  raise ValueError('`completions` must not be empty.')
40
89
 
@@ -48,28 +97,85 @@ def score(
48
97
  f'{[type(c) for c in completions]}.'
49
98
  )
50
99
 
51
- input_message = prompting.query(
52
- prompt,
53
- schema,
54
- examples=examples,
55
- protocol=protocol,
56
- skip_lm=True,
57
- returns_message=True,
58
- **kwargs,
59
- )
100
+ if isinstance(prompt, list):
101
+ prompts = []
102
+ for p in prompt:
103
+ prompts.append(
104
+ querying.query_prompt(
105
+ p,
106
+ schema,
107
+ examples=examples,
108
+ protocol=protocol,
109
+ **kwargs,
110
+ )
111
+ )
112
+ input_message = prompts
113
+ else:
114
+ input_message = querying.query_prompt(
115
+ prompt,
116
+ schema,
117
+ examples=examples,
118
+ protocol=protocol,
119
+ **kwargs,
120
+ )
60
121
  if lm is None:
61
122
  lm_override = lf.get_contextual_override('lm')
62
123
  if lm_override is None:
63
124
  raise ValueError('`lm` must be specified or provided from `lf.context`.')
64
125
  lm = lm_override.value
65
126
 
127
+ completion_reprs = []
128
+ for c in completions:
129
+ if isinstance(c, mapping.MappingError):
130
+ completion_reprs.append(c.lm_response)
131
+ else:
132
+ rep = mapping.MappingExample.value_repr(
133
+ c, protocol=protocol, compact=False, verbose=False
134
+ )
135
+
136
+ # NOTE(daiyip): supporting scenario of scoring object field with
137
+ # `pg.oneof`.
138
+ oneof_pos = rep.find('OneOf(')
139
+ if oneof_pos == -1:
140
+ completion_reprs.append(rep)
141
+ else:
142
+ assert protocol == 'python', protocol
143
+ if isinstance(input_message, list):
144
+ raise ValueError(
145
+ 'Scoring on object fields using `pg.oneof` must share the '
146
+ f'same prompt. Encountered: {prompt}'
147
+ )
148
+ input_message.text += '\n' + rep[:oneof_pos]
149
+ oneof = _get_first_oneof(c)
150
+ for v in oneof.candidates:
151
+ completion_reprs.append(
152
+ pg.format(
153
+ v,
154
+ python_format=True,
155
+ compact=False,
156
+ verbose=False,
157
+ root_indent=oneof.sym_path.depth
158
+ )
159
+ )
160
+
66
161
  results = lm.score(
67
162
  input_message,
68
- [
69
- mapping.MappingExample.value_repr(
70
- c, protocol=protocol, compact=False, verbose=False
71
- )
72
- for c in completions
73
- ],
163
+ completion_reprs,
74
164
  )
165
+ if return_scoring_results:
166
+ return results
75
167
  return [r.score for r in results]
168
+
169
+
170
+ def _get_first_oneof(value: Any) -> pg.hyper.OneOf:
171
+ """Gets the first pg.oneof from a symbolic object."""
172
+ oneofs = []
173
+ def select_oneofs(k, v, p):
174
+ del k, p
175
+ if isinstance(v, pg.hyper.OneOf):
176
+ oneofs.append(v)
177
+ return pg.TraverseAction.CONTINUE
178
+ return pg.TraverseAction.ENTER
179
+ pg.traverse(value, select_oneofs)
180
+ assert oneofs
181
+ return oneofs[0]
@@ -16,6 +16,11 @@ import unittest
16
16
  import langfun.core as lf
17
17
  from langfun.core.llms import fake
18
18
  from langfun.core.structured import scoring
19
+ import pyglove as pg
20
+
21
+
22
+ class Answer(pg.Object):
23
+ result: int
19
24
 
20
25
 
21
26
  class ScoringTest(unittest.TestCase):
@@ -32,9 +37,34 @@ class ScoringTest(unittest.TestCase):
32
37
  with self.assertRaisesRegex(ValueError, '`lm` must be specified'):
33
38
  scoring.score('hi', [1, 2])
34
39
 
40
+ with self.assertRaisesRegex(
41
+ ValueError,
42
+ 'Scoring on object fields using `pg.oneof` must share the same prompt',
43
+ ):
44
+ scoring.score(
45
+ ['1 + 1=', '2 + 3='],
46
+ [Answer(pg.oneof([1, 2, 3]))],
47
+ lm=fake.Echo(),
48
+ )
49
+
35
50
  def test_score(self):
36
51
  self.assertEqual(scoring.score('hi', [1, 2], lm=fake.Echo()), [0.0, -1.0])
37
52
 
53
+ def test_score_on_field_values(self):
54
+ self.assertEqual(
55
+ scoring.score(
56
+ '1 + 1=',
57
+ [Answer(pg.oneof([1, 2, 3]))], lm=fake.Echo()
58
+ ),
59
+ [0.0, -1.0, -2.0]
60
+ )
61
+
62
+ def test_score_returning_scoring_results(self):
63
+ self.assertEqual(scoring.score(
64
+ 'hi', [1, 2], lm=fake.Echo(), return_scoring_results=True),
65
+ [lf.LMScoringResult(score=0.0, gradients=None),
66
+ lf.LMScoringResult(score=-1.0, gradients=None)])
67
+
38
68
  def test_scope_with_lm_from_the_context(self):
39
69
  with lf.context(lm=fake.Echo()):
40
70
  self.assertEqual(scoring.score('hi', [1, 2]), [0.0, -1.0])
@@ -0,0 +1,64 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tokenize the prompt for `lf.query`."""
15
+
16
+ from typing import Any, Type, Union
17
+
18
+ import langfun.core as lf
19
+ from langfun.core.structured import mapping
20
+ from langfun.core.structured import querying
21
+ from langfun.core.structured import schema as schema_lib
22
+ import pyglove as pg
23
+
24
+
25
+ def tokenize(
26
+ prompt: Union[str, pg.Symbolic] | list[str | pg.Symbolic],
27
+ schema: Union[
28
+ schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
29
+ ] = None,
30
+ *,
31
+ lm: lf.LanguageModel | None = None,
32
+ examples: list[mapping.MappingExample] | None = None,
33
+ protocol: schema_lib.SchemaProtocol = 'python',
34
+ **kwargs,
35
+ ) -> list[tuple[str | bytes, int]]:
36
+ """Tokenize the prompt for `lf.query`.
37
+
38
+ Args:
39
+ prompt: The prompt(s) based on which each completion will be scored.
40
+ schema: The schema as the output type. If None, it will be inferred from
41
+ the completions.
42
+ lm: The language model used for scoring.
43
+ examples: Fewshot exemplars used together with the prompt in getting the
44
+ completions.
45
+ protocol: The protocol for formulating the prompt based on objects.
46
+ **kwargs: Keyword arguments that are referred by the prompt.
47
+
48
+ Returns:
49
+ A list of (text, token_id) tuples.
50
+ """
51
+ input_message = querying.query_prompt(
52
+ prompt,
53
+ schema,
54
+ examples=examples,
55
+ protocol=protocol,
56
+ **kwargs,
57
+ )
58
+ if lm is None:
59
+ lm_override = lf.get_contextual_override('lm')
60
+ if lm_override is None:
61
+ raise ValueError('`lm` must be specified or provided from `lf.context`.')
62
+ lm = lm_override.value
63
+
64
+ return lm.tokenize(input_message)
@@ -0,0 +1,48 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import unittest
16
+ import langfun.core as lf
17
+ from langfun.core.llms import fake
18
+ from langfun.core.structured import tokenization
19
+ import pyglove as pg
20
+
21
+
22
+ class Answer(pg.Object):
23
+ result: int
24
+
25
+
26
+ class TokenizationTest(unittest.TestCase):
27
+
28
+ def test_bad_call(self):
29
+
30
+ with self.assertRaisesRegex(ValueError, '`lm` must be specified'):
31
+ tokenization.tokenize('hi')
32
+
33
+ def test_tokenize(self):
34
+ self.assertEqual(
35
+ tokenization.tokenize('hi', lm=fake.Echo()),
36
+ [('hi', 0)]
37
+ )
38
+
39
+ def test_tokenize_with_lm_from_the_context(self):
40
+ with lf.context(lm=fake.Echo()):
41
+ self.assertEqual(
42
+ tokenization.tokenize('hi'),
43
+ [('hi', 0)]
44
+ )
45
+
46
+
47
+ if __name__ == '__main__':
48
+ unittest.main()