langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501150804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. langfun/__init__.py +20 -2
  2. langfun/core/__init__.py +16 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -21
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +63 -2
  18. langfun/core/component_test.py +53 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +17 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +16 -1
  24. langfun/core/eval/base.py +622 -174
  25. langfun/core/eval/base_test.py +200 -54
  26. langfun/core/eval/matching.py +63 -76
  27. langfun/core/eval/matching_test.py +17 -8
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +26 -26
  31. langfun/core/eval/scoring_test.py +19 -2
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +4 -17
  55. langfun/core/langfunc_test.py +22 -6
  56. langfun/core/language_model.py +577 -39
  57. langfun/core/language_model_test.py +470 -56
  58. langfun/core/llms/__init__.py +87 -16
  59. langfun/core/llms/anthropic.py +312 -87
  60. langfun/core/llms/anthropic_test.py +71 -3
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +53 -2
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +11 -7
  69. langfun/core/llms/fake_test.py +14 -0
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -202
  74. langfun/core/llms/groq.py +160 -144
  75. langfun/core/llms/groq_test.py +31 -137
  76. langfun/core/llms/llama_cpp.py +15 -42
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +395 -203
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +30 -395
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -26
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +12 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +7 -6
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +60 -27
  112. langfun/core/structured/function_generation_test.py +72 -2
  113. langfun/core/structured/mapping.py +97 -47
  114. langfun/core/structured/mapping_test.py +90 -2
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +53 -9
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
  119. langfun/core/structured/schema.py +204 -97
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_test.py +130 -29
  122. langfun/core/structured/scoring.py +125 -19
  123. langfun/core/structured/scoring_test.py +30 -0
  124. langfun/core/structured/tokenization.py +64 -0
  125. langfun/core/structured/tokenization_test.py +48 -0
  126. langfun/core/template.py +115 -1
  127. langfun/core/template_test.py +71 -1
  128. langfun/core/templates/conversation.py +9 -0
  129. langfun/core/templates/conversation_test.py +4 -3
  130. langfun/core/templates/selfplay_test.py +10 -2
  131. langfun-0.1.2.dev202501150804.dist-info/METADATA +225 -0
  132. langfun-0.1.2.dev202501150804.dist-info/RECORD +153 -0
  133. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/WHEEL +1 -1
  134. langfun/core/coding/python/errors.py +0 -108
  135. langfun/core/coding/python/errors_test.py +0 -99
  136. langfun/core/coding/python/permissions.py +0 -90
  137. langfun/core/coding/python/permissions_test.py +0 -86
  138. langfun/core/structured/prompting.py +0 -238
  139. langfun/core/text_formatting.py +0 -162
  140. langfun/core/text_formatting_test.py +0 -47
  141. langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
  142. langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
  143. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/LICENSE +0 -0
  144. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@ import inspect
18
18
  import typing
19
19
  import unittest
20
20
 
21
+ import langfun.core as lf
21
22
  from langfun.core.llms import fake
22
23
  from langfun.core.structured import schema as schema_lib
23
24
  import pyglove as pg
@@ -260,15 +261,18 @@ class ClassDependenciesTest(unittest.TestCase):
260
261
  class A(pg.Object):
261
262
  foo: tuple[Foo, int]
262
263
 
264
+ class B(pg.Object):
265
+ pass
266
+
263
267
  class X(pg.Object):
264
- k: int
268
+ k: dict[str, B]
265
269
 
266
- class B(A):
270
+ class C(A):
267
271
  bar: Bar
268
272
  foo2: Foo | X
269
273
 
270
274
  a = A(foo=(Foo(1), 0))
271
- self.assertEqual(schema_lib.class_dependencies(a), [Foo, A, Bar, X, B])
275
+ self.assertEqual(schema_lib.class_dependencies(a), [Foo, A, Bar, B, X, C])
272
276
 
273
277
  self.assertEqual(schema_lib.class_dependencies(1), [])
274
278
 
@@ -280,9 +284,10 @@ class SchemaPythonReprTest(unittest.TestCase):
280
284
  value_spec: pg.typing.ValueSpec,
