langfun 0.1.2.dev202412070804__py3-none-any.whl → 0.1.2.dev202412110804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,6 +18,8 @@ import unittest
18
18
  import langfun.core as lf
19
19
  from langfun.core.agentic import action as action_lib
20
20
  from langfun.core.llms import fake
21
+ import langfun.core.structured as lf_structured
22
+ import pyglove as pg
21
23
 
22
24
 
23
25
  class SessionTest(unittest.TestCase):
@@ -28,60 +30,94 @@ class SessionTest(unittest.TestCase):
28
30
  class Bar(action_lib.Action):
29
31
 
30
32
  def call(self, session, *, lm, **kwargs):
31
- test.assertIs(session.current_invocation.action, self)
33
+ test.assertIs(session.current_action.action, self)
32
34
  session.info('Begin Bar')
33
35
  session.query('bar', lm=lm)
36
+ session.add_metadata(note='bar')
34
37
  return 2
35
38
 
36
39
  class Foo(action_lib.Action):
37
40
  x: int
38
41
 
39
42
  def call(self, session, *, lm, **kwargs):
40
- test.assertIs(session.current_invocation.action, self)
41
- session.info('Begin Foo', x=1)
42
- session.query('foo', lm=lm)
43
+ test.assertIs(session.current_action.action, self)
44
+ with session.phase('prepare'):
45
+ session.info('Begin Foo', x=1)
46
+ session.query('foo', lm=lm)
47
+ with session.track_queries():
48
+ self.make_additional_query(lm)
49
+ session.add_metadata(note='foo')
43
50
  return self.x + Bar()(session, lm=lm)
44
51
 
52
+ def make_additional_query(self, lm):
53
+ lf_structured.query('additional query', lm=lm)
54
+
45
55
  lm = fake.StaticResponse('lm response')
