chainlit 1.0.401__py3-none-any.whl → 2.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

Files changed (112) hide show
  1. chainlit/__init__.py +98 -279
  2. chainlit/_utils.py +8 -0
  3. chainlit/action.py +12 -10
  4. chainlit/{auth.py → auth/__init__.py} +28 -36
  5. chainlit/auth/cookie.py +122 -0
  6. chainlit/auth/jwt.py +39 -0
  7. chainlit/cache.py +4 -6
  8. chainlit/callbacks.py +362 -0
  9. chainlit/chat_context.py +64 -0
  10. chainlit/chat_settings.py +3 -1
  11. chainlit/cli/__init__.py +77 -8
  12. chainlit/config.py +181 -101
  13. chainlit/context.py +42 -13
  14. chainlit/copilot/dist/index.js +8750 -903
  15. chainlit/data/__init__.py +101 -416
  16. chainlit/data/acl.py +6 -2
  17. chainlit/data/base.py +107 -0
  18. chainlit/data/chainlit_data_layer.py +608 -0
  19. chainlit/data/dynamodb.py +590 -0
  20. chainlit/data/literalai.py +500 -0
  21. chainlit/data/sql_alchemy.py +721 -0
  22. chainlit/data/storage_clients/__init__.py +0 -0
  23. chainlit/data/storage_clients/azure.py +81 -0
  24. chainlit/data/storage_clients/azure_blob.py +89 -0
  25. chainlit/data/storage_clients/base.py +26 -0
  26. chainlit/data/storage_clients/gcs.py +88 -0
  27. chainlit/data/storage_clients/s3.py +75 -0
  28. chainlit/data/utils.py +29 -0
  29. chainlit/discord/__init__.py +6 -0
  30. chainlit/discord/app.py +354 -0
  31. chainlit/element.py +91 -33
  32. chainlit/emitter.py +80 -29
  33. chainlit/frontend/dist/assets/DailyMotion-C_XC7xJI.js +1 -0
  34. chainlit/frontend/dist/assets/Dataframe-Cs4l4hA1.js +22 -0
  35. chainlit/frontend/dist/assets/Facebook-CUeCH7hk.js +1 -0
  36. chainlit/frontend/dist/assets/FilePlayer-CB-fYkx8.js +1 -0
  37. chainlit/frontend/dist/assets/Kaltura-YX6qaq72.js +1 -0
  38. chainlit/frontend/dist/assets/Mixcloud-DGV0ldjP.js +1 -0
  39. chainlit/frontend/dist/assets/Mux-CmRss5oc.js +1 -0
  40. chainlit/frontend/dist/assets/Preview-DBVJn7-H.js +1 -0
  41. chainlit/frontend/dist/assets/SoundCloud-qLUb18oY.js +1 -0
  42. chainlit/frontend/dist/assets/Streamable-BvYP7bFp.js +1 -0
  43. chainlit/frontend/dist/assets/Twitch-CTHt-sGZ.js +1 -0
  44. chainlit/frontend/dist/assets/Vidyard-B-0mCJbm.js +1 -0
  45. chainlit/frontend/dist/assets/Vimeo-Dnp7ri8q.js +1 -0
  46. chainlit/frontend/dist/assets/Wistia-DW0x_UBn.js +1 -0
  47. chainlit/frontend/dist/assets/YouTube--98FipvA.js +1 -0
  48. chainlit/frontend/dist/assets/index-D71nZ46o.js +8665 -0
  49. chainlit/frontend/dist/assets/index-g8LTJwwr.css +1 -0
  50. chainlit/frontend/dist/assets/react-plotly-Cn_BQTQw.js +3484 -0
  51. chainlit/frontend/dist/index.html +2 -4
  52. chainlit/haystack/callbacks.py +4 -7
  53. chainlit/input_widget.py +8 -4
  54. chainlit/langchain/callbacks.py +103 -68
  55. chainlit/langflow/__init__.py +1 -0
  56. chainlit/llama_index/callbacks.py +65 -40
  57. chainlit/markdown.py +22 -6
  58. chainlit/message.py +54 -56
  59. chainlit/mistralai/__init__.py +50 -0
  60. chainlit/oauth_providers.py +266 -8
  61. chainlit/openai/__init__.py +10 -18
  62. chainlit/secret.py +1 -1
  63. chainlit/server.py +789 -228
  64. chainlit/session.py +108 -90
  65. chainlit/slack/__init__.py +6 -0
  66. chainlit/slack/app.py +397 -0
  67. chainlit/socket.py +199 -116
  68. chainlit/step.py +141 -89
  69. chainlit/sync.py +2 -1
  70. chainlit/teams/__init__.py +6 -0
  71. chainlit/teams/app.py +338 -0
  72. chainlit/translations/bn.json +235 -0
  73. chainlit/translations/en-US.json +83 -4
  74. chainlit/translations/gu.json +235 -0
  75. chainlit/translations/he-IL.json +235 -0
  76. chainlit/translations/hi.json +235 -0
  77. chainlit/translations/kn.json +235 -0
  78. chainlit/translations/ml.json +235 -0
  79. chainlit/translations/mr.json +235 -0
  80. chainlit/translations/nl-NL.json +233 -0
  81. chainlit/translations/ta.json +235 -0
  82. chainlit/translations/te.json +235 -0
  83. chainlit/translations/zh-CN.json +233 -0
  84. chainlit/translations.py +60 -0
  85. chainlit/types.py +133 -28
  86. chainlit/user.py +14 -3
  87. chainlit/user_session.py +6 -3
  88. chainlit/utils.py +52 -5
  89. chainlit/version.py +3 -2
  90. {chainlit-1.0.401.dist-info → chainlit-2.0.3.dist-info}/METADATA +48 -50
  91. chainlit-2.0.3.dist-info/RECORD +106 -0
  92. chainlit/cli/utils.py +0 -24
  93. chainlit/frontend/dist/assets/index-9711593e.js +0 -723
  94. chainlit/frontend/dist/assets/index-d088547c.css +0 -1
  95. chainlit/frontend/dist/assets/react-plotly-d8762cc2.js +0 -3602
  96. chainlit/playground/__init__.py +0 -2
  97. chainlit/playground/config.py +0 -40
  98. chainlit/playground/provider.py +0 -108
  99. chainlit/playground/providers/__init__.py +0 -13
  100. chainlit/playground/providers/anthropic.py +0 -118
  101. chainlit/playground/providers/huggingface.py +0 -75
  102. chainlit/playground/providers/langchain.py +0 -89
  103. chainlit/playground/providers/openai.py +0 -408
  104. chainlit/playground/providers/vertexai.py +0 -171
  105. chainlit/translations/pt-BR.json +0 -155
  106. chainlit-1.0.401.dist-info/RECORD +0 -66
  107. /chainlit/copilot/dist/assets/{logo_dark-2a3cf740.svg → logo_dark-IkGJ_IwC.svg} +0 -0
  108. /chainlit/copilot/dist/assets/{logo_light-b078e7bc.svg → logo_light-Bb_IPh6r.svg} +0 -0
  109. /chainlit/frontend/dist/assets/{logo_dark-2a3cf740.svg → logo_dark-IkGJ_IwC.svg} +0 -0
  110. /chainlit/frontend/dist/assets/{logo_light-b078e7bc.svg → logo_light-Bb_IPh6r.svg} +0 -0
  111. {chainlit-1.0.401.dist-info → chainlit-2.0.3.dist-info}/WHEEL +0 -0
  112. {chainlit-1.0.401.dist-info → chainlit-2.0.3.dist-info}/entry_points.txt +0 -0
