langfun 0.0.2.dev20240202__py3-none-any.whl → 0.0.2.dev20240203__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -101,7 +101,7 @@ class EvaluationTest(unittest.TestCase):
101
101
  self.assertEqual(s.dir, os.path.join(s.root_dir, s.id))
102
102
  self.assertEqual(s.hash, s.clone().hash)
103
103
  # Test persistent hash.
104
- self.assertEqual(s.hash, 'c76d4fe6')
104
+ self.assertEqual(s.hash, 'abc7c29a')
105
105
  self.assertEqual(
106
106
  s.hash, s.clone(override={'max_workers': 2, 'lm.timeout': 20}).hash
107
107
  )
@@ -195,6 +195,7 @@ class EvaluationTest(unittest.TestCase):
195
195
  result=Solution(2),
196
196
  cache_seed=0,
197
197
  score=1.0,
198
+ logprobs=None,
198
199
  tags=['lm-response', 'lm-output', 'transformed'],
199
200
  ),
200
201
  )
@@ -323,7 +324,7 @@ class EvaluationTest(unittest.TestCase):
323
324
  s.children[0].dir, os.path.join(s.root_dir, s.children[0].id)
324
325
  )
325
326
  # Test persistent hash.
326
- self.assertEqual(s.hash, 'e987475a')
327
+ self.assertEqual(s.hash, 'ca7f722b')
327
328
 
328
329
  summary = s.run(verbose=True)
329
330
  self.assertEqual(len(summary.evaluations), 2)
@@ -451,7 +452,7 @@ class SuiteTest(unittest.TestCase):
451
452
  ],
452
453
  )
453
454
  # Test for persistent hash.
454
- self.assertEqual(s.hash, 'bb86a963')
455
+ self.assertEqual(s.hash, '7285e52b')
455
456
  s.run()
456
457
  expected = {
457
458
  s.children[0].id: dict(
@@ -82,7 +82,7 @@ class LangFuncCallTest(unittest.TestCase):
82
82
  self.assertEqual(i.tags, ['rendered'])
83
83
 
84
84
  r = l()
85
- self.assertEqual(r, message.AIMessage('Hello!!!', score=0.0))
85
+ self.assertEqual(r, message.AIMessage('Hello!!!', score=0.0, logprobs=None))
86
86
  self.assertEqual(r.tags, ['lm-response', 'lm-output'])
87
87
  self.assertEqual(r.source, message.UserMessage('Hello'))
88
88
  self.assertEqual(r.source.tags, ['rendered', 'lm-input'])
@@ -94,8 +94,9 @@ class LangFuncCallTest(unittest.TestCase):
94
94
  "LangFunc(template_str='Hello', clean=True,"
95
95
  ' lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=0.0,'
96
96
  ' max_tokens=1024, n=1, top_k=40, top_p=None, stop=None,'
97
- ' random_seed=None), cache=None, timeout=120.0, max_attempts=5,'
98
- ' retry_interval=(5, 60), exponential_backoff=True, debug=False))',
97
+ ' random_seed=None, logprobs=False, top_logprobs=None), cache=None,'
98
+ ' timeout=120.0, max_attempts=5, retry_interval=(5, 60),'
99
+ ' exponential_backoff=True, debug=False))',
99
100
  )
100
101
 
101
102
  l = LangFunc('Hello')
@@ -104,7 +105,9 @@ class LangFuncCallTest(unittest.TestCase):
104
105
  self.assertEqual(l.natural_language_format(), 'Hello')
105
106
  self.assertEqual(l.render(), 'Hello')
106
107
  r = l()
107
- self.assertEqual(r, message.AIMessage('Hello!!!', score=0.0))
108
+ self.assertEqual(
109
+ r, message.AIMessage('Hello!!!', score=0.0, logprobs=None)
110
+ )
108
111
  self.assertEqual(r.tags, ['lm-response', 'lm-output'])
109
112
 
110
113
  self.assertEqual(str(l), 'Hello')
@@ -31,15 +31,20 @@ class LMSample(pg.Object):
31
31
  pg.typing.Object(
32
32
  message_lib.Message,
33
33
  # Allowing automatic conversion from text to AIMessage.
34
- transform=message_lib.AIMessage.from_value
34
+ transform=message_lib.AIMessage.from_value,
35
35
  ),
36
- 'The natural language response of LM.'
36
+ 'The natural language response of LM.',
37
37
  ]
38
38
 