46
- session = action_lib.Session()
47
- root = session.root_invocation
48
- self.assertIsInstance(root.action, action_lib.RootAction)
49
- self.assertIs(session.current_invocation, session.root_invocation)
50
- self.assertEqual(Foo(1)(session, lm=lm), 3)
51
- self.assertEqual(len(session.root_invocation.child_invocations), 1)
52
- self.assertEqual(len(list(session.root_invocation.queries())), 0)
53
- self.assertEqual(
54
- len(list(session.root_invocation.queries(include_children=True))), 2
55
- )
56
- self.assertEqual(
57
- len(list(session.root_invocation.child_invocations[0].queries())), 1
58
- )
59
- self.assertEqual(len(session.root_invocation.child_invocations[0].logs), 1)
60
- self.assertEqual(
61
- len(session.root_invocation.child_invocations[0].child_invocations),
62
- 1
63
- )
64
- self.assertEqual(
65
- len(session.root_invocation
66
- .child_invocations[0].child_invocations[0].logs),
67
- 1
68
- )
69
- self.assertEqual(
70
- len(list(session.root_invocation
71
- .child_invocations[0].child_invocations[0].queries())),
72
- 1
73
- )
74
- self.assertEqual(
75
- len(session.root_invocation
76
- .child_invocations[0].child_invocations[0].child_invocations),
77
- 0
78
- )
79
- self.assertIs(session.current_invocation, session.root_invocation)
80
- self.assertIs(session.final_result, 3)
81
- self.assertIn(
82
- 'invocation-final-result',
83
- session.to_html().content,
56
+ foo = Foo(1)
57
+ self.assertEqual(foo(lm=lm), 3)
58
+
59
+ session = foo.session
60
+ self.assertIsNotNone(session)
61
+ self.assertIsInstance(session.root.action, action_lib.RootAction)
62
+ self.assertIs(session.current_action, session.root)
63
+
64
+ #
65
+ # Inspecting the root invocation.
66
+ #
67
+
68
+ root = session.root
69
+ self.assertEqual(len(root.execution.items), 1)
70
+ self.assertIs(root.execution.items[0].action, foo)
71
+
72
+ self.assertTrue(root.execution.has_started)
73
+ self.assertTrue(root.execution.has_stopped)
74
+ self.assertGreater(root.execution.elapse, 0)
75
+ self.assertEqual(root.result, 3)
76
+ self.assertEqual(root.result_metadata, dict(note='foo'))
77
+
78
+ # The root space should have one action (foo), no queries, and no logs.
79
+ self.assertEqual(len(list(root.actions)), 1)
80
+ self.assertEqual(len(list(root.queries)), 0)
81
+ self.assertEqual(len(list(root.logs)), 0)
82
+ # 1 query from Bar and 2 from Foo.
83
+ self.assertEqual(len(list(root.all_queries)), 3)
84
+ # 1 log from Bar and 1 from Foo.
85
+ self.assertEqual(len(list(root.all_logs)), 2)
86
+ self.assertEqual(root.usage_summary.total.num_requests, 3)
87
+
88
+ # Inspecting the top-level action (Foo)
89
+ foo_invocation = root.execution.items[0]
90
+ self.assertEqual(len(foo_invocation.execution.items), 3)
91
+
92
+ # Prepare phase.
93
+ prepare_phase = foo_invocation.execution.items[0]
94
+ self.assertIsInstance(
95
+ prepare_phase, action_lib.ExecutionTrace
84
96
  )
97
+ self.assertEqual(len(prepare_phase.items), 2)
98
+ self.assertTrue(prepare_phase.has_started)
99
+ self.assertTrue(prepare_phase.has_stopped)
100
+ self.assertEqual(prepare_phase.usage_summary.total.num_requests, 1)
101
+
102
+ # Tracked queries.
103
+ query_invocation = foo_invocation.execution.items[1]
104
+ self.assertIsInstance(query_invocation, lf_structured.QueryInvocation)
105
+ self.assertIs(query_invocation.lm, lm)
106
+
107
+ # Invocation to Bar.
108
+ bar_invocation = foo_invocation.execution.items[2]
109
+ self.assertIsInstance(bar_invocation, action_lib.ActionInvocation)
110
+ self.assertIsInstance(bar_invocation.action, Bar)
111
+ self.assertEqual(bar_invocation.result, 2)
112
+ self.assertEqual(bar_invocation.result_metadata, dict(note='bar'))
113
+ self.assertEqual(len(bar_invocation.execution.items), 2)
114
+
115
+ # Save to HTML
116
+ self.assertIn('invocation-result', session.to_html().content)
117
+
118
+ # Save session to JSON
119
+ json_str = session.to_json_str(save_ref_value=True)
120
+ self.assertIsInstance(pg.from_json_str(json_str), action_lib.Session)
85
121
 
86
122
  def test_log(self):
87
123
  session = action_lib.Session()
@@ -556,6 +556,9 @@ class Evaluation(experiment_lib.Experiment):
556
556
  border: 0px;
557
557
  margin: 0px;
558
558
  }
559
+ .eval-details .tab-control {
560
+ width: 100%;
561
+ }
559
562
  .eval-details .tab-button {
560
563
  font-size: large;
561
564
  }
@@ -33,6 +33,7 @@ from langfun.core.llms.rest import REST
33
33
  # Gemini models.
34
34
  from langfun.core.llms.google_genai import GenAI
35
35
  from langfun.core.llms.google_genai import GeminiExp_20241114
36
+ from langfun.core.llms.google_genai import GeminiExp_20241206
36
37
  from langfun.core.llms.google_genai import GeminiFlash1_5
37
38
  from langfun.core.llms.google_genai import GeminiPro
38
39
  from langfun.core.llms.google_genai import GeminiPro1_5
@@ -35,7 +35,7 @@ except ImportError:
35
35
  google_auth = None
36
36
  auth_requests = None
37
37
  credentials_lib = None
38
- Credentials = None # pylint: disable=invalid-name
38
+ Credentials = Any # pylint: disable=invalid-name
39
39
 