@@ -4,7 +4,6 @@
4
4
  <meta charset="UTF-8" />
5
5
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
6
6
  <!-- TAG INJECTION PLACEHOLDER -->
7
- <link rel="icon" href="/favicon" />
8
7
  <link rel="preconnect" href="https://fonts.googleapis.com" />
9
8
  <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin />
10
9
  <!-- FONT START -->
@@ -22,11 +21,10 @@
22
21
  <script>
23
22
  const global = globalThis;
24
23
  </script>
25
- <script type="module" crossorigin src="/assets/index-9711593e.js"></script>
26
- <link rel="stylesheet" href="/assets/index-d088547c.css">
24
+ <script type="module" crossorigin src="/assets/index-D71nZ46o.js"></script>
25
+ <link rel="stylesheet" crossorigin href="/assets/index-g8LTJwwr.css">
27
26
  </head>
28
27
  <body>
29
28
  <div id="root"></div>
30
-
31
29
  </body>
32
30
  </html>
@@ -1,14 +1,13 @@
1
1
  import re
2
2
  from typing import Any, Generic, List, Optional, TypeVar
3
3
 
4
- from chainlit.context import context
5
- from chainlit.step import Step
6
- from chainlit.sync import run_sync
7
4
  from haystack.agents import Agent, Tool
