chainlit 0.7.604rc2__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

Files changed (40) hide show
  1. chainlit/__init__.py +32 -23
  2. chainlit/auth.py +9 -10
  3. chainlit/cache.py +3 -3
  4. chainlit/cli/__init__.py +12 -2
  5. chainlit/config.py +22 -13
  6. chainlit/context.py +7 -3
  7. chainlit/data/__init__.py +375 -9
  8. chainlit/data/acl.py +6 -5
  9. chainlit/element.py +86 -123
  10. chainlit/emitter.py +117 -50
  11. chainlit/frontend/dist/assets/index-6aee009a.js +697 -0
  12. chainlit/frontend/dist/assets/{react-plotly-16f7de12.js → react-plotly-2f07c02a.js} +1 -1
  13. chainlit/frontend/dist/index.html +1 -1
  14. chainlit/haystack/callbacks.py +45 -43
  15. chainlit/hello.py +1 -1
  16. chainlit/langchain/callbacks.py +135 -120
  17. chainlit/llama_index/callbacks.py +68 -48
  18. chainlit/message.py +179 -207
  19. chainlit/oauth_providers.py +39 -34
  20. chainlit/playground/provider.py +44 -30
  21. chainlit/playground/providers/anthropic.py +4 -4
  22. chainlit/playground/providers/huggingface.py +2 -2
  23. chainlit/playground/providers/langchain.py +8 -10
  24. chainlit/playground/providers/openai.py +19 -13
  25. chainlit/server.py +155 -99
  26. chainlit/session.py +109 -40
  27. chainlit/socket.py +54 -38
  28. chainlit/step.py +393 -0
  29. chainlit/types.py +78 -21
  30. chainlit/user.py +32 -0
  31. chainlit/user_session.py +1 -5
  32. {chainlit-0.7.604rc2.dist-info → chainlit-1.0.0rc0.dist-info}/METADATA +12 -31
  33. chainlit-1.0.0rc0.dist-info/RECORD +60 -0
  34. chainlit/client/base.py +0 -169
  35. chainlit/client/cloud.py +0 -500
  36. chainlit/frontend/dist/assets/index-c58dbd4b.js +0 -871
  37. chainlit/prompt.py +0 -40
  38. chainlit-0.7.604rc2.dist-info/RECORD +0 -61
  39. {chainlit-0.7.604rc2.dist-info → chainlit-1.0.0rc0.dist-info}/WHEEL +0 -0
  40. {chainlit-0.7.604rc2.dist-info → chainlit-1.0.0rc0.dist-info}/entry_points.txt +0 -0
@@ -20,7 +20,7 @@
20
20
  <script>
21
21
  const global = globalThis;
22
22
  </script>
23
- <script type="module" crossorigin src="/assets/index-c58dbd4b.js"></script>
23
+ <script type="module" crossorigin src="/assets/index-6aee009a.js"></script>
24
24
  <link rel="stylesheet" href="/assets/index-8942cb2d.css">
25
25
  </head>
26
26
  <body>
@@ -1,12 +1,12 @@
1
+ from datetime import datetime
1
2
  from typing import Any, Generic, List, Optional, TypeVar
2
3
 
3
- from chainlit.config import config
4
4
  from chainlit.context import context
5
+ from chainlit.step import Step
6
+ from chainlit.sync import run_sync
5
7
  from haystack.agents import Agent, Tool
6
8
  from haystack.agents.agent_step import AgentStep
7
9
 
8
- import chainlit as cl
9
-
10
10
  T = TypeVar("T")
11
11
 
12
12
 
@@ -31,8 +31,8 @@ class Stack(Generic[T]):
31
31
 
32
32
 
33
33
  class HaystackAgentCallbackHandler:
34
- stack: Stack[cl.Message]
35
- latest_agent_message: Optional[cl.Message]
34
+ stack: Stack[Step]
35
+ last_step: Optional[Step]
36
36
 
