chainlit 0.7.700__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

Files changed (38) hide show
  1. chainlit/__init__.py +32 -23
  2. chainlit/auth.py +9 -10
  3. chainlit/cli/__init__.py +1 -2
  4. chainlit/config.py +13 -12
  5. chainlit/context.py +7 -3
  6. chainlit/data/__init__.py +375 -9
  7. chainlit/data/acl.py +6 -5
  8. chainlit/element.py +86 -123
  9. chainlit/emitter.py +117 -50
  10. chainlit/frontend/dist/assets/{index-71698725.js → index-6aee009a.js} +118 -292
  11. chainlit/frontend/dist/assets/{react-plotly-2c0acdf0.js → react-plotly-2f07c02a.js} +1 -1
  12. chainlit/frontend/dist/index.html +1 -1
  13. chainlit/haystack/callbacks.py +45 -43
  14. chainlit/hello.py +1 -1
  15. chainlit/langchain/callbacks.py +132 -120
  16. chainlit/llama_index/callbacks.py +68 -48
  17. chainlit/message.py +179 -207
  18. chainlit/oauth_providers.py +39 -34
  19. chainlit/playground/provider.py +44 -30
  20. chainlit/playground/providers/anthropic.py +4 -4
  21. chainlit/playground/providers/huggingface.py +2 -2
  22. chainlit/playground/providers/langchain.py +8 -10
  23. chainlit/playground/providers/openai.py +19 -13
  24. chainlit/server.py +155 -99
  25. chainlit/session.py +109 -40
  26. chainlit/socket.py +47 -36
  27. chainlit/step.py +393 -0
  28. chainlit/types.py +78 -21
  29. chainlit/user.py +32 -0
  30. chainlit/user_session.py +1 -5
  31. {chainlit-0.7.700.dist-info → chainlit-1.0.0rc0.dist-info}/METADATA +12 -31
  32. chainlit-1.0.0rc0.dist-info/RECORD +60 -0
  33. chainlit/client/base.py +0 -169
  34. chainlit/client/cloud.py +0 -502
  35. chainlit/prompt.py +0 -40
  36. chainlit-0.7.700.dist-info/RECORD +0 -61
  37. {chainlit-0.7.700.dist-info → chainlit-1.0.0rc0.dist-info}/WHEEL +0 -0
  38. {chainlit-0.7.700.dist-info → chainlit-1.0.0rc0.dist-info}/entry_points.txt +0 -0
@@ -20,7 +20,7 @@
20
20
  <script>
21
21
  const global = globalThis;
22
22
  </script>
23
- <script type="module" crossorigin src="/assets/index-71698725.js"></script>
23
+ <script type="module" crossorigin src="/assets/index-6aee009a.js"></script>
24
24
  <link rel="stylesheet" href="/assets/index-8942cb2d.css">
25
25
  </head>
26
26
  <body>
@@ -1,12 +1,12 @@
1
+ from datetime import datetime
1
2
  from typing import Any, Generic, List, Optional, TypeVar
2
3
 
3
- from chainlit.config import config
4
4
  from chainlit.context import context
5
+ from chainlit.step import Step
6
+ from chainlit.sync import run_sync
5
7
  from haystack.agents import Agent, Tool
6
8
  from haystack.agents.agent_step import AgentStep
7
9
 
8
- import chainlit as cl
9
-
10
10
  T = TypeVar("T")
11
11
 
12
12
 
@@ -31,8 +31,8 @@ class Stack(Generic[T]):
31
31
 
32
32
 
33
33
  class HaystackAgentCallbackHandler:
34
- stack: Stack[cl.Message]
35
- latest_agent_message: Optional[cl.Message]
34
+ stack: Stack[Step]
35
+ last_step: Optional[Step]
36
36
 
37
37
  def __init__(self, agent: Agent):
38
38
  agent.callback_manager.on_agent_start += self.on_agent_start
@@ -44,55 +44,53 @@ class HaystackAgentCallbackHandler:
44
44
  agent.tm.callback_manager.on_tool_finish += self.on_tool_finish
45
45
  agent.tm.callback_manager.on_tool_error += self.on_tool_error
46
46
 
47
- def get_root_message(self):
48
- if not context.session.root_message:
49
- root_message = cl.Message(author=config.ui.name, content="")
50
- cl.run_sync(root_message.send())
51
-
52
- return context.session.root_message
53
-
54
47
  def on_agent_start(self, **kwargs: Any) -> None:
55
48
  # Prepare agent step message for streaming
56
49
  self.agent_name = kwargs.get("name", "Agent")
57
- self.stack = Stack[cl.Message]()
58
- self.stack.push(self.get_root_message())
50
+ self.stack = Stack[Step]()
51
+ root_message = context.session.root_message
52
+ parent_id = root_message.id if root_message else None
53
+ run_step = Step(name=self.agent_name, type="run", parent_id=parent_id)
54
+ run_step.start = datetime.utcnow().isoformat()
55
+ run_step.input = kwargs
59
56
 
60
- agent_message = cl.Message(
61
- author=self.agent_name, parent_id=self.stack.peek().id, content=""
62
- )
63
- self.stack.push(agent_message)
57
+ run_sync(run_step.send())
58
+
59
+ self.stack.push(run_step)
60
+
61
+ def on_agent_finish(self, agent_step: AgentStep, **kwargs: Any) -> None:
62
+ run_step = self.stack.pop()
63
+ run_step.end = datetime.utcnow().isoformat()
64
+ run_step.output = agent_step.prompt_node_response
65
+ run_sync(run_step.update())
64
66
 
65
67
  # This method is called when a step has finished
66
68
  def on_agent_step(self, agent_step: AgentStep, **kwargs: Any) -> None:
67
69
  # Send previous agent step message
68
- self.latest_agent_message = self.stack.pop()
70
+ self.last_step = self.stack.pop()
69
71
 
70
72
  # If token streaming is disabled
71
- if self.latest_agent_message.content == "":
72
- self.latest_agent_message.content = agent_step.prompt_node_response
73
-
74
- cl.run_sync(self.latest_agent_message.send())
73
+ if self.last_step.output == "":
74
+ self.last_step.output = agent_step.prompt_node_response
75
+ self.last_step.end = datetime.utcnow().isoformat()
76
+ run_sync(self.last_step.update())
75
77
 
76
78
  if not agent_step.is_last():
77
- # Prepare message for next agent step
78
- agent_message = cl.Message(
79
- author=self.agent_name, parent_id=self.stack.peek().id, content=""
80
- )
81
- self.stack.push(agent_message)
82
-
83
- def on_agent_finish(self, agent_step: AgentStep, **kwargs: Any) -> None:
84
- self.latest_agent_message = None
85
- self.stack.clear()
79
+ # Prepare step for next agent step
80
+ step = Step(name=self.agent_name, parent_id=self.stack.peek().id)
81
+ self.stack.push(step)
86
82
 
87
83
  def on_new_token(self, token, **kwargs: Any) -> None:
88
84
  # Stream agent step tokens
89
- cl.run_sync(self.stack.peek().stream_token(token))
85
+ run_sync(self.stack.peek().stream_token(token))
90
86
 
91
87
  def on_tool_start(self, tool_input: str, tool: Tool, **kwargs: Any) -> None:
92
- # Tool started, create message
93
- parent_id = self.latest_agent_message.id if self.latest_agent_message else None
94
- tool_message = cl.Message(author=tool.name, parent_id=parent_id, content="")
95
- self.stack.push(tool_message)
88
+ # Tool started, create step
89
+ parent_id = self.stack.items[0].id if self.stack.items[0] else None
90
+ tool_step = Step(name=tool.name, type="tool", parent_id=parent_id)
91
+ tool_step.input = tool_input
92
+ tool_step.start = datetime.utcnow().isoformat()
93
+ self.stack.push(tool_step)
96
94
 
97
95
  def on_tool_finish(
98
96
  self,
@@ -101,12 +99,16 @@ class HaystackAgentCallbackHandler:
101
99
  tool_input: Optional[str] = None,
102
100
  **kwargs: Any
103
101
  ) -> None:
104
- # Tool finished, send message with tool_result
105
- tool_message = self.stack.pop()
106
- tool_message.content = tool_result
107
- cl.run_sync(tool_message.send())
102
+ # Tool finished, send step with tool_result
103
+ tool_step = self.stack.pop()
104
+ tool_step.output = tool_result
105
+ tool_step.end = datetime.utcnow().isoformat()
106
+ run_sync(tool_step.update())
108
107
 
109
108
  def on_tool_error(self, exception: Exception, tool: Tool, **kwargs: Any) -> None:
110
109
  # Tool error, send error message
111
- cl.run_sync(self.stack.pop().remove())
112
- cl.run_sync(cl.ErrorMessage(str(exception), author=tool.name).send())
110
+ error_step = self.stack.pop()
111
+ error_step.is_error = True
112
+ error_step.output = str(exception)
113
+ error_step.end = datetime.utcnow().isoformat()
114
+ run_sync(error_step.update())
chainlit/hello.py CHANGED
@@ -8,5 +8,5 @@ async def main():
8
8
  res = await AskUserMessage(content="What is your name?", timeout=30).send()
9
9
  if res:
10
10
  await Message(
11
- content=f"Your name is: {res['content']}.\nChainlit installation is working!\nYou can now start building your own chainlit apps!",
11
+ content=f"Your name is: {res['output']}.\nChainlit installation is working!\nYou can now start building your own chainlit apps!",
12
12
  ).send()
@@ -1,13 +1,15 @@
1
+ from datetime import datetime
1
2
  from typing import Any, Dict, List, Optional, Union
2
3
  from uuid import UUID
3
4
 
4
5
  from chainlit.context import context_var
5
6
  from chainlit.message import Message
6
7
  from chainlit.playground.providers.openai import stringify_function_call
7
- from chainlit.prompt import Prompt, PromptMessage
8
+ from chainlit.step import Step, TrueStepType
9
+ from chainlit_client import ChatGeneration, CompletionGeneration, GenerationMessage
8
10
  from langchain.callbacks.tracers.base import BaseTracer
9
11
  from langchain.callbacks.tracers.schemas import Run
10
- from langchain.schema.messages import BaseMessage
12
+ from langchain.schema import BaseMessage
11
13
  from langchain.schema.output import ChatGenerationChunk, GenerationChunk
12
14
 
13
15
  DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"]
@@ -85,15 +87,15 @@ class FinalStreamHelper:
85
87
  self.last_tokens_stripped.pop(0)
86
88
 
87
89
 
88
- class PromptHelper:
89
- prompt_sequence: List[Prompt]
90
+ class GenerationHelper:
91
+ generation_sequence: List[Union[ChatGeneration, CompletionGeneration]]
90
92
 
91
93
  def __init__(self) -> None:
92
- self.prompt_sequence = []
94
+ self.generation_sequence = []
93
95
 
94
96
  @property
95
- def current_prompt(self):
96
- return self.prompt_sequence[-1] if self.prompt_sequence else None
97
+ def current_generation(self):
98
+ return self.generation_sequence[-1] if self.generation_sequence else None
97
99
 
98
100
  def _convert_message_role(self, role: str):
99
101
  if "human" in role.lower():
@@ -109,7 +111,7 @@ class PromptHelper:
109
111
  self,
110
112
  message: Dict,
111
113
  template: Optional[str] = None,
112
- template_format: Optional[str] = None,
114
+ template_format: str = "f-string",
113
115
  ):
114
116
  class_name = message["id"][-1]
115
117
  kwargs = message.get("kwargs", {})
@@ -118,7 +120,7 @@ class PromptHelper:
118
120
  content = stringify_function_call(function_call)
119
121
  else:
120
122
  content = kwargs.get("content", "")
