langfun 0.0.2.dev20240215__py3-none-any.whl → 0.0.2.dev20240217__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langfun/__init__.py CHANGED
@@ -30,6 +30,7 @@ parse = structured.parse
30
30
  query = structured.query
31
31
  describe = structured.describe
32
32
  complete = structured.complete
33
+ score = structured.score
33
34
 
34
35
  from langfun.core import eval # pylint: disable=redefined-builtin
35
36
  from langfun.core import templates
langfun/core/__init__.py CHANGED
@@ -100,6 +100,7 @@ from langfun.core.language_model import LanguageModel
100
100
  from langfun.core.language_model import LMSample
101
101
  from langfun.core.language_model import LMSamplingOptions
102
102
  from langfun.core.language_model import LMSamplingResult
103
+ from langfun.core.language_model import LMScoringResult
103
104
  from langfun.core.language_model import LMCache
104
105
  from langfun.core.language_model import LMDebugMode
105
106
 
langfun/core/langfunc.py CHANGED
@@ -210,6 +210,7 @@ class LangFunc(
210
210
  lm: language_model.LanguageModel | None = None,
211
211
  lm_input: message_lib.Message | None = None,
212
212
  cache_seed: int | None = 0,
213
+ skip_lm: bool = False,
213
214
  **variables,
214
215
  ) -> message_lib.Message:
