chainlit 1.0.400__py3-none-any.whl → 2.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

Files changed (113) hide show
  1. chainlit/__init__.py +98 -279
  2. chainlit/_utils.py +8 -0
  3. chainlit/action.py +12 -10
  4. chainlit/{auth.py → auth/__init__.py} +28 -36
  5. chainlit/auth/cookie.py +122 -0
  6. chainlit/auth/jwt.py +39 -0
  7. chainlit/cache.py +4 -6
  8. chainlit/callbacks.py +362 -0
  9. chainlit/chat_context.py +64 -0
  10. chainlit/chat_settings.py +3 -1
  11. chainlit/cli/__init__.py +77 -8
  12. chainlit/config.py +181 -101
  13. chainlit/context.py +42 -13
  14. chainlit/copilot/dist/index.js +8750 -903
  15. chainlit/data/__init__.py +101 -416
  16. chainlit/data/acl.py +6 -2
  17. chainlit/data/base.py +107 -0
  18. chainlit/data/chainlit_data_layer.py +608 -0
  19. chainlit/data/dynamodb.py +590 -0
  20. chainlit/data/literalai.py +500 -0
  21. chainlit/data/sql_alchemy.py +721 -0
  22. chainlit/data/storage_clients/__init__.py +0 -0
  23. chainlit/data/storage_clients/azure.py +81 -0
  24. chainlit/data/storage_clients/azure_blob.py +89 -0
  25. chainlit/data/storage_clients/base.py +26 -0
  26. chainlit/data/storage_clients/gcs.py +88 -0
  27. chainlit/data/storage_clients/s3.py +75 -0
  28. chainlit/data/utils.py +29 -0
  29. chainlit/discord/__init__.py +6 -0
  30. chainlit/discord/app.py +354 -0
  31. chainlit/element.py +91 -33
  32. chainlit/emitter.py +80 -29
  33. chainlit/frontend/dist/assets/DailyMotion-C_XC7xJI.js +1 -0
  34. chainlit/frontend/dist/assets/Dataframe-Cs4l4hA1.js +22 -0
  35. chainlit/frontend/dist/assets/Facebook-CUeCH7hk.js +1 -0
  36. chainlit/frontend/dist/assets/FilePlayer-CB-fYkx8.js +1 -0
  37. chainlit/frontend/dist/assets/Kaltura-YX6qaq72.js +1 -0
  38. chainlit/frontend/dist/assets/Mixcloud-DGV0ldjP.js +1 -0
  39. chainlit/frontend/dist/assets/Mux-CmRss5oc.js +1 -0
  40. chainlit/frontend/dist/assets/Preview-DBVJn7-H.js +1 -0
  41. chainlit/frontend/dist/assets/SoundCloud-qLUb18oY.js +1 -0
  42. chainlit/frontend/dist/assets/Streamable-BvYP7bFp.js +1 -0
  43. chainlit/frontend/dist/assets/Twitch-CTHt-sGZ.js +1 -0
  44. chainlit/frontend/dist/assets/Vidyard-B-0mCJbm.js +1 -0
  45. chainlit/frontend/dist/assets/Vimeo-Dnp7ri8q.js +1 -0
  46. chainlit/frontend/dist/assets/Wistia-DW0x_UBn.js +1 -0
  47. chainlit/frontend/dist/assets/YouTube--98FipvA.js +1 -0
  48. chainlit/frontend/dist/assets/index-D71nZ46o.js +8665 -0
  49. chainlit/frontend/dist/assets/index-g8LTJwwr.css +1 -0
  50. chainlit/frontend/dist/assets/react-plotly-Cn_BQTQw.js +3484 -0
  51. chainlit/frontend/dist/index.html +2 -4
  52. chainlit/haystack/callbacks.py +4 -7
  53. chainlit/input_widget.py +8 -4
  54. chainlit/langchain/callbacks.py +107 -72
  55. chainlit/langflow/__init__.py +1 -0
  56. chainlit/llama_index/__init__.py +2 -2
  57. chainlit/llama_index/callbacks.py +67 -42
  58. chainlit/markdown.py +22 -6
  59. chainlit/message.py +54 -56
  60. chainlit/mistralai/__init__.py +50 -0
  61. chainlit/oauth_providers.py +266 -8
  62. chainlit/openai/__init__.py +10 -18
  63. chainlit/secret.py +1 -1
  64. chainlit/server.py +789 -228
  65. chainlit/session.py +108 -90
  66. chainlit/slack/__init__.py +6 -0
  67. chainlit/slack/app.py +397 -0
  68. chainlit/socket.py +199 -116
  69. chainlit/step.py +141 -89
  70. chainlit/sync.py +2 -1
  71. chainlit/teams/__init__.py +6 -0
  72. chainlit/teams/app.py +338 -0
  73. chainlit/translations/bn.json +235 -0
  74. chainlit/translations/en-US.json +83 -4
  75. chainlit/translations/gu.json +235 -0
  76. chainlit/translations/he-IL.json +235 -0
  77. chainlit/translations/hi.json +235 -0
  78. chainlit/translations/kn.json +235 -0
  79. chainlit/translations/ml.json +235 -0
  80. chainlit/translations/mr.json +235 -0
  81. chainlit/translations/nl-NL.json +233 -0
  82. chainlit/translations/ta.json +235 -0
  83. chainlit/translations/te.json +235 -0
  84. chainlit/translations/zh-CN.json +233 -0
  85. chainlit/translations.py +60 -0
  86. chainlit/types.py +133 -28
  87. chainlit/user.py +14 -3
  88. chainlit/user_session.py +6 -3
  89. chainlit/utils.py +52 -5
  90. chainlit/version.py +3 -2
  91. {chainlit-1.0.400.dist-info → chainlit-2.0.3.dist-info}/METADATA +48 -50
  92. chainlit-2.0.3.dist-info/RECORD +106 -0
  93. chainlit/cli/utils.py +0 -24
  94. chainlit/frontend/dist/assets/index-9711593e.js +0 -723
  95. chainlit/frontend/dist/assets/index-d088547c.css +0 -1
  96. chainlit/frontend/dist/assets/react-plotly-d8762cc2.js +0 -3602
  97. chainlit/playground/__init__.py +0 -2
  98. chainlit/playground/config.py +0 -40
  99. chainlit/playground/provider.py +0 -108
  100. chainlit/playground/providers/__init__.py +0 -13
  101. chainlit/playground/providers/anthropic.py +0 -118
  102. chainlit/playground/providers/huggingface.py +0 -75
  103. chainlit/playground/providers/langchain.py +0 -89
  104. chainlit/playground/providers/openai.py +0 -408
  105. chainlit/playground/providers/vertexai.py +0 -171
  106. chainlit/translations/pt-BR.json +0 -155
  107. chainlit-1.0.400.dist-info/RECORD +0 -66
  108. /chainlit/copilot/dist/assets/{logo_dark-2a3cf740.svg → logo_dark-IkGJ_IwC.svg} +0 -0
  109. /chainlit/copilot/dist/assets/{logo_light-b078e7bc.svg → logo_light-Bb_IPh6r.svg} +0 -0
  110. /chainlit/frontend/dist/assets/{logo_dark-2a3cf740.svg → logo_dark-IkGJ_IwC.svg} +0 -0
  111. /chainlit/frontend/dist/assets/{logo_light-b078e7bc.svg → logo_light-Bb_IPh6r.svg} +0 -0
  112. {chainlit-1.0.400.dist-info → chainlit-2.0.3.dist-info}/WHEEL +0 -0
  113. {chainlit-1.0.400.dist-info → chainlit-2.0.3.dist-info}/entry_points.txt +0 -0