8
5
  from haystack.agents.agent_step import AgentStep
9
6
  from literalai.helper import utc_now
10
7
 
11
8
  from chainlit import Message
9
+ from chainlit.step import Step
10
+ from chainlit.sync import run_sync
12
11
 
13
12
  T = TypeVar("T")
14
13
 
@@ -68,9 +67,7 @@ class HaystackAgentCallbackHandler:
68
67
  self.last_tokens: List[str] = []
69
68
  self.answer_reached = False
70
69
 
71
- root_message = context.session.root_message
72
- parent_id = root_message.id if root_message else None
73
- run_step = Step(name=self.agent_name, type="run", parent_id=parent_id)
70
+ run_step = Step(name=self.agent_name, type="run")
74
71
  run_step.start = utc_now()
75
72
  run_step.input = kwargs
76
73
 
@@ -135,7 +132,7 @@ class HaystackAgentCallbackHandler:
135
132
  tool_result: str,
136
133
  tool_name: Optional[str] = None,
137
134
  tool_input: Optional[str] = None,
138
- **kwargs: Any
135
+ **kwargs: Any,
139
136
  ) -> None:
140
137
  # Tool finished, send step with tool_result
141
138
  tool_step = self.stack.pop()
chainlit/input_widget.py CHANGED
@@ -2,8 +2,10 @@ from abc import abstractmethod
2
2
  from collections import defaultdict
3
3
  from typing import Any, Dict, List, Optional
4
4
 
5
+ from pydantic import Field
6
+ from pydantic.dataclasses import dataclass
7
+
5
8
  from chainlit.types import InputWidgetType
6
- from pydantic.dataclasses import Field, dataclass
7
9
 
8
10
 
9
11
  @dataclass
@@ -75,7 +77,7 @@ class Select(InputWidget):
75
77
  initial: Optional[str] = None
76
78
  initial_index: Optional[int] = None
77
79
  initial_value: Optional[str] = None
78
- values: List[str] = Field(default_factory=lambda: [])
80
+ values: List[str] = Field(default_factory=list)
79
81
  items: Dict[str, str] = Field(default_factory=lambda: defaultdict(dict))
80
82
 
