langfun 0.1.2.dev202412090805__py3-none-any.whl → 0.1.2.dev202412130804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,6 +18,8 @@ import unittest
18
18
  import langfun.core as lf
19
19
  from langfun.core.agentic import action as action_lib
20
20
  from langfun.core.llms import fake
21
+ import langfun.core.structured as lf_structured
22
+ import pyglove as pg
21
23
 
22
24
 
23
25
  class SessionTest(unittest.TestCase):
@@ -28,60 +30,94 @@ class SessionTest(unittest.TestCase):
28
30
  class Bar(action_lib.Action):
29
31
 
30
32
  def call(self, session, *, lm, **kwargs):
31
- test.assertIs(session.current_invocation.action, self)
33
+ test.assertIs(session.current_action.action, self)
32
34
  session.info('Begin Bar')
33
35
  session.query('bar', lm=lm)
36
+ session.add_metadata(note='bar')
34
37
  return 2
35
38
 
36
39
  class Foo(action_lib.Action):
37
40
  x: int
38
41
 
39
42
  def call(self, session, *, lm, **kwargs):
40
- test.assertIs(session.current_invocation.action, self)
41
- session.info('Begin Foo', x=1)
42
- session.query('foo', lm=lm)
43
+ test.assertIs(session.current_action.action, self)
44
+ with session.phase('prepare'):
45
+ session.info('Begin Foo', x=1)
46
+ session.query('foo', lm=lm)
47
+ with session.track_queries():
48
+ self.make_additional_query(lm)
49
+ session.add_metadata(note='foo')
43
50
  return self.x + Bar()(session, lm=lm)
44
51
 
52
+ def make_additional_query(self, lm):
53
+ lf_structured.query('additional query', lm=lm)
54
+
45
55
  lm = fake.StaticResponse('lm response')
