chainlit 1.0.401__py3-none-any.whl → 2.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

Files changed (113) hide show
  1. chainlit/__init__.py +98 -279
  2. chainlit/_utils.py +8 -0
  3. chainlit/action.py +12 -10
  4. chainlit/{auth.py → auth/__init__.py} +28 -36
  5. chainlit/auth/cookie.py +123 -0
  6. chainlit/auth/jwt.py +39 -0
  7. chainlit/cache.py +4 -6
  8. chainlit/callbacks.py +362 -0
  9. chainlit/chat_context.py +64 -0
  10. chainlit/chat_settings.py +3 -1
  11. chainlit/cli/__init__.py +77 -8
  12. chainlit/config.py +191 -102
  13. chainlit/context.py +42 -13
  14. chainlit/copilot/dist/index.js +8750 -903
  15. chainlit/data/__init__.py +101 -416
  16. chainlit/data/acl.py +6 -2
  17. chainlit/data/base.py +107 -0
  18. chainlit/data/chainlit_data_layer.py +614 -0
  19. chainlit/data/dynamodb.py +590 -0
  20. chainlit/data/literalai.py +500 -0
  21. chainlit/data/sql_alchemy.py +721 -0
  22. chainlit/data/storage_clients/__init__.py +0 -0
  23. chainlit/data/storage_clients/azure.py +81 -0
  24. chainlit/data/storage_clients/azure_blob.py +89 -0
  25. chainlit/data/storage_clients/base.py +26 -0
  26. chainlit/data/storage_clients/gcs.py +88 -0
  27. chainlit/data/storage_clients/s3.py +75 -0
  28. chainlit/data/utils.py +29 -0
  29. chainlit/discord/__init__.py +6 -0
  30. chainlit/discord/app.py +354 -0
  31. chainlit/element.py +91 -33
  32. chainlit/emitter.py +81 -29
  33. chainlit/frontend/dist/assets/DailyMotion-Ce9dQoqZ.js +1 -0
  34. chainlit/frontend/dist/assets/Dataframe-C1XonMcV.js +22 -0
  35. chainlit/frontend/dist/assets/Facebook-DVVt6lrr.js +1 -0
  36. chainlit/frontend/dist/assets/FilePlayer-c7stW4vz.js +1 -0
  37. chainlit/frontend/dist/assets/Kaltura-BmMmgorA.js +1 -0
  38. chainlit/frontend/dist/assets/Mixcloud-Cw8hDmiO.js +1 -0
  39. chainlit/frontend/dist/assets/Mux-DiRZfeUf.js +1 -0
  40. chainlit/frontend/dist/assets/Preview-6Jt2mRHx.js +1 -0
  41. chainlit/frontend/dist/assets/SoundCloud-DKwcT58_.js +1 -0
  42. chainlit/frontend/dist/assets/Streamable-BVdxrEeX.js +1 -0
  43. chainlit/frontend/dist/assets/Twitch-DFqZR7Gu.js +1 -0
  44. chainlit/frontend/dist/assets/Vidyard-0BQAAtVk.js +1 -0
  45. chainlit/frontend/dist/assets/Vimeo-CRFSH0Vu.js +1 -0
  46. chainlit/frontend/dist/assets/Wistia-CKrmdQaG.js +1 -0
  47. chainlit/frontend/dist/assets/YouTube-CQpL-rvU.js +1 -0
  48. chainlit/frontend/dist/assets/index-DQmLRKyv.css +1 -0
  49. chainlit/frontend/dist/assets/index-QdmxtIMQ.js +8665 -0
  50. chainlit/frontend/dist/assets/react-plotly-B9hvVpUG.js +3484 -0
  51. chainlit/frontend/dist/index.html +2 -4
  52. chainlit/haystack/callbacks.py +4 -7
  53. chainlit/input_widget.py +8 -4
  54. chainlit/langchain/callbacks.py +103 -68
  55. chainlit/langflow/__init__.py +1 -0
  56. chainlit/llama_index/callbacks.py +65 -40
  57. chainlit/markdown.py +22 -6
  58. chainlit/message.py +54 -56
  59. chainlit/mistralai/__init__.py +50 -0
  60. chainlit/oauth_providers.py +266 -8
  61. chainlit/openai/__init__.py +10 -18
  62. chainlit/secret.py +1 -1
  63. chainlit/server.py +789 -228
  64. chainlit/session.py +108 -90
  65. chainlit/slack/__init__.py +6 -0
  66. chainlit/slack/app.py +397 -0
  67. chainlit/socket.py +199 -116
  68. chainlit/step.py +141 -89
  69. chainlit/sync.py +2 -1
  70. chainlit/teams/__init__.py +6 -0
  71. chainlit/teams/app.py +338 -0
  72. chainlit/translations/bn.json +244 -0
  73. chainlit/translations/en-US.json +122 -8
  74. chainlit/translations/gu.json +244 -0
  75. chainlit/translations/he-IL.json +244 -0
  76. chainlit/translations/hi.json +244 -0
  77. chainlit/translations/ja.json +242 -0
  78. chainlit/translations/kn.json +244 -0
  79. chainlit/translations/ml.json +244 -0
  80. chainlit/translations/mr.json +244 -0
  81. chainlit/translations/nl-NL.json +242 -0
  82. chainlit/translations/ta.json +244 -0
  83. chainlit/translations/te.json +244 -0
  84. chainlit/translations/zh-CN.json +243 -0
  85. chainlit/translations.py +60 -0
  86. chainlit/types.py +133 -28
  87. chainlit/user.py +14 -3
  88. chainlit/user_session.py +6 -3
  89. chainlit/utils.py +52 -5
  90. chainlit/version.py +3 -2
  91. {chainlit-1.0.401.dist-info → chainlit-2.0.4.dist-info}/METADATA +48 -50
  92. chainlit-2.0.4.dist-info/RECORD +107 -0
  93. chainlit/cli/utils.py +0 -24
  94. chainlit/frontend/dist/assets/index-9711593e.js +0 -723
  95. chainlit/frontend/dist/assets/index-d088547c.css +0 -1
  96. chainlit/frontend/dist/assets/react-plotly-d8762cc2.js +0 -3602
  97. chainlit/playground/__init__.py +0 -2
  98. chainlit/playground/config.py +0 -40
  99. chainlit/playground/provider.py +0 -108
  100. chainlit/playground/providers/__init__.py +0 -13
  101. chainlit/playground/providers/anthropic.py +0 -118
  102. chainlit/playground/providers/huggingface.py +0 -75
  103. chainlit/playground/providers/langchain.py +0 -89
  104. chainlit/playground/providers/openai.py +0 -408
  105. chainlit/playground/providers/vertexai.py +0 -171
  106. chainlit/translations/pt-BR.json +0 -155
  107. chainlit-1.0.401.dist-info/RECORD +0 -66
  108. /chainlit/copilot/dist/assets/{logo_dark-2a3cf740.svg → logo_dark-IkGJ_IwC.svg} +0 -0
  109. /chainlit/copilot/dist/assets/{logo_light-b078e7bc.svg → logo_light-Bb_IPh6r.svg} +0 -0
  110. /chainlit/frontend/dist/assets/{logo_dark-2a3cf740.svg → logo_dark-IkGJ_IwC.svg} +0 -0
  111. /chainlit/frontend/dist/assets/{logo_light-b078e7bc.svg → logo_light-Bb_IPh6r.svg} +0 -0
  112. {chainlit-1.0.401.dist-info → chainlit-2.0.4.dist-info}/WHEEL +0 -0
  113. {chainlit-1.0.401.dist-info → chainlit-2.0.4.dist-info}/entry_points.txt +0 -0