81
83
  def __post_init__(
@@ -127,6 +129,7 @@ class TextInput(InputWidget):
127
129
  type: InputWidgetType = "textinput"
128
130
  initial: Optional[str] = None
129
131
  placeholder: Optional[str] = None
132
+ multiline: bool = False
130
133
 
131
134
  def to_dict(self) -> Dict[str, Any]:
132
135
  return {
@@ -137,6 +140,7 @@ class TextInput(InputWidget):
137
140
  "placeholder": self.placeholder,
138
141
  "tooltip": self.tooltip,
139
142
  "description": self.description,
143
+ "multiline": self.multiline,
140
144
  }
141
145
 
142
146
 
@@ -165,8 +169,8 @@ class Tags(InputWidget):
165
169
  """Useful to create an input for an array of strings."""
166
170
 
167
171
  type: InputWidgetType = "tags"
168
- initial: List[str] = Field(default_factory=lambda: [])
169
- values: List[str] = Field(default_factory=lambda: [])
172
+ initial: List[str] = Field(default_factory=list)
173
+ values: List[str] = Field(default_factory=list)
170
174
 
171
175
  def to_dict(self) -> Dict[str, Any]:
172
176
  return {
@@ -1,19 +1,20 @@
1
1
  import json
2
2
  import time
3
- from typing import Any, Dict, List, Optional, TypedDict, Union
3
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
4
4
  from uuid import UUID
5
5
 
6
- from chainlit.context import context_var
7
- from chainlit.message import Message
8
- from chainlit.step import Step
9
- from langchain.callbacks.tracers.base import BaseTracer
6
+ import pydantic
10
7
  from langchain.callbacks.tracers.schemas import Run
11
8
  from langchain.schema import BaseMessage
12
- from langchain.schema.output import ChatGenerationChunk, GenerationChunk
13
9
  from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
10
+ from langchain_core.tracers.base import AsyncBaseTracer
14
11
  from literalai import ChatGeneration, CompletionGeneration, GenerationMessage
15
12
  from literalai.helper import utc_now
16
- from literalai.step import TrueStepType
13
+ from literalai.observability.step import TrueStepType
14
+
15
+ from chainlit.context import context_var
16
+ from chainlit.message import Message
17
+ from chainlit.step import Step
17
18
 
18
19
  DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"]
19
20
 
@@ -123,6 +124,14 @@ class GenerationHelper:
123
124
  key: self.ensure_values_serializable(value)
124
125
  for key, value in data.items()
125
126
  }
127
+ elif isinstance(data, pydantic.BaseModel):
128
+ # Fallback to support pydantic v1
129
+ # https://docs.pydantic.dev/latest/migration/#changes-to-pydanticbasemodel
130
+ if pydantic.VERSION.startswith("1"):
131
+ return data.dict()
132
+
133
+ # pydantic v2
134
+ return data.model_dump() # pyright: ignore reportAttributeAccessIssue
126
135
  elif isinstance(data, list):
127
136
  return [self.ensure_values_serializable(item) for item in data]
128
137
  elif isinstance(data, (str, int, float, bool, type(None))):
@@ -229,11 +238,28 @@ class GenerationHelper:
229
238
  return provider, model, tools, settings
230
239
 
231
240
 
232
- DEFAULT_TO_IGNORE = ["Runnable", "<lambda>"]
241
+ def process_content(content: Any) -> Tuple[Dict, Optional[str]]:
242
+ if content is None:
243
+ return {}, None
244
+ if isinstance(content, dict):
245
+ return content, "json"
246
+ elif isinstance(content, str):
247
+ return {"content": content}, "text"
248
+ else:
249
+ return {"content": str(content)}, "text"
250
+
251
+
252
+ DEFAULT_TO_IGNORE = [
253
+ "RunnableSequence",
254
+ "RunnableParallel",
255
+ "RunnableAssign",
256
+ "RunnableLambda",
257
+ "<lambda>",
258
+ ]
233
259
  DEFAULT_TO_KEEP = ["retriever", "llm", "agent", "chain", "tool"]
234
260
 
235
261
 
236
- class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
262
+ class LangchainTracer(AsyncBaseTracer, GenerationHelper, FinalStreamHelper):
237
263
  steps: Dict[str, Step]
238
264
  parent_id_map: Dict[str, str]
239
265
  ignored_runs: set
@@ -252,7 +278,7 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
252
278
  to_keep: Optional[List[str]] = None,
253
279
  **kwargs: Any,
254
280
  ) -> None:
255
- BaseTracer.__init__(self, **kwargs)
281
+ AsyncBaseTracer.__init__(self, **kwargs)
256
282
  GenerationHelper.__init__(self)