@@ -1,20 +1,21 @@
1
1
  from typing import Any, Dict, List, Optional
2
2
 
3
+ from literalai import ChatGeneration, CompletionGeneration, GenerationMessage
4
+ from literalai.helper import utc_now
5
+ from llama_index.core.callbacks import TokenCountingHandler
6
+ from llama_index.core.callbacks.schema import CBEventType, EventPayload
7
+ from llama_index.core.llms import ChatMessage, ChatResponse, CompletionResponse
8
+ from llama_index.core.tools.types import ToolMetadata
9
+
3
10
  from chainlit.context import context_var
4
11
  from chainlit.element import Text
5
12
  from chainlit.step import Step, StepType
6
- from literalai import ChatGeneration, CompletionGeneration, GenerationMessage
7
- from literalai.helper import utc_now
8
- from llama_index.callbacks import TokenCountingHandler
9
- from llama_index.callbacks.schema import CBEventType, EventPayload
10
- from llama_index.llms.base import ChatMessage, ChatResponse, CompletionResponse
11
13
 
12
14
  DEFAULT_IGNORE = [
13
15
  CBEventType.CHUNKING,
14
16
  CBEventType.SYNTHESIZE,
15
17
  CBEventType.EMBEDDING,
16
18
  CBEventType.NODE_PARSING,
17
- CBEventType.QUERY,
18
19
  CBEventType.TREE,
19
20
  ]