chainlit/markdown.py CHANGED
@@ -1,7 +1,11 @@
1
1
  import os
2
+ from pathlib import Path
3
+ from typing import Optional
2
4
 
3
5
  from chainlit.logger import logger
4
6
 
7
+ from ._utils import is_path_inside
8
+
5
9
  # Default chainlit.md file created if none exists
6
10
  DEFAULT_MARKDOWN_STR = """# Welcome to Chainlit! 🚀🤖
7
11
 
@@ -30,12 +34,24 @@ def init_markdown(root: str):
30
34
  logger.info(f"Created default chainlit markdown file at {chainlit_md_file}")
31
35
 
32
36
 
33
- def get_markdown_str(root: str):
37
+ def get_markdown_str(root: str, language: str) -> Optional[str]:
34
38
  """Get the chainlit.md file as a string."""
35
- chainlit_md_path = os.path.join(root, "chainlit.md")
36
- if os.path.exists(chainlit_md_path):
37
- with open(chainlit_md_path, "r", encoding="utf-8") as f:
38
- chainlit_md = f.read()
39
- return chainlit_md
39
+ root_path = Path(root)
40
+ translated_chainlit_md_path = root_path / f"chainlit_{language}.md"
41
+ default_chainlit_md_path = root_path / "chainlit.md"
42
+
43
+ if (
44
+ is_path_inside(translated_chainlit_md_path, root_path)
45
+ and translated_chainlit_md_path.is_file()
46
+ ):
47
+ chainlit_md_path = translated_chainlit_md_path
48
+ else:
49
+ chainlit_md_path = default_chainlit_md_path
50
+ logger.warning(
51
+ f"Translated markdown file for {language} not found. Defaulting to chainlit.md."
52
+ )
53
+
54
+ if chainlit_md_path.is_file():
55
+ return chainlit_md_path.read_text(encoding="utf-8")
40
56
  else:
41
57
  return None
chainlit/message.py CHANGED
@@ -5,9 +5,13 @@ import uuid
5
5
  from abc import ABC
6
6
  from typing import Dict, List, Optional, Union, cast
7
7
 
8
+ from literalai.helper import utc_now
9
+ from literalai.observability.step import MessageStepType
10
+
8
11
  from chainlit.action import Action
12
+ from chainlit.chat_context import chat_context
9
13
  from chainlit.config import config
10
- from chainlit.context import context
14
+ from chainlit.context import context, local_steps
11
15
  from chainlit.data import get_data_layer
12
16
  from chainlit.element import ElementBased
13
17
  from chainlit.logger import logger
@@ -21,9 +25,6 @@ from chainlit.types import (
21
25
  AskSpec,
22
26
  FileDict,
23
27
  )
24
- from literalai import BaseGeneration
25
- from literalai.helper import utc_now
26
- from literalai.step import MessageStepType
27
28
 
28
29
 
29
30
  class MessageBase(ABC):
@@ -32,57 +33,60 @@ class MessageBase(ABC):
32
33
  author: str
33
34
  content: str = ""
34
35
  type: MessageStepType = "assistant_message"
35
- disable_feedback = False
36
36
  streaming = False
37
37
  created_at: Union[str, None] = None
38
38
  fail_on_persist_error: bool = False
39
39
  persisted = False
40
40
  is_error = False
41
+ parent_id: Optional[str] = None
41
42
  language: Optional[str] = None
43
+ metadata: Optional[Dict] = None
44
+ tags: Optional[List[str]] = None
42
45
  wait_for_answer = False
43
- indent: Optional[int] = None
44
- generation: Optional[BaseGeneration] = None
45
46
 
46
47
  def __post_init__(self) -> None:
47
48
  trace_event(f"init {self.__class__.__name__}")
48
49
  self.thread_id = context.session.thread_id
49
50
 
51
+ previous_steps = local_steps.get() or []
52
+ parent_step = previous_steps[-1] if previous_steps else None
53
+ if parent_step:
54
+ self.parent_id = parent_step.id
55
+
50
56
  if not getattr(self, "id", None):
51
57
  self.id = str(uuid.uuid4())
52
58
 
53
59
  @classmethod
54
60
  def from_dict(self, _dict: StepDict):
55
61
  type = _dict.get("type", "assistant_message")
56
- message = Message(
62
+ return Message(
57
63
  id=_dict["id"],
64
+ parent_id=_dict.get("parentId"),
58
65
  created_at=_dict["createdAt"],
59
66
  content=_dict["output"],
60
67
  author=_dict.get("name", config.ui.name),
61
68
  type=type, # type: ignore
62
- disable_feedback=_dict.get("disableFeedback", False),
63
69
  language=_dict.get("language"),
70
+ metadata=_dict.get("metadata", {}),
64
71
  )
65
72
 
66
- return message
67
-
68
73
  def to_dict(self) -> StepDict:
69
74
  _dict: StepDict = {
70
75
  "id": self.id,
71
76
  "threadId": self.thread_id,
77
+ "parentId": self.parent_id,
72
78
  "createdAt": self.created_at,
73
79
  "start": self.created_at,
74
80
  "end": self.created_at,
75
81
  "output": self.content,
76
82
  "name": self.author,
77
83
  "type": self.type,
78
- "createdAt": self.created_at,
79
84
  "language": self.language,
80
85
  "streaming": self.streaming,
81
- "disableFeedback": self.disable_feedback,
82
86
  "isError": self.is_error,
83
87
  "waitForAnswer": self.wait_for_answer,
84
- "indent": self.indent,
85
- "generation": self.generation.to_dict() if self.generation else None,
88
+ "metadata": self.metadata or {},
89
+ "tags": self.tags,
86
90
  }
87
91
 
88
92
  return _dict
@@ -99,6 +103,7 @@ class MessageBase(ABC):
99
103
  self.streaming = False
100
104
 
101
105
  step_dict = self.to_dict()
106
+ chat_context.add(self)
102
107
 
103
108
  data_layer = get_data_layer()
104
109
  if data_layer:
@@ -107,7 +112,7 @@ class MessageBase(ABC):
107
112
  except Exception as e:
108
113
  if self.fail_on_persist_error:
109
114
  raise e
110
- logger.error(f"Failed to persist message update: {str(e)}")
115
+ logger.error(f"Failed to persist message update: {e!s}")
111
116
 
112
117
  await context.emitter.update_step(step_dict)
113
118
 
@@ -118,7 +123,7 @@ class MessageBase(ABC):
118
123
  Remove a message already sent to the UI.
119
124
  """