37
37
  def __init__(self, agent: Agent):
38
38
  agent.callback_manager.on_agent_start += self.on_agent_start
@@ -44,55 +44,53 @@ class HaystackAgentCallbackHandler:
44
44
  agent.tm.callback_manager.on_tool_finish += self.on_tool_finish
45
45
  agent.tm.callback_manager.on_tool_error += self.on_tool_error
46
46
 
47
- def get_root_message(self):
48
- if not context.session.root_message:
49
- root_message = cl.Message(author=config.ui.name, content="")
50
- cl.run_sync(root_message.send())
51
-
52
- return context.session.root_message
53
-
54
47
  def on_agent_start(self, **kwargs: Any) -> None:
55
48
  # Prepare agent step message for streaming
56
49
  self.agent_name = kwargs.get("name", "Agent")
57
- self.stack = Stack[cl.Message]()
58
- self.stack.push(self.get_root_message())
50
+ self.stack = Stack[Step]()
51
+ root_message = context.session.root_message
52
+ parent_id = root_message.id if root_message else None
53
+ run_step = Step(name=self.agent_name, type="run", parent_id=parent_id)
54
+ run_step.start = datetime.utcnow().isoformat()
55
+ run_step.input = kwargs
59
56
 
60
- agent_message = cl.Message(
61
- author=self.agent_name, parent_id=self.stack.peek().id, content=""
62
- )
63
- self.stack.push(agent_message)
57
+ run_sync(run_step.send())
58
+
59
+ self.stack.push(run_step)
60
+
61
+ def on_agent_finish(self, agent_step: AgentStep, **kwargs: Any) -> None:
62
+ run_step = self.stack.pop()
63
+ run_step.end = datetime.utcnow().isoformat()
64
+ run_step.output = agent_step.prompt_node_response
65
+ run_sync(run_step.update())
64
66
 
65
67
  # This method is called when a step has finished
66
68
  def on_agent_step(self, agent_step: AgentStep, **kwargs: Any) -> None:
67
69
  # Send previous agent step message
68
- self.latest_agent_message = self.stack.pop()
70
+ self.last_step = self.stack.pop()
69
71
 
70
72
  # If token streaming is disabled
71
- if self.latest_agent_message.content == "":
72
- self.latest_agent_message.content = agent_step.prompt_node_response
73
-
74
- cl.run_sync(self.latest_agent_message.send())
73
+ if self.last_step.output == "":
74
+ self.last_step.output = agent_step.prompt_node_response
75
+ self.last_step.end = datetime.utcnow().isoformat()
76
+ run_sync(self.last_step.update())
75
77
 
76
78
  if not agent_step.is_last():
77
- # Prepare message for next agent step
78
- agent_message = cl.Message(
79
- author=self.agent_name, parent_id=self.stack.peek().id, content=""
80
- )
81
- self.stack.push(agent_message)
82
-
83
- def on_agent_finish(self, agent_step: AgentStep, **kwargs: Any) -> None:
84
- self.latest_agent_message = None
85
- self.stack.clear()
79
+ # Prepare step for next agent step
80
+ step = Step(name=self.agent_name, parent_id=self.stack.peek().id)
81
+ self.stack.push(step)
86
82
 
87
83
  def on_new_token(self, token, **kwargs: Any) -> None:
88
84
  # Stream agent step tokens
89
- cl.run_sync(self.stack.peek().stream_token(token))
85
+ run_sync(self.stack.peek().stream_token(token))
90
86
 
91
87
  def on_tool_start(self, tool_input: str, tool: Tool, **kwargs: Any) -> None:
