langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501150804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +20 -2
- langfun/core/__init__.py +16 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -21
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +63 -2
- langfun/core/component_test.py +53 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +17 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +16 -1
- langfun/core/eval/base.py +622 -174
- langfun/core/eval/base_test.py +200 -54
- langfun/core/eval/matching.py +63 -76
- langfun/core/eval/matching_test.py +17 -8
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +26 -26
- langfun/core/eval/scoring_test.py +19 -2
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +4 -17
- langfun/core/langfunc_test.py +22 -6
- langfun/core/language_model.py +577 -39
- langfun/core/language_model_test.py +470 -56
- langfun/core/llms/__init__.py +87 -16
- langfun/core/llms/anthropic.py +312 -87
- langfun/core/llms/anthropic_test.py +71 -3
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +53 -2
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +11 -7
- langfun/core/llms/fake_test.py +14 -0
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -202
- langfun/core/llms/groq.py +160 -144
- langfun/core/llms/groq_test.py +31 -137
- langfun/core/llms/llama_cpp.py +15 -42
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +395 -203
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +30 -395
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -26
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +12 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +7 -6
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +60 -27
- langfun/core/structured/function_generation_test.py +72 -2
- langfun/core/structured/mapping.py +97 -47
- langfun/core/structured/mapping_test.py +90 -2
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +53 -9
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
- langfun/core/structured/schema.py +204 -97
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_test.py +130 -29
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +115 -1
- langfun/core/template_test.py +71 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +10 -2
- langfun-0.1.2.dev202501150804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501150804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -238
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
- langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@ import inspect
|
|
18
18
|
import typing
|
19
19
|
import unittest
|
20
20
|
|
21
|
+
import langfun.core as lf
|
21
22
|
from langfun.core.llms import fake
|
22
23
|
from langfun.core.structured import schema as schema_lib
|
23
24
|
import pyglove as pg
|
@@ -260,15 +261,18 @@ class ClassDependenciesTest(unittest.TestCase):
|
|
260
261
|
class A(pg.Object):
|
261
262
|
foo: tuple[Foo, int]
|
262
263
|
|
264
|
+
class B(pg.Object):
|
265
|
+
pass
|
266
|
+
|
263
267
|
class X(pg.Object):
|
264
|
-
k:
|
268
|
+
k: dict[str, B]
|
265
269
|
|
266
|
-
class
|
270
|
+
class C(A):
|
267
271
|
bar: Bar
|
268
272
|
foo2: Foo | X
|
269
273
|
|
270
274
|
a = A(foo=(Foo(1), 0))
|
271
|
-
self.assertEqual(schema_lib.class_dependencies(a), [Foo, A, Bar, X,
|
275
|
+
self.assertEqual(schema_lib.class_dependencies(a), [Foo, A, Bar, B, X, C])
|
272
276
|
|
273
277
|
self.assertEqual(schema_lib.class_dependencies(1), [])
|
274
278
|
|
@@ -280,9 +284,10 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
280
284
|
value_spec: pg.typing.ValueSpec,
|
281
285
|
expected_annotation: str,
|
282
286
|
strict: bool = False,
|
287
|
+
**kwargs,
|
283
288
|
) -> None:
|
284
289
|
self.assertEqual(
|
285
|
-
schema_lib.annotation(value_spec, strict=strict),
|
290
|
+
schema_lib.annotation(value_spec, strict=strict, **kwargs),
|
286
291
|
expected_annotation,
|
287
292
|
)
|
288
293
|
|
@@ -358,11 +363,27 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
358
363
|
self.assert_annotation(
|
359
364
|
pg.typing.Object(Activity).noneable(), 'Activity | None'
|
360
365
|
)
|
366
|
+
self.assert_annotation(
|
367
|
+
pg.typing.Object(Activity).noneable(), 'Activity | None',
|
368
|
+
allowed_dependencies=set([Activity]),
|
369
|
+
)
|
370
|
+
self.assert_annotation(
|
371
|
+
pg.typing.Object(Activity).noneable(), 'Any | None',
|
372
|
+
allowed_dependencies=set(),
|
373
|
+
)
|
361
374
|
|
362
375
|
# List.
|
363
376
|
self.assert_annotation(
|
364
377
|
pg.typing.List(pg.typing.Object(Activity)), 'list[Activity]'
|
365
378
|
)
|
379
|
+
self.assert_annotation(
|
380
|
+
pg.typing.List(pg.typing.Object(Activity)), 'list[Activity]',
|
381
|
+
allowed_dependencies=set([Activity]),
|
382
|
+
)
|
383
|
+
self.assert_annotation(
|
384
|
+
pg.typing.List(pg.typing.Object(Activity)), 'list[Any]',
|
385
|
+
allowed_dependencies=set(),
|
386
|
+
)
|
366
387
|
self.assert_annotation(
|
367
388
|
pg.typing.List(pg.typing.Object(Activity)).noneable(),
|
368
389
|
'list[Activity] | None',
|
@@ -374,16 +395,35 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
374
395
|
|
375
396
|
# Tuple.
|
376
397
|
self.assert_annotation(
|
377
|
-
pg.typing.Tuple([
|
398
|
+
pg.typing.Tuple([Activity, pg.typing.Str()]), 'tuple[Activity, str]'
|
399
|
+
)
|
400
|
+
self.assert_annotation(
|
401
|
+
pg.typing.Tuple([Activity, pg.typing.Str()]), 'tuple[Activity, str]',
|
402
|
+
allowed_dependencies=set([Activity]),
|
403
|
+
)
|
404
|
+
self.assert_annotation(
|
405
|
+
pg.typing.Tuple([Activity, pg.typing.Str()]), 'tuple[Any, str]',
|
406
|
+
allowed_dependencies=set(),
|
378
407
|
)
|
379
408
|
self.assert_annotation(
|
380
|
-
pg.typing.Tuple([
|
381
|
-
'tuple[
|
409
|
+
pg.typing.Tuple([Activity, pg.typing.Str()]).noneable(),
|
410
|
+
'tuple[Activity, str] | None',
|
382
411
|
)
|
383
412
|
|
384
413
|
# Dict.
|
385
414
|
self.assert_annotation(
|
386
|
-
pg.typing.Dict({'x':
|
415
|
+
pg.typing.Dict({'x': Activity, 'y': str}),
|
416
|
+
'{\'x\': Activity, \'y\': str}'
|
417
|
+
)
|
418
|
+
self.assert_annotation(
|
419
|
+
pg.typing.Dict({'x': Activity, 'y': str}),
|
420
|
+
'{\'x\': Activity, \'y\': str}',
|
421
|
+
allowed_dependencies=set([Activity]),
|
422
|
+
)
|
423
|
+
self.assert_annotation(
|
424
|
+
pg.typing.Dict({'x': Activity, 'y': str}),
|
425
|
+
'{\'x\': Any, \'y\': str}',
|
426
|
+
allowed_dependencies=set(),
|
387
427
|
)
|
388
428
|
self.assert_annotation(
|
389
429
|
pg.typing.Dict({'x': int, 'y': str}),
|
@@ -395,6 +435,15 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
395
435
|
'dict[str, Any]',
|
396
436
|
strict=False,
|
397
437
|
)
|
438
|
+
|
439
|
+
class DictValue(pg.Object):
|
440
|
+
pass
|
441
|
+
|
442
|
+
self.assert_annotation(
|
443
|
+
pg.typing.Dict([(pg.typing.StrKey(), DictValue)]),
|
444
|
+
'dict[str, DictValue]',
|
445
|
+
strict=False,
|
446
|
+
)
|
398
447
|
self.assert_annotation(
|
399
448
|
pg.typing.Dict(),
|
400
449
|
'dict[str, Any]',
|
@@ -408,6 +457,13 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
408
457
|
).noneable(),
|
409
458
|
'Union[Activity, Itinerary, None]',
|
410
459
|
)
|
460
|
+
self.assert_annotation(
|
461
|
+
pg.typing.Union(
|
462
|
+
[pg.typing.Object(Activity), pg.typing.Object(Itinerary)]
|
463
|
+
).noneable(),
|
464
|
+
'Union[Activity, Any, None]',
|
465
|
+
allowed_dependencies=set([Activity]),
|
466
|
+
)
|
411
467
|
|
412
468
|
# Any.
|
413
469
|
self.assert_annotation(pg.typing.Any(), 'Any')
|
@@ -415,13 +471,13 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
415
471
|
|
416
472
|
def test_class_definition(self):
|
417
473
|
self.assertEqual(
|
418
|
-
schema_lib.class_definition(Activity),
|
474
|
+
schema_lib.class_definition(Activity, allowed_dependencies=set()),
|
419
475
|
'class Activity:\n description: str\n',
|
420
476
|
)
|
421
477
|
self.assertEqual(
|
422
478
|
schema_lib.class_definition(Itinerary),
|
423
479
|
inspect.cleandoc("""
|
424
|
-
class Itinerary:
|
480
|
+
class Itinerary(Object):
|
425
481
|
\"\"\"A travel itinerary for a day.\"\"\"
|
426
482
|
day: int(min=1)
|
427
483
|
type: Literal['daytime', 'nighttime']
|
@@ -431,7 +487,9 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
431
487
|
""") + '\n',
|
432
488
|
)
|
433
489
|
self.assertEqual(
|
434
|
-
schema_lib.class_definition(
|
490
|
+
schema_lib.class_definition(
|
491
|
+
PlaceOfInterest, allowed_dependencies=set()
|
492
|
+
),
|
435
493
|
inspect.cleandoc("""
|
436
494
|
class PlaceOfInterest:
|
437
495
|
\"\"\"The name of a place of interest.
|
@@ -447,11 +505,11 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
447
505
|
pass
|
448
506
|
|
449
507
|
self.assertEqual(
|
450
|
-
schema_lib.class_definition(A),
|
508
|
+
schema_lib.class_definition(A, allowed_dependencies=set()),
|
451
509
|
'class A:\n pass\n',
|
452
510
|
)
|
453
511
|
self.assertEqual(
|
454
|
-
schema_lib.class_definition(A
|
512
|
+
schema_lib.class_definition(A),
|
455
513
|
'class A(Object):\n pass\n',
|
456
514
|
)
|
457
515
|
|
@@ -459,7 +517,27 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
459
517
|
x: str
|
460
518
|
__kwargs__: typing.Any
|
461
519
|
|
462
|
-
self.assertEqual(
|
520
|
+
self.assertEqual(
|
521
|
+
schema_lib.class_definition(C), 'class C(Object):\n x: str\n'
|
522
|
+
)
|
523
|
+
|
524
|
+
class D(pg.Object):
|
525
|
+
x: str
|
526
|
+
@schema_lib.include_method_in_prompt
|
527
|
+
def __call__(self, y: int) -> int:
|
528
|
+
return len(self.x) + y
|
529
|
+
|
530
|
+
self.assertEqual(
|
531
|
+
schema_lib.class_definition(D),
|
532
|
+
inspect.cleandoc(
|
533
|
+
"""
|
534
|
+
class D(Object):
|
535
|
+
x: str
|
536
|
+
|
537
|
+
def __call__(self, y: int) -> int:
|
538
|
+
return len(self.x) + y
|
539
|
+
""") + '\n'
|
540
|
+
)
|
463
541
|
|
464
542
|
def test_repr(self):
|
465
543
|
class Foo(pg.Object):
|
@@ -477,10 +555,21 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
477
555
|
class A(pg.Object):
|
478
556
|
foo: Foo
|
479
557
|
|
558
|
+
@schema_lib.include_method_in_prompt
|
559
|
+
def foo_value(self) -> int:
|
560
|
+
return self.foo.x
|
561
|
+
|
562
|
+
def baz_value(self) -> str:
|
563
|
+
return 'baz'
|
564
|
+
|
480
565
|
class B(A):
|
481
566
|
bar: Bar
|
482
567
|
foo2: Foo
|
483
568
|
|
569
|
+
@schema_lib.include_method_in_prompt
|
570
|
+
def bar_value(self) -> str:
|
571
|
+
return self.bar.y
|
572
|
+
|
484
573
|
schema = schema_lib.Schema([B])
|
485
574
|
self.assertEqual(
|
486
575
|
schema_lib.SchemaPythonRepr().class_definitions(schema),
|
@@ -488,9 +577,6 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
488
577
|
class Foo:
|
489
578
|
x: int
|
490
579
|
|
491
|
-
class A:
|
492
|
-
foo: Foo
|
493
|
-
|
494
580
|
class Bar:
|
495
581
|
"""Class Bar."""
|
496
582
|
y: str
|
@@ -499,10 +585,16 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
499
585
|
"""Baz(y: str)"""
|
500
586
|
y: str
|
501
587
|
|
502
|
-
class B
|
588
|
+
class B:
|
503
589
|
foo: Foo
|
504
590
|
bar: Bar
|
505
591
|
foo2: Foo
|
592
|
+
|
593
|
+
def bar_value(self) -> str:
|
594
|
+
return self.bar.y
|
595
|
+
|
596
|
+
def foo_value(self) -> int:
|
597
|
+
return self.foo.x
|
506
598
|
''') + '\n',
|
507
599
|
)
|
508
600
|
|
@@ -519,9 +611,6 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
519
611
|
class Foo:
|
520
612
|
x: int
|
521
613
|
|
522
|
-
class A:
|
523
|
-
foo: Foo
|
524
|
-
|
525
614
|
class Bar:
|
526
615
|
"""Class Bar."""
|
527
616
|
y: str
|
@@ -530,10 +619,16 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
530
619
|
"""Baz(y: str)"""
|
531
620
|
y: str
|
532
621
|
|
533
|
-
class B
|
622
|
+
class B:
|
534
623
|
foo: Foo
|
535
624
|
bar: Bar
|
536
625
|
foo2: Foo
|
626
|
+
|
627
|
+
def bar_value(self) -> str:
|
628
|
+
return self.bar.y
|
629
|
+
|
630
|
+
def foo_value(self) -> int:
|
631
|
+
return self.foo.x
|
537
632
|
```
|
538
633
|
'''),
|
539
634
|
)
|
@@ -541,16 +636,12 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
541
636
|
schema_lib.SchemaPythonRepr().repr(
|
542
637
|
schema,
|
543
638
|
include_result_definition=False,
|
544
|
-
include_pg_object_as_base=True,
|
545
639
|
markdown=False,
|
546
640
|
),
|
547
641
|
inspect.cleandoc('''
|
548
|
-
class Foo
|
642
|
+
class Foo:
|
549
643
|
x: int
|
550
644
|
|
551
|
-
class A(Object):
|
552
|
-
foo: Foo
|
553
|
-
|
554
645
|
class Bar:
|
555
646
|
"""Class Bar."""
|
556
647
|
y: str
|
@@ -559,10 +650,16 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
559
650
|
"""Baz(y: str)"""
|
560
651
|
y: str
|
561
652
|
|
562
|
-
class B
|
653
|
+
class B:
|
563
654
|
foo: Foo
|
564
655
|
bar: Bar
|
565
656
|
foo2: Foo
|
657
|
+
|
658
|
+
def bar_value(self) -> str:
|
659
|
+
return self.bar.y
|
660
|
+
|
661
|
+
def foo_value(self) -> int:
|
662
|
+
return self.foo.x
|
566
663
|
'''),
|
567
664
|
)
|
568
665
|
|
@@ -598,6 +695,10 @@ class ValuePythonReprTest(unittest.TestCase):
|
|
598
695
|
schema_lib.ValuePythonRepr().repr(1, schema_lib.Schema(int)),
|
599
696
|
'```python\n1\n```'
|
600
697
|
)
|
698
|
+
self.assertEqual(
|
699
|
+
schema_lib.ValuePythonRepr().repr(lf.Template('hi, {{a}}', a='foo')),
|
700
|
+
'hi, foo'
|
701
|
+
)
|
601
702
|
self.assertEqual(
|
602
703
|
schema_lib.ValuePythonRepr().repr(
|
603
704
|
A([Foo(1), Foo(2)], 'bar'), schema_lib.Schema(A), markdown=False,
|
@@ -610,7 +711,7 @@ class ValuePythonReprTest(unittest.TestCase):
|
|
610
711
|
```python
|
611
712
|
class Foo(Object):
|
612
713
|
x: int
|
613
|
-
|
714
|
+
|
614
715
|
class A(Object):
|
615
716
|
foo: list[Foo]
|
616
717
|
y: str | None
|
@@ -17,13 +17,13 @@ from typing import Any, Type, Union
|
|
17
17
|
|
18
18
|
import langfun.core as lf
|
19
19
|
from langfun.core.structured import mapping
|
20
|
-
from langfun.core.structured import
|
20
|
+
from langfun.core.structured import querying
|
21
21
|
from langfun.core.structured import schema as schema_lib
|
22
22
|
import pyglove as pg
|
23
23
|
|
24
24
|
|
25
25
|
def score(
|
26
|
-
prompt: Union[str, pg.Symbolic],
|
26
|
+
prompt: Union[str, pg.Symbolic] | list[str | pg.Symbolic],
|
27
27
|
completions: list[str | pg.Symbolic],
|
28
28
|
schema: Union[
|
29
29
|
schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
|
@@ -32,9 +32,58 @@ def score(
|
|
32
32
|
lm: lf.LanguageModel | None = None,
|
33
33
|
examples: list[mapping.MappingExample] | None = None,
|
34
34
|
protocol: schema_lib.SchemaProtocol = 'python',
|
35
|
+
return_scoring_results: bool = False,
|
35
36
|
**kwargs,
|
36
|
-
) -> list[float]:
|
37
|
-
"""Scores the outputs based on the prompt.
|
37
|
+
) -> list[float] | list[lf.LMScoringResult]:
|
38
|
+
"""Scores the outputs based on the prompt.
|
39
|
+
|
40
|
+
Examples:
|
41
|
+
```
|
42
|
+
# Example 1: Scoring text output based on the user prompt.
|
43
|
+
scores = lf.score('{{x}} + {{y}} =', ['1', '2', '3'], lm=lm, x=1, y=2)
|
44
|
+
assert len(scores) == 3
|
45
|
+
|
46
|
+
# Example 2: Scoring int output based on the formulated OOP prompt.
|
47
|
+
scores = lf.score('1 + 1 =', [1, 2, 3], lm=lm)
|
48
|
+
assert len(scores) == 3
|
49
|
+
|
50
|
+
class Answer(pg.Object):
|
51
|
+
result: int
|
52
|
+
|
53
|
+
# Example 3: Scoring object output based on the formulated OOP prompt.
|
54
|
+
scores = lf.score('1 + 1 =', [Answer(1), Answer(2), Answer(3)], lm=lm)
|
55
|
+
assert len(scores) == 3
|
56
|
+
|
57
|
+
# Example 4: Scoring object field value based on the formulated OOP prompt
|
58
|
+
# and the generated tokens before the first `pg.oneof`.
|
59
|
+
scores = lf.score('1 + 1 =', [Answer(pg.oneof([1, 2, 3]))], lm=lm)
|
60
|
+
assert len(scores) == 3
|
61
|
+
|
62
|
+
# Example 5: Scoring multiple prompt/completion pairs.
|
63
|
+
scores = lf.score(
|
64
|
+
['1 + 1=', '2 + 3='],
|
65
|
+
['2', '4'],
|
66
|
+
lm=lm
|
67
|
+
)
|
68
|
+
assert len(scores) == 2
|
69
|
+
```
|
70
|
+
|
71
|
+
Args:
|
72
|
+
prompt: The prompt(s) based on which each completion will be scored.
|
73
|
+
completions: A list of strings or symbolic objects as the output.
|
74
|
+
schema: The schema as the output type. If None, it will be inferred from
|
75
|
+
the completions.
|
76
|
+
lm: The language model used for scoring.
|
77
|
+
examples: Fewshot exemplars used together with the prompt in getting the
|
78
|
+
completions.
|
79
|
+
protocol: The protocol for formulating the prompt based on objects.
|
80
|
+
return_scoring_results: If True, returns a list of `lf.LMScoringResult`,
|
81
|
+
otherwise returns a list of floats as the scores of each completion.
|
82
|
+
**kwargs: Keyword arguments that are referred by the prompt.
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
A list of floats or `lf.LMScoringResult` as the score of each completion.
|
86
|
+
"""
|
38
87
|
if not completions:
|
39
88
|
raise ValueError('`completions` must not be empty.')
|
40
89
|
|
@@ -48,28 +97,85 @@ def score(
|
|
48
97
|
f'{[type(c) for c in completions]}.'
|
49
98
|
)
|
50
99
|
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
100
|
+
if isinstance(prompt, list):
|
101
|
+
prompts = []
|
102
|
+
for p in prompt:
|
103
|
+
prompts.append(
|
104
|
+
querying.query_prompt(
|
105
|
+
p,
|
106
|
+
schema,
|
107
|
+
examples=examples,
|
108
|
+
protocol=protocol,
|
109
|
+
**kwargs,
|
110
|
+
)
|
111
|
+
)
|
112
|
+
input_message = prompts
|
113
|
+
else:
|
114
|
+
input_message = querying.query_prompt(
|
115
|
+
prompt,
|
116
|
+
schema,
|
117
|
+
examples=examples,
|
118
|
+
protocol=protocol,
|
119
|
+
**kwargs,
|
120
|
+
)
|
60
121
|
if lm is None:
|
61
122
|
lm_override = lf.get_contextual_override('lm')
|
62
123
|
if lm_override is None:
|
63
124
|
raise ValueError('`lm` must be specified or provided from `lf.context`.')
|
64
125
|
lm = lm_override.value
|
65
126
|
|
127
|
+
completion_reprs = []
|
128
|
+
for c in completions:
|
129
|
+
if isinstance(c, mapping.MappingError):
|
130
|
+
completion_reprs.append(c.lm_response)
|
131
|
+
else:
|
132
|
+
rep = mapping.MappingExample.value_repr(
|
133
|
+
c, protocol=protocol, compact=False, verbose=False
|
134
|
+
)
|
135
|
+
|
136
|
+
# NOTE(daiyip): supporting scenario of scoring object field with
|
137
|
+
# `pg.oneof`.
|
138
|
+
oneof_pos = rep.find('OneOf(')
|
139
|
+
if oneof_pos == -1:
|
140
|
+
completion_reprs.append(rep)
|
141
|
+
else:
|
142
|
+
assert protocol == 'python', protocol
|
143
|
+
if isinstance(input_message, list):
|
144
|
+
raise ValueError(
|
145
|
+
'Scoring on object fields using `pg.oneof` must share the '
|
146
|
+
f'same prompt. Encountered: {prompt}'
|
147
|
+
)
|
148
|
+
input_message.text += '\n' + rep[:oneof_pos]
|
149
|
+
oneof = _get_first_oneof(c)
|
150
|
+
for v in oneof.candidates:
|
151
|
+
completion_reprs.append(
|
152
|
+
pg.format(
|
153
|
+
v,
|
154
|
+
python_format=True,
|
155
|
+
compact=False,
|
156
|
+
verbose=False,
|
157
|
+
root_indent=oneof.sym_path.depth
|
158
|
+
)
|
159
|
+
)
|
160
|
+
|
66
161
|
results = lm.score(
|
67
162
|
input_message,
|
68
|
-
|
69
|
-
mapping.MappingExample.value_repr(
|
70
|
-
c, protocol=protocol, compact=False, verbose=False
|
71
|
-
)
|
72
|
-
for c in completions
|
73
|
-
],
|
163
|
+
completion_reprs,
|
74
164
|
)
|
165
|
+
if return_scoring_results:
|
166
|
+
return results
|
75
167
|
return [r.score for r in results]
|
168
|
+
|
169
|
+
|
170
|
+
def _get_first_oneof(value: Any) -> pg.hyper.OneOf:
|
171
|
+
"""Gets the first pg.oneof from a symbolic object."""
|
172
|
+
oneofs = []
|
173
|
+
def select_oneofs(k, v, p):
|
174
|
+
del k, p
|
175
|
+
if isinstance(v, pg.hyper.OneOf):
|
176
|
+
oneofs.append(v)
|
177
|
+
return pg.TraverseAction.CONTINUE
|
178
|
+
return pg.TraverseAction.ENTER
|
179
|
+
pg.traverse(value, select_oneofs)
|
180
|
+
assert oneofs
|
181
|
+
return oneofs[0]
|
@@ -16,6 +16,11 @@ import unittest
|
|
16
16
|
import langfun.core as lf
|
17
17
|
from langfun.core.llms import fake
|
18
18
|
from langfun.core.structured import scoring
|
19
|
+
import pyglove as pg
|
20
|
+
|
21
|
+
|
22
|
+
class Answer(pg.Object):
|
23
|
+
result: int
|
19
24
|
|
20
25
|
|
21
26
|
class ScoringTest(unittest.TestCase):
|
@@ -32,9 +37,34 @@ class ScoringTest(unittest.TestCase):
|
|
32
37
|
with self.assertRaisesRegex(ValueError, '`lm` must be specified'):
|
33
38
|
scoring.score('hi', [1, 2])
|
34
39
|
|
40
|
+
with self.assertRaisesRegex(
|
41
|
+
ValueError,
|
42
|
+
'Scoring on object fields using `pg.oneof` must share the same prompt',
|
43
|
+
):
|
44
|
+
scoring.score(
|
45
|
+
['1 + 1=', '2 + 3='],
|
46
|
+
[Answer(pg.oneof([1, 2, 3]))],
|
47
|
+
lm=fake.Echo(),
|
48
|
+
)
|
49
|
+
|
35
50
|
def test_score(self):
|
36
51
|
self.assertEqual(scoring.score('hi', [1, 2], lm=fake.Echo()), [0.0, -1.0])
|
37
52
|
|
53
|
+
def test_score_on_field_values(self):
|
54
|
+
self.assertEqual(
|
55
|
+
scoring.score(
|
56
|
+
'1 + 1=',
|
57
|
+
[Answer(pg.oneof([1, 2, 3]))], lm=fake.Echo()
|
58
|
+
),
|
59
|
+
[0.0, -1.0, -2.0]
|
60
|
+
)
|
61
|
+
|
62
|
+
def test_score_returning_scoring_results(self):
|
63
|
+
self.assertEqual(scoring.score(
|
64
|
+
'hi', [1, 2], lm=fake.Echo(), return_scoring_results=True),
|
65
|
+
[lf.LMScoringResult(score=0.0, gradients=None),
|
66
|
+
lf.LMScoringResult(score=-1.0, gradients=None)])
|
67
|
+
|
38
68
|
def test_scope_with_lm_from_the_context(self):
|
39
69
|
with lf.context(lm=fake.Echo()):
|
40
70
|
self.assertEqual(scoring.score('hi', [1, 2]), [0.0, -1.0])
|
@@ -0,0 +1,64 @@
|
|
1
|
+
# Copyright 2023 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Tokenize the prompt for `lf.query`."""
|
15
|
+
|
16
|
+
from typing import Any, Type, Union
|
17
|
+
|
18
|
+
import langfun.core as lf
|
19
|
+
from langfun.core.structured import mapping
|
20
|
+
from langfun.core.structured import querying
|
21
|
+
from langfun.core.structured import schema as schema_lib
|
22
|
+
import pyglove as pg
|
23
|
+
|
24
|
+
|
25
|
+
def tokenize(
|
26
|
+
prompt: Union[str, pg.Symbolic] | list[str | pg.Symbolic],
|
27
|
+
schema: Union[
|
28
|
+
schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
|
29
|
+
] = None,
|
30
|
+
*,
|
31
|
+
lm: lf.LanguageModel | None = None,
|
32
|
+
examples: list[mapping.MappingExample] | None = None,
|
33
|
+
protocol: schema_lib.SchemaProtocol = 'python',
|
34
|
+
**kwargs,
|
35
|
+
) -> list[tuple[str | bytes, int]]:
|
36
|
+
"""Tokenize the prompt for `lf.query`.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
prompt: The prompt(s) based on which each completion will be scored.
|
40
|
+
schema: The schema as the output type. If None, it will be inferred from
|
41
|
+
the completions.
|
42
|
+
lm: The language model used for scoring.
|
43
|
+
examples: Fewshot exemplars used together with the prompt in getting the
|
44
|
+
completions.
|
45
|
+
protocol: The protocol for formulating the prompt based on objects.
|
46
|
+
**kwargs: Keyword arguments that are referred by the prompt.
|
47
|
+
|
48
|
+
Returns:
|
49
|
+
A list of (text, token_id) tuples.
|
50
|
+
"""
|
51
|
+
input_message = querying.query_prompt(
|
52
|
+
prompt,
|
53
|
+
schema,
|
54
|
+
examples=examples,
|
55
|
+
protocol=protocol,
|
56
|
+
**kwargs,
|
57
|
+
)
|
58
|
+
if lm is None:
|
59
|
+
lm_override = lf.get_contextual_override('lm')
|
60
|
+
if lm_override is None:
|
61
|
+
raise ValueError('`lm` must be specified or provided from `lf.context`.')
|
62
|
+
lm = lm_override.value
|
63
|
+
|
64
|
+
return lm.tokenize(input_message)
|
@@ -0,0 +1,48 @@
|
|
1
|
+
# Copyright 2023 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import unittest
|
16
|
+
import langfun.core as lf
|
17
|
+
from langfun.core.llms import fake
|
18
|
+
from langfun.core.structured import tokenization
|
19
|
+
import pyglove as pg
|
20
|
+
|
21
|
+
|
22
|
+
class Answer(pg.Object):
|
23
|
+
result: int
|
24
|
+
|
25
|
+
|
26
|
+
class TokenizationTest(unittest.TestCase):
|
27
|
+
|
28
|
+
def test_bad_call(self):
|
29
|
+
|
30
|
+
with self.assertRaisesRegex(ValueError, '`lm` must be specified'):
|
31
|
+
tokenization.tokenize('hi')
|
32
|
+
|
33
|
+
def test_tokenize(self):
|
34
|
+
self.assertEqual(
|
35
|
+
tokenization.tokenize('hi', lm=fake.Echo()),
|
36
|
+
[('hi', 0)]
|
37
|
+
)
|
38
|
+
|
39
|
+
def test_tokenize_with_lm_from_the_context(self):
|
40
|
+
with lf.context(lm=fake.Echo()):
|
41
|
+
self.assertEqual(
|
42
|
+
tokenization.tokenize('hi'),
|
43
|
+
[('hi', 0)]
|
44
|
+
)
|
45
|
+
|
46
|
+
|
47
|
+
if __name__ == '__main__':
|
48
|
+
unittest.main()
|