121
- return PromptMessage(
123
+ return GenerationMessage(
122
124
  name=kwargs.get("name"),
123
125
  role=self._convert_message_role(class_name),
124
126
  template=template,
@@ -130,7 +132,7 @@ class PromptHelper:
130
132
  self,
131
133
  message: Union[Dict, BaseMessage],
132
134
  template: Optional[str] = None,
133
- template_format: Optional[str] = None,
135
+ template_format: str = "f-string",
134
136
  ):
135
137
  if isinstance(message, dict):
136
138
  return self._convert_message_dict(
@@ -141,7 +143,7 @@ class PromptHelper:
141
143
  content = stringify_function_call(function_call)
142
144
  else:
143
145
  content = message.content
144
- return PromptMessage(
146
+ return GenerationMessage(
145
147
  name=getattr(message, "name", None),
146
148
  role=self._convert_message_role(message.type),
147
149
  template=template,
@@ -165,16 +167,16 @@ class PromptHelper:
165
167
 
166
168
  return chain_messages
167
169
 
168
- def _build_prompt(self, serialized: Dict, inputs: Dict):
170
+ def _build_generation(self, serialized: Dict, inputs: Dict):
169
171
  messages = self._get_messages(serialized)
170
172
  if messages:
171
173
  # If prompt is chat, the formatted values will be added in on_chat_model_start
172
- self._build_chat_template_prompt(messages, inputs)
174
+ self._build_chat_template_generation(messages, inputs)
173
175
  else:
174
176
  # For completion prompt everything is done here
175
- self._build_completion_prompt(serialized, inputs)
177
+ self._build_completion_generation(serialized, inputs)
176
178
 
177
- def _build_completion_prompt(self, serialized: Dict, inputs: Dict):
179
+ def _build_completion_generation(self, serialized: Dict, inputs: Dict):
178
180
  if not serialized:
179
181
  return
180
182
  kwargs = serialized.get("kwargs", {})
@@ -185,15 +187,15 @@ class PromptHelper:
185
187
  if not template:
186
188
  return
187
189
 
188
- self.prompt_sequence.append(
189
- Prompt(
190
+ self.generation_sequence.append(
191
+ CompletionGeneration(
190
192
  template=template,
191
193
  template_format=template_format,
192
194
  inputs=stringified_inputs,
193
195
  )
194
196
  )
195
197
 
196
- def _build_default_prompt(
198
+ def _build_default_generation(
197
199
  self,
198
200
  run: Run,
199
201
  generation_type: str,
@@ -203,12 +205,12 @@ class PromptHelper:
203
205
  ):
204
206
  """Build a prompt once an LLM has been executed if no current prompt exists (without template)"""
205
207
  if "chat" in generation_type.lower():
206
- return Prompt(
208
+ return ChatGeneration(
207
209
  provider=provider,
208
210
  settings=llm_settings,
209
211
  completion=completion,
210
212
  messages=[
211
- PromptMessage(
213
+ GenerationMessage(
212
214
  formatted=formatted_prompt,
213
215
  role=self._convert_message_role(formatted_prompt.split(":")[0]),
214
216
  )
@@ -216,16 +218,16 @@ class PromptHelper:
216
218
  ],
217
219
  )
218
220
  else:
219
- return Prompt(
221
+ return CompletionGeneration(
220
222
  provider=provider,
221
223
  settings=llm_settings,
222
224
  completion=completion,
223
225
  formatted=run.inputs.get("prompts", [])[0],
224
226
  )
225
227
 
226
- def _build_chat_template_prompt(self, lc_messages: List[Dict], inputs: Dict):
227
- def build_template_messages() -> List[PromptMessage]:
228
- template_messages = [] # type: List[PromptMessage]
228
+ def _build_chat_template_generation(self, lc_messages: List[Dict], inputs: Dict):
229
+ def build_template_messages() -> List[GenerationMessage]:
230
+ template_messages = [] # type: List[GenerationMessage]
229
231
 
230
232
  if not lc_messages:
231
233
  return template_messages
@@ -249,13 +251,12 @@ class PromptHelper:
249
251
 
250
252
  if placeholder_size:
251
253
  template_messages += [
252
- PromptMessage(placeholder_size=placeholder_size)
254
+ GenerationMessage(placeholder_size=placeholder_size)
253
255
  ]
254
256
  else:
255
257
  template_messages += [
256
- PromptMessage(
258
+ GenerationMessage(
257
259
  template=template,
258
- template_format=template_format,
259
260
  role=self._convert_message_role(class_name),
260
261
  )
261
262
  ]
@@ -267,18 +268,18 @@ class PromptHelper:
267
268
  return
268
269
 
269
270
  stringified_inputs = {k: str(v) for (k, v) in inputs.items()}
270
- self.prompt_sequence.append(
271
- Prompt(messages=template_messages, inputs=stringified_inputs)
271
+ self.generation_sequence.append(
272
+ ChatGeneration(messages=template_messages, inputs=stringified_inputs)
272
273
  )
273
274
 
274
- def _build_chat_formatted_prompt(
275
+ def _build_chat_formatted_generation(
275
276
  self, lc_messages: Union[List[BaseMessage], List[dict]]
276
277
  ):
277
- if not self.current_prompt:
278
+ if not self.current_generation:
278
279
  return
279
280
 
280
- formatted_messages = [] # type: List[PromptMessage]
281
- if self.current_prompt.messages:
281
+ formatted_messages = [] # type: List[GenerationMessage]
282
+ if self.current_generation.messages:
282
283
  # This is needed to compute the correct message index to read
283
284
  placeholder_offset = 0
284
285
  # The final list of messages
@@ -286,7 +287,7 @@ class PromptHelper:
286
287
  # Looping the messages built in build_prompt
287
288
  # They only contain the template
288
289
  for template_index, template_message in enumerate(
289
- self.current_prompt.messages
290
+ self.current_generation.messages
290
291
  ):
291
292
  # If a message has a placeholder size, we need to replace it
292
293
  # With the N following messages, where N is the placeholder size
@@ -322,7 +323,7 @@ class PromptHelper:
322
323
  self._convert_message(lc_message) for lc_message in lc_messages
323
324
  ]
324
325
 
325
- self.current_prompt.messages = formatted_messages
326
+ self.current_generation.messages = formatted_messages
326
327
 
327
328
  def _build_llm_settings(
328
329
  self,
@@ -356,8 +357,8 @@ DEFAULT_TO_IGNORE = ["RunnableSequence", "RunnableParallel", "<lambda>"]
356
357
  DEFAULT_TO_KEEP = ["retriever", "llm", "agent", "chain", "tool"]
357
358
 
358
359
 
359
- class LangchainTracer(BaseTracer, PromptHelper, FinalStreamHelper):
360
- llm_stream_message: Dict[str, Message]
360
+ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
361
+ steps: Dict[str, Step]
361
362
  parent_id_map: Dict[str, str]
362
363
  ignored_runs: set
363
364
 
@@ -376,7 +377,7 @@ class LangchainTracer(BaseTracer, PromptHelper, FinalStreamHelper):
376
377
  **kwargs: Any,
377
378
  ) -> None:
378
379
  BaseTracer.__init__(self, **kwargs)
379
- PromptHelper.__init__(self)
380
+ GenerationHelper.__init__(self)
380
381
  FinalStreamHelper.__init__(
381
382
  self,
382
383
  answer_prefix_tokens=answer_prefix_tokens,
@@ -384,7 +385,7 @@ class LangchainTracer(BaseTracer, PromptHelper, FinalStreamHelper):
384
385
  force_stream_final_answer=force_stream_final_answer,
385
386
  )
386
387
  self.context = context_var.get()
387
- self.llm_stream_message = {}
388
+ self.steps = {}
388
389
  self.parent_id_map = {}
389
390
  self.ignored_runs = set()
390
391
  self.root_parent_id = (
@@ -420,33 +421,37 @@ class LangchainTracer(BaseTracer, PromptHelper, FinalStreamHelper):
420
421
  return self.root_parent_id
421
422
 
422
423
  if current_parent_id not in self.parent_id_map:
423
- return current_parent_id
424
+ return None
424
425
 
425
426
  while current_parent_id in self.parent_id_map:
426
- current_parent_id = self.parent_id_map[current_parent_id]
427
+ # If the parent id is in the ignored runs, we need to get the parent id of the ignored run
428
+ if current_parent_id in self.ignored_runs:
429
+ current_parent_id = self.parent_id_map[current_parent_id]
430
+ else:
431
+ return current_parent_id
427
432
 
428
- return current_parent_id
433
+ return self.root_parent_id
429
434
 
430
435
  def _should_ignore_run(self, run: Run):
431
436
  parent_id = self._get_run_parent_id(run)
432
437
 
438
+ if parent_id:
439
+ # Add the parent id of the ignored run in the mapping
440
+ # so we can re-attach a kept child to the right parent id
441
+ self.parent_id_map[str(run.id)] = parent_id
442
+
433
443
  ignore_by_name = run.name in self.to_ignore
434
444
  ignore_by_parent = parent_id in self.ignored_runs
435
445
 
436
446
  ignore = ignore_by_name or ignore_by_parent
437
447
 
438
- if ignore:
439
- if parent_id:
440
- # Add the parent id of the ignored run in the mapping
441
- # so we can re-attach a kept child to the right parent id
442
- self.parent_id_map[str(run.id)] = parent_id
443
- # Tag the run as ignored
444
- self.ignored_runs.add(str(run.id))
445
-
446
448
  # If the ignore cause is the parent being ignored, check if we should nonetheless keep the child
447
449
  if ignore_by_parent and not ignore_by_name and run.run_type in self.to_keep:
448
- return False, self._get_non_ignored_parent_id(str(run.id))
450
+ return False, self._get_non_ignored_parent_id(parent_id)
449
451
  else:
452
+ if ignore:
453
+ # Tag the run as ignored
454
+ self.ignored_runs.add(str(run.id))
450
455
  return ignore, parent_id
451
456
 
452
457
  def _is_annotable(self, run: Run):
@@ -477,12 +482,12 @@ class LangchainTracer(BaseTracer, PromptHelper, FinalStreamHelper):
477
482
  ) -> Any:
478
483
  """Adding formatted content and new message to the previously built template prompt"""
479
484
  lc_messages = messages[0]
480
- if not self.current_prompt:
481
- self.prompt_sequence.append(
482
- Prompt(messages=[self._convert_message(m) for m in lc_messages])
485
+ if not self.current_generation:
486
+ self.generation_sequence.append(
487
+ ChatGeneration(messages=[self._convert_message(m) for m in lc_messages])
483
488
  )
484
489
  else:
485
- self._build_chat_formatted_prompt(lc_messages)
490
+ self._build_chat_formatted_generation(lc_messages)
486
491
 
487
492
  super().on_chat_model_start(
488
493
  serialized,
@@ -503,7 +508,7 @@ class LangchainTracer(BaseTracer, PromptHelper, FinalStreamHelper):
503
508
  parent_run_id: Optional[UUID] = None,
504
509
  **kwargs: Any,
505
510
  ) -> Any:
506
- msg = self.llm_stream_message.get(str(run_id), None)
511
+ msg = self.steps.get(str(run_id), None)
507
512
  if msg:
508
513
  self._run_sync(msg.stream_token(token))
509
514
 
@@ -513,6 +518,7 @@ class LangchainTracer(BaseTracer, PromptHelper, FinalStreamHelper):
513
518
  if self.answer_reached:
514
519
  if not self.final_stream:
515
520
  self.final_stream = Message(content="")
521
+ self._run_sync(self.final_stream.send())
516
522
  self._run_sync(self.final_stream.stream_token(token))
517
523
  self.has_streamed_final_answer = True
518
524
  else:
@@ -533,36 +539,41 @@ class LangchainTracer(BaseTracer, PromptHelper, FinalStreamHelper):
533
539
 
534
540
  if run.run_type in ["chain", "prompt"]:
535
541
  # Prompt templates are contained in chains or prompts (lcel)
536
- self._build_prompt(run.serialized or {}, run.inputs)
542
+ self._build_generation(run.serialized or {}, run.inputs)
537
543
 
538
544
  ignore, parent_id = self._should_ignore_run(run)
539
545
 
540
546
  if ignore:
541
547
  return
542
548
 
543
- disable_human_feedback = not self._is_annotable(run)
549
+ step_type: TrueStepType = "undefined"
544
550
 
545
- if run.run_type == "llm":
546
- msg = Message(
547
- id=run.id,
548
- content="",
549
- author=run.name,
550
- parent_id=parent_id,
551
- disable_human_feedback=disable_human_feedback,
552
- )
553
- self.llm_stream_message[str(run.id)] = msg
554
- self._run_sync(msg.send())
555
- return
556
-
557
- self._run_sync(
558
- Message(
559
- id=run.id,
560
- content="",
561
- author=run.name,
562
- parent_id=parent_id,
563
- disable_human_feedback=disable_human_feedback,
564
- ).send()
551
+ if run.run_type in ["agent", "chain"]:
552
+ step_type = "run"
553
+ elif run.run_type == "llm":
554
+ step_type = "llm"
555
+ elif run.run_type == "retriever":
556
+ step_type = "retrieval"
557
+ elif run.run_type == "tool":
558
+ step_type = "tool"
559
+ elif run.run_type == "embedding":
560
+ step_type = "embedding"
561
+
562
+ disable_feedback = not self._is_annotable(run)
563
+
564
+ step = Step(
565
+ id=str(run.id),
566
+ name=run.name,
567
+ type=step_type,
568
+ parent_id=parent_id,
569
+ disable_feedback=disable_feedback,
565
570
  )
571
+ step.start = datetime.utcnow().isoformat()
572
+ step.input = run.inputs
573
+
574
+ self.steps[str(run.id)] = step
575
+
576
+ self._run_sync(step.send())
566
577
 
567
578
  def _on_run_update(self, run: Run) -> None:
568
579
  """Process a run upon update."""
@@ -573,74 +584,75 @@ class LangchainTracer(BaseTracer, PromptHelper, FinalStreamHelper):
573
584
  if ignore:
574
585
  return
575
586
 
576
- disable_human_feedback = not self._is_annotable(run)
587
+ current_step = self.steps.get(str(run.id), None)
577
588
 
578
589
  if run.run_type in ["chain"]:
579
- if self.prompt_sequence:
580
- self.prompt_sequence.pop()
590
+ if self.generation_sequence:
591
+ self.generation_sequence.pop()
581
592
 
582
593
  if run.run_type == "llm":
583
594
  provider, llm_settings = self._build_llm_settings(
584
595
  (run.serialized or {}), (run.extra or {}).get("invocation_params")
585
596
  )
586
597
  generations = (run.outputs or {}).get("generations", [])
598
+ llm_output = (run.outputs or {}).get("llm_output")
587
599
  completion, language = self._get_completion(generations[0][0])
588
- current_prompt = (
589
- self.prompt_sequence.pop() if self.prompt_sequence else None
600
+ current_generation = (
601
+ self.generation_sequence.pop() if self.generation_sequence else None
590
602
  )
591
603
 
592
- if current_prompt:
593
- current_prompt.provider = provider
594
- current_prompt.settings = llm_settings
595
- current_prompt.completion = completion
604
+ if current_generation:
605
+ current_generation.provider = provider
606
+ current_generation.settings = llm_settings
607
+ current_generation.completion = completion
596
608
  else:
597
609
  generation_type = generations[0][0].get("type", "")
598
- current_prompt = self._build_default_prompt(
610
+ current_generation = self._build_default_generation(
599
611
  run, generation_type, provider, llm_settings, completion
600
612
  )
601
613
 
602
- msg = self.llm_stream_message.get(str(run.id), None)
603
- if msg:
604
- msg.content = completion
605
- msg.language = language
606
- msg.prompt = current_prompt
607
- self._run_sync(msg.update())
614
+ if llm_output and current_generation:
615
+ token_count = llm_output.get("token_usage", {}).get("total_tokens")
616
+ current_generation.token_count = token_count
617
+
618
+ if current_step:
619
+ current_step.output = completion
620
+ current_step.language = language
621
+ current_step.end = datetime.utcnow().isoformat()
622
+ current_step.generation = current_generation
623
+ self._run_sync(current_step.update())
608
624
 
609
625
  if self.final_stream and self.has_streamed_final_answer:
610
626
  self.final_stream.content = completion
611
627
  self.final_stream.language = language
612
- self.final_stream.prompt = current_prompt
613
- self._run_sync(self.final_stream.send())
628
+ self._run_sync(self.final_stream.update())
629
+
614
630
  return
615
631
 
616
632
  outputs = run.outputs or {}
617
633
  output_keys = list(outputs.keys())
634
+ output = outputs
618
635
  if output_keys:
619
- content = outputs.get(output_keys[0], "")
620
- else:
621
- return
636
+ output = outputs.get(output_keys[0], outputs)
622
637
 
623
- if run.run_type in ["agent", "chain"]:
624
- pass
625
- # # Add the response of the chain/tool
626
- # self._run_sync(
627
- # Message(
628
- # content=content,
629
- # author=run.name,
630
- # parent_id=parent_id,
631
- # disable_human_feedback=disable_human_feedback,
632
- # ).send()
633
- # )
634
- else:
635
- self._run_sync(
636
- Message(
637
- id=run.id,
638
- content=content,
639
- author=run.name,
640
- parent_id=parent_id,
641
- disable_human_feedback=disable_human_feedback,
642
- ).update()
643
- )
638
+ if current_step:
639
+ current_step.output = output
640
+ current_step.end = datetime.utcnow().isoformat()
641
+ self._run_sync(current_step.update())
642
+
643
+ def _on_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any):
644
+ context_var.set(self.context)
645
+
646
+ if current_step := self.steps.get(str(run_id), None):
647
+ current_step.is_error = True
648
+ current_step.output = str(error)
649
+ current_step.end = datetime.utcnow().isoformat()
650
+ self._run_sync(current_step.update())
651
+
652
+ on_llm_error = _on_error
653
+ on_chain_error = _on_error
654
+ on_tool_error = _on_error
655
+ on_retriever_error = _on_error
644
656
 
645
657
 
646
658
  LangchainCallbackHandler = LangchainTracer