120
125
  trace_event("remove_message")
121
-
126
+ chat_context.remove(self)
122
127
  step_dict = self.to_dict()
123
128
  data_layer = get_data_layer()
124
129
  if data_layer:
@@ -127,7 +132,7 @@ class MessageBase(ABC):
127
132
  except Exception as e:
128
133
  if self.fail_on_persist_error:
129
134
  raise e
130
- logger.error(f"Failed to persist message deletion: {str(e)}")
135
+ logger.error(f"Failed to persist message deletion: {e!s}")
131
136
 
132
137
  await context.emitter.delete_step(step_dict)
133
138
 
@@ -143,7 +148,7 @@ class MessageBase(ABC):
143
148
  except Exception as e:
144
149
  if self.fail_on_persist_error:
145
150
  raise e
146
- logger.error(f"Failed to persist message creation: {str(e)}")
151
+ logger.error(f"Failed to persist message creation: {e!s}")
147
152
 
148
153
  return step_dict
149
154
 
@@ -160,30 +165,31 @@ class MessageBase(ABC):
160
165
  self.streaming = False
161
166
 
162
167
  step_dict = await self._create()
168
+ chat_context.add(self)
163
169
  await context.emitter.send_step(step_dict)
164
170
 
165
- return self.id
171
+ return self
166
172
 