281
285
  expected_annotation: str,
282
286
  strict: bool = False,
287
+ **kwargs,
283
288
  ) -> None:
284
289
  self.assertEqual(
285
- schema_lib.annotation(value_spec, strict=strict),
290
+ schema_lib.annotation(value_spec, strict=strict, **kwargs),
286
291
  expected_annotation,
287
292
  )
288
293
 
@@ -358,11 +363,27 @@ class SchemaPythonReprTest(unittest.TestCase):
358
363
  self.assert_annotation(
359
364
  pg.typing.Object(Activity).noneable(), 'Activity | None'
360
365
  )
366
+ self.assert_annotation(
367
+ pg.typing.Object(Activity).noneable(), 'Activity | None',
368
+ allowed_dependencies=set([Activity]),
369
+ )
370
+ self.assert_annotation(
371
+ pg.typing.Object(Activity).noneable(), 'Any | None',
372
+ allowed_dependencies=set(),
373
+ )
361
374
 
362
375
  # List.
363
376
  self.assert_annotation(
364
377
  pg.typing.List(pg.typing.Object(Activity)), 'list[Activity]'
365
378
  )
379
+ self.assert_annotation(
380
+ pg.typing.List(pg.typing.Object(Activity)), 'list[Activity]',
381
+ allowed_dependencies=set([Activity]),
382
+ )
383
+ self.assert_annotation(
384
+ pg.typing.List(pg.typing.Object(Activity)), 'list[Any]',
385
+ allowed_dependencies=set(),
386
+ )
366
387
  self.assert_annotation(
367
388
  pg.typing.List(pg.typing.Object(Activity)).noneable(),
368
389
  'list[Activity] | None',
@@ -374,16 +395,35 @@ class SchemaPythonReprTest(unittest.TestCase):
374
395
 
375
396
  # Tuple.
376
397
  self.assert_annotation(
377
- pg.typing.Tuple([pg.typing.Int(), pg.typing.Str()]), 'tuple[int, str]'
398
+ pg.typing.Tuple([Activity, pg.typing.Str()]), 'tuple[Activity, str]'
399
+ )
400
+ self.assert_annotation(
401
+ pg.typing.Tuple([Activity, pg.typing.Str()]), 'tuple[Activity, str]',
402
+ allowed_dependencies=set([Activity]),
403
+ )
404
+ self.assert_annotation(
405
+ pg.typing.Tuple([Activity, pg.typing.Str()]), 'tuple[Any, str]',
406
+ allowed_dependencies=set(),
378
407
  )
379
408
  self.assert_annotation(
380
- pg.typing.Tuple([pg.typing.Int(), pg.typing.Str()]).noneable(),
381
- 'tuple[int, str] | None',
409
+ pg.typing.Tuple([Activity, pg.typing.Str()]).noneable(),
410
+ 'tuple[Activity, str] | None',
382
411
  )
383
412
 
384
413
  # Dict.
385
414
  self.assert_annotation(
386
- pg.typing.Dict({'x': int, 'y': str}), '{\'x\': int, \'y\': str}'
415
+ pg.typing.Dict({'x': Activity, 'y': str}),
416
+ '{\'x\': Activity, \'y\': str}'
417
+ )
418
+ self.assert_annotation(
419
+ pg.typing.Dict({'x': Activity, 'y': str}),
420
+ '{\'x\': Activity, \'y\': str}',
421
+ allowed_dependencies=set([Activity]),
422
+ )
423
+ self.assert_annotation(
424
+ pg.typing.Dict({'x': Activity, 'y': str}),
425
+ '{\'x\': Any, \'y\': str}',
426
+ allowed_dependencies=set(),
387
427
  )
388
428
  self.assert_annotation(
389
429
  pg.typing.Dict({'x': int, 'y': str}),
@@ -395,6 +435,15 @@ class SchemaPythonReprTest(unittest.TestCase):
395
435
  'dict[str, Any]',
396
436
  strict=False,
397
437
  )
438
+
439
+ class DictValue(pg.Object):
440
+ pass
441
+
442
+ self.assert_annotation(
443
+ pg.typing.Dict([(pg.typing.StrKey(), DictValue)]),
444
+ 'dict[str, DictValue]',
445
+ strict=False,
446
+ )
398
447
  self.assert_annotation(
399
448
  pg.typing.Dict(),
400
449
  'dict[str, Any]',
@@ -408,6 +457,13 @@ class SchemaPythonReprTest(unittest.TestCase):
408
457
  ).noneable(),
409
458
  'Union[Activity, Itinerary, None]',
410
459
  )