257
283
  FinalStreamHelper.__init__(
258
284
  self,
@@ -267,8 +293,6 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
267
293
 
268
294
  if self.context.current_step:
269
295
  self.root_parent_id = self.context.current_step.id
270
- elif self.context.session.root_message:
271
- self.root_parent_id = self.context.session.root_message.id
272
296
  else:
273
297
  self.root_parent_id = None
274
298
 
@@ -282,7 +306,7 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
282
306
  else:
283
307
  self.to_keep = to_keep
284
308
 
285
- def on_chat_model_start(
309
+ async def on_chat_model_start(
286
310
  self,
287
311
  serialized: Dict[str, Any],
288
312
  messages: List[List[BaseMessage]],
@@ -291,8 +315,9 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
291
315
  parent_run_id: Optional["UUID"] = None,
292
316
  tags: Optional[List[str]] = None,
293
317
  metadata: Optional[Dict[str, Any]] = None,
318
+ name: Optional[str] = None,
294
319
  **kwargs: Any,
295
- ) -> Any:
320
+ ) -> Run:
296
321
  lc_messages = messages[0]
297
322
  self.chat_generations[str(run_id)] = {
298
323
  "input_messages": lc_messages,
@@ -301,46 +326,48 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
301
326
  "tt_first_token": None,
302
327
  }
303
328
 
304
- return super().on_chat_model_start(
329
+ return await super().on_chat_model_start(
305
330
  serialized,
306
331
  messages,
307
332
  run_id=run_id,
308
333
  parent_run_id=parent_run_id,
309
334
  tags=tags,
310
335
  metadata=metadata,
336
+ name=name,
311
337
  **kwargs,
312
338
  )
313
339
 
314
- def on_llm_start(
340
+ async def on_llm_start(
315
341
  self,
316
342
  serialized: Dict[str, Any],
317
343
  prompts: List[str],
318
344
  *,
319
345
  run_id: "UUID",
346
+ parent_run_id: Optional[UUID] = None,
320
347
  tags: Optional[List[str]] = None,
321
- parent_run_id: Optional["UUID"] = None,
322
348
  metadata: Optional[Dict[str, Any]] = None,
323
- name: Optional[str] = None,
324
349
  **kwargs: Any,
325
- ) -> Run:
326
- self.completion_generations[str(run_id)] = {
327
- "prompt": prompts[0],
328
- "start": time.time(),
329
- "token_count": 0,
330
- "tt_first_token": None,
331
- }
332
- return super().on_llm_start(
350
+ ) -> None:
351
+ await super().on_llm_start(
333
352
  serialized,
334
353
  prompts,
335
354
  run_id=run_id,
336
355
  parent_run_id=parent_run_id,
337
356
  tags=tags,
338
357
  metadata=metadata,
339
- name=name,
340
358
  **kwargs,
341
359
  )
342
360
 
343
- def on_llm_new_token(
361
+ self.completion_generations[str(run_id)] = {
362
+ "prompt": prompts[0],
363
+ "start": time.time(),
364
+ "token_count": 0,
365
+ "tt_first_token": None,
366
+ }
367
+
368
+ return None
369
+
370
+ async def on_llm_new_token(
344
371
  self,
345
372
  token: str,
346
373
  *,
@@ -348,7 +375,14 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
348
375
  run_id: "UUID",
349
376
  parent_run_id: Optional["UUID"] = None,
350
377
  **kwargs: Any,
351
- ) -> Run:
378
+ ) -> None:
379
+ await super().on_llm_new_token(
380
+ token=token,
381
+ chunk=chunk,
382
+ run_id=run_id,
383
+ parent_run_id=parent_run_id,
384
+ **kwargs,
385
+ )
352
386
  if isinstance(chunk, ChatGenerationChunk):
353
387
  start = self.chat_generations[str(run_id)]
354
388
  else:
@@ -363,24 +397,13 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
363
397
  if self.answer_reached:
364
398
  if not self.final_stream:
365
399
  self.final_stream = Message(content="")
366
- self._run_sync(self.final_stream.send())
367
- self._run_sync(self.final_stream.stream_token(token))
400
+ await self.final_stream.send()
401
+ await self.final_stream.stream_token(token)
368
402
  self.has_streamed_final_answer = True
369
403
  else:
370
404
  self.answer_reached = self._check_if_answer_reached()
371
405
 
372
- return super().on_llm_new_token(
373
- token,
374
- chunk=chunk,
375
- run_id=run_id,
376
- parent_run_id=parent_run_id,
377
- )
378
-
379
- def _run_sync(self, co): # TODO: WHAT TO DO WITH THIS?
380
- context_var.set(self.context)
381
- self.context.loop.create_task(co)
382
-
383
- def _persist_run(self, run: Run) -> None:
406
+ async def _persist_run(self, run: Run) -> None:
384
407
  pass
385
408
 
386
409
  def _get_run_parent_id(self, run: Run):
@@ -431,11 +454,8 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
431
454
  self.ignored_runs.add(str(run.id))
432
455
  return ignore, parent_id
433
456
 
434
- def _is_annotable(self, run: Run):
435
- return run.run_type in ["retriever", "llm"]
436
-
437
- def _start_trace(self, run: Run) -> None:
438
- super()._start_trace(run)
457
+ async def _start_trace(self, run: Run) -> None:
458
+ await super()._start_trace(run)
439
459
  context_var.set(self.context)
440
460
 
441
461
  ignore, parent_id = self._should_ignore_run(run)
@@ -448,40 +468,39 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
448
468
  if ignore:
449
469
  return
450
470
 
451
- step_type: "TrueStepType" = "undefined"
471
+ step_type: TrueStepType = "undefined"
452
472
  if run.run_type == "agent":
453
473
  step_type = "run"
454
474
  elif run.run_type == "chain":
455
- pass
475
+ if not self.steps:
476
+ step_type = "run"
456
477
  elif run.run_type == "llm":
457
478
  step_type = "llm"
458
479
  elif run.run_type == "retriever":
459
- step_type = "retrieval"
480
+ step_type = "tool"
460
481
  elif run.run_type == "tool":
461
482
  step_type = "tool"
462
483
  elif run.run_type == "embedding":
463
484
  step_type = "embedding"
464
485
 
465
- if not self.steps:
466
- step_type = "run"
467
-
468
- disable_feedback = not self._is_annotable(run)
469
-
470
486
  step = Step(
471
487
  id=str(run.id),
472
488
  name=run.name,
473
489
  type=step_type,
474
490
  parent_id=parent_id,
475
- disable_feedback=disable_feedback,
476
491
  )
477
492
  step.start = utc_now()
478
- step.input = run.inputs
493
+ step.input, language = process_content(run.inputs)
494
+ if language is not None:
495
+ if step.metadata is None:
496
+ step.metadata = {}
497
+ step.metadata["language"] = language
479
498
 
480
499
  self.steps[str(run.id)] = step
481
500
 
482
- self._run_sync(step.send())
501
+ await step.send()
483
502
 
484
- def _on_run_update(self, run: Run) -> None:
503
+ async def _on_run_update(self, run: Run) -> None:
485
504
  """Process a run upon update."""
486
505
  context_var.set(self.context)
487
506
 
@@ -499,6 +518,9 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
499
518
  generations = (run.outputs or {}).get("generations", [])
500
519
  generation = generations[0][0]
501
520
  variables = self.generation_inputs.get(str(run.parent_run_id), {})
521
+ variables = {
522
+ k: process_content(v) for k, v in variables.items() if v is not None
523
+ }
502
524
  if message := generation.get("message"):
503
525
  chat_start = self.chat_generations[str(run.id)]
504
526
  duration = time.time() - chat_start["start"]
@@ -529,11 +551,17 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
529
551
  "prompt_id"
530
552
  ]
531
553
  if custom_variables := m.additional_kwargs.get("variables"):
532
- current_step.generation.variables = custom_variables
554
+ current_step.generation.variables = {
555
+ k: process_content(v)
556
+ for k, v in custom_variables.items()
557
+ if v is not None
558
+ }
533
559
  break
534
560
 
535
561
  current_step.language = "json"
536
- current_step.output = json.dumps(message_completion)
562
+ current_step.output = json.dumps(
563
+ message_completion, indent=4, ensure_ascii=False
564
+ )
537
565
  else:
538
566
  completion_start = self.completion_generations[str(run.id)]
539
567
  completion = generation.get("text", "")
@@ -557,32 +585,39 @@ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
557
585
 
558
586
  if current_step:
559
587
  current_step.end = utc_now()
560
- self._run_sync(current_step.update())
588
+ await current_step.update()
561
589
 
562
590
  if self.final_stream and self.has_streamed_final_answer:
563
- self._run_sync(self.final_stream.update())
591
+ await self.final_stream.update()
564
592
 
565
593
  return
566
594
 
567
595
  outputs = run.outputs or {}
568
596
  output_keys = list(outputs.keys())
569
597
  output = outputs
598
+
570
599
  if output_keys:
571
600
  output = outputs.get(output_keys[0], outputs)
572
601
 
573
602
  if current_step:
574
- current_step.output = output
603
+ current_step.output = (
604
+ output[0]
605
+ if isinstance(output, Sequence)
606
+ and not isinstance(output, str)
607
+ and len(output)
608
+ else output
609
+ )
575
610
  current_step.end = utc_now()
576
- self._run_sync(current_step.update())
611
+ await current_step.update()
577
612
 
578
- def _on_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any):
613
+ async def _on_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any):
579
614
  context_var.set(self.context)