167
173
  async def stream_token(self, token: str, is_sequence=False):
168
174
  """
169
175
  Sends a token to the UI. This is useful for streaming messages.
170
176
  Once all tokens have been streamed, call .send() to end the stream and persist the message if persistence is enabled.
171
177
  """
172
-
173
- if not self.streaming:
174
- self.streaming = True
175
- step_dict = self.to_dict()
176
- await context.emitter.stream_start(step_dict)
177
-
178
178
  if is_sequence:
179
179
  self.content = token
180
180
  else:
181
181
  self.content += token
182
182
 
183
183
  assert self.id
184
- await context.emitter.send_token(
185
- id=self.id, token=token, is_sequence=is_sequence
186
- )
184
+
185
+ if not self.streaming:
186
+ self.streaming = True
187
+ step_dict = self.to_dict()
188
+ await context.emitter.stream_start(step_dict)
189
+ else:
190
+ await context.emitter.send_token(
191
+ id=self.id, token=token, is_sequence=is_sequence
192
+ )
187
193
 
188
194
 
189
195
  class Message(MessageBase):
@@ -192,29 +198,28 @@ class Message(MessageBase):
192
198
 
193
199
  Args:
194
200
  content (Union[str, Dict]): The content of the message.
195
- author (str, optional): The author of the message, this will be used in the UI. Defaults to the chatbot name (see config).
201
+ author (str, optional): The author of the message, this will be used in the UI. Defaults to the assistant name (see config).
196
202
  language (str, optional): Language of the code is the content is code. See https://react-code-blocks-rajinwonderland.vercel.app/?path=/story/codeblock--supported-languages for a list of supported languages.
197
203
  actions (List[Action], optional): A list of actions to send with the message.
198
204
  elements (List[ElementBased], optional): A list of elements to send with the message.