215
216
  """Calls language model with `lm_input` or rendered text.
@@ -223,6 +224,8 @@ class LangFunc(
223
224
  cache_seed: Seed for computing cache key. The cache key is determined by a
224
225
  tuple of (lm, prompt, cache seed). If None, cache will be disabled for
225
226
  the query even cache is configured by the LM.
227
+ skip_lm: If True, returns the rendered prompt as a UserMessage object.
228
+ otherwise return the LLM response based on the rendered prompt.
226
229
  **variables: Template variables applicable to this or child LangFunc.
227
230
 
228
231
  Returns:
@@ -232,6 +235,7 @@ class LangFunc(
232
235
  lm=lm,
233
236
  lm_input=lm_input,
234
237
  cache_seed=cache_seed,
238
+ skip_lm=skip_lm,
235
239
  **variables,
236
240
  )
237
241
 
@@ -241,6 +245,7 @@ class LangFunc(
241
245
  lm: language_model.LanguageModel | None = None,
242
246
  lm_input: message_lib.Message | None = None,
243
247
  cache_seed: int | None = 0,
248
+ skip_lm: bool = False,
244
249
  **variables,
245
250
  ) -> message_lib.Message:
246
251
  """Call the language model once, with invoking the output transform."""
@@ -256,10 +261,13 @@ class LangFunc(
256
261
  if lm_input is None:
257
262
  lm_input = self.render(**kwargs)
258
263
 
264
+ lm_input.tag(message_lib.Message.TAG_LM_INPUT)
265
+ if skip_lm:
266
+ return lm_input
267
+
259
268
  self._cached_lm_input = lm_input
260
269
 
261
270
  # Send rendered text to LM.
262
- lm_input.tag(message_lib.Message.TAG_LM_INPUT)
263
271
  lm_output = self.lm(lm_input, cache_seed=cache_seed)
264
272
 
265
273
  # Track the input as the source of the output.
@@ -194,6 +194,11 @@ class LangFuncCallTest(unittest.TestCase):
194
194
  self.assertEqual(l(x=1, cache_seed=None), 'd')
195
195
  self.assertEqual(l(x=2), 'b')
196
196
 
197
+ def test_call_with_skip_lm(self):
198
+ l = LangFunc('hi')
199
+ with component.context(lm=ExcitedEchoer()):
200
+ self.assertEqual(l(skip_lm=True), 'hi')
201
+
197
202
 
198
203
  class CallEventTest(unittest.TestCase):
199
204
 
@@ -127,6 +127,15 @@ class LMSamplingOptions(component.Component):
127
127
  )
128
128
 
129
129
 
130
+ class LMScoringResult(pg.Object):
131
+ """Language model scoring result."""
132
+
133
+ score: Annotated[
134
+ float,
135
+ 'The log likelyhood of the requested completion towards the prompt.',
136
+ ]
137
+
138
+
130
139
  class LMCache(pg.Object):
131
140
  """Interface for LM cache."""
132
141
 
@@ -425,3 +434,83 @@ class LanguageModel(component.Component):
425
434
  title=f'\n[{call_counter}] LM RESPONSE (in {elapse:.2f} seconds):',
426
435
  color='blue',
427
436
  )
437
+
438
+ def score(
439
+ self,
440
+ prompt: str | message_lib.Message,
441
+ completions: list[str | message_lib.Message],
442
+ **kwargs,
443
+ ) -> list[LMScoringResult]:
444
+ """Scores the given prompt."""
445
+ prompt = message_lib.UserMessage.from_value(prompt)
446
+ completions = [message_lib.UserMessage.from_value(c) for c in completions]
447
+
448
+ call_counter = self._call_counter
449
+ self._call_counter += 1
450
+ request_start = time.time()
451
+
452
+ with component.context(override_attrs=True, **kwargs):
453
+ scoring_results = self._score(prompt, completions)
454
+ elapse = time.time() - request_start
455
+ self._debug_score(
456
+ prompt, completions, scoring_results, call_counter, elapse
457
+ )
458
+ return scoring_results
459
+
460
+ def _score(
461
+ self, prompt: message_lib.Message, completions: list[message_lib.Message]
462
+ ) -> list[LMScoringResult]:
463
+ """Subclass to implement."""
464
+ raise NotImplementedError(
465
+ f'{self.__class__.__name__} does not support scoring.'
466
+ )
467
+
468
+ def _debug_score(
469
+ self,
470
+ prompt: message_lib.Message,
471
+ completions: list[message_lib.Message],
472
+ scoring_results: list[LMScoringResult],
473
+ call_counter: int,
474
+ elapse: float,
475
+ ):
476
+ debug = self.debug
477
+ if isinstance(debug, bool):
478
+ debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
479
+
480
+ if debug & LMDebugMode.INFO:
481
+ self._debug_model_info(call_counter)
482
+
483
+ if debug & LMDebugMode.PROMPT:
484
+ console.write(
485
+ prompt,
486
+ title=f'\n[{call_counter}] SCORING LM WITH PROMPT:',
487
+ color='green',
488
+ )
489
+ referred_modalities = prompt.referred_modalities()
490
+ if referred_modalities:
491
+ console.write(
492
+ pg.object_utils.kvlist_str(
493
+ [(k, repr(v), None) for k, v in referred_modalities.items()]
494
+ ),
495
+ title=f'\n[{call_counter}] MODALITY OBJECTS SENT TO LM:',
496
+ color='green',
497
+ )
498
+
499
+ if debug & LMDebugMode.RESPONSE:
500
+ console.write(
501
+ '',
502
+ title=(
503
+ f'\n[{call_counter}] SCORING COMPLETED (in {elapse:.2f} seconds):'
504
+ ),
505
+ color='blue',
506
+ )
507
+ for i, (c, r) in enumerate(zip(completions, scoring_results)):
508
+ console.write(
509
+ c,
510
+ title=f'COMPLETION #{i}',
511
+ color='green',
512
+ )
513
+ console.write(
514
+ f'score: {r.score}',
515
+ color='blue',
516
+ )
@@ -54,6 +54,19 @@ class MockModel(lm_lib.LanguageModel):
54
54
  )(prompts)
55
55
 
56
56
 
57
+ class MockScoringModel(MockModel):
58
+
59
+ def _score(
60
+ self,
61
+ prompt: message_lib.Message,
62
+ completions: list[message_lib.Message],
63
+ **kwargs
64
+ ) -> list[lm_lib.LMScoringResult]:
65
+ return [
66
+ lm_lib.LMScoringResult(score=-i * 1.0) for i in range(len(completions))
67
+ ]
68
+
69
+
57
70
  class LMSamplingOptionsTest(unittest.TestCase):
58
71
  """Tests for LMSamplingOptions."""
59
72
 
@@ -266,6 +279,68 @@ class LanguageModelTest(unittest.TestCase):
266
279
  for expected_exclude in expected_excluded:
267
280
  self.assertNotIn('[0] ' + expected_exclude, debug_info)
268
281
 
282
+ def test_score(self):
283
+ info_flag = lm_lib.LMDebugMode.INFO
284
+ prompt_flag = lm_lib.LMDebugMode.PROMPT
285
+ response_flag = lm_lib.LMDebugMode.RESPONSE
286
+ debug_prints = {
287
+ info_flag: 'LM INFO',
288
+ prompt_flag: 'SCORING LM WITH PROMPT',
289
+ response_flag: 'SCORING COMPLETED',
290
+ }
291
+ debug_modes = [
292
+ info_flag,
293
+ prompt_flag,
294
+ response_flag,
295
+ info_flag | prompt_flag,
296
+ info_flag | response_flag,
297
+ prompt_flag | response_flag,
298
+ info_flag | prompt_flag | response_flag,
299
+ ]
300
+
301
+ class Image(modality.Modality):
302
+ def to_bytes(self):
303
+ return b'fake_image'
304
+
305
+ for debug_mode in debug_modes:
306
+ string_io = io.StringIO()
307
+ lm = MockScoringModel()
308
+
309
+ with contextlib.redirect_stdout(string_io):
310
+ self.assertEqual(
311
+ lm.score(
312
+ message_lib.UserMessage('hi {{image}}', image=Image()),
313
+ ['1', '2'], debug=debug_mode),
314
+ [
315
+ lm_lib.LMScoringResult(score=-0.0),
316
+ lm_lib.LMScoringResult(score=-1.0),
317
+ ],
318
+ )
319
+
320
+ debug_info = string_io.getvalue()
321
+ expected_included = [
322
+ debug_prints[f]
323
+ for f in lm_lib.LMDebugMode
324
+ if f != lm_lib.LMDebugMode.NONE and f in debug_mode
325
+ ]
326
+ expected_excluded = [
327
+ debug_prints[f]
328
+ for f in lm_lib.LMDebugMode
329
+ if f != lm_lib.LMDebugMode.NONE and f not in debug_mode
330
+ ]
331
+
332
+ for expected_include in expected_included:
333
+ self.assertIn('[0] ' + expected_include, debug_info)
334
+ for expected_exclude in expected_excluded:
335
+ self.assertNotIn('[0] ' + expected_exclude, debug_info)
336
+
337
+ if debug_mode & lm_lib.LMDebugMode.PROMPT:
338
+ self.assertIn('[0] MODALITY OBJECTS SENT TO LM', debug_info)
339
+
340
+ def test_score_with_unsupported_model(self):
341
+ with self.assertRaises(NotImplementedError):
342
+ MockModel().score('hi', ['1', '2'])
343
+
269
344
 
270
345
  if __name__ == '__main__':
271
346
  unittest.main()
@@ -18,6 +18,7 @@
18
18
  # pylint: disable=g-import-not-at-top
19
19
 
20
20
  # LMs for testing.
21
+ from langfun.core.llms.fake import Fake
21
22
  from langfun.core.llms.fake import Echo
22
23
  from langfun.core.llms.fake import StaticMapping
23
24
  from langfun.core.llms.fake import StaticResponse
langfun/core/llms/fake.py CHANGED
@@ -17,7 +17,14 @@ from typing import Annotated
17
17
  import langfun.core as lf
18
18
 
19
19
 
20
- class Echo(lf.LanguageModel):
20
+ class Fake(lf.LanguageModel):
21
+ """The base class for all fake language models."""
22
+
23
+ def _score(self, prompt: lf.Message, completions: list[lf.Message]):
24
+ return [lf.LMScoringResult(score=-i * 1.0) for i in range(len(completions))]
25
+
26
+
27
+ class Echo(Fake):
21
28
  """A simple echo language model for testing."""
22
29
 
23
30
  def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
@@ -28,7 +35,7 @@ class Echo(lf.LanguageModel):
28
35
 
29
36
 
30
37
  @lf.use_init_args(['response'])
31
- class StaticResponse(lf.LanguageModel):
38
+ class StaticResponse(Fake):
32
39
  """Language model that always gives the same canned response."""
33
40
 
34
41
  response: Annotated[
@@ -44,7 +51,7 @@ class StaticResponse(lf.LanguageModel):
44
51
 
45
52
 
46
53
  @lf.use_init_args(['mapping'])
47
- class StaticMapping(lf.LanguageModel):
54
+ class StaticMapping(Fake):
48
55
  """A static mapping from prompt to response."""
49
56
 
50
57
  mapping: Annotated[
@@ -60,7 +67,7 @@ class StaticMapping(lf.LanguageModel):
60
67
 
61
68
 
62
69
  @lf.use_init_args(['sequence'])
63
- class StaticSequence(lf.LanguageModel):
70
+ class StaticSequence(Fake):
64
71
  """A static sequence of responses to use."""
65
72
 
66
73
  sequence: Annotated[
@@ -38,6 +38,13 @@ class EchoTest(unittest.TestCase):
38
38
  self.assertIn('[0] PROMPT SENT TO LM:', debug_info)
39
39
  self.assertIn('[0] LM RESPONSE', debug_info)
40
40
 
41
+ def test_score(self):
42
+ lm = fakelm.Echo()
43
+ self.assertEqual(
44
+ lm.score('hi', ['hello', 'how are you']),
45
+ [lf.LMScoringResult(0.0), lf.LMScoringResult(-1.0)],
46
+ )
47
+
41
48
 
42
49
  class StaticResponseTest(unittest.TestCase):
43
50
 
@@ -64,6 +64,8 @@ from langfun.core.structured.description import describe
64
64
  from langfun.core.structured.completion import CompleteStructure
65
65
  from langfun.core.structured.completion import complete
66
66
 
67
+ from langfun.core.structured.scoring import score
68
+
67
69
  # Expose default examples for structured operations so users could refer to
68
70
  # them.
69
71
  from langfun.core.structured.parsing import DEFAULT_PARSE_EXAMPLES
@@ -153,6 +153,7 @@ def complete(
153
153
  *,
154
154
  lm: lf.LanguageModel | None = None,
155
155
  examples: list[mapping.MappingExample] | None = None,
156
+ cache_seed: int | None = 0,
156
157
  autofix: int = 0,
157
158
  autofix_lm: lf.LanguageModel | None = None,
158
159
  returns_message: bool = False,
@@ -197,6 +198,9 @@ def complete(
197
198
  `lf.context` context manager will be used.
198
199
  examples: An optional list of fewshot examples for helping parsing. If None,
199
200
  the default one-shot example will be added.
201
+ cache_seed: Seed for computing cache key. The cache key is determined by a
202
+ tuple of (lm, prompt, cache seed). If None, cache will be disabled for
203
+ the query even cache is configured by the LM.
200
204
  autofix: Number of attempts to auto fix the generated code. If 0, autofix is
201
205
  disabled.
202
206
  autofix_lm: The language model to use for autofix. If not specified, the
@@ -218,5 +222,5 @@ def complete(
218
222
  **kwargs,
219
223
  )
220
224
 
221
- output = t(lm=lm, autofix_lm=autofix_lm or lm)
225
+ output = t(lm=lm, cache_seed=cache_seed, autofix_lm=autofix_lm or lm)
222
226
  return output if returns_message else output.result
@@ -44,6 +44,7 @@ def describe(
44
44
  *,
45
45
  lm: lf.LanguageModel | None = None,
46
46
  examples: list[mapping.MappingExample] | None = None,
47
+ cache_seed: int | None = 0,
47
48
  **kwargs,
48
49
  ) -> str:
49
50
  """Describes a structured value using natural language.