580
615
 
581
616
  if current_step := self.steps.get(str(run_id), None):
582
617
  current_step.is_error = True
583
618
  current_step.output = str(error)
584
619
  current_step.end = utc_now()
585
- self._run_sync(current_step.update())
620
+ await current_step.update()
586
621
 
587
622
  on_llm_error = _on_error
588
623
  on_chain_error = _on_error
@@ -8,6 +8,7 @@ if not check_module_version("langflow", "0.1.4"):
8
8
  from typing import Dict, Optional, Union
9
9
 
10
10
  import httpx
11
+
11
12
  from chainlit.telemetry import trace_event
12
13
 
13
14
 
@@ -1,20 +1,21 @@
1
1
  from typing import Any, Dict, List, Optional
2
2
 
3
- from chainlit.context import context_var
4
- from chainlit.element import Text
5
- from chainlit.step import Step, StepType
6
3
  from literalai import ChatGeneration, CompletionGeneration, GenerationMessage
7
4
  from literalai.helper import utc_now
8
5
  from llama_index.core.callbacks import TokenCountingHandler
9
6
  from llama_index.core.callbacks.schema import CBEventType, EventPayload
10
7
  from llama_index.core.llms import ChatMessage, ChatResponse, CompletionResponse
8
+ from llama_index.core.tools.types import ToolMetadata
9
+
10
+ from chainlit.context import context_var
11
+ from chainlit.element import Text
12
+ from chainlit.step import Step, StepType
11
13
 