199
- disable_feedback (bool, optional): Hide the feedback buttons for this specific message
200
205
  """
201
206
 
202
207
  def __init__(
203
208
  self,
204
209
  content: Union[str, Dict],
205
- author: str = config.ui.name,
210
+ author: Optional[str] = None,
206
211
  language: Optional[str] = None,
207
212
  actions: Optional[List[Action]] = None,
208
213
  elements: Optional[List[ElementBased]] = None,
209
- disable_feedback: bool = False,
210
214
  type: MessageStepType = "assistant_message",
211
- generation: Optional[BaseGeneration] = None,
215
+ metadata: Optional[Dict] = None,
216
+ tags: Optional[List[str]] = None,
212
217
  id: Optional[str] = None,
218
+ parent_id: Optional[str] = None,
213
219
  created_at: Union[str, None] = None,
214
220
  ):
215
221
  time.sleep(0.001)
216
222
  self.language = language
217
- self.generation = generation
218
223
  if isinstance(content, dict):
219
224
  try:
220
225
  self.content = json.dumps(content, indent=4, ensure_ascii=False)
@@ -231,18 +236,23 @@ class Message(MessageBase):
231
236
  if id:
232
237
  self.id = str(id)
233
238
 
239
+ if parent_id:
240
+ self.parent_id = str(parent_id)
241
+
234
242
  if created_at:
235
243
  self.created_at = created_at
236
244
 
237
- self.author = author
245
+ self.metadata = metadata
246
+ self.tags = tags
247
+
248
+ self.author = author or config.ui.name
238
249
  self.type = type
239
250
  self.actions = actions if actions is not None else []
240
251
  self.elements = elements if elements is not None else []
241
- self.disable_feedback = disable_feedback
242
252
 
243
253
  super().__post_init__()
244
254
 
245
- async def send(self) -> str:
255
+ async def send(self):
246
256
  """
247
257
  Send the message to the UI and persist it in the cloud if a project ID is configured.
248
258
  Return the ID of the message.
@@ -250,8 +260,6 @@ class Message(MessageBase):
250
260
  trace_event("send_message")
251
261
  await super().send()
252
262
 
253
- context.session.root_message = self
254
-
255
263
  # Create tasks for all actions and elements
256
264
  tasks = [action.send(for_id=self.id) for action in self.actions]
257
265
  tasks.extend(element.send(for_id=self.id) for element in self.elements)
@@ -259,7 +267,7 @@ class Message(MessageBase):
259
267
  # Run all tasks concurrently
260
268
  await asyncio.gather(*tasks)
261
269
 
262
- return self.id
270
+ return self
263
271
 
264
272
  async def update(self):
265
273
  """
@@ -294,9 +302,7 @@ class ErrorMessage(MessageBase):
294
302
 
295
303
  Args:
296
304
  content (str): Text displayed above the upload button.
297
- author (str, optional): The author of the message, this will be used in the UI. Defaults to the chatbot name (see config).
298
- parent_id (str, optional): If provided, the message will be nested inside the parent in the UI.
299
- indent (int, optional): If positive, the message will be nested in the UI.
305
+ author (str, optional): The author of the message, this will be used in the UI. Defaults to the assistant name (see config).
300
306
  """
301
307
 
302
308
  def __init__(
@@ -307,7 +313,7 @@ class ErrorMessage(MessageBase):
307
313
  ):
308
314
  self.content = content
309
315
  self.author = author
310
- self.type = "system_message"
316
+ self.type = "assistant_message"
311
317
  self.is_error = True
312
318
  self.fail_on_persist_error = fail_on_persist_error
313
319
 
@@ -337,8 +343,7 @@ class AskUserMessage(AskMessageBase):
337
343
 
338
344
  Args:
339
345
  content (str): The content of the prompt.
340
- author (str, optional): The author of the message, this will be used in the UI. Defaults to the chatbot name (see config).
341
- disable_feedback (bool, optional): Hide the feedback buttons for this specific message
346
+ author (str, optional): The author of the message, this will be used in the UI. Defaults to the assistant name (see config).
342
347
  timeout (int, optional): The number of seconds to wait for an answer before raising a TimeoutError.
343
348
  raise_on_timeout (bool, optional): Whether to raise a socketio TimeoutError if the user does not answer in time.