40
40
 
41
41
  SUPPORTED_MODELS_AND_SETTINGS = {
@@ -20,6 +20,7 @@ from typing import Annotated, Any, Literal
20
20
 
21
21
  import langfun.core as lf
22
22
  from langfun.core import modalities as lf_modalities
23
+ from langfun.core.llms import vertexai
23
24
  import pyglove as pg
24
25
 
25
26
 
@@ -54,6 +55,7 @@ class GenAI(lf.LanguageModel):
54
55
  'gemini-1.5-pro-latest',
55
56
  'gemini-1.5-flash-latest',
56
57
  'gemini-exp-1114',
58
+ 'gemini-exp-1206',
57
59
  ],
58
60
  'Model name.',
59
61
  ]
@@ -306,64 +308,52 @@ _GOOGLE_GENAI_MODEL_HUB = _ModelHub()
306
308
  #
307
309
 
308
310
 
309
- _IMAGE_TYPES = [
310
- 'image/png',
311
- 'image/jpeg',
312
- 'image/webp',
313
- 'image/heic',
314
- 'image/heif',
315
- ]
316
-
317
- _AUDIO_TYPES = [
318
- 'audio/aac',
319
- 'audio/flac',
320
- 'audio/mp3',
321
- 'audio/m4a',
322
- 'audio/mpeg',
323
- 'audio/mpga',
324
- 'audio/mp4',
325
- 'audio/opus',
326
- 'audio/pcm',
327
- 'audio/wav',
328
- 'audio/webm'
329
- ]
330
-
331
- _VIDEO_TYPES = [
332
- 'video/mov',
333
- 'video/mpeg',
334
- 'video/mpegps',
335
- 'video/mpg',
336
- 'video/mp4',
337
- 'video/webm',
338
- 'video/wmv',
339
- 'video/x-flv',
340
- 'video/3gpp',
341
- ]
342
-
343
- _PDF = [
344
- 'application/pdf',
345
- ]
311
+ class GeminiExp_20241206(GenAI): # pylint: disable=invalid-name
312
+ """Gemini Experimental model launched on 12/06/2024."""
313
+
314
+ model = 'gemini-exp-1206'
315
+ supported_modalities = (
316
+ vertexai.DOCUMENT_TYPES
317
+ + vertexai.IMAGE_TYPES
318
+ + vertexai.AUDIO_TYPES
319
+ + vertexai.VIDEO_TYPES
320
+ )
346
321
 
347
322
 
348
323
  class GeminiExp_20241114(GenAI): # pylint: disable=invalid-name
349
324
  """Gemini Experimental model launched on 11/14/2024."""
350
325
 
351
326
  model = 'gemini-exp-1114'
352
- supported_modalities = _PDF + _IMAGE_TYPES + _AUDIO_TYPES + _VIDEO_TYPES
327
+ supported_modalities = (
328
+ vertexai.DOCUMENT_TYPES
329
+ + vertexai.IMAGE_TYPES
330
+ + vertexai.AUDIO_TYPES
331
+ + vertexai.VIDEO_TYPES
332
+ )
353
333
 
354
334
 
355
335
  class GeminiPro1_5(GenAI): # pylint: disable=invalid-name
356
336
  """Gemini Pro latest model."""
357
337
 
358
338
  model = 'gemini-1.5-pro-latest'
359
- supported_modalities = _PDF + _IMAGE_TYPES + _AUDIO_TYPES + _VIDEO_TYPES
339
+ supported_modalities = (
340
+ vertexai.DOCUMENT_TYPES
341
+ + vertexai.IMAGE_TYPES
342
+ + vertexai.AUDIO_TYPES
343
+ + vertexai.VIDEO_TYPES
344
+ )
360
345
 
361
346
 
362
347
  class GeminiFlash1_5(GenAI): # pylint: disable=invalid-name
363
348
  """Gemini Flash latest model."""
364
349
 
365
350
  model = 'gemini-1.5-flash-latest'