46
- session = action_lib.Session()
47
- root = session.root_invocation
48
- self.assertIsInstance(root.action, action_lib.RootAction)
49
- self.assertIs(session.current_invocation, session.root_invocation)
50
- self.assertEqual(Foo(1)(session, lm=lm), 3)
51
- self.assertEqual(len(session.root_invocation.child_invocations), 1)
52
- self.assertEqual(len(list(session.root_invocation.queries())), 0)
53
- self.assertEqual(
54
- len(list(session.root_invocation.queries(include_children=True))), 2
55
- )
56
- self.assertEqual(
57
- len(list(session.root_invocation.child_invocations[0].queries())), 1
58
- )
59
- self.assertEqual(len(session.root_invocation.child_invocations[0].logs), 1)
60
- self.assertEqual(
61
- len(session.root_invocation.child_invocations[0].child_invocations),
62
- 1
63
- )
64
- self.assertEqual(
65
- len(session.root_invocation
66
- .child_invocations[0].child_invocations[0].logs),
67
- 1
68
- )
69
- self.assertEqual(
70
- len(list(session.root_invocation
71
- .child_invocations[0].child_invocations[0].queries())),
72
- 1
73
- )
74
- self.assertEqual(
75
- len(session.root_invocation
76
- .child_invocations[0].child_invocations[0].child_invocations),
77
- 0
78
- )
79
- self.assertIs(session.current_invocation, session.root_invocation)
80
- self.assertIs(session.final_result, 3)
81
- self.assertIn(
82
- 'invocation-final-result',
83
- session.to_html().content,
56
+ foo = Foo(1)
57
+ self.assertEqual(foo(lm=lm), 3)
58
+
59
+ session = foo.session
60
+ self.assertIsNotNone(session)
61
+ self.assertIsInstance(session.root.action, action_lib.RootAction)
62
+ self.assertIs(session.current_action, session.root)
63
+
64
+ #
65
+ # Inspecting the root invocation.
66
+ #
67
+
68
+ root = session.root
69
+ self.assertEqual(len(root.execution.items), 1)
70
+ self.assertIs(root.execution.items[0].action, foo)
71
+
72
+ self.assertTrue(root.execution.has_started)
73
+ self.assertTrue(root.execution.has_stopped)
74
+ self.assertGreater(root.execution.elapse, 0)
75
+ self.assertEqual(root.result, 3)
76
+ self.assertEqual(root.metadata, dict(note='foo'))
77
+
78
+ # The root space should have one action (foo), no queries, and no logs.
79
+ self.assertEqual(len(list(root.actions)), 1)
80
+ self.assertEqual(len(list(root.queries)), 0)
81
+ self.assertEqual(len(list(root.logs)), 0)
82
+ # 1 query from Bar and 2 from Foo.
83
+ self.assertEqual(len(list(root.all_queries)), 3)
84
+ # 1 log from Bar and 1 from Foo.
85
+ self.assertEqual(len(list(root.all_logs)), 2)
86
+ self.assertEqual(root.usage_summary.total.num_requests, 3)
87
+
88
+ # Inspecting the top-level action (Foo)
89
+ foo_invocation = root.execution.items[0]
90
+ self.assertEqual(len(foo_invocation.execution.items), 3)
91
+
92
+ # Prepare phase.
93
+ prepare_phase = foo_invocation.execution.items[0]
94
+ self.assertIsInstance(
95
+ prepare_phase, action_lib.ExecutionTrace
84
96
  )
97
+ self.assertEqual(len(prepare_phase.items), 2)
98
+ self.assertTrue(prepare_phase.has_started)
99
+ self.assertTrue(prepare_phase.has_stopped)
100
+ self.assertEqual(prepare_phase.usage_summary.total.num_requests, 1)
101
+
102
+ # Tracked queries.
103
+ query_invocation = foo_invocation.execution.items[1]
104
+ self.assertIsInstance(query_invocation, lf_structured.QueryInvocation)
105
+ self.assertIs(query_invocation.lm, lm)
106
+
107
+ # Invocation to Bar.
108
+ bar_invocation = foo_invocation.execution.items[2]
109
+ self.assertIsInstance(bar_invocation, action_lib.ActionInvocation)
110
+ self.assertIsInstance(bar_invocation.action, Bar)
111
+ self.assertEqual(bar_invocation.result, 2)
112
+ self.assertEqual(bar_invocation.metadata, dict(note='bar'))
113
+ self.assertEqual(len(bar_invocation.execution.items), 2)
114
+
115
+ # Save to HTML
116
+ self.assertIn('result', session.to_html().content)
117
+
118
+ # Save session to JSON
119
+ json_str = session.to_json_str(save_ref_value=True)
120
+ self.assertIsInstance(pg.from_json_str(json_str), action_lib.Session)
85
121
 
86
122
  def test_log(self):
87
123
  session = action_lib.Session()
@@ -556,6 +556,9 @@ class Evaluation(experiment_lib.Experiment):
556
556
  border: 0px;
557
557
  margin: 0px;
558
558
  }
559
+ .eval-details .tab-control {
560
+ width: 100%;
561
+ }
559
562
  .eval-details .tab-button {
560
563
  font-size: large;
561
564
  }
@@ -465,6 +465,33 @@ class Experiment(lf.Component, pg.views.HtmlTreeView.Extension):
465
465
  runner.run()
466
466
  return runner.current_run
467
467
 
468
+ def run_preconfigured(
469
+ self,
470
+ root_dir: str | None = None,
471
+ id: str | None = None, # pylint: disable=redefined-builtin
472
+ **kwargs
473
+ ) -> 'Run':
474
+ """Runs the experiment with pre-configured kwargs from `cls.RUN_ARGS`.
475
+
476
+ This helper method allows users to config running arguments as a part of
477
+ the class.
478
+
479
+ Args:
480
+ root_dir: root directory of the experiment.
481
+ id: ID of the current run.
482
+ **kwargs: Keyword arguments to override the RUN_CONFIG.
483
+
484
+ Returns:
485
+ The current run.
486
+ """
487
+ run_config = getattr(self, 'RUN_ARGS', {})
488
+ run_config.update(kwargs)
489
+ if root_dir is not None:
490
+ run_config['root_dir'] = root_dir
491
+ if id is not None:
492
+ run_config['id'] = id
493
+ return self.run(**run_config)
494
+
468
495
  #