39
39
  score: Annotated[
40
40
  float, 'The score of sampled response. The larger is better'
41
41
  ] = 0.0
42
42
 
43
+ logprobs: Annotated[
44
+ list[tuple[str, float, list[tuple[str, float]]]] | None,
45
+ '(token, log prob, top tokens and their probs).',
46
+ ] = None
47
+
43
48
 
44
49
  class LMSamplingResult(pg.Object):
45
50
  """Language model response."""
@@ -92,6 +97,23 @@ class LMSamplingOptions(component.Component):
92
97
  random_seed: Annotated[
93
98
  int | None, 'A fixed random seed used during model inference.'
94
99
  ] = None
100
+ logprobs: Annotated[
101
+ bool,
102
+ (
103
+ 'Whether to return log probabilities of the output tokens or not. If '
104
+ 'true, returns the log probabilities of each output token returned '
105
+ 'in the content of message.'
106
+ ),
107
+ ] = False
108
+ top_logprobs: Annotated[
109
+ int | None,
110
+ (
111
+ 'An integer between 0 and 5 specifying the number of most likely '
112
+ 'tokens to return at each token position, each with an associated '
113
+ 'log probability. logprobs must be set to true if this parameter is '
114
+ 'used.'
115
+ ),
116
+ ] = None
95
117
 
96
118
  def cache_key(self) -> tuple[Any, ...]:
97
119
  """Returns a tuple of current values as cache key."""
@@ -339,7 +361,9 @@ class LanguageModel(component.Component):
339
361
  [prompt], sampling_options=sampling_options, cache_seed=cache_seed
340
362
  )[0]
341
363
  response = result.samples[0].response
364
+ logprobs = result.samples[0].logprobs
342
365
  response.set('score', result.samples[0].score)
366
+ response.metadata.logprobs = logprobs
343
367
  elapse = time.time() - request_start
344
368
  self._debug(prompt, response, call_counter, elapse)
345
369
  return response
@@ -167,6 +167,8 @@ class OpenAI(lf.LanguageModel):
167
167
  max_tokens=options.max_tokens,
168
168
  stream=False,
169
169
  timeout=self.timeout,
170
+ logprobs=options.logprobs,
171
+ top_logprobs=options.top_logprobs,
170
172
  )
171
173
  # Completion and ChatCompletion uses different parameter name for model.
172
174
  args['model' if self.is_chat_model else 'engine'] = self.model
@@ -249,11 +251,28 @@ class OpenAI(lf.LanguageModel):
249
251
  **self._get_request_args(self.sampling_options),
250
252
  )
251
253
  response = cast(openai_object.OpenAIObject, response)
254
+ samples = []
255
+ for choice in response.choices:
256
+ logprobs = None
257
+ if choice.logprobs:
258
+ logprobs = [
259
+ (
260
+ t.token,
261
+ t.logprob,
262
+ [(tt.token, tt.logprob) for tt in t.top_logprobs],
263
+ )
264
+ for t in choice.logprobs.content
265
+ ]
266
+ samples.append(
267
+ lf.LMSample(
268
+ choice.message.content,
269
+ score=0.0,
270
+ logprobs=logprobs,
271
+ )
272
+ )
273
+
252
274
  return LMSamplingResult(
253
- [
254
- lf.LMSample(choice.message.content, score=0.0)
255
- for choice in response.choices
256
- ],
275
+ samples=samples,
257
276
  usage=Usage(
258
277
  prompt_tokens=response.usage.prompt_tokens,
259
278
  completion_tokens=response.usage.completion_tokens,
@@ -46,7 +46,8 @@ def mock_chat_completion_query(messages, *, n=1, **kwargs):
46
46
  choices.append(pg.Dict(
47
47
  message=pg.Dict(
48
48
  content=f'Sample {k} for message.'
49
- )
49
+ ),
50
+ logprobs=None,
50
51
  ))
51
52
  return pg.Dict(choices=choices, usage=openai.Usage(
52
53
  prompt_tokens=100,
@@ -65,7 +66,8 @@ def mock_chat_completion_query_vision(messages, *, n=1, **kwargs):
65
66
  choices.append(pg.Dict(
66
67
  message=pg.Dict(
67
68
  content=f'Sample {k} for message: {"".join(urls)}'
68
- )
69
+ ),
70
+ logprobs=None,
69
71
  ))
70
72
  return pg.Dict(choices=choices, usage=openai.Usage(
71
73
  prompt_tokens=100,
@@ -99,6 +101,8 @@ class OpenaiTest(unittest.TestCase):
99
101
  top_p=1.0)),
100
102
  dict(
101
103
  engine='text-davinci-003',
104
+ logprobs=False,
105
+ top_logprobs=None,
102
106
  n=2,
103
107
  temperature=2.0,
104
108
  max_tokens=4096,
@@ -113,6 +117,8 @@ class OpenaiTest(unittest.TestCase):
113
117
  ),
114
118
  dict(
115
119
  model='gpt-4',
120
+ logprobs=False,
121
+ top_logprobs=None,
116
122
  n=1,
117
123
  temperature=1.0,
118
124
  max_tokens=1024,
@@ -582,6 +582,7 @@ class CompleteStructureTest(unittest.TestCase):
582
582
  text='Activity(description="foo")',
583
583
  result=Activity(description='foo'),
584
584
  score=1.0,
585
+ logprobs=None,
585
586
  tags=['lm-response', 'lm-output', 'transformed']
586
587
  )
587
588
  )
@@ -286,7 +286,7 @@ class ParseStructurePythonTest(unittest.TestCase):
286
286
  returns_message=True
287
287
  ),