20
21
 
@@ -34,33 +35,17 @@ class LlamaIndexCallbackHandler(TokenCountingHandler):
34
35
  event_starts_to_ignore=event_starts_to_ignore,
35
36
  event_ends_to_ignore=event_ends_to_ignore,
36
37
  )
37
- self.context = context_var.get()
38
38
 
39
39
  self.steps = {}
40
40
 
41
41
  def _get_parent_id(self, event_parent_id: Optional[str] = None) -> Optional[str]:
42
42
  if event_parent_id and event_parent_id in self.steps:
43
43
  return event_parent_id
44
- elif self.context.current_step:
45
- return self.context.current_step.id
46
- elif self.context.session.root_message:
47
- return self.context.session.root_message.id
44
+ elif context_var.get().current_step:
45
+ return context_var.get().current_step.id
48
46
  else:
49
47
  return None
50
48
 
51
- def _restore_context(self) -> None:
52
- """Restore Chainlit context in the current thread
53
-
54
- Chainlit context is local to the main thread, and LlamaIndex
55
- runs the callbacks in its own threads, so they don't have a
56
- Chainlit context by default.
57
-
58
- This method restores the context in which the callback handler
59
- has been created (it's always created in the main thread), so
60
- that we can actually send messages.
61
- """
62
- context_var.set(self.context)
63
-
64
49
  def on_event_start(
65
50
  self,
66
51
  event_type: CBEventType,
@@ -70,26 +55,36 @@ class LlamaIndexCallbackHandler(TokenCountingHandler):
70
55
  **kwargs: Any,
71
56
  ) -> str:
72
57
  """Run when an event starts and return id of event."""
73
- self._restore_context()
74
58
  step_type: StepType = "undefined"
75
- if event_type == CBEventType.RETRIEVE:
76
- step_type = "retrieval"
59
+ step_name: str = event_type.value
60
+ step_input: Optional[Dict[str, Any]] = payload
61
+ if event_type == CBEventType.FUNCTION_CALL:
62
+ step_type = "tool"
63
+ if payload:
64
+ metadata: Optional[ToolMetadata] = payload.get(EventPayload.TOOL)
65
+ if metadata:
66
+ step_name = getattr(metadata, "name", step_name)
67
+ step_input = payload.get(EventPayload.FUNCTION_CALL)
68
+ elif event_type == CBEventType.RETRIEVE:
69
+ step_type = "tool"
70
+ elif event_type == CBEventType.QUERY:
71
+ step_type = "tool"
77
72
  elif event_type == CBEventType.LLM:
78
73
  step_type = "llm"
79
74
  else:
80
75
  return event_id
81
76
 
82
77
  step = Step(
83
- name=event_type.value,
78
+ name=step_name,
84
79
  type=step_type,
85
80
  parent_id=self._get_parent_id(parent_id),
86
81
  id=event_id,
87
- disable_feedback=False,
88
82
  )
83
+
89
84
  self.steps[event_id] = step
90
85
  step.start = utc_now()
91
- step.input = payload or {}
92
- self.context.loop.create_task(step.send())
86
+ step.input = step_input or {}
87
+ context_var.get().loop.create_task(step.send())
93
88
  return event_id
94
89
 
