langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +22 -2
- langfun/core/__init__.py +17 -5
- langfun/core/agentic/__init__.py +30 -0
- langfun/core/agentic/action.py +854 -0
- langfun/core/agentic/action_eval.py +150 -0
- langfun/core/agentic/action_eval_test.py +109 -0
- langfun/core/agentic/action_test.py +136 -0
- langfun/core/coding/python/__init__.py +5 -11
- langfun/core/coding/python/correction.py +37 -28
- langfun/core/coding/python/correction_test.py +29 -3
- langfun/core/coding/python/execution.py +40 -216
- langfun/core/coding/python/execution_test.py +29 -89
- langfun/core/coding/python/generation.py +21 -11
- langfun/core/coding/python/generation_test.py +2 -2
- langfun/core/coding/python/parsing.py +108 -193
- langfun/core/coding/python/parsing_test.py +2 -105
- langfun/core/component.py +69 -2
- langfun/core/component_test.py +54 -0
- langfun/core/concurrent.py +414 -117
- langfun/core/concurrent_test.py +111 -24
- langfun/core/console.py +18 -5
- langfun/core/console_test.py +17 -0
- langfun/core/eval/__init__.py +17 -0
- langfun/core/eval/base.py +767 -140
- langfun/core/eval/base_test.py +238 -53
- langfun/core/eval/matching.py +80 -76
- langfun/core/eval/matching_test.py +19 -9
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +37 -28
- langfun/core/eval/scoring_test.py +21 -3
- langfun/core/eval/v2/__init__.py +42 -0
- langfun/core/eval/v2/checkpointing.py +380 -0
- langfun/core/eval/v2/checkpointing_test.py +228 -0
- langfun/core/eval/v2/eval_test_helper.py +136 -0
- langfun/core/eval/v2/evaluation.py +725 -0
- langfun/core/eval/v2/evaluation_test.py +180 -0
- langfun/core/eval/v2/example.py +305 -0
- langfun/core/eval/v2/example_test.py +128 -0
- langfun/core/eval/v2/experiment.py +1048 -0
- langfun/core/eval/v2/experiment_test.py +433 -0
- langfun/core/eval/v2/metric_values.py +156 -0
- langfun/core/eval/v2/metric_values_test.py +80 -0
- langfun/core/eval/v2/metrics.py +357 -0
- langfun/core/eval/v2/metrics_test.py +203 -0
- langfun/core/eval/v2/progress.py +348 -0
- langfun/core/eval/v2/progress_test.py +82 -0
- langfun/core/eval/v2/progress_tracking.py +210 -0
- langfun/core/eval/v2/progress_tracking_test.py +66 -0
- langfun/core/eval/v2/reporting.py +270 -0
- langfun/core/eval/v2/reporting_test.py +158 -0
- langfun/core/eval/v2/runners.py +488 -0
- langfun/core/eval/v2/runners_test.py +334 -0
- langfun/core/langfunc.py +3 -21
- langfun/core/langfunc_test.py +26 -8
- langfun/core/language_model.py +686 -48
- langfun/core/language_model_test.py +681 -44
- langfun/core/llms/__init__.py +100 -12
- langfun/core/llms/anthropic.py +488 -0
- langfun/core/llms/anthropic_test.py +235 -0
- langfun/core/llms/cache/base.py +21 -2
- langfun/core/llms/cache/in_memory.py +13 -0
- langfun/core/llms/cache/in_memory_test.py +88 -28
- langfun/core/llms/compositional.py +101 -0
- langfun/core/llms/compositional_test.py +73 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/fake.py +39 -26
- langfun/core/llms/fake_test.py +136 -11
- langfun/core/llms/gemini.py +507 -0
- langfun/core/llms/gemini_test.py +195 -0
- langfun/core/llms/google_genai.py +62 -218
- langfun/core/llms/google_genai_test.py +9 -197
- langfun/core/llms/groq.py +276 -0
- langfun/core/llms/groq_test.py +64 -0
- langfun/core/llms/llama_cpp.py +15 -40
- langfun/core/llms/llama_cpp_test.py +4 -30
- langfun/core/llms/openai.py +436 -226
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +495 -0
- langfun/core/llms/openai_test.py +35 -174
- langfun/core/llms/rest.py +113 -0
- langfun/core/llms/rest_test.py +111 -0
- langfun/core/llms/vertexai.py +192 -0
- langfun/core/llms/vertexai_test.py +52 -0
- langfun/core/logging.py +284 -0
- langfun/core/logging_test.py +125 -0
- langfun/core/message.py +319 -9
- langfun/core/message_test.py +190 -13
- langfun/core/modalities/__init__.py +6 -2
- langfun/core/modalities/audio.py +30 -0
- langfun/core/modalities/audio_test.py +63 -0
- langfun/core/modalities/image.py +39 -20
- langfun/core/modalities/image_test.py +52 -9
- langfun/core/modalities/mime.py +206 -29
- langfun/core/modalities/mime_test.py +90 -9
- langfun/core/modalities/ms_office.py +117 -0
- langfun/core/modalities/ms_office_test.py +389 -0
- langfun/core/modalities/pdf.py +22 -0
- langfun/core/modalities/pdf_test.py +57 -0
- langfun/core/modalities/video.py +9 -23
- langfun/core/modalities/video_test.py +3 -3
- langfun/core/modality.py +26 -3
- langfun/core/modality_test.py +2 -2
- langfun/core/sampling.py +11 -11
- langfun/core/structured/__init__.py +15 -16
- langfun/core/structured/completion.py +32 -5
- langfun/core/structured/completion_test.py +9 -8
- langfun/core/structured/description.py +2 -2
- langfun/core/structured/description_test.py +3 -3
- langfun/core/structured/function_generation.py +278 -0
- langfun/core/structured/function_generation_test.py +399 -0
- langfun/core/structured/mapping.py +150 -46
- langfun/core/structured/mapping_test.py +105 -0
- langfun/core/structured/parsing.py +33 -21
- langfun/core/structured/parsing_test.py +71 -22
- langfun/core/structured/querying.py +746 -0
- langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
- langfun/core/structured/schema.py +208 -99
- langfun/core/structured/schema_generation.py +1 -1
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +133 -34
- langfun/core/structured/scoring.py +125 -19
- langfun/core/structured/scoring_test.py +30 -0
- langfun/core/structured/tokenization.py +64 -0
- langfun/core/structured/tokenization_test.py +48 -0
- langfun/core/template.py +240 -11
- langfun/core/template_test.py +146 -1
- langfun/core/templates/conversation.py +9 -0
- langfun/core/templates/conversation_test.py +4 -3
- langfun/core/templates/selfplay_test.py +14 -2
- langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
- langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
- langfun/core/coding/python/errors.py +0 -108
- langfun/core/coding/python/errors_test.py +0 -99
- langfun/core/coding/python/permissions.py +0 -90
- langfun/core/coding/python/permissions_test.py +0 -86
- langfun/core/structured/prompting.py +0 -217
- langfun/core/text_formatting.py +0 -162
- langfun/core/text_formatting_test.py +0 -47
- langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
- langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@ import inspect
|
|
18
18
|
import typing
|
19
19
|
import unittest
|
20
20
|
|
21
|
+
import langfun.core as lf
|
21
22
|
from langfun.core.llms import fake
|
22
23
|
from langfun.core.structured import schema as schema_lib
|
23
24
|
import pyglove as pg
|
@@ -192,9 +193,9 @@ class SchemaTest(unittest.TestCase):
|
|
192
193
|
self.assertEqual(schema.parse('{"result": 1}'), 1)
|
193
194
|
schema = schema_lib.Schema(dict[str, int])
|
194
195
|
self.assertEqual(
|
195
|
-
schema.parse(
|
196
|
-
|
197
|
-
|
196
|
+
schema.parse('{"result": {"x": 1}}}'),
|
197
|
+
dict(x=1)
|
198
|
+
)
|
198
199
|
with self.assertRaisesRegex(
|
199
200
|
schema_lib.SchemaError, 'Expect .* but encountered .*'):
|
200
201
|
schema.parse('{"result": "def"}')
|
@@ -260,15 +261,18 @@ class ClassDependenciesTest(unittest.TestCase):
|
|
260
261
|
class A(pg.Object):
|
261
262
|
foo: tuple[Foo, int]
|
262
263
|
|
264
|
+
class B(pg.Object):
|
265
|
+
pass
|
266
|
+
|
263
267
|
class X(pg.Object):
|
264
|
-
k:
|
268
|
+
k: dict[str, B]
|
265
269
|
|
266
|
-
class
|
270
|
+
class C(A):
|
267
271
|
bar: Bar
|
268
272
|
foo2: Foo | X
|
269
273
|
|
270
274
|
a = A(foo=(Foo(1), 0))
|
271
|
-
self.assertEqual(schema_lib.class_dependencies(a), [Foo, A, Bar, X,
|
275
|
+
self.assertEqual(schema_lib.class_dependencies(a), [Foo, A, Bar, B, X, C])
|
272
276
|
|
273
277
|
self.assertEqual(schema_lib.class_dependencies(1), [])
|
274
278
|
|
@@ -280,9 +284,10 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
280
284
|
value_spec: pg.typing.ValueSpec,
|
281
285
|
expected_annotation: str,
|
282
286
|
strict: bool = False,
|
287
|
+
**kwargs,
|
283
288
|
) -> None:
|
284
289
|
self.assertEqual(
|
285
|
-
schema_lib.annotation(value_spec, strict=strict),
|
290
|
+
schema_lib.annotation(value_spec, strict=strict, **kwargs),
|
286
291
|
expected_annotation,
|
287
292
|
)
|
288
293
|
|
@@ -358,11 +363,27 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
358
363
|
self.assert_annotation(
|
359
364
|
pg.typing.Object(Activity).noneable(), 'Activity | None'
|
360
365
|
)
|
366
|
+
self.assert_annotation(
|
367
|
+
pg.typing.Object(Activity).noneable(), 'Activity | None',
|
368
|
+
allowed_dependencies=set([Activity]),
|
369
|
+
)
|
370
|
+
self.assert_annotation(
|
371
|
+
pg.typing.Object(Activity).noneable(), 'Any | None',
|
372
|
+
allowed_dependencies=set(),
|
373
|
+
)
|
361
374
|
|
362
375
|
# List.
|
363
376
|
self.assert_annotation(
|
364
377
|
pg.typing.List(pg.typing.Object(Activity)), 'list[Activity]'
|
365
378
|
)
|
379
|
+
self.assert_annotation(
|
380
|
+
pg.typing.List(pg.typing.Object(Activity)), 'list[Activity]',
|
381
|
+
allowed_dependencies=set([Activity]),
|
382
|
+
)
|
383
|
+
self.assert_annotation(
|
384
|
+
pg.typing.List(pg.typing.Object(Activity)), 'list[Any]',
|
385
|
+
allowed_dependencies=set(),
|
386
|
+
)
|
366
387
|
self.assert_annotation(
|
367
388
|
pg.typing.List(pg.typing.Object(Activity)).noneable(),
|
368
389
|
'list[Activity] | None',
|
@@ -374,16 +395,35 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
374
395
|
|
375
396
|
# Tuple.
|
376
397
|
self.assert_annotation(
|
377
|
-
pg.typing.Tuple([
|
398
|
+
pg.typing.Tuple([Activity, pg.typing.Str()]), 'tuple[Activity, str]'
|
399
|
+
)
|
400
|
+
self.assert_annotation(
|
401
|
+
pg.typing.Tuple([Activity, pg.typing.Str()]), 'tuple[Activity, str]',
|
402
|
+
allowed_dependencies=set([Activity]),
|
403
|
+
)
|
404
|
+
self.assert_annotation(
|
405
|
+
pg.typing.Tuple([Activity, pg.typing.Str()]), 'tuple[Any, str]',
|
406
|
+
allowed_dependencies=set(),
|
378
407
|
)
|
379
408
|
self.assert_annotation(
|
380
|
-
pg.typing.Tuple([
|
381
|
-
'tuple[
|
409
|
+
pg.typing.Tuple([Activity, pg.typing.Str()]).noneable(),
|
410
|
+
'tuple[Activity, str] | None',
|
382
411
|
)
|
383
412
|
|
384
413
|
# Dict.
|
385
414
|
self.assert_annotation(
|
386
|
-
pg.typing.Dict({'x':
|
415
|
+
pg.typing.Dict({'x': Activity, 'y': str}),
|
416
|
+
'{\'x\': Activity, \'y\': str}'
|
417
|
+
)
|
418
|
+
self.assert_annotation(
|
419
|
+
pg.typing.Dict({'x': Activity, 'y': str}),
|
420
|
+
'{\'x\': Activity, \'y\': str}',
|
421
|
+
allowed_dependencies=set([Activity]),
|
422
|
+
)
|
423
|
+
self.assert_annotation(
|
424
|
+
pg.typing.Dict({'x': Activity, 'y': str}),
|
425
|
+
'{\'x\': Any, \'y\': str}',
|
426
|
+
allowed_dependencies=set(),
|
387
427
|
)
|
388
428
|
self.assert_annotation(
|
389
429
|
pg.typing.Dict({'x': int, 'y': str}),
|
@@ -395,6 +435,15 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
395
435
|
'dict[str, Any]',
|
396
436
|
strict=False,
|
397
437
|
)
|
438
|
+
|
439
|
+
class DictValue(pg.Object):
|
440
|
+
pass
|
441
|
+
|
442
|
+
self.assert_annotation(
|
443
|
+
pg.typing.Dict([(pg.typing.StrKey(), DictValue)]),
|
444
|
+
'dict[str, DictValue]',
|
445
|
+
strict=False,
|
446
|
+
)
|
398
447
|
self.assert_annotation(
|
399
448
|
pg.typing.Dict(),
|
400
449
|
'dict[str, Any]',
|
@@ -408,6 +457,13 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
408
457
|
).noneable(),
|
409
458
|
'Union[Activity, Itinerary, None]',
|
410
459
|
)
|
460
|
+
self.assert_annotation(
|
461
|
+
pg.typing.Union(
|
462
|
+
[pg.typing.Object(Activity), pg.typing.Object(Itinerary)]
|
463
|
+
).noneable(),
|
464
|
+
'Union[Activity, Any, None]',
|
465
|
+
allowed_dependencies=set([Activity]),
|
466
|
+
)
|
411
467
|
|
412
468
|
# Any.
|
413
469
|
self.assert_annotation(pg.typing.Any(), 'Any')
|
@@ -415,13 +471,13 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
415
471
|
|
416
472
|
def test_class_definition(self):
|
417
473
|
self.assertEqual(
|
418
|
-
schema_lib.class_definition(Activity),
|
474
|
+
schema_lib.class_definition(Activity, allowed_dependencies=set()),
|
419
475
|
'class Activity:\n description: str\n',
|
420
476
|
)
|
421
477
|
self.assertEqual(
|
422
478
|
schema_lib.class_definition(Itinerary),
|
423
479
|
inspect.cleandoc("""
|
424
|
-
class Itinerary:
|
480
|
+
class Itinerary(Object):
|
425
481
|
\"\"\"A travel itinerary for a day.\"\"\"
|
426
482
|
day: int(min=1)
|
427
483
|
type: Literal['daytime', 'nighttime']
|
@@ -431,7 +487,9 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
431
487
|
""") + '\n',
|
432
488
|
)
|
433
489
|
self.assertEqual(
|
434
|
-
schema_lib.class_definition(
|
490
|
+
schema_lib.class_definition(
|
491
|
+
PlaceOfInterest, allowed_dependencies=set()
|
492
|
+
),
|
435
493
|
inspect.cleandoc("""
|
436
494
|
class PlaceOfInterest:
|
437
495
|
\"\"\"The name of a place of interest.
|
@@ -447,11 +505,11 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
447
505
|
pass
|
448
506
|
|
449
507
|
self.assertEqual(
|
450
|
-
schema_lib.class_definition(A),
|
508
|
+
schema_lib.class_definition(A, allowed_dependencies=set()),
|
451
509
|
'class A:\n pass\n',
|
452
510
|
)
|
453
511
|
self.assertEqual(
|
454
|
-
schema_lib.class_definition(A
|
512
|
+
schema_lib.class_definition(A),
|
455
513
|
'class A(Object):\n pass\n',
|
456
514
|
)
|
457
515
|
|
@@ -459,9 +517,27 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
459
517
|
x: str
|
460
518
|
__kwargs__: typing.Any
|
461
519
|
|
462
|
-
|
463
|
-
|
464
|
-
|
520
|
+
self.assertEqual(
|
521
|
+
schema_lib.class_definition(C), 'class C(Object):\n x: str\n'
|
522
|
+
)
|
523
|
+
|
524
|
+
class D(pg.Object):
|
525
|
+
x: str
|
526
|
+
@schema_lib.include_method_in_prompt
|
527
|
+
def __call__(self, y: int) -> int:
|
528
|
+
return len(self.x) + y
|
529
|
+
|
530
|
+
self.assertEqual(
|
531
|
+
schema_lib.class_definition(D),
|
532
|
+
inspect.cleandoc(
|
533
|
+
"""
|
534
|
+
class D(Object):
|
535
|
+
x: str
|
536
|
+
|
537
|
+
def __call__(self, y: int) -> int:
|
538
|
+
return len(self.x) + y
|
539
|
+
""") + '\n'
|
540
|
+
)
|
465
541
|
|
466
542
|
def test_repr(self):
|
467
543
|
class Foo(pg.Object):
|
@@ -479,10 +555,21 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
479
555
|
class A(pg.Object):
|
480
556
|
foo: Foo
|
481
557
|
|
558
|
+
@schema_lib.include_method_in_prompt
|
559
|
+
def foo_value(self) -> int:
|
560
|
+
return self.foo.x
|
561
|
+
|
562
|
+
def baz_value(self) -> str:
|
563
|
+
return 'baz'
|
564
|
+
|
482
565
|
class B(A):
|
483
566
|
bar: Bar
|
484
567
|
foo2: Foo
|
485
568
|
|
569
|
+
@schema_lib.include_method_in_prompt
|
570
|
+
def bar_value(self) -> str:
|
571
|
+
return self.bar.y
|
572
|
+
|
486
573
|
schema = schema_lib.Schema([B])
|
487
574
|
self.assertEqual(
|
488
575
|
schema_lib.SchemaPythonRepr().class_definitions(schema),
|
@@ -490,9 +577,6 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
490
577
|
class Foo:
|
491
578
|
x: int
|
492
579
|
|
493
|
-
class A:
|
494
|
-
foo: Foo
|
495
|
-
|
496
580
|
class Bar:
|
497
581
|
"""Class Bar."""
|
498
582
|
y: str
|
@@ -501,10 +585,16 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
501
585
|
"""Baz(y: str)"""
|
502
586
|
y: str
|
503
587
|
|
504
|
-
class B
|
588
|
+
class B:
|
505
589
|
foo: Foo
|
506
590
|
bar: Bar
|
507
591
|
foo2: Foo
|
592
|
+
|
593
|
+
def bar_value(self) -> str:
|
594
|
+
return self.bar.y
|
595
|
+
|
596
|
+
def foo_value(self) -> int:
|
597
|
+
return self.foo.x
|
508
598
|
''') + '\n',
|
509
599
|
)
|
510
600
|
|
@@ -521,9 +611,6 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
521
611
|
class Foo:
|
522
612
|
x: int
|
523
613
|
|
524
|
-
class A:
|
525
|
-
foo: Foo
|
526
|
-
|
527
614
|
class Bar:
|
528
615
|
"""Class Bar."""
|
529
616
|
y: str
|
@@ -532,10 +619,16 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
532
619
|
"""Baz(y: str)"""
|
533
620
|
y: str
|
534
621
|
|
535
|
-
class B
|
622
|
+
class B:
|
536
623
|
foo: Foo
|
537
624
|
bar: Bar
|
538
625
|
foo2: Foo
|
626
|
+
|
627
|
+
def bar_value(self) -> str:
|
628
|
+
return self.bar.y
|
629
|
+
|
630
|
+
def foo_value(self) -> int:
|
631
|
+
return self.foo.x
|
539
632
|
```
|
540
633
|
'''),
|
541
634
|
)
|
@@ -543,16 +636,12 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
543
636
|
schema_lib.SchemaPythonRepr().repr(
|
544
637
|
schema,
|
545
638
|
include_result_definition=False,
|
546
|
-
include_pg_object_as_base=True,
|
547
639
|
markdown=False,
|
548
640
|
),
|
549
641
|
inspect.cleandoc('''
|
550
|
-
class Foo
|
642
|
+
class Foo:
|
551
643
|
x: int
|
552
644
|
|
553
|
-
class A(Object):
|
554
|
-
foo: Foo
|
555
|
-
|
556
645
|
class Bar:
|
557
646
|
"""Class Bar."""
|
558
647
|
y: str
|
@@ -561,10 +650,16 @@ class SchemaPythonReprTest(unittest.TestCase):
|
|
561
650
|
"""Baz(y: str)"""
|
562
651
|
y: str
|
563
652
|
|
564
|
-
class B
|
653
|
+
class B:
|
565
654
|
foo: Foo
|
566
655
|
bar: Bar
|
567
656
|
foo2: Foo
|
657
|
+
|
658
|
+
def bar_value(self) -> str:
|
659
|
+
return self.bar.y
|
660
|
+
|
661
|
+
def foo_value(self) -> int:
|
662
|
+
return self.foo.x
|
568
663
|
'''),
|
569
664
|
)
|
570
665
|
|
@@ -600,6 +695,10 @@ class ValuePythonReprTest(unittest.TestCase):
|
|
600
695
|
schema_lib.ValuePythonRepr().repr(1, schema_lib.Schema(int)),
|
601
696
|
'```python\n1\n```'
|
602
697
|
)
|
698
|
+
self.assertEqual(
|
699
|
+
schema_lib.ValuePythonRepr().repr(lf.Template('hi, {{a}}', a='foo')),
|
700
|
+
'hi, foo'
|
701
|
+
)
|
603
702
|
self.assertEqual(
|
604
703
|
schema_lib.ValuePythonRepr().repr(
|
605
704
|
A([Foo(1), Foo(2)], 'bar'), schema_lib.Schema(A), markdown=False,
|
@@ -612,7 +711,7 @@ class ValuePythonReprTest(unittest.TestCase):
|
|
612
711
|
```python
|
613
712
|
class Foo(Object):
|
614
713
|
x: int
|
615
|
-
|
714
|
+
|
616
715
|
class A(Object):
|
617
716
|
foo: list[Foo]
|
618
717
|
y: str | None
|
@@ -17,13 +17,13 @@ from typing import Any, Type, Union
|
|
17
17
|
|
18
18
|
import langfun.core as lf
|
19
19
|
from langfun.core.structured import mapping
|
20
|
-
from langfun.core.structured import
|
20
|
+
from langfun.core.structured import querying
|
21
21
|
from langfun.core.structured import schema as schema_lib
|
22
22
|
import pyglove as pg
|
23
23
|
|
24
24
|
|
25
25
|
def score(
|
26
|
-
prompt: Union[str, pg.Symbolic],
|
26
|
+
prompt: Union[str, pg.Symbolic] | list[str | pg.Symbolic],
|
27
27
|
completions: list[str | pg.Symbolic],
|
28
28
|
schema: Union[
|
29
29
|
schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
|
@@ -32,9 +32,58 @@ def score(
|
|
32
32
|
lm: lf.LanguageModel | None = None,
|
33
33
|
examples: list[mapping.MappingExample] | None = None,
|
34
34
|
protocol: schema_lib.SchemaProtocol = 'python',
|
35
|
+
return_scoring_results: bool = False,
|
35
36
|
**kwargs,
|
36
|
-
) -> list[float]:
|
37
|
-
"""Scores the outputs based on the prompt.
|
37
|
+
) -> list[float] | list[lf.LMScoringResult]:
|
38
|
+
"""Scores the outputs based on the prompt.
|
39
|
+
|
40
|
+
Examples:
|
41
|
+
```
|
42
|
+
# Example 1: Scoring text output based on the user prompt.
|
43
|
+
scores = lf.score('{{x}} + {{y}} =', ['1', '2', '3'], lm=lm, x=1, y=2)
|
44
|
+
assert len(scores) == 3
|
45
|
+
|
46
|
+
# Example 2: Scoring int output based on the formulated OOP prompt.
|
47
|
+
scores = lf.score('1 + 1 =', [1, 2, 3], lm=lm)
|
48
|
+
assert len(scores) == 3
|
49
|
+
|
50
|
+
class Answer(pg.Object):
|
51
|
+
result: int
|
52
|
+
|
53
|
+
# Example 3: Scoring object output based on the formulated OOP prompt.
|
54
|
+
scores = lf.score('1 + 1 =', [Answer(1), Answer(2), Answer(3)], lm=lm)
|
55
|
+
assert len(scores) == 3
|
56
|
+
|
57
|
+
# Example 4: Scoring object field value based on the formulated OOP prompt
|
58
|
+
# and the generated tokens before the first `pg.oneof`.
|
59
|
+
scores = lf.score('1 + 1 =', [Answer(pg.oneof([1, 2, 3]))], lm=lm)
|
60
|
+
assert len(scores) == 3
|
61
|
+
|
62
|
+
# Example 5: Scoring multiple prompt/completion pairs.
|
63
|
+
scores = lf.score(
|
64
|
+
['1 + 1=', '2 + 3='],
|
65
|
+
['2', '4'],
|
66
|
+
lm=lm
|
67
|
+
)
|
68
|
+
assert len(scores) == 2
|
69
|
+
```
|
70
|
+
|
71
|
+
Args:
|
72
|
+
prompt: The prompt(s) based on which each completion will be scored.
|
73
|
+
completions: A list of strings or symbolic objects as the output.
|
74
|
+
schema: The schema as the output type. If None, it will be inferred from
|
75
|
+
the completions.
|
76
|
+
lm: The language model used for scoring.
|
77
|
+
examples: Fewshot exemplars used together with the prompt in getting the
|
78
|
+
completions.
|
79
|
+
protocol: The protocol for formulating the prompt based on objects.
|
80
|
+
return_scoring_results: If True, returns a list of `lf.LMScoringResult`,
|
81
|
+
otherwise returns a list of floats as the scores of each completion.
|
82
|
+
**kwargs: Keyword arguments that are referred by the prompt.
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
A list of floats or `lf.LMScoringResult` as the score of each completion.
|
86
|
+
"""
|
38
87
|
if not completions:
|
39
88
|
raise ValueError('`completions` must not be empty.')
|
40
89
|
|
@@ -48,28 +97,85 @@ def score(
|
|
48
97
|
f'{[type(c) for c in completions]}.'
|
49
98
|
)
|
50
99
|
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
100
|
+
if isinstance(prompt, list):
|
101
|
+
prompts = []
|
102
|
+
for p in prompt:
|
103
|
+
prompts.append(
|
104
|
+
querying.query_prompt(
|
105
|
+
p,
|
106
|
+
schema,
|
107
|
+
examples=examples,
|
108
|
+
protocol=protocol,
|
109
|
+
**kwargs,
|
110
|
+
)
|
111
|
+
)
|
112
|
+
input_message = prompts
|
113
|
+
else:
|
114
|
+
input_message = querying.query_prompt(
|
115
|
+
prompt,
|
116
|
+
schema,
|
117
|
+
examples=examples,
|
118
|
+
protocol=protocol,
|
119
|
+
**kwargs,
|
120
|
+
)
|
60
121
|
if lm is None:
|
61
122
|
lm_override = lf.get_contextual_override('lm')
|
62
123
|
if lm_override is None:
|
63
124
|
raise ValueError('`lm` must be specified or provided from `lf.context`.')
|
64
125
|
lm = lm_override.value
|
65
126
|
|
127
|
+
completion_reprs = []
|
128
|
+
for c in completions:
|
129
|
+
if isinstance(c, mapping.MappingError):
|
130
|
+
completion_reprs.append(c.lm_response)
|
131
|
+
else:
|
132
|
+
rep = mapping.MappingExample.value_repr(
|
133
|
+
c, protocol=protocol, compact=False, verbose=False
|
134
|
+
)
|
135
|
+
|
136
|
+
# NOTE(daiyip): supporting scenario of scoring object field with
|
137
|
+
# `pg.oneof`.
|
138
|
+
oneof_pos = rep.find('OneOf(')
|
139
|
+
if oneof_pos == -1:
|
140
|
+
completion_reprs.append(rep)
|
141
|
+
else:
|
142
|
+
assert protocol == 'python', protocol
|
143
|
+
if isinstance(input_message, list):
|
144
|
+
raise ValueError(
|
145
|
+
'Scoring on object fields using `pg.oneof` must share the '
|
146
|
+
f'same prompt. Encountered: {prompt}'
|
147
|
+
)
|
148
|
+
input_message.text += '\n' + rep[:oneof_pos]
|
149
|
+
oneof = _get_first_oneof(c)
|
150
|
+
for v in oneof.candidates:
|
151
|
+
completion_reprs.append(
|
152
|
+
pg.format(
|
153
|
+
v,
|
154
|
+
python_format=True,
|
155
|
+
compact=False,
|
156
|
+
verbose=False,
|
157
|
+
root_indent=oneof.sym_path.depth
|
158
|
+
)
|
159
|
+
)
|
160
|
+
|
66
161
|
results = lm.score(
|
67
162
|
input_message,
|
68
|
-
|
69
|
-
mapping.MappingExample.value_repr(
|
70
|
-
c, protocol=protocol, compact=False, verbose=False
|
71
|
-
)
|
72
|
-
for c in completions
|
73
|
-
],
|
163
|
+
completion_reprs,
|
74
164
|
)
|
165
|
+
if return_scoring_results:
|
166
|
+
return results
|
75
167
|
return [r.score for r in results]
|
168
|
+
|
169
|
+
|
170
|
+
def _get_first_oneof(value: Any) -> pg.hyper.OneOf:
|
171
|
+
"""Gets the first pg.oneof from a symbolic object."""
|
172
|
+
oneofs = []
|
173
|
+
def select_oneofs(k, v, p):
|
174
|
+
del k, p
|
175
|
+
if isinstance(v, pg.hyper.OneOf):
|
176
|
+
oneofs.append(v)
|
177
|
+
return pg.TraverseAction.CONTINUE
|
178
|
+
return pg.TraverseAction.ENTER
|
179
|
+
pg.traverse(value, select_oneofs)
|
180
|
+
assert oneofs
|
181
|
+
return oneofs[0]
|
@@ -16,6 +16,11 @@ import unittest
|
|
16
16
|
import langfun.core as lf
|
17
17
|
from langfun.core.llms import fake
|
18
18
|
from langfun.core.structured import scoring
|
19
|
+
import pyglove as pg
|
20
|
+
|
21
|
+
|
22
|
+
class Answer(pg.Object):
|
23
|
+
result: int
|
19
24
|
|
20
25
|
|
21
26
|
class ScoringTest(unittest.TestCase):
|
@@ -32,9 +37,34 @@ class ScoringTest(unittest.TestCase):
|
|
32
37
|
with self.assertRaisesRegex(ValueError, '`lm` must be specified'):
|
33
38
|
scoring.score('hi', [1, 2])
|
34
39
|
|
40
|
+
with self.assertRaisesRegex(
|
41
|
+
ValueError,
|
42
|
+
'Scoring on object fields using `pg.oneof` must share the same prompt',
|
43
|
+
):
|
44
|
+
scoring.score(
|
45
|
+
['1 + 1=', '2 + 3='],
|
46
|
+
[Answer(pg.oneof([1, 2, 3]))],
|
47
|
+
lm=fake.Echo(),
|
48
|
+
)
|
49
|
+
|
35
50
|
def test_score(self):
|
36
51
|
self.assertEqual(scoring.score('hi', [1, 2], lm=fake.Echo()), [0.0, -1.0])
|
37
52
|
|
53
|
+
def test_score_on_field_values(self):
|
54
|
+
self.assertEqual(
|
55
|
+
scoring.score(
|
56
|
+
'1 + 1=',
|
57
|
+
[Answer(pg.oneof([1, 2, 3]))], lm=fake.Echo()
|
58
|
+
),
|
59
|
+
[0.0, -1.0, -2.0]
|
60
|
+
)
|
61
|
+
|
62
|
+
def test_score_returning_scoring_results(self):
|
63
|
+
self.assertEqual(scoring.score(
|
64
|
+
'hi', [1, 2], lm=fake.Echo(), return_scoring_results=True),
|
65
|
+
[lf.LMScoringResult(score=0.0, gradients=None),
|
66
|
+
lf.LMScoringResult(score=-1.0, gradients=None)])
|
67
|
+
|
38
68
|
def test_scope_with_lm_from_the_context(self):
|
39
69
|
with lf.context(lm=fake.Echo()):
|
40
70
|
self.assertEqual(scoring.score('hi', [1, 2]), [0.0, -1.0])
|
@@ -0,0 +1,64 @@
|
|
1
|
+
# Copyright 2023 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Tokenize the prompt for `lf.query`."""
|
15
|
+
|
16
|
+
from typing import Any, Type, Union
|
17
|
+
|
18
|
+
import langfun.core as lf
|
19
|
+
from langfun.core.structured import mapping
|
20
|
+
from langfun.core.structured import querying
|
21
|
+
from langfun.core.structured import schema as schema_lib
|
22
|
+
import pyglove as pg
|
23
|
+
|
24
|
+
|
25
|
+
def tokenize(
|
26
|
+
prompt: Union[str, pg.Symbolic] | list[str | pg.Symbolic],
|
27
|
+
schema: Union[
|
28
|
+
schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
|
29
|
+
] = None,
|
30
|
+
*,
|
31
|
+
lm: lf.LanguageModel | None = None,
|
32
|
+
examples: list[mapping.MappingExample] | None = None,
|
33
|
+
protocol: schema_lib.SchemaProtocol = 'python',
|
34
|
+
**kwargs,
|
35
|
+
) -> list[tuple[str | bytes, int]]:
|
36
|
+
"""Tokenize the prompt for `lf.query`.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
prompt: The prompt(s) based on which each completion will be scored.
|
40
|
+
schema: The schema as the output type. If None, it will be inferred from
|
41
|
+
the completions.
|
42
|
+
lm: The language model used for scoring.
|
43
|
+
examples: Fewshot exemplars used together with the prompt in getting the
|
44
|
+
completions.
|
45
|
+
protocol: The protocol for formulating the prompt based on objects.
|
46
|
+
**kwargs: Keyword arguments that are referred by the prompt.
|
47
|
+
|
48
|
+
Returns:
|
49
|
+
A list of (text, token_id) tuples.
|
50
|
+
"""
|
51
|
+
input_message = querying.query_prompt(
|
52
|
+
prompt,
|
53
|
+
schema,
|
54
|
+
examples=examples,
|
55
|
+
protocol=protocol,
|
56
|
+
**kwargs,
|
57
|
+
)
|
58
|
+
if lm is None:
|
59
|
+
lm_override = lf.get_contextual_override('lm')
|
60
|
+
if lm_override is None:
|
61
|
+
raise ValueError('`lm` must be specified or provided from `lf.context`.')
|
62
|
+
lm = lm_override.value
|
63
|
+
|
64
|
+
return lm.tokenize(input_message)
|
@@ -0,0 +1,48 @@
|
|
1
|
+
# Copyright 2023 The Langfun Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import unittest
|
16
|
+
import langfun.core as lf
|
17
|
+
from langfun.core.llms import fake
|
18
|
+
from langfun.core.structured import tokenization
|
19
|
+
import pyglove as pg
|
20
|
+
|
21
|
+
|
22
|
+
class Answer(pg.Object):
|
23
|
+
result: int
|
24
|
+
|
25
|
+
|
26
|
+
class TokenizationTest(unittest.TestCase):
|
27
|
+
|
28
|
+
def test_bad_call(self):
|
29
|
+
|
30
|
+
with self.assertRaisesRegex(ValueError, '`lm` must be specified'):
|
31
|
+
tokenization.tokenize('hi')
|
32
|
+
|
33
|
+
def test_tokenize(self):
|
34
|
+
self.assertEqual(
|
35
|
+
tokenization.tokenize('hi', lm=fake.Echo()),
|
36
|
+
[('hi', 0)]
|
37
|
+
)
|
38
|
+
|
39
|
+
def test_tokenize_with_lm_from_the_context(self):
|
40
|
+
with lf.context(lm=fake.Echo()):
|
41
|
+
self.assertEqual(
|
42
|
+
tokenization.tokenize('hi'),
|
43
|
+
[('hi', 0)]
|
44
|
+
)
|
45
|
+
|
46
|
+
|
47
|
+
if __name__ == '__main__':
|
48
|
+
unittest.main()
|