288
288
  lf.AIMessage(
289
- '1', score=1.0, result=1,
289
+ '1', score=1.0, result=1, logprobs=None,
290
290
  tags=['lm-response', 'lm-output', 'transformed']
291
291
  ),
292
292
  )
@@ -640,7 +640,7 @@ class CallTest(unittest.TestCase):
640
640
  returns_message=True
641
641
  ),
642
642
  lf.AIMessage(
643
- '3', result=3, score=1.0,
643
+ '3', result=3, score=1.0, logprobs=None,
644
644
  tags=['lm-response', 'lm-output', 'transformed']
645
645
  ),
646
646
  )
@@ -76,6 +76,7 @@ class QueryTest(unittest.TestCase):
76
76
  '1',
77
77
  result=1,
78
78
  score=1.0,
79
+ logprobs=None,
79
80
  tags=['lm-response', 'lm-output', 'transformed'],
80
81
  ),
81
82
  )
@@ -56,7 +56,7 @@ class SelfPlayTest(unittest.TestCase):
56
56
  g = NumberGuess(target_num=10)
57
57
 
58
58
  with lf.context(lm=NumberGuesser(guesses=[50, 20, 5, 10])):
59
- self.assertEqual(g(), lf.AIMessage('10', score=0.0))
59
+ self.assertEqual(g(), lf.AIMessage('10', score=0.0, logprobs=None))
60
60
 
61
61
  self.assertEqual(g.num_turns, 4)
62
62
 
@@ -64,7 +64,7 @@ class SelfPlayTest(unittest.TestCase):
64
64
  g = NumberGuess(target_num=10, max_turns=10)
65
65
 
66
66
  with lf.context(lm=NumberGuesser(guesses=[50, 20, 5, 2, 5, 4])):
67
- self.assertEqual(g(), lf.AIMessage('2', score=0.0))
67
+ self.assertEqual(g(), lf.AIMessage('2', score=0.0, logprobs=None))
68
68
 
69
69
  self.assertEqual(g.num_turns, 10)
70
70
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.0.2.dev20240202
3
+ Version: 0.0.2.dev20240203
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -7,8 +7,8 @@ langfun/core/concurrent_test.py,sha256=qQT6_Dq5NVz7qXFLzSf2Rhzkfkh07gocjHMBaT1nS
7
7
  langfun/core/console.py,sha256=bk5rNPNm9rMGW5YT2HixxU04p2umnoabn5SDz6Dqe88,2317
8
8
  langfun/core/console_test.py,sha256=5SYJdxpJGLgdSSQqqMPoA1X6jpsLD8rgcyk-EgI65oE,1077
9
9
  langfun/core/langfunc.py,sha256=266xNz8Vgal7K4HSsrYt7z7_qPYV4bWWK626IbbohrE,11573
10
- langfun/core/langfunc_test.py,sha256=rnU-7JQcY-fIXB5aoKbouPd2LkjdzIjB5g_aLPBHQCE,8069
11
- langfun/core/language_model.py,sha256=UppuqiwmZ6AfWCNkfom12XqEu1uf8qBB3yp-xezzQ0s,12056
10
+ langfun/core/langfunc_test.py,sha256=ukv5cnad5ZBckM2PhyIFq79BPN0Db4cszMrPqh_CZkA,8163
11
+ langfun/core/language_model.py,sha256=HG2XQhisyiyvMBws5dASH1OkbGJB3twn0p2e_As8hUo,12889
12
12
  langfun/core/language_model_test.py,sha256=gcW4OJJjB-V1b4kEF8zG91t36sVn3H0Yuj0LQxi83Ek,9122