95
90
  def on_event_end(
@@ -105,37 +100,59 @@ class LlamaIndexCallbackHandler(TokenCountingHandler):
105
100
  if payload is None or step is None:
106
101
  return
107
102
 
108
- self._restore_context()
109
-
110
103
  step.end = utc_now()
111
104
 
112
- if event_type == CBEventType.RETRIEVE:
105
+ if event_type == CBEventType.FUNCTION_CALL:
106
+ response = payload.get(EventPayload.FUNCTION_OUTPUT)
107
+ if response:
108
+ step.output = f"{response}"
109
+ context_var.get().loop.create_task(step.update())
110
+
111
+ elif event_type == CBEventType.QUERY:
112
+ response = payload.get(EventPayload.RESPONSE)
113
+ source_nodes = getattr(response, "source_nodes", None)
114
+ if source_nodes:
115
+ source_refs = ", ".join(
116
+ [f"Source {idx}" for idx, _ in enumerate(source_nodes)]
117
+ )
118
+ step.elements = [
119
+ Text(
120
+ name=f"Source {idx}",
121
+ content=source.text or "Empty node",
122
+ display="side",
123
+ )
124
+ for idx, source in enumerate(source_nodes)
125
+ ]
126
+ step.output = f"Retrieved the following sources: {source_refs}"
127
+ context_var.get().loop.create_task(step.update())
128
+
129
+ elif event_type == CBEventType.RETRIEVE:
113
130
  sources = payload.get(EventPayload.NODES)
114
131
  if sources:
115
- source_refs = "\, ".join(
132
+ source_refs = ", ".join(
116
133
  [f"Source {idx}" for idx, _ in enumerate(sources)]
117
134
  )
118
135
  step.elements = [
119
136
  Text(
120
137
  name=f"Source {idx}",
138
+ display="side",
121
139
  content=source.node.get_text() or "Empty node",
122
140
  )
123
141
  for idx, source in enumerate(sources)
124
142
  ]
125
143
  step.output = f"Retrieved the following sources: {source_refs}"
126
- self.context.loop.create_task(step.update())
144
+ context_var.get().loop.create_task(step.update())
127
145
 
128
- if event_type == CBEventType.LLM:
129
- formatted_messages = payload.get(
130
- EventPayload.MESSAGES
131
- ) # type: Optional[List[ChatMessage]]
146
+ elif event_type == CBEventType.LLM:
147
+ formatted_messages = payload.get(EventPayload.MESSAGES) # type: Optional[List[ChatMessage]]
132
148
  formatted_prompt = payload.get(EventPayload.PROMPT)
133
149
  response = payload.get(EventPayload.RESPONSE)
134
150
 
135
151
  if formatted_messages:
136
152
  messages = [
137
153
  GenerationMessage(
138
- role=m.role.value, content=m.content or "" # type: ignore
154
+ role=m.role.value, # type: ignore
155
+ content=m.content or "",
139
156
  )
140
157
  for m in formatted_messages
141
158
  ]
@@ -152,10 +169,13 @@ class LlamaIndexCallbackHandler(TokenCountingHandler):
152
169
  step.output = content
153
170
 
154
171
  token_count = self.total_llm_token_count or None
172
+ raw_response = response.raw if response else None
173
+ model = getattr(raw_response, "model", None)
155
174
 
156
175
  if messages and isinstance(response, ChatResponse):
157
176
  msg: ChatMessage = response.message
158
177
  step.generation = ChatGeneration(
178
+ model=model,
159
179
  messages=messages,
160
180
  message_completion=GenerationMessage(
161
181
  role=msg.role.value, # type: ignore
@@ -165,12 +185,17 @@ class LlamaIndexCallbackHandler(TokenCountingHandler):
165
185
  )
166
186
  elif formatted_prompt:
167
187
  step.generation = CompletionGeneration(
188
+ model=model,
168
189
  prompt=formatted_prompt,
169
190
  completion=content,
170
191
  token_count=token_count,
171
192
  )
172
193
 
173
- self.context.loop.create_task(step.update())
194
+ context_var.get().loop.create_task(step.update())
195
+
196
+ else:
197
+ step.output = payload
198
+ context_var.get().loop.create_task(step.update())
174
199
 
175
200
  self.steps.pop(event_id, None)
176
201
 
chainlit/markdown.py CHANGED
@@ -1,7 +1,11 @@
1
1
  import os
2
+ from pathlib import Path
3
+ from typing import Optional
2
4
 
3
5
  from chainlit.logger import logger
4
6
 
7
+ from ._utils import is_path_inside
8
+
5
9
  # Default chainlit.md file created if none exists
6
10
  DEFAULT_MARKDOWN_STR = """# Welcome to Chainlit! 🚀🤖
7
11
 
@@ -30,12 +34,24 @@ def init_markdown(root: str):
30
34
  logger.info(f"Created default chainlit markdown file at {chainlit_md_file}")
31
35
 
32
36
 
33
- def get_markdown_str(root: str):
37
+ def get_markdown_str(root: str, language: str) -> Optional[str]:
34
38
  """Get the chainlit.md file as a string."""
35
- chainlit_md_path = os.path.join(root, "chainlit.md")
36
- if os.path.exists(chainlit_md_path):
37
- with open(chainlit_md_path, "r", encoding="utf-8") as f:
38
- chainlit_md = f.read()
39
- return chainlit_md
39
+ root_path = Path(root)
40
+ translated_chainlit_md_path = root_path / f"chainlit_{language}.md"
41
+ default_chainlit_md_path = root_path / "chainlit.md"
42
+
43
+ if (
44
+ is_path_inside(translated_chainlit_md_path, root_path)
45
+ and translated_chainlit_md_path.is_file()
46
+ ):
47
+ chainlit_md_path = translated_chainlit_md_path
48
+ else:
49
+ chainlit_md_path = default_chainlit_md_path
50
+ logger.warning(
51
+ f"Translated markdown file for {language} not found. Defaulting to chainlit.md."
52
+ )
53
+
54
+ if chainlit_md_path.is_file():
55
+ return chainlit_md_path.read_text(encoding="utf-8")
40
56
  else:
41
57
  return None
chainlit/message.py CHANGED
@@ -5,9 +5,13 @@ import uuid
5
5
  from abc import ABC
6
6
  from typing import Dict, List, Optional, Union, cast
7
7
 
8
+ from literalai.helper import utc_now
9
+ from literalai.observability.step import MessageStepType
10
+
8
11
  from chainlit.action import Action
12
+ from chainlit.chat_context import chat_context
9
13
  from chainlit.config import config
10
- from chainlit.context import context
14
+ from chainlit.context import context, local_steps
11
15
  from chainlit.data import get_data_layer
12
16
  from chainlit.element import ElementBased
13
17
  from chainlit.logger import logger
@@ -21,9 +25,6 @@ from chainlit.types import (
21
25
  AskSpec,
22
26
  FileDict,
23
27
  )
24
- from literalai import BaseGeneration
25
- from literalai.helper import utc_now
26
- from literalai.step import MessageStepType
27
28
 
28
29
 
29
30
  class MessageBase(ABC):
@@ -32,57 +33,60 @@ class MessageBase(ABC):
32
33
  author: str
33
34
  content: str = ""
34
35
  type: MessageStepType = "assistant_message"
35
- disable_feedback = False
36
36
  streaming = False
37
37
  created_at: Union[str, None] = None
38
38
  fail_on_persist_error: bool = False
39
39
  persisted = False
40
40
  is_error = False
41
+ parent_id: Optional[str] = None
41
42
  language: Optional[str] = None
43
+ metadata: Optional[Dict] = None
44
+ tags: Optional[List[str]] = None
42
45
  wait_for_answer = False
43
- indent: Optional[int] = None
44
- generation: Optional[BaseGeneration] = None
45
46
 
46
47
  def __post_init__(self) -> None:
47
48
  trace_event(f"init {self.__class__.__name__}")
48
49
  self.thread_id = context.session.thread_id
49
50
 
51
+ previous_steps = local_steps.get() or []
52
+ parent_step = previous_steps[-1] if previous_steps else None
53
+ if parent_step:
54
+ self.parent_id = parent_step.id
55
+
50
56
  if not getattr(self, "id", None):
51
57
  self.id = str(uuid.uuid4())
52
58
 
53
59
  @classmethod
54
60
  def from_dict(self, _dict: StepDict):
55
61
  type = _dict.get("type", "assistant_message")
56
- message = Message(
62
+ return Message(
57
63
  id=_dict["id"],
64
+ parent_id=_dict.get("parentId"),
58
65
  created_at=_dict["createdAt"],
59
66
  content=_dict["output"],
60
67
  author=_dict.get("name", config.ui.name),
61
68
  type=type, # type: ignore
62
- disable_feedback=_dict.get("disableFeedback", False),
63
69
  language=_dict.get("language"),
70
+ metadata=_dict.get("metadata", {}),
64
71
  )
65
72
 
66
- return message
67
-
68
73
  def to_dict(self) -> StepDict:
69
74
  _dict: StepDict = {
70
75
  "id": self.id,
71
76
  "threadId": self.thread_id,
77
+ "parentId": self.parent_id,
72
78
  "createdAt": self.created_at,
73
79
  "start": self.created_at,
74
80
  "end": self.created_at,
75
81
  "output": self.content,
76
82
  "name": self.author,
77
83
  "type": self.type,
78
- "createdAt": self.created_at,
79
84
  "language": self.language,
80
85
  "streaming": self.streaming,
81
- "disableFeedback": self.disable_feedback,
82
86
  "isError": self.is_error,
83
87
  "waitForAnswer": self.wait_for_answer,
84
- "indent": self.indent,
85
- "generation": self.generation.to_dict() if self.generation else None,
88
+ "metadata": self.metadata or {},
89
+ "tags": self.tags,
86
90
  }
87
91
 
88
92
  return _dict
@@ -99,6 +103,7 @@ class MessageBase(ABC):
99
103
  self.streaming = False
100
104
 
101
105
  step_dict = self.to_dict()
106
+ chat_context.add(self)
102
107
 
103
108
  data_layer = get_data_layer()
104
109
  if data_layer:
@@ -107,7 +112,7 @@ class MessageBase(ABC):
107
112
  except Exception as e:
108
113
  if self.fail_on_persist_error:
109
114
  raise e
110
- logger.error(f"Failed to persist message update: {str(e)}")
115
+ logger.error(f"Failed to persist message update: {e!s}")
111
116
 
112
117
  await context.emitter.update_step(step_dict)
113
118
 
@@ -118,7 +123,7 @@ class MessageBase(ABC):
118
123
  Remove a message already sent to the UI.
119
124
  """
120
125
  trace_event("remove_message")
121
-
126
+ chat_context.remove(self)
122
127
  step_dict = self.to_dict()
123
128
  data_layer = get_data_layer()
124
129
  if data_layer:
@@ -127,7 +132,7 @@ class MessageBase(ABC):
127
132
  except Exception as e:
128
133
  if self.fail_on_persist_error:
129
134
  raise e
130
- logger.error(f"Failed to persist message deletion: {str(e)}")
135
+ logger.error(f"Failed to persist message deletion: {e!s}")
131
136
 
132
137
  await context.emitter.delete_step(step_dict)
133
138
 
@@ -143,7 +148,7 @@ class MessageBase(ABC):
143
148
  except Exception as e:
144
149
  if self.fail_on_persist_error:
145
150
  raise e
146
- logger.error(f"Failed to persist message creation: {str(e)}")
151
+ logger.error(f"Failed to persist message creation: {e!s}")
147
152
 
148
153
  return step_dict
149
154
 
@@ -160,30 +165,31 @@ class MessageBase(ABC):
160
165
  self.streaming = False
161
166
 
162
167
  step_dict = await self._create()
168
+ chat_context.add(self)
163
169
  await context.emitter.send_step(step_dict)
164
170
 
165
- return self.id
171
+ return self
166
172
 
167
173
  async def stream_token(self, token: str, is_sequence=False):
168
174
  """
169
175
  Sends a token to the UI. This is useful for streaming messages.
170
176
  Once all tokens have been streamed, call .send() to end the stream and persist the message if persistence is enabled.
171
177
  """
172
-
173
- if not self.streaming:
174
- self.streaming = True
175
- step_dict = self.to_dict()
176
- await context.emitter.stream_start(step_dict)
177
-
178
178
  if is_sequence:
179
179
  self.content = token
180
180
  else:
181
181
  self.content += token
182
182
 
183
183
  assert self.id
184
- await context.emitter.send_token(
185
- id=self.id, token=token, is_sequence=is_sequence
186
- )
184
+
185
+ if not self.streaming:
186
+ self.streaming = True
187
+ step_dict = self.to_dict()
188
+ await context.emitter.stream_start(step_dict)
189
+ else:
190
+ await context.emitter.send_token(
191
+ id=self.id, token=token, is_sequence=is_sequence
192
+ )
187
193
 
188
194
 
189
195
  class Message(MessageBase):
@@ -192,29 +198,28 @@ class Message(MessageBase):
192
198
 
193
199
  Args:
194
200
  content (Union[str, Dict]): The content of the message.
195
- author (str, optional): The author of the message, this will be used in the UI. Defaults to the chatbot name (see config).
201
+ author (str, optional): The author of the message, this will be used in the UI. Defaults to the assistant name (see config).
196
202
  language (str, optional): Language of the code is the content is code. See https://react-code-blocks-rajinwonderland.vercel.app/?path=/story/codeblock--supported-languages for a list of supported languages.
197
203
  actions (List[Action], optional): A list of actions to send with the message.
198
204
  elements (List[ElementBased], optional): A list of elements to send with the message.
199
- disable_feedback (bool, optional): Hide the feedback buttons for this specific message
200
205
  """
201
206
 
202
207
  def __init__(
203
208
  self,
204
209
  content: Union[str, Dict],
205
- author: str = config.ui.name,
210
+ author: Optional[str] = None,
206
211
  language: Optional[str] = None,
207
212
  actions: Optional[List[Action]] = None,
208
213
  elements: Optional[List[ElementBased]] = None,
209
- disable_feedback: bool = False,
210
214
  type: MessageStepType = "assistant_message",
211
- generation: Optional[BaseGeneration] = None,
215
+ metadata: Optional[Dict] = None,
216
+ tags: Optional[List[str]] = None,
212
217
  id: Optional[str] = None,
218
+ parent_id: Optional[str] = None,
213
219
  created_at: Union[str, None] = None,
214
220
  ):
215
221
  time.sleep(0.001)
216
222
  self.language = language
217
- self.generation = generation
218
223
  if isinstance(content, dict):
219
224
  try:
220
225
  self.content = json.dumps(content, indent=4, ensure_ascii=False)
@@ -231,18 +236,23 @@ class Message(MessageBase):
231
236
  if id:
232
237
  self.id = str(id)
233
238
 
239
+ if parent_id:
240
+ self.parent_id = str(parent_id)
241
+
234
242
  if created_at:
235
243
  self.created_at = created_at
236
244
 
237
- self.author = author
245
+ self.metadata = metadata
246
+ self.tags = tags
247
+
248
+ self.author = author or config.ui.name
238
249
  self.type = type
239
250
  self.actions = actions if actions is not None else []
240
251
  self.elements = elements if elements is not None else []
241
- self.disable_feedback = disable_feedback
242
252
 
243
253
  super().__post_init__()
244
254
 
245
- async def send(self) -> str:
255
+ async def send(self):
246
256
  """
247
257
  Send the message to the UI and persist it in the cloud if a project ID is configured.
248
258
  Return the ID of the message.
@@ -250,8 +260,6 @@ class Message(MessageBase):
250
260
  trace_event("send_message")
251
261
  await super().send()
252
262
 
253
- context.session.root_message = self
254
-
255
263
  # Create tasks for all actions and elements
256
264
  tasks = [action.send(for_id=self.id) for action in self.actions]
257
265
  tasks.extend(element.send(for_id=self.id) for element in self.elements)
@@ -259,7 +267,7 @@ class Message(MessageBase):
259
267
  # Run all tasks concurrently
260
268
  await asyncio.gather(*tasks)
261
269
 
262
- return self.id
270
+ return self
263
271
 
264
272
  async def update(self):
265
273
  """
@@ -294,9 +302,7 @@ class ErrorMessage(MessageBase):
294
302
 
295
303
  Args:
296
304
  content (str): Text displayed above the upload button.
297
- author (str, optional): The author of the message, this will be used in the UI. Defaults to the chatbot name (see config).
298
- parent_id (str, optional): If provided, the message will be nested inside the parent in the UI.
299
- indent (int, optional): If positive, the message will be nested in the UI.
305
+ author (str, optional): The author of the message, this will be used in the UI. Defaults to the assistant name (see config).
300
306
  """
301
307
 
302
308
  def __init__(
@@ -307,7 +313,7 @@ class ErrorMessage(MessageBase):
307
313
  ):
308
314
  self.content = content
309
315
  self.author = author
310
- self.type = "system_message"
316
+ self.type = "assistant_message"
311
317
  self.is_error = True
312
318
  self.fail_on_persist_error = fail_on_persist_error
313
319
 
@@ -337,8 +343,7 @@ class AskUserMessage(AskMessageBase):
337
343
 
338
344
  Args:
339
345
  content (str): The content of the prompt.
340
- author (str, optional): The author of the message, this will be used in the UI. Defaults to the chatbot name (see config).
341
- disable_feedback (bool, optional): Hide the feedback buttons for this specific message
346
+ author (str, optional): The author of the message, this will be used in the UI. Defaults to the assistant name (see config).
342
347
  timeout (int, optional): The number of seconds to wait for an answer before raising a TimeoutError.
343
348
  raise_on_timeout (bool, optional): Whether to raise a socketio TimeoutError if the user does not answer in time.
344
349
  """
@@ -348,7 +353,6 @@ class AskUserMessage(AskMessageBase):
348
353
  content: str,
349
354
  author: str = config.ui.name,
350
355
  type: MessageStepType = "assistant_message",
351
- disable_feedback: bool = False,
352
356
  timeout: int = 60,
353
357
  raise_on_timeout: bool = False,
354
358
  ):
@@ -356,7 +360,6 @@ class AskUserMessage(AskMessageBase):
356
360
  self.author = author
357
361
  self.timeout = timeout
358
362
  self.type = type
359
- self.disable_feedback = disable_feedback
360
363
  self.raise_on_timeout = raise_on_timeout
361
364
 
362
365
  super().__post_init__()
@@ -402,8 +405,7 @@ class AskFileMessage(AskMessageBase):
402
405
  accept (Union[List[str], Dict[str, List[str]]]): List of mime type to accept like ["text/csv", "application/pdf"] or a dict like {"text/plain": [".txt", ".py"]}.
403
406
  max_size_mb (int, optional): Maximum size per file in MB. Maximum value is 100.
404
407
  max_files (int, optional): Maximum number of files to upload. Maximum value is 10.
405
- author (str, optional): The author of the message, this will be used in the UI. Defaults to the chatbot name (see config).
406
- disable_feedback (bool, optional): Hide the feedback buttons for this specific message
408
+ author (str, optional): The author of the message, this will be used in the UI. Defaults to the assistant name (see config).
407
409
  timeout (int, optional): The number of seconds to wait for an answer before raising a TimeoutError.
408
410
  raise_on_timeout (bool, optional): Whether to raise a socketio TimeoutError if the user does not answer in time.
409
411
  """
@@ -416,7 +418,6 @@ class AskFileMessage(AskMessageBase):
416
418
  max_files=1,
417
419
  author=config.ui.name,
418
420
  type: MessageStepType = "assistant_message",
419
- disable_feedback: bool = False,
420
421
  timeout=90,
421
422
  raise_on_timeout=False,
422
423
  ):
@@ -428,7 +429,6 @@ class AskFileMessage(AskMessageBase):
428
429
  self.author = author
429
430
  self.timeout = timeout
430
431
  self.raise_on_timeout = raise_on_timeout
431
- self.disable_feedback = disable_feedback
432
432
 
433
433
  super().__post_init__()
434
434
 
@@ -492,14 +492,12 @@ class AskActionMessage(AskMessageBase):
492
492
  content: str,
493
493
  actions: List[Action],
494
494
  author=config.ui.name,
495
- disable_feedback=False,
496
495
  timeout=90,
497
496
  raise_on_timeout=False,
498
497
  ):
499
498
  self.content = content
500
499
  self.actions = actions
501
500
  self.author = author
502
- self.disable_feedback = disable_feedback
503
501
  self.timeout = timeout
504
502
  self.raise_on_timeout = raise_on_timeout
505
503
 
@@ -542,7 +540,7 @@ class AskActionMessage(AskMessageBase):
542
540
  if res is None:
543
541
  self.content = "Timed out: no action was taken"
544
542
  else:
545
- self.content = f'**Selected action:** {res["label"]}'
543
+ self.content = f"**Selected:** {res['label']}"
546
544
 
547
545
  self.wait_for_answer = False
548
546
 
@@ -0,0 +1,50 @@
1
+ import asyncio
2
+ from typing import Union
3
+
4
+ from literalai import ChatGeneration, CompletionGeneration
5
+ from literalai.helper import timestamp_utc
6
+
7
+ from chainlit.context import get_context
8
+ from chainlit.step import Step
9
+
10
+
11
+ def instrument_mistralai():
12
+ from literalai.instrumentation.mistralai import instrument_mistralai
13
+
14
+ def on_new_generation(
15
+ generation: Union["ChatGeneration", "CompletionGeneration"], timing
16
+ ):
17
+ context = get_context()
18
+
19
+ parent_id = None
20
+ if context.current_step:
21
+ parent_id = context.current_step.id
22
+
23
+ step = Step(
24
+ name=generation.model if generation.model else generation.provider,
25
+ type="llm",
26
+ parent_id=parent_id,
27
+ )
28
+ step.generation = generation
29
+ # Convert start/end time from seconds to milliseconds
30
+ step.start = (
31
+ timestamp_utc(timing.get("start"))
32
+ if timing.get("start", None) is not None
33
+ else None
34
+ )
35
+ step.end = (
36
+ timestamp_utc(timing.get("end"))
37
+ if timing.get("end", None) is not None
38
+ else None
39
+ )
40
+
41
+ if isinstance(generation, ChatGeneration):
42
+ step.input = generation.messages
43
+ step.output = generation.message_completion # type: ignore
44
+ else:
45
+ step.input = generation.prompt
46
+ step.output = generation.completion
47
+
48
+ asyncio.create_task(step.send())
49
+
50
+ instrument_mistralai(None, on_new_generation)