460
+ self.assert_annotation(
461
+ pg.typing.Union(
462
+ [pg.typing.Object(Activity), pg.typing.Object(Itinerary)]
463
+ ).noneable(),
464
+ 'Union[Activity, Any, None]',
465
+ allowed_dependencies=set([Activity]),
466
+ )
411
467
 
412
468
  # Any.
413
469
  self.assert_annotation(pg.typing.Any(), 'Any')
@@ -415,13 +471,13 @@ class SchemaPythonReprTest(unittest.TestCase):
415
471
 
416
472
  def test_class_definition(self):
417
473
  self.assertEqual(
418
- schema_lib.class_definition(Activity),
474
+ schema_lib.class_definition(Activity, allowed_dependencies=set()),
419
475
  'class Activity:\n description: str\n',
420
476
  )
421
477
  self.assertEqual(
422
478
  schema_lib.class_definition(Itinerary),
423
479
  inspect.cleandoc("""
424
- class Itinerary:
480
+ class Itinerary(Object):
425
481
  \"\"\"A travel itinerary for a day.\"\"\"
426
482
  day: int(min=1)
427
483
  type: Literal['daytime', 'nighttime']
@@ -431,7 +487,9 @@ class SchemaPythonReprTest(unittest.TestCase):
431
487
  """) + '\n',
432
488
  )