344
349
  """
@@ -348,7 +353,6 @@ class AskUserMessage(AskMessageBase):
348
353
  content: str,
349
354
  author: str = config.ui.name,
350
355
  type: MessageStepType = "assistant_message",
351
- disable_feedback: bool = False,
352
356
  timeout: int = 60,
353
357
  raise_on_timeout: bool = False,
354
358
  ):
@@ -356,7 +360,6 @@ class AskUserMessage(AskMessageBase):
356
360
  self.author = author
357
361
  self.timeout = timeout
358
362
  self.type = type
359
- self.disable_feedback = disable_feedback
360
363
  self.raise_on_timeout = raise_on_timeout
361
364
 
362
365
  super().__post_init__()
@@ -402,8 +405,7 @@ class AskFileMessage(AskMessageBase):
402
405
  accept (Union[List[str], Dict[str, List[str]]]): List of mime type to accept like ["text/csv", "application/pdf"] or a dict like {"text/plain": [".txt", ".py"]}.
403
406
  max_size_mb (int, optional): Maximum size per file in MB. Maximum value is 100.
404
407
  max_files (int, optional): Maximum number of files to upload. Maximum value is 10.
405
- author (str, optional): The author of the message, this will be used in the UI. Defaults to the chatbot name (see config).
406
- disable_feedback (bool, optional): Hide the feedback buttons for this specific message
408
+ author (str, optional): The author of the message, this will be used in the UI. Defaults to the assistant name (see config).
407
409
  timeout (int, optional): The number of seconds to wait for an answer before raising a TimeoutError.
408
410
  raise_on_timeout (bool, optional): Whether to raise a socketio TimeoutError if the user does not answer in time.
409
411
  """
@@ -416,7 +418,6 @@ class AskFileMessage(AskMessageBase):
416
418
  max_files=1,
417
419
  author=config.ui.name,
418
420
  type: MessageStepType = "assistant_message",
419
- disable_feedback: bool = False,
420
421
  timeout=90,
421
422
  raise_on_timeout=False,
422
423
  ):
@@ -428,7 +429,6 @@ class AskFileMessage(AskMessageBase):
428
429
  self.author = author
429
430
  self.timeout = timeout
430
431
  self.raise_on_timeout = raise_on_timeout
431
- self.disable_feedback = disable_feedback
432
432
 
433
433
  super().__post_init__()
434
434
 
@@ -492,14 +492,12 @@ class AskActionMessage(AskMessageBase):
492
492
  content: str,
493
493
  actions: List[Action],
494
494
  author=config.ui.name,
495
- disable_feedback=False,
496
495
  timeout=90,
497
496
  raise_on_timeout=False,
498
497
  ):
499
498
  self.content = content
500
499
  self.actions = actions
501
500
  self.author = author
502
- self.disable_feedback = disable_feedback
503
501
  self.timeout = timeout
504
502
  self.raise_on_timeout = raise_on_timeout
505
503
 
@@ -542,7 +540,7 @@ class AskActionMessage(AskMessageBase):
542
540
  if res is None:
543
541
  self.content = "Timed out: no action was taken"
544
542
  else:
545
- self.content = f'**Selected action:** {res["label"]}'
543
+ self.content = f"**Selected:** {res['label']}"
546
544
 
547
545
  self.wait_for_answer = False
548
546
 
@@ -0,0 +1,50 @@
1
+ import asyncio
2
+ from typing import Union
3
+
4
+ from literalai import ChatGeneration, CompletionGeneration
5
+ from literalai.helper import timestamp_utc
6
+
7
+ from chainlit.context import get_context
8
+ from chainlit.step import Step
9
+
10
+
11
+ def instrument_mistralai():
12
+ from literalai.instrumentation.mistralai import instrument_mistralai
13
+
14
+ def on_new_generation(
15
+ generation: Union["ChatGeneration", "CompletionGeneration"], timing
16
+ ):
17
+ context = get_context()
18
+
19
+ parent_id = None
20
+ if context.current_step:
21
+ parent_id = context.current_step.id
22
+
23
+ step = Step(
24
+ name=generation.model if generation.model else generation.provider,
25
+ type="llm",
26
+ parent_id=parent_id,
27
+ )
28
+ step.generation = generation
29
+ # Convert start/end time from seconds to milliseconds
30
+ step.start = (
31
+ timestamp_utc(timing.get("start"))
32
+ if timing.get("start", None) is not None
33
+ else None
34
+ )
35
+ step.end = (
36
+ timestamp_utc(timing.get("end"))
37
+ if timing.get("end", None) is not None
38
+ else None
39
+ )
40
+
41
+ if isinstance(generation, ChatGeneration):
42
+ step.input = generation.messages
43
+ step.output = generation.message_completion # type: ignore
44
+ else:
45
+ step.input = generation.prompt
46
+ step.output = generation.completion
47
+
48
+ asyncio.create_task(step.send())
49
+
50
+ instrument_mistralai(None, on_new_generation)