366
- supported_modalities = _PDF + _IMAGE_TYPES + _AUDIO_TYPES + _VIDEO_TYPES
351
+ supported_modalities = (
352
+ vertexai.DOCUMENT_TYPES
353
+ + vertexai.IMAGE_TYPES
354
+ + vertexai.AUDIO_TYPES
355
+ + vertexai.VIDEO_TYPES
356
+ )
367
357
 
368
358
 
369
359
  class GeminiPro(GenAI):
@@ -376,7 +366,7 @@ class GeminiProVision(GenAI):
376
366
  """Gemini Pro vision model."""
377
367
 
378
368
  model = 'gemini-pro-vision'
379
- supported_modalities = _IMAGE_TYPES + _VIDEO_TYPES
369
+ supported_modalities = vertexai.IMAGE_TYPES + vertexai.VIDEO_TYPES
380
370
 
381
371
 
382
372
  class Palm2(GenAI):
@@ -343,7 +343,7 @@ class VertexAI(rest.REST):
343
343
  return lf.AIMessage.from_chunks(chunks)
344
344
 
345
345
 
346
- _IMAGE_TYPES = [
346
+ IMAGE_TYPES = [
347
347
  'image/png',
348
348
  'image/jpeg',
349
349
  'image/webp',
@@ -351,7 +351,7 @@ _IMAGE_TYPES = [
351
351
  'image/heif',
352
352
  ]
353
353
 
354
- _AUDIO_TYPES = [
354
+ AUDIO_TYPES = [
355
355
  'audio/aac',
356
356
  'audio/flac',
357
357
  'audio/mp3',
@@ -362,10 +362,10 @@ _AUDIO_TYPES = [
362
362
  'audio/opus',
363
363
  'audio/pcm',
364
364
  'audio/wav',
365
- 'audio/webm'
365
+ 'audio/webm',
366
366
  ]
367
367
 
368
- _VIDEO_TYPES = [
368
+ VIDEO_TYPES = [
369
369
  'video/mov',
370
370
  'video/mpeg',
371
371
  'video/mpegps',
@@ -375,9 +375,10 @@ _VIDEO_TYPES = [
375
375
  'video/wmv',
376
376
  'video/x-flv',
377
377
  'video/3gpp',
378
+ 'video/quicktime',
378
379
  ]
379
380
 
380
- _DOCUMENT_TYPES = [
381
+ DOCUMENT_TYPES = [
381
382
  'application/pdf',
382
383
  'text/plain',
383
384
  'text/csv',
@@ -391,8 +392,8 @@ _DOCUMENT_TYPES = [
391
392
  class VertexAIGemini1_5(VertexAI): # pylint: disable=invalid-name
392
393
  """Vertex AI Gemini 1.5 model."""
393
394
 
394
- supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
395
- _DOCUMENT_TYPES + _IMAGE_TYPES + _AUDIO_TYPES + _VIDEO_TYPES
395
+ supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
396
+ DOCUMENT_TYPES + IMAGE_TYPES + AUDIO_TYPES + VIDEO_TYPES
396
397
  )
397
398
 
398
399
 
@@ -460,8 +461,8 @@ class VertexAIGeminiPro1Vision(VertexAI): # pylint: disable=invalid-name
460
461
  """Vertex AI Gemini 1.0 Pro Vision model."""
461
462
 
462
463
  model = 'gemini-1.0-pro-vision'
463
- supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
464
- _IMAGE_TYPES + _VIDEO_TYPES
464
+ supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
465
+ IMAGE_TYPES + VIDEO_TYPES
465
466
  )
466
467
 
467
468
 
langfun/core/message.py CHANGED
@@ -769,7 +769,6 @@ class Message(
769
769
  padding: 20px;
770
770
  margin: 10px 5px 10px 5px;
771
771
  font-style: italic;
772
- font-size: 1.1em;
773
772
  white-space: pre-wrap;
774
773
  border: 1px solid #EEE;
775
774
  border-radius: 5px;
@@ -778,7 +777,7 @@ class Message(
778
777
  .modality-in-text {
779
778
  display: inline-block;
780
779
  }
781
- .modality-in-text > details {
780
+ .modality-in-text > details.pyglove {
782
781
  display: inline-block;
783
782
  font-size: 0.8em;
784
783
  border: 0;
@@ -380,7 +380,6 @@ class MessageTest(unittest.TestCase):
380
380
  padding: 20px;
381
381
  margin: 10px 5px 10px 5px;
382
382
  font-style: italic;
383
- font-size: 1.1em;
384
383
  white-space: pre-wrap;
385
384
  border: 1px solid #EEE;
386
385
  border-radius: 5px;
@@ -389,7 +388,7 @@ class MessageTest(unittest.TestCase):
389
388
  .modality-in-text {
390
389
  display: inline-block;
391
390
  }
392
- .modality-in-text > details {
391
+ .modality-in-text > details.pyglove {
393
392
  display: inline-block;
394
393
  font-size: 0.8em;
395
394
  border: 0;
@@ -76,6 +76,7 @@ def unittest_with_test_cases(f, unittests):
76
76
 
77
77
  def _function_gen(
78
78
  func: Callable[..., Any],
79
+ context: dict[str, Any],
79
80
  signature: str,
80
81
  lm: language_model.LanguageModel,
81
82
  num_retries: int = 1,
@@ -141,21 +142,23 @@ def _function_gen(
141
142
  elif isinstance(unittest, list):
142
143
  unittest_examples = unittest
143
144
 
145
+ last_error = None
144
146
  for _ in range(num_retries):
145
147
  try:
146
148
  source_code = prompting.query(
147
149
  PythonFunctionPrompt(signature=signature), lm=lm
148
150
  )
149
- f = python.evaluate(source_code)
151
+ f = python.evaluate(source_code, global_vars=context)
150
152
 
151
153
  # Check whether the sigantures are the same.
152
154
  if inspect.signature(f) != inspect.signature(func):
153
- pg.logging.warning(
154
- "Signature mismatch. Expected: %s, Actual: %s",
155
- inspect.signature(func),
156
- inspect.signature(f),
155
+ raise python.CodeError(
156
+ code=source_code,
157
+ cause=TypeError(
158
+ f"Signature mismatch: Expected: {inspect.signature(func)}, "
159
+ f"Actual: {inspect.signature(f)}.",
160
+ ),
157
161
  )
158
- continue
159
162
 
160
163
  if callable(unittest):
161
164
  unittest(f)
@@ -163,10 +166,12 @@ def _function_gen(
163
166
  unittest_with_test_cases(f, unittest_examples)
164
167
 
165
168
  return f, source_code
166
- except Exception: # pylint: disable=broad-exception-caught
167
- pass
168
-
169
- return None, None
169
+ except python.CodeError as e:
170
+ last_error = e
171
+ pg.logging.warning(
172
+ f"Bad code generated: {e}",
173
+ )
174
+ raise last_error
170
175
 
171
176
 
172
177
  def _process_signature(signature):
@@ -220,6 +225,13 @@ def function_gen(
220
225
  setattr(func, "__function__", None)
221
226
  setattr(func, "__source_code__", None)
222
227
 
228
+ # Prepare the globals/locals for the generated code to be evaluated against.
229
+ callstack = inspect.stack()
230
+ assert len(callstack) > 1
231
+ context = dict(callstack[1][0].f_globals)
232
+ context.update(callstack[1][0].f_locals)
233
+ context.pop(func.__name__, None)
234
+
223
235
  @functools.wraps(func)
224
236
  def lm_generated_func(*args, **kwargs):
225
237
  if func.__function__ is not None:
@@ -238,20 +250,20 @@ def function_gen(
238
250
 
239
251
  if signature in cache:
240
252
  func.__source_code__ = cache[signature]
241
- func.__function__ = python.evaluate(func.__source_code__)
253
+ func.__function__ = python.evaluate(
254
+ func.__source_code__, global_vars=context
255
+ )
242
256
  return func.__function__(*args, **kwargs)
243
257
 
244
258
  func.__function__, func.__source_code__ = _function_gen(
245
259
  func,
260
+ context,
246
261
  signature,
247
262
  lm,
248
263
  num_retries=num_retries,
249
264
  unittest=unittest,
250
265
  unittest_num_retries=unittest_num_retries,
251
266
  )
252
- if func.__function__ is None:
253
- raise ValueError(f"Function generation failed. Signature:\n{signature}")
254
-
255
267
  if cache_filename is not None:
256
268
  cache[signature] = func.__source_code__
257
269
  cache.save(cache_filename)
@@ -311,6 +311,36 @@ class FunctionGenerationTest(unittest.TestCase):
311
311
 
312
312
  self.assertEqual(linear_search(['a', 'b', 'c'], 'c'), 2)
313
313
 
314
+ def test_context_passthrough(self):
315
+
316
+ class Number(pg.Object):
317
+ value: int
318
+
319
+ function_gen_lm_response = inspect.cleandoc("""
320
+ ```python
321
+ def add(a: Number, b: Number) -> Number:
322
+ \"\"\"Adds two numbers together.\"\"\"
323
+ return Number(a.value + b.value)
324
+ ```
325
+ """)
326
+
327
+ lm = fake.StaticSequence(
328
+ [function_gen_lm_response]
329
+ )
330
+
331
+ def _unittest_fn(func):
332
+ assert func(Number(1), Number(2)) == Number(3)
333
+
334
+ custom_unittest = _unittest_fn
335
+
336
+ @function_generation.function_gen(
337
+ lm=lm, unittest=custom_unittest, num_retries=1
338
+ )
339
+ def add(a: Number, b: Number) -> Number: # pylint: disable=unused-argument
340
+ """Adds two numbers together."""
341
+
342
+ self.assertEqual(add(Number(2), Number(3)), Number(5))
343
+
314
344
  def test_siganture_check(self):
315
345
  incorrect_signature_lm_response = inspect.cleandoc("""
316
346
  ```python
@@ -264,9 +264,9 @@ def query(
264
264
  schema_lib.Schema.from_value(schema)
265
265
  if schema not in (None, str) else None
266
266
  ),
267
- output=pg.Ref(_result(output_message)),
268
267
  lm=pg.Ref(lm),
269
268
  examples=pg.Ref(examples) if examples else [],
269
+ lm_response=lf.AIMessage(output_message.text),
270
270
  usage_summary=usage_summary,
271
271
  )
272
272
  for i, (tracker, include_child_scopes) in enumerate(trackers):
@@ -357,7 +357,7 @@ def _reward_fn(cls) -> Callable[
357
357
  return _reward
358
358
 
359
359
 
360
- class QueryInvocation(pg.Object):
360
+ class QueryInvocation(pg.Object, pg.views.HtmlTreeView.Extension):
361
361
  """A class to represent the invocation of `lf.query`."""
362
362
 
363
363
  input: Annotated[
@@ -368,9 +368,9 @@ class QueryInvocation(pg.Object):
368
368
  schema_lib.schema_spec(noneable=True),
369
369
  'Schema of `lf.query`.'
370
370
  ]
371
- output: Annotated[
372
- Any,
373
- 'Mapping output of `lf.query`.'
371
+ lm_response: Annotated[
372
+ lf.Message,
373
+ 'Raw LM response.'
374
374
  ]
375
375
  lm: Annotated[
376
376
  lf.LanguageModel,
@@ -385,6 +385,106 @@ class QueryInvocation(pg.Object):
385
385
  'Usage summary for `lf.query`.'
386
386
  ]
387
387
 
388
+ @functools.cached_property
389
+ def lm_request(self) -> lf.Message:
390
+ return query_prompt(self.input, self.schema)
391
+
392
+ @functools.cached_property
393
+ def output(self) -> Any:
394
+ return query_output(self.lm_response, self.schema)
395
+
396
+ def _on_bound(self):
397
+ super()._on_bound()
398
+ self.__dict__.pop('lm_request', None)
399
+ self.__dict__.pop('output', None)
400
+
401
+ def _html_tree_view_summary(
402
+ self,
403
+ *,
404
+ view: pg.views.HtmlTreeView,
405
+ **kwargs: Any
406
+ ) -> pg.Html | None:
407
+ return view.summary(
408
+ value=self,
409
+ title=pg.Html.element(
410
+ 'div',
411
+ [
412
+ pg.views.html.controls.Label(
413
+ 'lf.query',
414
+ css_classes=['query-invocation-type-name']
415
+ ),
416
+ pg.views.html.controls.Badge(
417
+ f'lm={self.lm.model_id}',
418
+ pg.format(
419
+ self.lm,
420
+ verbose=False,
421
+ python_format=True,
422
+ hide_default_values=True
423
+ ),
424
+ css_classes=['query-invocation-lm']
425
+ ),
426
+ self.usage_summary.to_html(extra_flags=dict(as_badge=True))
427
+ ],
428
+ css_classes=['query-invocation-title']
429
+ ),
430
+ enable_summary_tooltip=False
431
+ )
432
+
433
+ def _html_tree_view_content(
434
+ self,
435
+ *,
436
+ view: pg.views.HtmlTreeView,
437
+ **kwargs: Any
438
+ ) -> pg.Html:
439
+ return pg.views.html.controls.TabControl([
440
+ pg.views.html.controls.Tab(
441
+ 'input',
442
+ pg.view(self.input, collapse_level=None),
443
+ ),
444
+ pg.views.html.controls.Tab(
445
+ 'schema',
446
+ pg.view(self.schema),
447
+ ),
448
+ pg.views.html.controls.Tab(
449
+ 'output',
450
+ pg.view(self.output, collapse_level=None),
451
+ ),
452
+ pg.views.html.controls.Tab(
453
+ 'lm_request',
454
+ pg.view(
455
+ self.lm_request,
456
+ extra_flags=dict(include_message_metadata=False),
457
+ ),
458
+ ),
459
+ pg.views.html.controls.Tab(
460
+ 'lm_response',
461
+ pg.view(
462
+ self.lm_response,
463
+ extra_flags=dict(include_message_metadata=False)
464
+ ),
465
+ ),
466
+ ], tab_position='top').to_html()
467
+
468
+ @classmethod
469
+ def _html_tree_view_css_styles(cls) -> list[str]:
470
+ return super()._html_tree_view_css_styles() + [
471
+ """
472
+ .query-invocation-title {
473
+ display: inline-block;
474
+ font-weight: normal;
475
+ }
476
+ .query-invocation-type-name {
477
+ font-style: italic;
478
+ color: #888;
479
+ }
480
+ .query-invocation-lm.badge {
481
+ margin-left: 5px;
482
+ margin-right: 5px;
483
+ background-color: #fff0d6;
484
+ }
485
+ """
486
+ ]
487
+
388
488
 
389
489
  @contextlib.contextmanager
390
490
  def track_queries(
@@ -962,6 +962,18 @@ class QueryStructureJsonTest(unittest.TestCase):
962
962
  )
963
963
 
964
964
 
965
+ class QueryInvocationTest(unittest.TestCase):
966
+
967
+ def test_to_html(self):
968
+ lm = fake.StaticSequence([
969
+ 'Activity(description="hi")',
970
+ ])
971
+ with prompting.track_queries() as queries:
972
+ prompting.query('foo', Activity, lm=lm)
973
+
974
+ self.assertIn('schema', queries[0].to_html_str())
975
+
976
+
965
977
  class TrackQueriesTest(unittest.TestCase):
966
978
 
967
979
  def test_include_child_scopes(self):