433
489
  self.assertEqual(
434
- schema_lib.class_definition(PlaceOfInterest),
490
+ schema_lib.class_definition(
491
+ PlaceOfInterest, allowed_dependencies=set()
492
+ ),
435
493
  inspect.cleandoc("""
436
494
  class PlaceOfInterest:
437
495
  \"\"\"The name of a place of interest.
@@ -447,11 +505,11 @@ class SchemaPythonReprTest(unittest.TestCase):
447
505
  pass
448
506
 
449
507
  self.assertEqual(
450
- schema_lib.class_definition(A),
508
+ schema_lib.class_definition(A, allowed_dependencies=set()),
451
509
  'class A:\n pass\n',
452
510
  )
453
511
  self.assertEqual(
454
- schema_lib.class_definition(A, include_pg_object_as_base=True),
512
+ schema_lib.class_definition(A),
455
513
  'class A(Object):\n pass\n',
456
514
  )
457
515
 
@@ -459,7 +517,27 @@ class SchemaPythonReprTest(unittest.TestCase):
459
517
  x: str
460
518
  __kwargs__: typing.Any
461
519
 
462
- self.assertEqual(schema_lib.class_definition(C), 'class C:\n x: str\n')
520
+ self.assertEqual(
521
+ schema_lib.class_definition(C), 'class C(Object):\n x: str\n'
522
+ )
523
+
524
+ class D(pg.Object):
525
+ x: str
526
+ @schema_lib.include_method_in_prompt
527
+ def __call__(self, y: int) -> int:
528
+ return len(self.x) + y
529
+
530
+ self.assertEqual(
531
+ schema_lib.class_definition(D),
532
+ inspect.cleandoc(
533
+ """
534
+ class D(Object):
535
+ x: str
536
+
537
+ def __call__(self, y: int) -> int:
538
+ return len(self.x) + y
539
+ """) + '\n'
540
+ )
463
541
 
464
542
  def test_repr(self):
465
543
  class Foo(pg.Object):
@@ -477,10 +555,21 @@ class SchemaPythonReprTest(unittest.TestCase):
477
555
  class A(pg.Object):
478
556
  foo: Foo
479
557
 
558
+ @schema_lib.include_method_in_prompt
559
+ def foo_value(self) -> int:
560
+ return self.foo.x
561
+
562
+ def baz_value(self) -> str:
563
+ return 'baz'
564
+
480
565
  class B(A):
481
566
  bar: Bar
482
567
  foo2: Foo
483
568
 
569
+ @schema_lib.include_method_in_prompt
570
+ def bar_value(self) -> str:
571
+ return self.bar.y
572
+
484
573
  schema = schema_lib.Schema([B])
485
574
  self.assertEqual(
486
575
  schema_lib.SchemaPythonRepr().class_definitions(schema),
@@ -488,9 +577,6 @@ class SchemaPythonReprTest(unittest.TestCase):
488
577
  class Foo:
489
578
  x: int
490
579
 
491
- class A:
492
- foo: Foo
493
-
494
580
  class Bar:
495
581
  """Class Bar."""
496
582
  y: str
@@ -499,10 +585,16 @@ class SchemaPythonReprTest(unittest.TestCase):
499
585
  """Baz(y: str)"""
500
586
  y: str
501
587
 
502
- class B(A):
588
+ class B:
503
589
  foo: Foo
504
590
  bar: Bar
505
591
  foo2: Foo
592
+
593
+ def bar_value(self) -> str:
594
+ return self.bar.y
595
+
596
+ def foo_value(self) -> int:
597
+ return self.foo.x
506
598
  ''') + '\n',
507
599
  )
508
600
 
@@ -519,9 +611,6 @@ class SchemaPythonReprTest(unittest.TestCase):
519
611
  class Foo:
520
612
  x: int
521
613
 
522
- class A:
523
- foo: Foo
524
-
525
614
  class Bar:
526
615
  """Class Bar."""
527
616
  y: str
@@ -530,10 +619,16 @@ class SchemaPythonReprTest(unittest.TestCase):
530
619
  """Baz(y: str)"""
531
620
  y: str
532
621
 
533
- class B(A):
622
+ class B:
534
623
  foo: Foo
535
624
  bar: Bar
536
625
  foo2: Foo
626
+
627
+ def bar_value(self) -> str:
628
+ return self.bar.y
629
+
630
+ def foo_value(self) -> int:
631
+ return self.foo.x
537
632
  ```
538
633
  '''),
539
634
  )
@@ -541,16 +636,12 @@ class SchemaPythonReprTest(unittest.TestCase):
541
636
  schema_lib.SchemaPythonRepr().repr(
542
637
  schema,
543
638
  include_result_definition=False,
544
- include_pg_object_as_base=True,
545
639
  markdown=False,
546
640
  ),
547
641
  inspect.cleandoc('''
548
- class Foo(Object):
642
+ class Foo:
549
643
  x: int
550
644
 
551
- class A(Object):
552
- foo: Foo
553
-
554
645
  class Bar:
555
646
  """Class Bar."""
556
647
  y: str
@@ -559,10 +650,16 @@ class SchemaPythonReprTest(unittest.TestCase):
559
650
  """Baz(y: str)"""
560
651
  y: str
561
652
 
562
- class B(A):
653
+ class B:
563
654
  foo: Foo
564
655
  bar: Bar
565
656
  foo2: Foo
657
+
658
+ def bar_value(self) -> str:
659
+ return self.bar.y
660
+
661
+ def foo_value(self) -> int:
662
+ return self.foo.x
566
663
  '''),
567
664
  )
568
665
 
@@ -598,6 +695,10 @@ class ValuePythonReprTest(unittest.TestCase):
598
695
  schema_lib.ValuePythonRepr().repr(1, schema_lib.Schema(int)),
599
696
  '```python\n1\n```'
600
697
  )
698
+ self.assertEqual(
699
+ schema_lib.ValuePythonRepr().repr(lf.Template('hi, {{a}}', a='foo')),
700
+ 'hi, foo'
701
+ )
601
702
  self.assertEqual(
602
703
  schema_lib.ValuePythonRepr().repr(
603
704
  A([Foo(1), Foo(2)], 'bar'), schema_lib.Schema(A), markdown=False,
@@ -610,7 +711,7 @@ class ValuePythonReprTest(unittest.TestCase):
610
711
  ```python
611
712
  class Foo(Object):
612
713
  x: int
613
-
714
+
614
715
  class A(Object):
615
716
  foo: list[Foo]
616
717
  y: str | None
@@ -17,13 +17,13 @@ from typing import Any, Type, Union
17
17
 
18
18
  import langfun.core as lf
19
19
  from langfun.core.structured import mapping
20
- from langfun.core.structured import prompting
20
+ from langfun.core.structured import querying
21
21
  from langfun.core.structured import schema as schema_lib
22
22
  import pyglove as pg
23
23
 
24
24
 
25
25
  def score(
26
- prompt: Union[str, pg.Symbolic],
26
+ prompt: Union[str, pg.Symbolic] | list[str | pg.Symbolic],
27
27
  completions: list[str | pg.Symbolic],
28
28
  schema: Union[
29
29
  schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
@@ -32,9 +32,58 @@ def score(
32
32
  lm: lf.LanguageModel | None = None,
33
33
  examples: list[mapping.MappingExample] | None = None,
34
34
  protocol: schema_lib.SchemaProtocol = 'python',
35
+ return_scoring_results: bool = False,
35
36
  **kwargs,
36
- ) -> list[float]:
37
- """Scores the outputs based on the prompt."""
37
+ ) -> list[float] | list[lf.LMScoringResult]:
38
+ """Scores the outputs based on the prompt.
39
+
40
+ Examples:
41
+ ```
42
+ # Example 1: Scoring text output based on the user prompt.
43
+ scores = lf.score('{{x}} + {{y}} =', ['1', '2', '3'], lm=lm, x=1, y=2)
44
+ assert len(scores) == 3
45
+
46
+ # Example 2: Scoring int output based on the formulated OOP prompt.
47
+ scores = lf.score('1 + 1 =', [1, 2, 3], lm=lm)
48
+ assert len(scores) == 3
49
+
50
+ class Answer(pg.Object):
51
+ result: int
52
+
53
+ # Example 3: Scoring object output based on the formulated OOP prompt.
54
+ scores = lf.score('1 + 1 =', [Answer(1), Answer(2), Answer(3)], lm=lm)
55
+ assert len(scores) == 3
56
+
57
+ # Example 4: Scoring object field value based on the formulated OOP prompt
58
+ # and the generated tokens before the first `pg.oneof`.
59
+ scores = lf.score('1 + 1 =', [Answer(pg.oneof([1, 2, 3]))], lm=lm)
60
+ assert len(scores) == 3
61
+
62
+ # Example 5: Scoring multiple prompt/completion pairs.
63
+ scores = lf.score(
64
+ ['1 + 1=', '2 + 3='],
65
+ ['2', '4'],
66
+ lm=lm
67
+ )
68
+ assert len(scores) == 2
69
+ ```
70
+
71
+ Args:
72
+ prompt: The prompt(s) based on which each completion will be scored.
73
+ completions: A list of strings or symbolic objects as the output.
74
+ schema: The schema as the output type. If None, it will be inferred from
75
+ the completions.
76
+ lm: The language model used for scoring.
77
+ examples: Fewshot exemplars used together with the prompt in getting the
78
+ completions.
79
+ protocol: The protocol for formulating the prompt based on objects.
80
+ return_scoring_results: If True, returns a list of `lf.LMScoringResult`,
81
+ otherwise returns a list of floats as the scores of each completion.
82
+ **kwargs: Keyword arguments that are referred by the prompt.
83
+
84
+ Returns:
85
+ A list of floats or `lf.LMScoringResult` as the score of each completion.
86
+ """
38
87
  if not completions:
39
88
  raise ValueError('`completions` must not be empty.')
40
89
 
@@ -48,28 +97,85 @@ def score(
48
97
  f'{[type(c) for c in completions]}.'
49
98
  )
50
99
 
51
- input_message = prompting.query(
52
- prompt,
53
- schema,
54
- examples=examples,
55
- protocol=protocol,
56
- skip_lm=True,
57
- returns_message=True,
58
- **kwargs,
59
- )
100
+ if isinstance(prompt, list):
101
+ prompts = []
102
+ for p in prompt:
103
+ prompts.append(
104
+ querying.query_prompt(
105
+ p,
106
+ schema,
107
+ examples=examples,
108
+ protocol=protocol,
109
+ **kwargs,
110
+ )
111
+ )
112
+ input_message = prompts
113
+ else:
114
+ input_message = querying.query_prompt(
115
+ prompt,
116
+ schema,
117
+ examples=examples,
118
+ protocol=protocol,
119
+ **kwargs,
120
+ )
60
121
  if lm is None:
61
122
  lm_override = lf.get_contextual_override('lm')
62
123
  if lm_override is None:
63
124
  raise ValueError('`lm` must be specified or provided from `lf.context`.')
64
125
  lm = lm_override.value
65
126
 
127
+ completion_reprs = []
128
+ for c in completions:
129
+ if isinstance(c, mapping.MappingError):
130
+ completion_reprs.append(c.lm_response)
131
+ else:
132
+ rep = mapping.MappingExample.value_repr(
133
+ c, protocol=protocol, compact=False, verbose=False
134
+ )
135
+
136
+ # NOTE(daiyip): supporting scenario of scoring object field with
137
+ # `pg.oneof`.
138
+ oneof_pos = rep.find('OneOf(')
139
+ if oneof_pos == -1:
140
+ completion_reprs.append(rep)
141
+ else:
142
+ assert protocol == 'python', protocol
143
+ if isinstance(input_message, list):
144
+ raise ValueError(
145
+ 'Scoring on object fields using `pg.oneof` must share the '
146
+ f'same prompt. Encountered: {prompt}'
147
+ )
148
+ input_message.text += '\n' + rep[:oneof_pos]
149
+ oneof = _get_first_oneof(c)
150
+ for v in oneof.candidates:
151
+ completion_reprs.append(
152
+ pg.format(
153
+ v,
154
+ python_format=True,
155
+ compact=False,
156
+ verbose=False,
157
+ root_indent=oneof.sym_path.depth
158
+ )
159
+ )
160
+
66
161
  results = lm.score(
67
162
  input_message,
68
- [
69
- mapping.MappingExample.value_repr(
70
- c, protocol=protocol, compact=False, verbose=False
71
- )
72
- for c in completions
73
- ],
163
+ completion_reprs,
74
164
  )
165
+ if return_scoring_results:
166
+ return results
75
167
  return [r.score for r in results]
168
+
169
+
170
+ def _get_first_oneof(value: Any) -> pg.hyper.OneOf:
171
+ """Gets the first pg.oneof from a symbolic object."""
172
+ oneofs = []
173
+ def select_oneofs(k, v, p):
174
+ del k, p
175
+ if isinstance(v, pg.hyper.OneOf):
176
+ oneofs.append(v)
177
+ return pg.TraverseAction.CONTINUE
178
+ return pg.TraverseAction.ENTER
179
+ pg.traverse(value, select_oneofs)
180
+ assert oneofs
181
+ return oneofs[0]
@@ -16,6 +16,11 @@ import unittest
16
16
  import langfun.core as lf
17
17
  from langfun.core.llms import fake
18
18
  from langfun.core.structured import scoring
19
+ import pyglove as pg
20
+
21
+
22
+ class Answer(pg.Object):
23
+ result: int
19
24
 
20
25
 
21
26
  class ScoringTest(unittest.TestCase):
@@ -32,9 +37,34 @@ class ScoringTest(unittest.TestCase):
32
37
  with self.assertRaisesRegex(ValueError, '`lm` must be specified'):
33
38
  scoring.score('hi', [1, 2])
34
39
 
40
+ with self.assertRaisesRegex(
41
+ ValueError,
42
+ 'Scoring on object fields using `pg.oneof` must share the same prompt',
43
+ ):
44
+ scoring.score(
45
+ ['1 + 1=', '2 + 3='],
46
+ [Answer(pg.oneof([1, 2, 3]))],
47
+ lm=fake.Echo(),
48
+ )
49
+
35
50
  def test_score(self):
36
51
  self.assertEqual(scoring.score('hi', [1, 2], lm=fake.Echo()), [0.0, -1.0])
37
52
 
53
+ def test_score_on_field_values(self):
54
+ self.assertEqual(
55
+ scoring.score(
56
+ '1 + 1=',
57
+ [Answer(pg.oneof([1, 2, 3]))], lm=fake.Echo()
58
+ ),
59
+ [0.0, -1.0, -2.0]
60
+ )
61
+
62
+ def test_score_returning_scoring_results(self):
63
+ self.assertEqual(scoring.score(
64
+ 'hi', [1, 2], lm=fake.Echo(), return_scoring_results=True),
65
+ [lf.LMScoringResult(score=0.0, gradients=None),
66
+ lf.LMScoringResult(score=-1.0, gradients=None)])
67
+
38
68
  def test_scope_with_lm_from_the_context(self):
39
69
  with lf.context(lm=fake.Echo()):
40
70
  self.assertEqual(scoring.score('hi', [1, 2]), [0.0, -1.0])
@@ -0,0 +1,64 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tokenize the prompt for `lf.query`."""
15
+
16
+ from typing import Any, Type, Union
17
+
18
+ import langfun.core as lf
19
+ from langfun.core.structured import mapping
20
+ from langfun.core.structured import querying
21
+ from langfun.core.structured import schema as schema_lib
22
+ import pyglove as pg
23
+
24
+
25
+ def tokenize(
26
+ prompt: Union[str, pg.Symbolic] | list[str | pg.Symbolic],
27
+ schema: Union[
28
+ schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
29
+ ] = None,
30
+ *,
31
+ lm: lf.LanguageModel | None = None,
32
+ examples: list[mapping.MappingExample] | None = None,
33
+ protocol: schema_lib.SchemaProtocol = 'python',
34
+ **kwargs,
35
+ ) -> list[tuple[str | bytes, int]]:
36
+ """Tokenize the prompt for `lf.query`.
37
+
38
+ Args:
39
+ prompt: The prompt(s) based on which each completion will be scored.
40
+ schema: The schema as the output type. If None, it will be inferred from
41
+ the completions.
42
+ lm: The language model used for scoring.
43
+ examples: Fewshot exemplars used together with the prompt in getting the
44
+ completions.
45
+ protocol: The protocol for formulating the prompt based on objects.
46
+ **kwargs: Keyword arguments that are referred by the prompt.
47
+
48
+ Returns:
49
+ A list of (text, token_id) tuples.
50
+ """
51
+ input_message = querying.query_prompt(
52
+ prompt,
53
+ schema,
54
+ examples=examples,
55
+ protocol=protocol,
56
+ **kwargs,
57
+ )
58
+ if lm is None:
59
+ lm_override = lf.get_contextual_override('lm')
60
+ if lm_override is None:
61
+ raise ValueError('`lm` must be specified or provided from `lf.context`.')
62
+ lm = lm_override.value
63
+
64
+ return lm.tokenize(input_message)
@@ -0,0 +1,48 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import unittest
16
+ import langfun.core as lf
17
+ from langfun.core.llms import fake
18
+ from langfun.core.structured import tokenization
19
+ import pyglove as pg
20
+
21
+
22
+ class Answer(pg.Object):
23
+ result: int
24
+
25
+
26
+ class TokenizationTest(unittest.TestCase):
27
+
28
+ def test_bad_call(self):
29
+
30
+ with self.assertRaisesRegex(ValueError, '`lm` must be specified'):
31
+ tokenization.tokenize('hi')
32
+
33
+ def test_tokenize(self):
34
+ self.assertEqual(
35
+ tokenization.tokenize('hi', lm=fake.Echo()),
36
+ [('hi', 0)]
37
+ )
38
+
39
+ def test_tokenize_with_lm_from_the_context(self):
40
+ with lf.context(lm=fake.Echo()):
41
+ self.assertEqual(
42
+ tokenization.tokenize('hi'),
43
+ [('hi', 0)]
44
+ )
45
+
46
+
47
+ if __name__ == '__main__':
48
+ unittest.main()