469
496
  # HTML views.
470
497
  #
@@ -39,6 +39,10 @@ def sample_inputs():
39
39
 
40
40
  class MyEvaluation(Evaluation):
41
41
  NAME = 'my_eval'
42
+ RUN_ARGS = dict(
43
+ runner='test'
44
+ )
45
+
42
46
  replica_id: int = 0
43
47
  inputs = sample_inputs()
44
48
  metrics = [metrics_lib.Match()]
@@ -288,10 +292,17 @@ class RunnerTest(unittest.TestCase):
288
292
  TestRunner
289
293
  )
290
294
  root_dir = os.path.join(tempfile.gettempdir(), 'my_eval')
295
+
296
+ # Test standard run.
291
297
  MyEvaluation(replica_id=0).run(
292
298
  root_dir, id='20241101_0', runner='test'
293
299
  )
294
300
 
301
+ # Test run preconfigured.
302
+ MyEvaluation(replica_id=0).run_preconfigured(
303
+ root_dir=root_dir, id='20241101_1'
304
+ )
305
+
295
306
  with self.assertRaisesRegex(
296
307
  ValueError, 'Runner class must define a NAME constant'
297
308
  ):
@@ -32,10 +32,12 @@ from langfun.core.llms.rest import REST
32
32
 
33
33
  # Gemini models.
34
34
  from langfun.core.llms.google_genai import GenAI
35
+ from langfun.core.llms.google_genai import GeminiFlash2_0Exp
35
36
  from langfun.core.llms.google_genai import GeminiExp_20241114
37
+ from langfun.core.llms.google_genai import GeminiExp_20241206
36
38
  from langfun.core.llms.google_genai import GeminiFlash1_5
37
- from langfun.core.llms.google_genai import GeminiPro
38
39
  from langfun.core.llms.google_genai import GeminiPro1_5
40
+ from langfun.core.llms.google_genai import GeminiPro
39
41
  from langfun.core.llms.google_genai import GeminiProVision
40
42
  from langfun.core.llms.google_genai import Palm2
41
43
  from langfun.core.llms.google_genai import Palm2_IT
@@ -120,6 +122,8 @@ from langfun.core.llms.groq import GroqWhisper_Large_v3
120
122
  from langfun.core.llms.groq import GroqWhisper_Large_v3Turbo
121
123
 
122
124
  from langfun.core.llms.vertexai import VertexAI
125
+ from langfun.core.llms.vertexai import VertexAIGemini2_0
126
+ from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0Exp
123
127
  from langfun.core.llms.vertexai import VertexAIGemini1_5
124
128
  from langfun.core.llms.vertexai import VertexAIGeminiPro1_5
125
129
  from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_001
@@ -35,7 +35,7 @@ except ImportError:
35
35
  google_auth = None
36
36
  auth_requests = None
37
37
  credentials_lib = None
38
- Credentials = None # pylint: disable=invalid-name
38
+ Credentials = Any # pylint: disable=invalid-name
39
39
 
40
40
 