92
- # Tool started, create message
93
- parent_id = self.latest_agent_message.id if self.latest_agent_message else None
94
- tool_message = cl.Message(author=tool.name, parent_id=parent_id, content="")
95
- self.stack.push(tool_message)
88
+ # Tool started, create step
89
+ parent_id = self.stack.items[0].id if self.stack.items[0] else None
90
+ tool_step = Step(name=tool.name, type="tool", parent_id=parent_id)
91
+ tool_step.input = tool_input
92
+ tool_step.start = datetime.utcnow().isoformat()
93
+ self.stack.push(tool_step)
96
94
 
97
95
  def on_tool_finish(
98
96
  self,
@@ -101,12 +99,16 @@ class HaystackAgentCallbackHandler:
101
99
  tool_input: Optional[str] = None,
102
100
  **kwargs: Any
103
101
  ) -> None:
104
- # Tool finished, send message with tool_result
105
- tool_message = self.stack.pop()
106
- tool_message.content = tool_result
107
- cl.run_sync(tool_message.send())
102
+ # Tool finished, send step with tool_result
103
+ tool_step = self.stack.pop()
104
+ tool_step.output = tool_result
105
+ tool_step.end = datetime.utcnow().isoformat()
106
+ run_sync(tool_step.update())
108
107
 
109
108
  def on_tool_error(self, exception: Exception, tool: Tool, **kwargs: Any) -> None:
110
109
  # Tool error, send error message
111
- cl.run_sync(self.stack.pop().remove())
112
- cl.run_sync(cl.ErrorMessage(str(exception), author=tool.name).send())
110
+ error_step = self.stack.pop()
111
+ error_step.is_error = True
112
+ error_step.output = str(exception)
113
+ error_step.end = datetime.utcnow().isoformat()
114
+ run_sync(error_step.update())
chainlit/hello.py CHANGED
@@ -8,5 +8,5 @@ async def main():
8
8
  res = await AskUserMessage(content="What is your name?", timeout=30).send()
9
9
  if res:
10
10
  await Message(
11
- content=f"Your name is: {res['content']}.\nChainlit installation is working!\nYou can now start building your own chainlit apps!",
11
+ content=f"Your name is: {res['output']}.\nChainlit installation is working!\nYou can now start building your own chainlit apps!",
12
12
  ).send()
@@ -1,13 +1,15 @@
1
+ from datetime import datetime
1
2
  from typing import Any, Dict, List, Optional, Union
2
3
  from uuid import UUID
3
4
 
4
5
  from chainlit.context import context_var
5
6
  from chainlit.message import Message
6
7
  from chainlit.playground.providers.openai import stringify_function_call
7
- from chainlit.prompt import Prompt, PromptMessage
8
+ from chainlit.step import Step, TrueStepType
9
+ from chainlit_client import ChatGeneration, CompletionGeneration, GenerationMessage
8
10
  from langchain.callbacks.tracers.base import BaseTracer
9
11
  from langchain.callbacks.tracers.schemas import Run
10
- from langchain.schema.messages import BaseMessage
12
+ from langchain.schema import BaseMessage
11
13
  from langchain.schema.output import ChatGenerationChunk, GenerationChunk
12
14
 
13
15
  DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"]
@@ -85,15 +87,15 @@ class FinalStreamHelper:
85
87
  self.last_tokens_stripped.pop(0)
86
88
 
87
89
 
88
- class PromptHelper:
89
- prompt_sequence: List[Prompt]
90
+ class GenerationHelper:
91
+ generation_sequence: List[Union[ChatGeneration, CompletionGeneration]]
90
92
 
91
93
  def __init__(self) -> None:
92
- self.prompt_sequence = []
94
+ self.generation_sequence = []
93
95
 
94
96
  @property
95
- def current_prompt(self):
96
- return self.prompt_sequence[-1] if self.prompt_sequence else None
97
+ def current_generation(self):
98
+ return self.generation_sequence[-1] if self.generation_sequence else None
97
99
 
98
100
  def _convert_message_role(self, role: str):
99
101
  if "human" in role.lower():
@@ -109,7 +111,7 @@ class PromptHelper:
109
111
  self,
110
112
  message: Dict,