@@ -97,6 +98,9 @@ def describe(
97
98
  `lf.context` context manager will be used.
98
99
  examples: An optional list of fewshot examples for helping parsing. If None,
99
100
  the default one-shot example will be added.
101
+ cache_seed: Seed for computing cache key. The cache key is determined by a
102
+ tuple of (lm, prompt, cache seed). If None, cache will be disabled for
103
+ the query even cache is configured by the LM.
100
104
  **kwargs: Keyword arguments passed to the `lf.structured.DescribeStructure`.
101
105
 
102
106
  Returns:
@@ -106,7 +110,7 @@ def describe(
106
110
  examples = DEFAULT_DESCRIBE_EXAMPLES
107
111
  return DescribeStructure(
108
112
  input=value, context=context, examples=examples, **kwargs
109
- )(lm=lm).text
113
+ )(lm=lm, cache_seed=cache_seed).text
110
114
 
111
115
 
112
116
  class _Country(pg.Object):
@@ -79,6 +79,7 @@ def parse(
79
79
  lm: lf.LanguageModel | None = None,
80
80
  examples: list[mapping.MappingExample] | None = None,
81
81
  include_context: bool = False,
82
+ cache_seed: int | None = 0,
82
83
  autofix: int = 0,
83
84
  autofix_lm: lf.LanguageModel | None = None,
84
85
  protocol: schema_lib.SchemaProtocol = 'python',
@@ -134,6 +135,9 @@ def parse(
134
135
  the default one-shot example will be added.
135
136
  include_context: If True, include the request sent to LLM for obtaining the
136
137
  response to pares. Otherwise include only the response.
138
+ cache_seed: Seed for computing cache key. The cache key is determined by a
139
+ tuple of (lm, prompt, cache seed). If None, cache will be disabled for
140
+ the query even cache is configured by the LM.
137
141
  autofix: Number of attempts to auto fix the generated code. If 0, autofix is
138
142
  disabled. Auto-fix is not supported for 'json' protocol.
139
143
  autofix_lm: The language model to use for autofix. If not specified, the
@@ -166,7 +170,7 @@ def parse(
166
170
  )
167
171
 
168
172
  # Setting up context.
169
- call_context = dict(autofix=autofix)
173
+ call_context = dict(cache_seed=cache_seed, autofix=autofix)
170
174
  if lm is not None:
171
175
  call_context['lm'] = lm
172
176
  autofix_lm = autofix_lm or lm
@@ -188,6 +192,7 @@ def call(
188
192
  parsing_lm: lf.LanguageModel | None = None,
189
193
  parsing_examples: list[mapping.MappingExample] | None = None,
190
194
  parsing_include_context: bool = False,
195
+ cache_seed: int | None = 0,
191
196
  autofix: int = 0,
192
197
  autofix_lm: lf.LanguageModel | None = None,
193
198
  response_postprocess: Callable[[str], str] | None = None,
@@ -231,6 +236,9 @@ def call(
231
236
  `lf.structured.DEFAULT_PARSE_EXAMPLES` will be used.
232
237
  parsing_include_context: If True, include the request sent to LLM for
233
238
  obtaining the response to pares. Otherwise include only the response.
239
+ cache_seed: Seed for computing cache key. The cache key is determined by a
240
+ tuple of (lm, prompt, cache seed). If None, cache will be disabled for
241
+ the query even cache is configured by the LM.
234
242
  autofix: Number of attempts to auto fix the generated code. If 0, autofix is
235
243
  disabled. Auto-fix is not supported for 'json' protocol.
236
244
  autofix_lm: The language model to use for autofix. If not specified, the
@@ -253,7 +261,10 @@ def call(
253
261
  lm_output = lf.LangFunc.from_value(prompt, **kwargs)(lm=lm)
254
262
 
255
263
  if response_postprocess is not None:
256
- lm_output.set('text', response_postprocess(lm_output.text))
264
+ postprocessed_text = response_postprocess(lm_output.text)
265
+ if postprocessed_text != lm_output.text:
266
+ processed_lm_output = lf.AIMessage(postprocessed_text, source=lm_output)
267
+ lm_output = processed_lm_output
257
268
 
258
269
  if schema in (str, None):
259
270
  return lm_output if returns_message else lm_output.text
@@ -265,6 +276,7 @@ def call(
265
276
  examples=parsing_examples,
266
277
  lm=parsing_lm or lm,
267
278
  include_context=parsing_include_context,
279
+ cache_seed=cache_seed,
268
280
  autofix=autofix,
269
281
  autofix_lm=autofix_lm or lm,
270
282
  protocol=protocol,
@@ -106,10 +106,12 @@ def query(
106
106
  *,
107
107
  lm: lf.LanguageModel | None = None,
108
108
  examples: list[mapping.MappingExample] | None = None,
109
+ cache_seed: int | None = 0,
109
110
  autofix: int = 0,
110
111
  autofix_lm: lf.LanguageModel | None = None,
111
112
  protocol: schema_lib.SchemaProtocol = 'python',
112
113
  returns_message: bool = False,
114
+ skip_lm: bool = False,
113
115
  **kwargs,
114
116
  ) -> Any:
115
117
  """Parse a natural langugage message based on schema.
@@ -154,6 +156,9 @@ def query(
154
156
  `lf.context` context manager will be used.
155
157
  examples: An optional list of fewshot examples for helping parsing. If None,
156
158
  the default one-shot example will be added.
159
+ cache_seed: Seed for computing cache key. The cache key is determined by a
160
+ tuple of (lm, prompt, cache seed). If None, cache will be disabled for
161
+ the query even cache is configured by the LM.
157
162
  autofix: Number of attempts to auto fix the generated code. If 0, autofix is
158
163
  disabled. Auto-fix is not supported for 'json' protocol.
159
164
  autofix_lm: The language model to use for autofix. If not specified, the
@@ -163,6 +168,8 @@ def query(
163
168
  are 'json' and 'python'. By default `python` will be used.
164
169
  returns_message: If True, returns `lf.Message` as the output, instead of
165
170
  returning the structured `message.result`.
171
+ skip_lm: If True, returns the rendered prompt as a UserMessage object.
172
+ otherwise return the LLM response based on the rendered prompt.
166
173
  **kwargs: Keyword arguments passed to the
167
174
  `lf.structured.NaturalLanguageToStructureed` transform.
168
175
 
@@ -178,7 +185,9 @@ def query(
178
185
 
179
186
  if schema in (None, str):
180
187
  # Query with natural language output.
181
- output = lf.LangFunc.from_value(prompt, **kwargs)(lm=lm)
188
+ output = lf.LangFunc.from_value(prompt, **kwargs)(
189
+ lm=lm, cache_seed=cache_seed, skip_lm=skip_lm
190
+ )
182
191
  return output if returns_message else output.text
183
192
 
184
193
  # Query with structured output.
@@ -202,5 +211,7 @@ def query(
202
211
  )(
203
212
  lm=lm,
204
213
  autofix_lm=autofix_lm or lm,
214
+ cache_seed=cache_seed,
215
+ skip_lm=skip_lm,
205
216
  )
206
217
  return output if returns_message else output.result
@@ -235,6 +235,10 @@ class QueryTest(unittest.TestCase):
235
235
  expected_modalities=2,
236
236
  )
237
237
 
238
+ def test_bad_protocol(self):
239
+ with self.assertRaisesRegex(ValueError, 'Unknown protocol'):
240
+ prompting.query('what is 1 + 1', int, protocol='text')
241
+
238
242
 
239
243
  class QueryStructurePythonTest(unittest.TestCase):
240
244
 
@@ -0,0 +1,75 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Scoring the output objects based on their inputs."""
15
+
16
+ from typing import Any, Type, Union
17
+
18
+ import langfun.core as lf
19
+ from langfun.core.structured import mapping
20
+ from langfun.core.structured import prompting
21
+ from langfun.core.structured import schema as schema_lib
22
+ import pyglove as pg
23
+
24
+
25
+ def score(
26
+ prompt: Union[str, pg.Symbolic],
27
+ completions: list[str | pg.Symbolic],
28
+ schema: Union[
29
+ schema_lib.Schema, Type[Any], list[Type[Any]], dict[str, Any], None
30
+ ] = None,
31
+ *,
32
+ lm: lf.LanguageModel | None = None,
33
+ examples: list[mapping.MappingExample] | None = None,
34
+ protocol: schema_lib.SchemaProtocol = 'python',
35
+ **kwargs,
36
+ ) -> list[float]:
37
+ """Scores the outputs based on the prompt."""
38
+ if not completions:
39
+ raise ValueError('`completions` must not be empty.')
40
+
41
+ if schema is None:
42
+ for c in completions:
43
+ if schema is None:
44
+ schema = type(c)
45
+ elif schema is not type(c):
46
+ raise ValueError(
47
+ '`schema` cannot be inferred from completions of different types: '
48
+ f'{[type(c) for c in completions]}.'
49
+ )
50
+
51
+ input_message = prompting.query(
52
+ prompt,
53
+ schema,
54
+ examples=examples,
55
+ protocol=protocol,
56
+ skip_lm=True,
57
+ returns_message=True,
58
+ **kwargs,
59
+ )
60
+ if lm is None:
61
+ lm_override = lf.get_contextual_override('lm')
62
+ if lm_override is None:
63
+ raise ValueError('`lm` must be specified or provided from `lf.context`.')
64
+ lm = lm_override.value
65
+
66
+ results = lm.score(
67
+ input_message,
68
+ [
69
+ mapping.MappingExample.value_repr(
70
+ c, protocol=protocol, compact=False, verbose=False
71
+ )
72
+ for c in completions
73
+ ],
74
+ )
75
+ return [r.score for r in results]
@@ -0,0 +1,44 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import unittest
16
+ import langfun.core as lf
17
+ from langfun.core.llms import fake
18
+ from langfun.core.structured import scoring
19
+
20
+
21
+ class ScoringTest(unittest.TestCase):
22
+
23
+ def test_bad_call(self):
24
+ with self.assertRaisesRegex(ValueError, '`completions` must not be empty'):
25
+ scoring.score('hi', [])
26
+
27
+ with self.assertRaisesRegex(
28
+ ValueError, '`schema` cannot be inferred from completions'
29
+ ):
30
+ scoring.score('hi', [1, 'b'])
31
+
32
+ with self.assertRaisesRegex(ValueError, '`lm` must be specified'):
33
+ scoring.score('hi', [1, 2])
34
+
35
+ def test_score(self):
36
+ self.assertEqual(scoring.score('hi', [1, 2], lm=fake.Echo()), [0.0, -1.0])
37
+
38
+ def test_scope_with_lm_from_the_context(self):
39
+ with lf.context(lm=fake.Echo()):
40
+ self.assertEqual(scoring.score('hi', [1, 2]), [0.0, -1.0])
41
+
42
+
43
+ if __name__ == '__main__':
44
+ unittest.main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.0.2.dev20240215
3
+ Version: 0.0.2.dev20240217
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -1,15 +1,15 @@
1
- langfun/__init__.py,sha256=2HUBxiByAEu63XqaF89hQfI4sqFG1qGffua-JPy4XIY,1689
2
- langfun/core/__init__.py,sha256=dl7itWvZUEvqDeK2EWd-9lGlZu8cLXCO45HcaZKWAo4,4136
1
+ langfun/__init__.py,sha256=8o5FY1mvt8gErV_AEyBBgQJxEC1cQnfXAvDphMvvS78,1714
2
+ langfun/core/__init__.py,sha256=sVcPl89lWYHQ1cUoaLaM8dErCovugJo5e2F3A_94Q3Y,4192
3
3
  langfun/core/component.py,sha256=VRPfDB_2jEnxcB3-HoiVjG4ID-SMenNPIsytb0uXMPg,9674
4
4
  langfun/core/component_test.py,sha256=VAPd6V_-odAe8rBvesW3ogYDd6OSqRq4FaPhfgOM4Zg,7949
5
5
  langfun/core/concurrent.py,sha256=HQJOseNZ-XZZR5VmC8lHoDNFzlkkCa_-ri7nOKJfV5s,24147
6
6
  langfun/core/concurrent_test.py,sha256=qQT6_Dq5NVz7qXFLzSf2Rhzkfkh07gocjHMBaT1nSeE,14928
7
7
  langfun/core/console.py,sha256=bk5rNPNm9rMGW5YT2HixxU04p2umnoabn5SDz6Dqe88,2317
8
8
  langfun/core/console_test.py,sha256=5SYJdxpJGLgdSSQqqMPoA1X6jpsLD8rgcyk-EgI65oE,1077
9
- langfun/core/langfunc.py,sha256=266xNz8Vgal7K4HSsrYt7z7_qPYV4bWWK626IbbohrE,11573
10
- langfun/core/langfunc_test.py,sha256=ukv5cnad5ZBckM2PhyIFq79BPN0Db4cszMrPqh_CZkA,8163
11
- langfun/core/language_model.py,sha256=JHIfW0GxFx1YVEM-drS_Iy4goFJt63LBosCM4CILWTY,12920
12
- langfun/core/language_model_test.py,sha256=gcW4OJJjB-V1b4kEF8zG91t36sVn3H0Yuj0LQxi83Ek,9122
9
+ langfun/core/langfunc.py,sha256=WXdTc3QsmGD_n80KD9dFRr5MHpGZ9E_y_Rhtk4t9-3w,11852
10
+ langfun/core/langfunc_test.py,sha256=8WeiyNauZPkbAA3HiLjVw5-pRSmiLlz-77lB_fjHGdA,8317
11
+ langfun/core/language_model.py,sha256=Qbm7wxgxW26bCVwtgpp-4aV3BKYAsb4IJrJuzhf3Q6o,15507
12
+ langfun/core/language_model_test.py,sha256=h5MWooOb9HubvOzxaBnH6WuDYBdxTetu7JZSWDzva3M,11368
13
13
  langfun/core/memory.py,sha256=f-asN1F7Vehgdn_fK84v73GrEUOxRtaW934keutTKjk,2416
14
14
  langfun/core/message.py,sha256=QhvV9t5qaryPcruyxxcXi3gm9QDInkSldwTtK6sVJ3c,15734
15
15
  langfun/core/message_test.py,sha256=Z23pUM5vPnDrYkIIibe2KL73D5HKur_awI0ut_EQFQA,9501
@@ -46,9 +46,9 @@ langfun/core/eval/matching.py,sha256=g2yuBb4FeOlAlB10hqdWvaIg4QVQlJbiViRDcD2Y8go
46
46
  langfun/core/eval/matching_test.py,sha256=IfuMF_dEmy4VzK6tIldRzD2Nqlml7SSh4u-baFNcZrw,4912
47
47
  langfun/core/eval/scoring.py,sha256=mshqbV_WM0zcp15TSR32ACMBDymlsbf6YH06PPx1Tw0,6139
48
48
  langfun/core/eval/scoring_test.py,sha256=_L_B40VZkyI2_PJce-jVKYC4llrO4jGUR5j86Gu6AT0,4046
49
- langfun/core/llms/__init__.py,sha256=zTTSz46M52wqJtgxg2lGvTgrTB1wl9xMaQvOxfi00bs,2346
50
- langfun/core/llms/fake.py,sha256=JH790_WDtlohL0leJMqd1F6a1YuM9XV3rgxHBsoILRg,2309
51
- langfun/core/llms/fake_test.py,sha256=nP3420LKGwTJJG1YH3y5XgH6yKmbFmmbonBwvMu-ZYA,3368
49
+ langfun/core/llms/__init__.py,sha256=T4mgT091BLA4mHrOjAvEGhZPHf0tiYgqD88l_JTp1dQ,2386
50
+ langfun/core/llms/fake.py,sha256=dVzOrW27RZ1p3DdQoRCRZs_vfoQcTcNrlWxia7oqmvw,2499
51
+ langfun/core/llms/fake_test.py,sha256=Qk_Yoi4Z7P9o6f8Q_BZkaSlvxH89ZVsDxnVIbSBRBXk,3555
52
52
  langfun/core/llms/gemini.py,sha256=p3d4Cl2uET-os1n_V3YNE6-6cYrZjndj7lxZIk2E8_4,5688
53
53
  langfun/core/llms/gemini_test.py,sha256=ybNNCn3JW3hYpMe0wT5ILGDrMPaYYU8PN2kSookM0jk,5433
54
54
  langfun/core/llms/llama_cpp.py,sha256=EIjJa1-Tg4_VaIxVR88oDWSWc_axc1r2KwSPpl4PSp0,2549
@@ -65,19 +65,21 @@ langfun/core/memories/conversation_history_test.py,sha256=AaW8aNoFjxNusanwJDV0r3
65
65
  langfun/core/modalities/__init__.py,sha256=VI96XGNfXqcJpBh2c17tkTs0gpO5ftc77Ep0jfLOztw,882
66
66
  langfun/core/modalities/image.py,sha256=HU0sV4ZTwRnAwQthmdWZwhFZRD86RyvqoS8JUW2Ia-A,2065
67
67
  langfun/core/modalities/image_test.py,sha256=YxDRvC49Bjwyyndd_P7y6XjyS7dOft0Zewwxk-7q4kE,2301
68
- langfun/core/structured/__init__.py,sha256=tGH0MYr5vzK0H2DpYQ2bcW2C5bpPUaLzMk2W2Fj29M4,3136
69
- langfun/core/structured/completion.py,sha256=XERoxtYPXOTlPdZ2bp4i9R4jl3kA3SOeyLmuSqHG9AM,7036
68
+ langfun/core/structured/__init__.py,sha256=LZ5BFLX6VXy1zH17yChWCdg8bvIDrhtL2lqtSCwtZ-M,3187
69
+ langfun/core/structured/completion.py,sha256=skBxt6V_fv2TBUKnzFgnPMbVY8HSYn8sY04MLok2yvs,7299
70
70
  langfun/core/structured/completion_test.py,sha256=98UCgA4gzfp6H6HgP2s2kcKs25YH3k4Nxj1rgAvmVBw,19249
71
- langfun/core/structured/description.py,sha256=vDiW1g2VbvG8ucNjV7Pp3VYCeAnLcp6vLQ0MfURcZFk,4825
71
+ langfun/core/structured/description.py,sha256=3MLTbpTpeiBqMRe3WfDNIxtrE6WQsKJsJdkbfcyPlsg,5088
72
72
  langfun/core/structured/description_test.py,sha256=UtZGjSFUaQ6130t1E5tcL7ODu0xIefkapb53TbnqsK8,7362
73
73
  langfun/core/structured/mapping.py,sha256=lGkjhmvVdhBGgJmc5KbfT2xQjC1MuU4OCcCfsAYJjaQ,10192
74
74
  langfun/core/structured/mapping_test.py,sha256=07DDCGbwytQHSMm7fCi5-Ly-JNgdV4ubHZq0wthX4A4,3338
75
- langfun/core/structured/parsing.py,sha256=XWo1UdG1A_c0v4OgQ1C_6nK0264_UAVrmJfFz4jHbRE,10690
75
+ langfun/core/structured/parsing.py,sha256=YKWl9ZQ2uFkt78SXiRISWHg8_cDMGMwAN3SeK-OqWt4,11382
76
76
  langfun/core/structured/parsing_test.py,sha256=2_Uf3LYNRON1-5ysEr75xiG_cAxR3ZiixSfvUQu6mOQ,20846
77
- langfun/core/structured/prompting.py,sha256=P8in3qHXCuwjfzLpplS5woQSHV5aheXgm2mFiqVQD4g,6384
78
- langfun/core/structured/prompting_test.py,sha256=5lPsxUzyHEjOh0D5V5GEYjFFJZvUrebLV1aCCJS4H3Y,18971
77
+ langfun/core/structured/prompting.py,sha256=0xRPC0K_RaFRv-j52x8_-1n1eRFSomJEpdZApVXsCV0,6902
78
+ langfun/core/structured/prompting_test.py,sha256=SwoYbPyKhUT1H2QbqHvl93biCiE9Ttn1aWixoHH-v9Y,19129
79
79
  langfun/core/structured/schema.py,sha256=5DKba0LrvXCJFRY-NVfER3p54BLOB7M3Yi2-u5IAJTw,24115
80
80
  langfun/core/structured/schema_test.py,sha256=LEtCST5Bfwoke59I6Q1mnOJLf2cFXQwKwTeAkI2hgqM,20912
81
+ langfun/core/structured/scoring.py,sha256=a3vfGnqf-DOWjD07MF54GCZTO_R1RTxTDVPzerXnU0s,2325
82
+ langfun/core/structured/scoring_test.py,sha256=TznLMl0x9QxzmhHz_3Vr44VOXuvFnUSeRQVhu33W5cA,1437
81
83
  langfun/core/templates/__init__.py,sha256=bO0eMsVJbi7sxEB2YlInKRQ2EVP-RyyKUwcD-8msuN4,927
82
84
  langfun/core/templates/completion.py,sha256=mUqZHOEV3ag6-A08XghpeEltcrBvCDxXP004eDDfeag,1931
83
85
  langfun/core/templates/completion_test.py,sha256=vGnjnM38UHyVDUyaUYtmp20s9KBGOdbPVsX-H-ET11E,1636
@@ -87,8 +89,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
87
89
  langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
88
90
  langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
89
91
  langfun/core/templates/selfplay_test.py,sha256=IB5rWbjK_9CTkqEo1BclQPzFAKcIiusJckH8J19HFgI,2096
90
- langfun-0.0.2.dev20240215.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
91
- langfun-0.0.2.dev20240215.dist-info/METADATA,sha256=oo4UvDeNdxk0glqHLdwP2tjejbXpEBnPfquoi3kzuOg,3368
92
- langfun-0.0.2.dev20240215.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
93
- langfun-0.0.2.dev20240215.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
94
- langfun-0.0.2.dev20240215.dist-info/RECORD,,
92
+ langfun-0.0.2.dev20240217.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
93
+ langfun-0.0.2.dev20240217.dist-info/METADATA,sha256=2D8wQbXh09hRN_TnmaxS2JeVLKKKqkbM9JDjqiBR7yg,3368
94
+ langfun-0.0.2.dev20240217.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
95
+ langfun-0.0.2.dev20240217.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
96
+ langfun-0.0.2.dev20240217.dist-info/RECORD,,