13
13
  langfun/core/memory.py,sha256=f-asN1F7Vehgdn_fK84v73GrEUOxRtaW934keutTKjk,2416
14
14
  langfun/core/message.py,sha256=QhvV9t5qaryPcruyxxcXi3gm9QDInkSldwTtK6sVJ3c,15734
@@ -41,7 +41,7 @@ langfun/core/coding/python/permissions.py,sha256=1QWGHvzL8MM0Ok_auQ9tURqZHtdOfJa
41
41
  langfun/core/coding/python/permissions_test.py,sha256=w5EDb8QxpxgJyZkojyzVWQvDfg366zn99-g__6TbPQ0,2699
42
42
  langfun/core/eval/__init__.py,sha256=iDA2OcJ3kR6ixZizXIY3N9LsjkaVrfTbSClTiSP8ekY,1291
43
43
  langfun/core/eval/base.py,sha256=wWFDDrf0jBzs9H_5XfdZSeOBGXyUtXAJJouk7cLckSM,52602
44
- langfun/core/eval/base_test.py,sha256=bGs3VLchkAJFWYJ8FdR7mC6qoDestAvCHOQpClG6Mzw,21248
44
+ langfun/core/eval/base_test.py,sha256=uXu6EtTagolDIcSadnVyMmlrz6ixx943jkZhquCRQPI,21275
45
45
  langfun/core/eval/matching.py,sha256=g2yuBb4FeOlAlB10hqdWvaIg4QVQlJbiViRDcD2Y8go,9567
46
46
  langfun/core/eval/matching_test.py,sha256=IfuMF_dEmy4VzK6tIldRzD2Nqlml7SSh4u-baFNcZrw,4912
47
47
  langfun/core/eval/scoring.py,sha256=mshqbV_WM0zcp15TSR32ACMBDymlsbf6YH06PPx1Tw0,6139
@@ -53,8 +53,8 @@ langfun/core/llms/gemini.py,sha256=p3d4Cl2uET-os1n_V3YNE6-6cYrZjndj7lxZIk2E8_4,5
53
53
  langfun/core/llms/gemini_test.py,sha256=ybNNCn3JW3hYpMe0wT5ILGDrMPaYYU8PN2kSookM0jk,5433
54
54
  langfun/core/llms/llama_cpp.py,sha256=EIjJa1-Tg4_VaIxVR88oDWSWc_axc1r2KwSPpl4PSp0,2549
55
55
  langfun/core/llms/llama_cpp_test.py,sha256=ZxC6defGd_HX9SFRU9U4cJiQnBKundbOrchbXuC1Z2M,1683
56
- langfun/core/llms/openai.py,sha256=ao2sDDoh5ma1GWpLpNPZARIeLZK55gL1Ldc94h1EGtE,11119
57
- langfun/core/llms/openai_test.py,sha256=JWcMveifVVVEFWdtmNq1irc9wSFQRxXs-SnOF3Urg9Y,7433
56
+ langfun/core/llms/openai.py,sha256=ufhRf-fV-AFq_-pn5sKrcr_xIU7V0VnhjK3JfBG8sF8,11617
57
+ langfun/core/llms/openai_test.py,sha256=yfw7A-4Zo9u1cIkAMk39evE-tO7z6isNYTXiSnJXDQw,7599
58
58
  langfun/core/llms/cache/__init__.py,sha256=QAo3InUMDM_YpteNnVCSejI4zOsnjSMWKJKzkb3VY64,993
59
59
  langfun/core/llms/cache/base.py,sha256=cFfYvOIUae842pncqCAsRvqXCk2AnAsRYVx0mcIoAeY,3338
60
60
  langfun/core/llms/cache/in_memory.py,sha256=YfFyJEhLs73cUiB0ZfhMxYpdE8Iuxxw-dvMFwGHTSHw,4742
@@ -67,15 +67,15 @@ langfun/core/modalities/image.py,sha256=HU0sV4ZTwRnAwQthmdWZwhFZRD86RyvqoS8JUW2I
67
67
  langfun/core/modalities/image_test.py,sha256=YxDRvC49Bjwyyndd_P7y6XjyS7dOft0Zewwxk-7q4kE,2301