111
113
  template: Optional[str] = None,
112
- template_format: Optional[str] = None,
114
+ template_format: str = "f-string",
113
115
  ):
114
116
  class_name = message["id"][-1]
115
117
  kwargs = message.get("kwargs", {})
@@ -118,7 +120,7 @@ class PromptHelper:
118
120
  content = stringify_function_call(function_call)
119
121
  else:
120
122
  content = kwargs.get("content", "")
121
- return PromptMessage(
123
+ return GenerationMessage(
122
124
  name=kwargs.get("name"),
123
125
  role=self._convert_message_role(class_name),
124
126
  template=template,
@@ -130,7 +132,7 @@ class PromptHelper:
130
132
  self,
131
133
  message: Union[Dict, BaseMessage],
132
134
  template: Optional[str] = None,
133
- template_format: Optional[str] = None,
135
+ template_format: str = "f-string",
134
136
  ):
135
137
  if isinstance(message, dict):
136
138
  return self._convert_message_dict(
@@ -141,7 +143,7 @@ class PromptHelper:
141
143
  content = stringify_function_call(function_call)
142
144
  else:
143
145
  content = message.content
144
- return PromptMessage(
146
+ return GenerationMessage(
145
147
  name=getattr(message, "name", None),
146
148
  role=self._convert_message_role(message.type),
147
149
  template=template,
@@ -165,16 +167,16 @@ class PromptHelper:
165
167
 
166
168
  return chain_messages
167
169
 
168
- def _build_prompt(self, serialized: Dict, inputs: Dict):
170
+ def _build_generation(self, serialized: Dict, inputs: Dict):
169
171
  messages = self._get_messages(serialized)
170
172
  if messages:
171
173
  # If prompt is chat, the formatted values will be added in on_chat_model_start
172
- self._build_chat_template_prompt(messages, inputs)
174
+ self._build_chat_template_generation(messages, inputs)
173
175
  else:
174
176
  # For completion prompt everything is done here
175
- self._build_completion_prompt(serialized, inputs)
177
+ self._build_completion_generation(serialized, inputs)
176
178
 
177
- def _build_completion_prompt(self, serialized: Dict, inputs: Dict):
179
+ def _build_completion_generation(self, serialized: Dict, inputs: Dict):
178
180
  if not serialized:
179
181
  return
180
182
  kwargs = serialized.get("kwargs", {})
@@ -185,15 +187,15 @@ class PromptHelper:
185
187
  if not template:
186
188
  return
187
189
 
188
- self.prompt_sequence.append(
189
- Prompt(
190
+ self.generation_sequence.append(
191
+ CompletionGeneration(
190
192
  template=template,
191
193
  template_format=template_format,
192
194
  inputs=stringified_inputs,
193
195
  )
194
196
  )
195
197
 
196
- def _build_default_prompt(
198
+ def _build_default_generation(
197
199
  self,
198
200
  run: Run,
199
201
  generation_type: str,
@@ -203,12 +205,12 @@ class PromptHelper:
203
205
  ):
204
206
  """Build a prompt once an LLM has been executed if no current prompt exists (without template)"""
205
207
  if "chat" in generation_type.lower():
206
- return Prompt(
208
+ return ChatGeneration(
207
209
  provider=provider,
208
210
  settings=llm_settings,
209
211
  completion=completion,
210
212
  messages=[
211
- PromptMessage(
213
+ GenerationMessage(
212
214
  formatted=formatted_prompt,
213
215
  role=self._convert_message_role(formatted_prompt.split(":")[0]),
214
216
  )
@@ -216,16 +218,16 @@ class PromptHelper:
216
218
  ],
217
219
  )
218
220
  else:
219
- return Prompt(
221
+ return CompletionGeneration(
220
222
  provider=provider,
221
223
  settings=llm_settings,
222
224
  completion=completion,
223
225
  formatted=run.inputs.get("prompts", [])[0],
224
226
  )
225
227
 
226
- def _build_chat_template_prompt(self, lc_messages: List[Dict], inputs: Dict):
227
- def build_template_messages() -> List[PromptMessage]:
228
- template_messages = [] # type: List[PromptMessage]
228
+ def _build_chat_template_generation(self, lc_messages: List[Dict], inputs: Dict):
229
+ def build_template_messages() -> List[GenerationMessage]:
230
+ template_messages = [] # type: List[GenerationMessage]
229
231
 
230
232
  if not lc_messages:
231
233
  return template_messages
@@ -241,18 +243,20 @@ class PromptHelper:
241
243
  if "placeholder" in class_name.lower():
242
244
  variable_name = lc_message.get(
243
245
  "variable_name"
246
+ ) or message_kwargs.get(
247
+ "variable_name"
244
248
  ) # type: Optional[str]
245
249
  variable = inputs.get(variable_name, [])
246
250
  placeholder_size = len(variable)
251
+
247
252
  if placeholder_size:
248
253
  template_messages += [
249
- PromptMessage(placeholder_size=placeholder_size)
254
+ GenerationMessage(placeholder_size=placeholder_size)
250
255
  ]
251
256
  else:
252
257
  template_messages += [
253
- PromptMessage(
258
+ GenerationMessage(
254
259
  template=template,
255
- template_format=template_format,
256
260
  role=self._convert_message_role(class_name),
257
261
  )
258
262
  ]
@@ -264,18 +268,18 @@ class PromptHelper:
264
268
  return
265
269
 
266
270
  stringified_inputs = {k: str(v) for (k, v) in inputs.items()}
267
- self.prompt_sequence.append(
268
- Prompt(messages=template_messages, inputs=stringified_inputs)
271
+ self.generation_sequence.append(
272
+ ChatGeneration(messages=template_messages, inputs=stringified_inputs)
269
273
  )
270
274
 
271
- def _build_chat_formatted_prompt(
275
+ def _build_chat_formatted_generation(
272
276
  self, lc_messages: Union[List[BaseMessage], List[dict]]
273
277
  ):
274
- if not self.current_prompt:
278
+ if not self.current_generation:
275
279
  return
276
280
 
277
- formatted_messages = [] # type: List[PromptMessage]
278
- if self.current_prompt.messages:
281
+ formatted_messages = [] # type: List[GenerationMessage]
282
+ if self.current_generation.messages:
279
283
  # This is needed to compute the correct message index to read
280
284
  placeholder_offset = 0
281
285
  # The final list of messages
@@ -283,7 +287,7 @@ class PromptHelper:
283
287
  # Looping the messages built in build_prompt
284
288
  # They only contain the template
285
289
  for template_index, template_message in enumerate(
286
- self.current_prompt.messages
290
+ self.current_generation.messages
287
291
  ):
288
292
  # If a message has a placeholder size, we need to replace it
289
293
  # With the N following messages, where N is the placeholder size
@@ -319,7 +323,7 @@ class PromptHelper:
319
323
  self._convert_message(lc_message) for lc_message in lc_messages
320
324
  ]
321
325
 
322
- self.current_prompt.messages = formatted_messages
326
+ self.current_generation.messages = formatted_messages
323
327
 
324
328
  def _build_llm_settings(
325
329
  self,
@@ -353,8 +357,8 @@ DEFAULT_TO_IGNORE = ["RunnableSequence", "RunnableParallel", "<lambda>"]
353
357
  DEFAULT_TO_KEEP = ["retriever", "llm", "agent", "chain", "tool"]
354
358
 
355
359
 
356
- class LangchainTracer(BaseTracer, PromptHelper, FinalStreamHelper):
357
- llm_stream_message: Dict[str, Message]
360
+ class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
361
+ steps: Dict[str, Step]
358
362
  parent_id_map: Dict[str, str]
359
363
  ignored_runs: set
360
364
 
@@ -373,7 +377,7 @@ class LangchainTracer(BaseTracer, PromptHelper, FinalStreamHelper):
373
377
  **kwargs: Any,
374
378
  ) -> None:
375
379
  BaseTracer.__init__(self, **kwargs)
376
- PromptHelper.__init__(self)
380
+ GenerationHelper.__init__(self)
377
381
  FinalStreamHelper.__init__(
378
382
  self,
379
383
  answer_prefix_tokens=answer_prefix_tokens,
@@ -381,7 +385,7 @@ class LangchainTracer(BaseTracer, PromptHelper, FinalStreamHelper):
381
385
  force_stream_final_answer=force_stream_final_answer,
382
386
  )
383
387
  self.context = context_var.get()
384
- self.llm_stream_message = {}
388
+ self.steps = {}
385
389
  self.parent_id_map = {}
386
390
  self.ignored_runs = set()
387
391
  self.root_parent_id = (
@@ -417,33 +421,37 @@ class LangchainTracer(BaseTracer, PromptHelper, FinalStreamHelper):
417
421
  return self.root_parent_id
418
422
 
419
423
  if current_parent_id not in self.parent_id_map:
420
- return current_parent_id
424
+ return None
421
425
 
422
426
  while current_parent_id in self.parent_id_map:
423
- current_parent_id = self.parent_id_map[current_parent_id]
427
+ # If the parent id is in the ignored runs, we need to get the parent id of the ignored run
428
+ if current_parent_id in self.ignored_runs:
429
+ current_parent_id = self.parent_id_map[current_parent_id]
430
+ else:
431
+ return current_parent_id
424
432
 
425
- return current_parent_id
433
+ return self.root_parent_id
426
434
 
427
435
  def _should_ignore_run(self, run: Run):
428
436
  parent_id = self._get_run_parent_id(run)
429
437
 
438
+ if parent_id:
439
+ # Add the parent id of the ignored run in the mapping
440
+ # so we can re-attach a kept child to the right parent id
441
+ self.parent_id_map[str(run.id)] = parent_id
442
+
430
443
  ignore_by_name = run.name in self.to_ignore
431
444
  ignore_by_parent = parent_id in self.ignored_runs
432
445
 
433
446
  ignore = ignore_by_name or ignore_by_parent
434
447
 
435
- if ignore:
436
- if parent_id:
437
- # Add the parent id of the ignored run in the mapping
438
- # so we can re-attach a kept child to the right parent id
439
- self.parent_id_map[str(run.id)] = parent_id
440
- # Tag the run as ignored
441
- self.ignored_runs.add(str(run.id))
442
-
443
448
  # If the ignore cause is the parent being ignored, check if we should nonetheless keep the child
444
449
  if ignore_by_parent and not ignore_by_name and run.run_type in self.to_keep:
445
- return False, self._get_non_ignored_parent_id(str(run.id))
450
+ return False, self._get_non_ignored_parent_id(parent_id)
446
451
  else:
452
+ if ignore:
453
+ # Tag the run as ignored
454
+ self.ignored_runs.add(str(run.id))
447
455
  return ignore, parent_id
448
456
 
449
457
  def _is_annotable(self, run: Run):
@@ -474,12 +482,12 @@ class LangchainTracer(BaseTracer, PromptHelper, FinalStreamHelper):
474
482
  ) -> Any:
475
483
  """Adding formatted content and new message to the previously built template prompt"""
476
484
  lc_messages = messages[0]
477
- if not self.current_prompt:
478
- self.prompt_sequence.append(
479
- Prompt(messages=[self._convert_message(m) for m in lc_messages])
485
+ if not self.current_generation:
486
+ self.generation_sequence.append(
487
+ ChatGeneration(messages=[self._convert_message(m) for m in lc_messages])
480
488
  )
481
489
  else:
482
- self._build_chat_formatted_prompt(lc_messages)
490
+ self._build_chat_formatted_generation(lc_messages)
483
491
 
484
492
  super().on_chat_model_start(
485
493
  serialized,
@@ -500,7 +508,7 @@ class LangchainTracer(BaseTracer, PromptHelper, FinalStreamHelper):
500
508
  parent_run_id: Optional[UUID] = None,
501
509
  **kwargs: Any,
502
510
  ) -> Any:
503
- msg = self.llm_stream_message.get(str(run_id), None)
511
+ msg = self.steps.get(str(run_id), None)
504
512
  if msg:
505
513
  self._run_sync(msg.stream_token(token))
506
514
 
@@ -510,6 +518,7 @@ class LangchainTracer(BaseTracer, PromptHelper, FinalStreamHelper):
510
518
  if self.answer_reached:
511
519
  if not self.final_stream:
512
520
  self.final_stream = Message(content="")
521
+ self._run_sync(self.final_stream.send())
513
522
  self._run_sync(self.final_stream.stream_token(token))
514
523
  self.has_streamed_final_answer = True
515
524
  else:
@@ -530,36 +539,41 @@ class LangchainTracer(BaseTracer, PromptHelper, FinalStreamHelper):
530
539
 
531
540
  if run.run_type in ["chain", "prompt"]:
532
541
  # Prompt templates are contained in chains or prompts (lcel)
533
- self._build_prompt(run.serialized or {}, run.inputs)
542
+ self._build_generation(run.serialized or {}, run.inputs)
534
543
 
535
544
  ignore, parent_id = self._should_ignore_run(run)
536
545
 
537
546
  if ignore:
538
547
  return
539
548
 
540
- disable_human_feedback = not self._is_annotable(run)
549
+ step_type: TrueStepType = "undefined"
541
550
 
542
- if run.run_type == "llm":
543
- msg = Message(
544
- id=run.id,
545
- content="",
546
- author=run.name,
547
- parent_id=parent_id,
548
- disable_human_feedback=disable_human_feedback,
549
- )
550
- self.llm_stream_message[str(run.id)] = msg
551
- self._run_sync(msg.send())
552
- return
553
-
554
- self._run_sync(
555
- Message(
556
- id=run.id,
557
- content="",
558
- author=run.name,
559
- parent_id=parent_id,
560
- disable_human_feedback=disable_human_feedback,
561
- ).send()
551
+ if run.run_type in ["agent", "chain"]:
552
+ step_type = "run"
553
+ elif run.run_type == "llm":
554
+ step_type = "llm"
555
+ elif run.run_type == "retriever":
556
+ step_type = "retrieval"
557
+ elif run.run_type == "tool":
558
+ step_type = "tool"
559
+ elif run.run_type == "embedding":
560
+ step_type = "embedding"
561
+
562
+ disable_feedback = not self._is_annotable(run)
563
+
564
+ step = Step(
565
+ id=str(run.id),
566
+ name=run.name,
567
+ type=step_type,
568
+ parent_id=parent_id,
569
+ disable_feedback=disable_feedback,
562
570
  )
571
+ step.start = datetime.utcnow().isoformat()
572
+ step.input = run.inputs
573
+
574
+ self.steps[str(run.id)] = step
575
+
576
+ self._run_sync(step.send())
563
577
 
564
578
  def _on_run_update(self, run: Run) -> None:
565
579
  """Process a run upon update."""
@@ -570,74 +584,75 @@ class LangchainTracer(BaseTracer, PromptHelper, FinalStreamHelper):
570
584
  if ignore:
571
585
  return
572
586
 
573
- disable_human_feedback = not self._is_annotable(run)
587
+ current_step = self.steps.get(str(run.id), None)
574
588
 
575
589
  if run.run_type in ["chain"]:
576
- if self.prompt_sequence:
577
- self.prompt_sequence.pop()
590
+ if self.generation_sequence:
591
+ self.generation_sequence.pop()
578
592
 
579
593
  if run.run_type == "llm":
580
594
  provider, llm_settings = self._build_llm_settings(
581
595
  (run.serialized or {}), (run.extra or {}).get("invocation_params")
582
596
  )
583
597
  generations = (run.outputs or {}).get("generations", [])
598
+ llm_output = (run.outputs or {}).get("llm_output")
584
599
  completion, language = self._get_completion(generations[0][0])
585
- current_prompt = (
586
- self.prompt_sequence.pop() if self.prompt_sequence else None
600
+ current_generation = (
601
+ self.generation_sequence.pop() if self.generation_sequence else None
587
602
  )
588
603
 
589
- if current_prompt:
590
- current_prompt.provider = provider
591
- current_prompt.settings = llm_settings
592
- current_prompt.completion = completion
604
+ if current_generation:
605
+ current_generation.provider = provider
606
+ current_generation.settings = llm_settings
607
+ current_generation.completion = completion
593
608
  else:
594
609
  generation_type = generations[0][0].get("type", "")
595
- current_prompt = self._build_default_prompt(
610
+ current_generation = self._build_default_generation(
596
611
  run, generation_type, provider, llm_settings, completion
597
612
  )
598
613
 
599
- msg = self.llm_stream_message.get(str(run.id), None)
600
- if msg:
601
- msg.content = completion
602
- msg.language = language
603
- msg.prompt = current_prompt
604
- self._run_sync(msg.update())
614
+ if llm_output and current_generation:
615
+ token_count = llm_output.get("token_usage", {}).get("total_tokens")
616
+ current_generation.token_count = token_count
617
+
618
+ if current_step:
619
+ current_step.output = completion
620
+ current_step.language = language
621
+ current_step.end = datetime.utcnow().isoformat()
622
+ current_step.generation = current_generation
623
+ self._run_sync(current_step.update())
605
624
 
606
625
  if self.final_stream and self.has_streamed_final_answer:
607
626
  self.final_stream.content = completion
608
627
  self.final_stream.language = language
609
- self.final_stream.prompt = current_prompt
610
- self._run_sync(self.final_stream.send())
628
+ self._run_sync(self.final_stream.update())
629
+
611
630
  return
612
631
 
613
632
  outputs = run.outputs or {}
614
633
  output_keys = list(outputs.keys())
634
+ output = outputs
615
635
  if output_keys:
616
- content = outputs.get(output_keys[0], "")
617
- else:
618
- return
636
+ output = outputs.get(output_keys[0], outputs)
619
637
 
620
- if run.run_type in ["agent", "chain"]:
621
- pass
622
- # # Add the response of the chain/tool
623
- # self._run_sync(
624
- # Message(
625
- # content=content,
626
- # author=run.name,
627
- # parent_id=parent_id,
628
- # disable_human_feedback=disable_human_feedback,
629
- # ).send()
630
- # )
631
- else:
632
- self._run_sync(
633
- Message(
634
- id=run.id,
635
- content=content,
636
- author=run.name,
637
- parent_id=parent_id,
638
- disable_human_feedback=disable_human_feedback,
639
- ).update()
640
- )
638
+ if current_step:
639
+ current_step.output = output
640
+ current_step.end = datetime.utcnow().isoformat()
641
+ self._run_sync(current_step.update())
642
+
643
+ def _on_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any):
644
+ context_var.set(self.context)
645
+
646
+ if current_step := self.steps.get(str(run_id), None):
647
+ current_step.is_error = True
648
+ current_step.output = str(error)
649
+ current_step.end = datetime.utcnow().isoformat()
650
+ self._run_sync(current_step.update())
651
+
652
+ on_llm_error = _on_error
653
+ on_chain_error = _on_error
654
+ on_tool_error = _on_error
655
+ on_retriever_error = _on_error
641
656
 
642
657
 
643
658
  LangchainCallbackHandler = LangchainTracer