41
41
  SUPPORTED_MODELS_AND_SETTINGS = {
@@ -20,6 +20,7 @@ from typing import Annotated, Any, Literal
20
20
 
21
21
  import langfun.core as lf
22
22
  from langfun.core import modalities as lf_modalities
23
+ from langfun.core.llms import vertexai
23
24
  import pyglove as pg
24
25
 
25
26
 
@@ -47,13 +48,15 @@ class GenAI(lf.LanguageModel):
47
48
 
48
49
  model: Annotated[
49
50
  Literal[
51
+ 'gemini-2.0-flash-exp',
52
+ 'gemini-exp-1206',
53
+ 'gemini-exp-1114',
54
+ 'gemini-1.5-pro-latest',
55
+ 'gemini-1.5-flash-latest',
50
56
  'gemini-pro',
51
57
  'gemini-pro-vision',
52
58
  'text-bison-001',
53
59
  'chat-bison-001',
54
- 'gemini-1.5-pro-latest',
55
- 'gemini-1.5-flash-latest',
56
- 'gemini-exp-1114',
57
60
  ],
58
61
  'Model name.',
59
62
  ]
@@ -306,64 +309,64 @@ _GOOGLE_GENAI_MODEL_HUB = _ModelHub()
306
309
  #
307
310
 
308
311
 
309
- _IMAGE_TYPES = [
310
- 'image/png',
311
- 'image/jpeg',
312
- 'image/webp',
313
- 'image/heic',
314
- 'image/heif',
315
- ]
316
-
317
- _AUDIO_TYPES = [
318
- 'audio/aac',
319
- 'audio/flac',
320
- 'audio/mp3',
321
- 'audio/m4a',
322
- 'audio/mpeg',
323
- 'audio/mpga',
324
- 'audio/mp4',
325
- 'audio/opus',
326
- 'audio/pcm',
327
- 'audio/wav',
328
- 'audio/webm'
329
- ]
330
-
331
- _VIDEO_TYPES = [
332
- 'video/mov',
333
- 'video/mpeg',
334
- 'video/mpegps',
335
- 'video/mpg',
336
- 'video/mp4',
337
- 'video/webm',
338
- 'video/wmv',
339
- 'video/x-flv',
340
- 'video/3gpp',
341
- ]
342
-
343
- _PDF = [
344
- 'application/pdf',
345
- ]
312
+ class GeminiFlash2_0Exp(GenAI): # pylint: disable=invalid-name
313
+ """Gemini Experimental model launched on 12/06/2024."""
314
+
315
+ model = 'gemini-2.0-flash-exp'
316
+ supported_modalities = (
317
+ vertexai.DOCUMENT_TYPES
318
+ + vertexai.IMAGE_TYPES
319
+ + vertexai.AUDIO_TYPES
320
+ + vertexai.VIDEO_TYPES
321
+ )
322
+
323
+
324
+ class GeminiExp_20241206(GenAI): # pylint: disable=invalid-name
325
+ """Gemini Experimental model launched on 12/06/2024."""
326
+
327
+ model = 'gemini-exp-1206'
328
+ supported_modalities = (
329
+ vertexai.DOCUMENT_TYPES
330
+ + vertexai.IMAGE_TYPES
331
+ + vertexai.AUDIO_TYPES
332
+ + vertexai.VIDEO_TYPES
333
+ )
346
334
 
347
335
 
348
336
  class GeminiExp_20241114(GenAI): # pylint: disable=invalid-name
349
337
  """Gemini Experimental model launched on 11/14/2024."""
350
338
 
351
339
  model = 'gemini-exp-1114'
352
- supported_modalities = _PDF + _IMAGE_TYPES + _AUDIO_TYPES + _VIDEO_TYPES
340
+ supported_modalities = (
341
+ vertexai.DOCUMENT_TYPES
342
+ + vertexai.IMAGE_TYPES
343
+ + vertexai.AUDIO_TYPES
344
+ + vertexai.VIDEO_TYPES
345
+ )
353
346
 
354
347
 
355
348
  class GeminiPro1_5(GenAI): # pylint: disable=invalid-name
356
349
  """Gemini Pro latest model."""
357
350
 
358
351
  model = 'gemini-1.5-pro-latest'
359
- supported_modalities = _PDF + _IMAGE_TYPES + _AUDIO_TYPES + _VIDEO_TYPES
352
+ supported_modalities = (
353
+ vertexai.DOCUMENT_TYPES
354
+ + vertexai.IMAGE_TYPES
355
+ + vertexai.AUDIO_TYPES
356
+ + vertexai.VIDEO_TYPES
357
+ )
360
358
 
361
359
 
362
360
  class GeminiFlash1_5(GenAI): # pylint: disable=invalid-name
363
361
  """Gemini Flash latest model."""
364
362
 
365
363
  model = 'gemini-1.5-flash-latest'
366
- supported_modalities = _PDF + _IMAGE_TYPES + _AUDIO_TYPES + _VIDEO_TYPES
364
+ supported_modalities = (
365
+ vertexai.DOCUMENT_TYPES
366
+ + vertexai.IMAGE_TYPES
367
+ + vertexai.AUDIO_TYPES
368
+ + vertexai.VIDEO_TYPES
369
+ )
367
370
 
368
371
 
369
372
  class GeminiPro(GenAI):
@@ -376,7 +379,7 @@ class GeminiProVision(GenAI):
376
379
  """Gemini Pro vision model."""
377
380
 
378
381
  model = 'gemini-pro-vision'
379
- supported_modalities = _IMAGE_TYPES + _VIDEO_TYPES
382
+ supported_modalities = vertexai.IMAGE_TYPES + vertexai.VIDEO_TYPES
380
383
 
381
384
 
382
385
  class Palm2(GenAI):
@@ -40,7 +40,7 @@ except ImportError:
40
40
 
41
41
  # https://cloud.google.com/vertex-ai/generative-ai/pricing
42
42
  # describes that the average number of characters per token is about 4.
43
- AVGERAGE_CHARS_PER_TOEKN = 4
43
+ AVGERAGE_CHARS_PER_TOKEN = 4
44
44
 
45
45
 
46
46
  # Price in US dollars,
@@ -102,6 +102,18 @@ SUPPORTED_MODELS_AND_SETTINGS = {
102
102
  cost_per_1k_input_chars=0.000125,
103
103
  cost_per_1k_output_chars=0.000375,
104
104
  ),
105
+ # TODO(sharatsharat): Update costs when published
106
+ 'gemini-exp-1206': pg.Dict(
107
+ rpm=20,
108
+ cost_per_1k_input_chars=0.000,
109
+ cost_per_1k_output_chars=0.000,
110
+ ),
111
+ # TODO(sharatsharat): Update costs when published
112
+ 'gemini-2.0-flash-exp': pg.Dict(
113
+ rpm=20,
114
+ cost_per_1k_input_chars=0.000,
115
+ cost_per_1k_output_chars=0.000,
116
+ ),
105
117
  # TODO(chengrun): Set a more appropriate rpm for endpoint.
106
118
  'vertexai-endpoint': pg.Dict(
107
119
  rpm=20,
@@ -215,7 +227,7 @@ class VertexAI(rest.REST):
215
227
  return (
216
228
  cost_per_1k_input_chars * num_input_tokens
217
229
  + cost_per_1k_output_chars * num_output_tokens
218
- ) * AVGERAGE_CHARS_PER_TOEKN / 1000
230
+ ) * AVGERAGE_CHARS_PER_TOKEN / 1000
219
231
 
220
232
  @functools.cached_property
221
233
  def _session(self):
@@ -343,7 +355,7 @@ class VertexAI(rest.REST):
343
355
  return lf.AIMessage.from_chunks(chunks)
344
356
 
345
357
 
346
- _IMAGE_TYPES = [
358
+ IMAGE_TYPES = [
347
359
  'image/png',
348
360
  'image/jpeg',
349
361
  'image/webp',
@@ -351,7 +363,7 @@ _IMAGE_TYPES = [
351
363
  'image/heif',
352
364
  ]
353
365
 
354
- _AUDIO_TYPES = [
366
+ AUDIO_TYPES = [
355
367
  'audio/aac',
356
368
  'audio/flac',
357
369
  'audio/mp3',
@@ -362,10 +374,10 @@ _AUDIO_TYPES = [
362
374
  'audio/opus',
363
375
  'audio/pcm',
364
376
  'audio/wav',
365
- 'audio/webm'
377
+ 'audio/webm',
366
378
  ]
367
379
 
368
- _VIDEO_TYPES = [
380
+ VIDEO_TYPES = [
369
381
  'video/mov',
370
382
  'video/mpeg',
371
383
  'video/mpegps',
@@ -375,9 +387,10 @@ _VIDEO_TYPES = [
375
387
  'video/wmv',
376
388
  'video/x-flv',
377
389
  'video/3gpp',
390
+ 'video/quicktime',
378
391
  ]
379
392
 
380
- _DOCUMENT_TYPES = [
393
+ DOCUMENT_TYPES = [
381
394
  'application/pdf',
382
395
  'text/plain',
383
396
  'text/csv',
@@ -388,11 +401,25 @@ _DOCUMENT_TYPES = [
388
401
  ]
389
402
 
390
403
 
404
+ class VertexAIGemini2_0(VertexAI): # pylint: disable=invalid-name
405
+ """Vertex AI Gemini 2.0 model."""
406
+
407
+ supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
408
+ DOCUMENT_TYPES + IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES
409
+ )
410
+
411
+
412
+ class VertexAIGeminiFlash2_0Exp(VertexAI): # pylint: disable=invalid-name
413
+ """Vertex AI Gemini 2.0 Flash model."""
414
+
415
+ model = 'gemini-2.0-flash-exp'
416
+
417
+
391
418
  class VertexAIGemini1_5(VertexAI): # pylint: disable=invalid-name
392
419
  """Vertex AI Gemini 1.5 model."""
393
420
 
394
- supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
395
- _DOCUMENT_TYPES + _IMAGE_TYPES + _AUDIO_TYPES + _VIDEO_TYPES
421
+ supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
422
+ DOCUMENT_TYPES + IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES
396
423
  )
397
424
 
398
425
 
@@ -460,8 +487,8 @@ class VertexAIGeminiPro1Vision(VertexAI): # pylint: disable=invalid-name
460
487
  """Vertex AI Gemini 1.0 Pro Vision model."""
461
488
 
462
489
  model = 'gemini-1.0-pro-vision'
463
- supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
464
- _IMAGE_TYPES + _VIDEO_TYPES
490
+ supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
491
+ IMAGE_TYPES + VIDEO_TYPES
465
492
  )
466
493
 
467
494
 
langfun/core/message.py CHANGED
@@ -769,7 +769,6 @@ class Message(
769
769
  padding: 20px;
770
770
  margin: 10px 5px 10px 5px;
771
771
  font-style: italic;
772
- font-size: 1.1em;
773
772
  white-space: pre-wrap;
774
773
  border: 1px solid #EEE;
775
774
  border-radius: 5px;
@@ -778,7 +777,7 @@ class Message(
778
777
  .modality-in-text {
779
778
  display: inline-block;
780
779
  }
781
- .modality-in-text > details {
780
+ .modality-in-text > details.pyglove {
782
781
  display: inline-block;
783
782
  font-size: 0.8em;
784
783
  border: 0;
@@ -380,7 +380,6 @@ class MessageTest(unittest.TestCase):
380
380
  padding: 20px;
381
381
  margin: 10px 5px 10px 5px;
382
382
  font-style: italic;
383
- font-size: 1.1em;
384
383
  white-space: pre-wrap;
385
384
  border: 1px solid #EEE;
386
385
  border-radius: 5px;
@@ -389,7 +388,7 @@ class MessageTest(unittest.TestCase):
389
388
  .modality-in-text {
390
389
  display: inline-block;
391
390
  }
392
- .modality-in-text > details {
391
+ .modality-in-text > details.pyglove {
393
392
  display: inline-block;
394
393
  font-size: 0.8em;
395
394
  border: 0;