68
68
  langfun/core/structured/__init__.py,sha256=tGH0MYr5vzK0H2DpYQ2bcW2C5bpPUaLzMk2W2Fj29M4,3136
69
69
  langfun/core/structured/completion.py,sha256=XERoxtYPXOTlPdZ2bp4i9R4jl3kA3SOeyLmuSqHG9AM,7036
70
- langfun/core/structured/completion_test.py,sha256=i68c_eK-QJIAW964E2o4W8kNj2ZokEQOPesG5Hw78-E,19222
70
+ langfun/core/structured/completion_test.py,sha256=98UCgA4gzfp6H6HgP2s2kcKs25YH3k4Nxj1rgAvmVBw,19249
71
71
  langfun/core/structured/description.py,sha256=vDiW1g2VbvG8ucNjV7Pp3VYCeAnLcp6vLQ0MfURcZFk,4825
72
72
  langfun/core/structured/description_test.py,sha256=UtZGjSFUaQ6130t1E5tcL7ODu0xIefkapb53TbnqsK8,7362
73
73
  langfun/core/structured/mapping.py,sha256=lGkjhmvVdhBGgJmc5KbfT2xQjC1MuU4OCcCfsAYJjaQ,10192
74
74
  langfun/core/structured/mapping_test.py,sha256=07DDCGbwytQHSMm7fCi5-Ly-JNgdV4ubHZq0wthX4A4,3338
75
75
  langfun/core/structured/parsing.py,sha256=V3i8-AuBI4zdpsi_f6H-mbQ3h0HPgCle2OnDQolWJnA,10244
76
- langfun/core/structured/parsing_test.py,sha256=EC9i4hQkcmoneBX7p1fRIpgznB8lgGHFwtKvHlNrwyA,20816
76
+ langfun/core/structured/parsing_test.py,sha256=2_Uf3LYNRON1-5ysEr75xiG_cAxR3ZiixSfvUQu6mOQ,20846
77
77
  langfun/core/structured/prompting.py,sha256=OmI21qQQikoQLAmr5W8nBJ8PWCu9w0lmsCnmbQg9hMc,6632
78
- langfun/core/structured/prompting_test.py,sha256=LOOka3CaID03SXPOSNHVFrnk7Ymkt6GvSDS9pNa-y3M,19116
78
+ langfun/core/structured/prompting_test.py,sha256=8lakKCidYuRlf-U1KexTOqQdKrBXUt8fb2J7epCdt84,19143
79
79
  langfun/core/structured/schema.py,sha256=5DKba0LrvXCJFRY-NVfER3p54BLOB7M3Yi2-u5IAJTw,24115
80
80
  langfun/core/structured/schema_test.py,sha256=LEtCST5Bfwoke59I6Q1mnOJLf2cFXQwKwTeAkI2hgqM,20912
81
81
  langfun/core/templates/__init__.py,sha256=bO0eMsVJbi7sxEB2YlInKRQ2EVP-RyyKUwcD-8msuN4,927
@@ -86,9 +86,9 @@ langfun/core/templates/conversation_test.py,sha256=RryYyIhfc34dLWOs6GfPQ8HU8mXpK
86
86
  langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fikKhwhzwhpKI,1460
87
87
  langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
88
88
  langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
89
- langfun/core/templates/selfplay_test.py,sha256=ZkDfwiW9OtO_MOIdVTRPn6P6vOExQIszqlVQHg5iD3U,2066
90
- langfun-0.0.2.dev20240202.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
91
- langfun-0.0.2.dev20240202.dist-info/METADATA,sha256=YNH_28StpsuvajP3UyX4a4-RjS3ErqeooJE_l55L1cQ,3368
92
- langfun-0.0.2.dev20240202.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
93
- langfun-0.0.2.dev20240202.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
94
- langfun-0.0.2.dev20240202.dist-info/RECORD,,
89
+ langfun/core/templates/selfplay_test.py,sha256=IB5rWbjK_9CTkqEo1BclQPzFAKcIiusJckH8J19HFgI,2096
90
+ langfun-0.0.2.dev20240203.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
91
+ langfun-0.0.2.dev20240203.dist-info/METADATA,sha256=9wGw6eai5F4auzoA-J2DHC-OXjqeRessnRbVl2voxNY,3368
92
+ langfun-0.0.2.dev20240203.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
93
+ langfun-0.0.2.dev20240203.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
94
+ langfun-0.0.2.dev20240203.dist-info/RECORD,,