12
14
  DEFAULT_IGNORE = [
13
15
  CBEventType.CHUNKING,
14
16
  CBEventType.SYNTHESIZE,
15
17
  CBEventType.EMBEDDING,
16
18
  CBEventType.NODE_PARSING,
17
- CBEventType.QUERY,
18
19
  CBEventType.TREE,
19
20
  ]
20
21
 
@@ -34,33 +35,17 @@ class LlamaIndexCallbackHandler(TokenCountingHandler):
34
35
  event_starts_to_ignore=event_starts_to_ignore,
35
36
  event_ends_to_ignore=event_ends_to_ignore,
36
37
  )
37
- self.context = context_var.get()
38
38
 
39
39
  self.steps = {}
40
40
 
41
41
  def _get_parent_id(self, event_parent_id: Optional[str] = None) -> Optional[str]:
42
42
  if event_parent_id and event_parent_id in self.steps:
43
43
  return event_parent_id
44
- elif self.context.current_step:
45
- return self.context.current_step.id
46
- elif self.context.session.root_message:
47
- return self.context.session.root_message.id
44
+ elif context_var.get().current_step:
45
+ return context_var.get().current_step.id
48
46
  else:
49
47
  return None
50
48
 
51
- def _restore_context(self) -> None:
52
- """Restore Chainlit context in the current thread
53
-
54
- Chainlit context is local to the main thread, and LlamaIndex
55
- runs the callbacks in its own threads, so they don't have a
56
- Chainlit context by default.
57
-
58
- This method restores the context in which the callback handler
59
- has been created (it's always created in the main thread), so
60
- that we can actually send messages.
61
- """
62
- context_var.set(self.context)
63
-
64
49
  def on_event_start(
65
50
  self,
66
51
  event_type: CBEventType,
@@ -70,26 +55,36 @@ class LlamaIndexCallbackHandler(TokenCountingHandler):
70
55
  **kwargs: Any,
71
56
  ) -> str:
72
57
  """Run when an event starts and return id of event."""
73
- self._restore_context()
74
58
  step_type: StepType = "undefined"
75
- if event_type == CBEventType.RETRIEVE:
76
- step_type = "retrieval"
59
+ step_name: str = event_type.value
60
+ step_input: Optional[Dict[str, Any]] = payload
61
+ if event_type == CBEventType.FUNCTION_CALL:
62
+ step_type = "tool"
63
+ if payload:
64
+ metadata: Optional[ToolMetadata] = payload.get(EventPayload.TOOL)
65
+ if metadata:
66
+ step_name = getattr(metadata, "name", step_name)
67
+ step_input = payload.get(EventPayload.FUNCTION_CALL)
68
+ elif event_type == CBEventType.RETRIEVE:
69
+ step_type = "tool"
70
+ elif event_type == CBEventType.QUERY:
71
+ step_type = "tool"
77
72
  elif event_type == CBEventType.LLM:
78
73
  step_type = "llm"
79
74
  else:
80
75
  return event_id
81
76
 
82
77
  step = Step(
83
- name=event_type.value,
78
+ name=step_name,
84
79
  type=step_type,
85
80
  parent_id=self._get_parent_id(parent_id),
86
81
  id=event_id,
87
- disable_feedback=False,
88
82
  )
83
+
89
84
  self.steps[event_id] = step
90
85
  step.start = utc_now()
91
- step.input = payload or {}
92
- self.context.loop.create_task(step.send())
86
+ step.input = step_input or {}
87
+ context_var.get().loop.create_task(step.send())
93
88
  return event_id
94
89
 
95
90
  def on_event_end(
@@ -105,37 +100,59 @@ class LlamaIndexCallbackHandler(TokenCountingHandler):
105
100
  if payload is None or step is None:
106
101
  return
107
102
 
108
- self._restore_context()
109
-
110
103
  step.end = utc_now()
111
104
 
112
- if event_type == CBEventType.RETRIEVE:
105
+ if event_type == CBEventType.FUNCTION_CALL:
106
+ response = payload.get(EventPayload.FUNCTION_OUTPUT)
107
+ if response:
108
+ step.output = f"{response}"
109
+ context_var.get().loop.create_task(step.update())
110
+
111
+ elif event_type == CBEventType.QUERY:
112
+ response = payload.get(EventPayload.RESPONSE)
113
+ source_nodes = getattr(response, "source_nodes", None)
114
+ if source_nodes:
115
+ source_refs = ", ".join(
116
+ [f"Source {idx}" for idx, _ in enumerate(source_nodes)]
117
+ )
118
+ step.elements = [
119
+ Text(
120
+ name=f"Source {idx}",
121
+ content=source.text or "Empty node",
122
+ display="side",
123
+ )
124
+ for idx, source in enumerate(source_nodes)
125
+ ]
126
+ step.output = f"Retrieved the following sources: {source_refs}"
127
+ context_var.get().loop.create_task(step.update())
128
+
129
+ elif event_type == CBEventType.RETRIEVE:
113
130
  sources = payload.get(EventPayload.NODES)
114
131
  if sources:
115
- source_refs = "\, ".join(
132
+ source_refs = ", ".join(
116
133
  [f"Source {idx}" for idx, _ in enumerate(sources)]
117
134
  )
118
135
  step.elements = [
119
136
  Text(
120
137
  name=f"Source {idx}",
138
+ display="side",
121
139
  content=source.node.get_text() or "Empty node",
122
140
  )
123
141
  for idx, source in enumerate(sources)
124
142
  ]
125
143
  step.output = f"Retrieved the following sources: {source_refs}"
126
- self.context.loop.create_task(step.update())
144
+ context_var.get().loop.create_task(step.update())
127
145
 
128
- if event_type == CBEventType.LLM:
129
- formatted_messages = payload.get(
130
- EventPayload.MESSAGES
131
- ) # type: Optional[List[ChatMessage]]
146
+ elif event_type == CBEventType.LLM:
147
+ formatted_messages = payload.get(EventPayload.MESSAGES) # type: Optional[List[ChatMessage]]
132
148
  formatted_prompt = payload.get(EventPayload.PROMPT)
133
149
  response = payload.get(EventPayload.RESPONSE)
134
150
 
135
151
  if formatted_messages:
136
152
  messages = [
137
153
  GenerationMessage(
138
- role=m.role.value, content=m.content or "" # type: ignore
154
+ role=m.role.value, # type: ignore
155
+ content=m.content or "",
139
156
  )
140
157
  for m in formatted_messages
141
158
  ]
@@ -152,10 +169,13 @@ class LlamaIndexCallbackHandler(TokenCountingHandler):
152
169
  step.output = content
153
170
 
154
171
  token_count = self.total_llm_token_count or None
172
+ raw_response = response.raw if response else None
173
+ model = getattr(raw_response, "model", None)
155
174
 
156
175
  if messages and isinstance(response, ChatResponse):
157
176
  msg: ChatMessage = response.message
158
177
  step.generation = ChatGeneration(
178
+ model=model,
159
179
  messages=messages,
160
180
  message_completion=GenerationMessage(
161
181
  role=msg.role.value, # type: ignore
@@ -165,12 +185,17 @@ class LlamaIndexCallbackHandler(TokenCountingHandler):
165
185
  )
166
186
  elif formatted_prompt:
167
187
  step.generation = CompletionGeneration(
188
+ model=model,
168
189
  prompt=formatted_prompt,
169
190
  completion=content,
170
191
  token_count=token_count,
171
192
  )
172
193
 
173
- self.context.loop.create_task(step.update())
194
+ context_var.get().loop.create_task(step.update())
195
+
196
+ else:
197
+ step.output = payload
198
+ context_var.get().loop.create_task(step.update())
174
199
 
175
200
  self.steps.pop(event_id, None)
176
201