chainlit 0.7.700__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chainlit might be problematic. Click here for more details.

Files changed (38) hide show
  1. chainlit/__init__.py +32 -23
  2. chainlit/auth.py +9 -10
  3. chainlit/cli/__init__.py +1 -2
  4. chainlit/config.py +13 -12
  5. chainlit/context.py +7 -3
  6. chainlit/data/__init__.py +375 -9
  7. chainlit/data/acl.py +6 -5
  8. chainlit/element.py +86 -123
  9. chainlit/emitter.py +117 -50
  10. chainlit/frontend/dist/assets/{index-71698725.js → index-6aee009a.js} +118 -292
  11. chainlit/frontend/dist/assets/{react-plotly-2c0acdf0.js → react-plotly-2f07c02a.js} +1 -1
  12. chainlit/frontend/dist/index.html +1 -1
  13. chainlit/haystack/callbacks.py +45 -43
  14. chainlit/hello.py +1 -1
  15. chainlit/langchain/callbacks.py +132 -120
  16. chainlit/llama_index/callbacks.py +68 -48
  17. chainlit/message.py +179 -207
  18. chainlit/oauth_providers.py +39 -34
  19. chainlit/playground/provider.py +44 -30
  20. chainlit/playground/providers/anthropic.py +4 -4
  21. chainlit/playground/providers/huggingface.py +2 -2
  22. chainlit/playground/providers/langchain.py +8 -10
  23. chainlit/playground/providers/openai.py +19 -13
  24. chainlit/server.py +155 -99
  25. chainlit/session.py +109 -40
  26. chainlit/socket.py +47 -36
  27. chainlit/step.py +393 -0
  28. chainlit/types.py +78 -21
  29. chainlit/user.py +32 -0
  30. chainlit/user_session.py +1 -5
  31. {chainlit-0.7.700.dist-info → chainlit-1.0.0rc0.dist-info}/METADATA +12 -31
  32. chainlit-1.0.0rc0.dist-info/RECORD +60 -0
  33. chainlit/client/base.py +0 -169
  34. chainlit/client/cloud.py +0 -502
  35. chainlit/prompt.py +0 -40
  36. chainlit-0.7.700.dist-info/RECORD +0 -61
  37. {chainlit-0.7.700.dist-info → chainlit-1.0.0rc0.dist-info}/WHEEL +0 -0
  38. {chainlit-0.7.700.dist-info → chainlit-1.0.0rc0.dist-info}/entry_points.txt +0 -0
@@ -1,11 +1,11 @@
1
- import asyncio
1
+ from datetime import datetime
2
2
  from typing import Any, Dict, List, Optional
3
3
 
4
4
  from chainlit.context import context_var
5
5
  from chainlit.element import Text
6
- from chainlit.message import Message
7
- from chainlit.prompt import Prompt, PromptMessage
8
- from llama_index.callbacks.base import BaseCallbackHandler
6
+ from chainlit.step import Step, StepType
7
+ from chainlit_client import ChatGeneration, CompletionGeneration, GenerationMessage
8
+ from llama_index.callbacks import TokenCountingHandler
9
9
  from llama_index.callbacks.schema import CBEventType, EventPayload
10
10
  from llama_index.llms.base import ChatMessage, ChatResponse, CompletionResponse
11
11
 
@@ -19,18 +19,31 @@ DEFAULT_IGNORE = [
19
19
  ]
20
20
 
21
21
 
22
- class LlamaIndexCallbackHandler(BaseCallbackHandler):
22
+ class LlamaIndexCallbackHandler(TokenCountingHandler):
23
23
  """Base callback handler that can be used to track event starts and ends."""
24
24
 
25
+ steps: Dict[str, Step]
26
+
25
27
  def __init__(
26
28
  self,
27
29
  event_starts_to_ignore: List[CBEventType] = DEFAULT_IGNORE,
28
30
  event_ends_to_ignore: List[CBEventType] = DEFAULT_IGNORE,
29
31
  ) -> None:
30
32
  """Initialize the base callback handler."""
33
+ super().__init__(
34
+ event_starts_to_ignore=event_starts_to_ignore,
35
+ event_ends_to_ignore=event_ends_to_ignore,
36
+ )
31
37
  self.context = context_var.get()
32
- self.event_starts_to_ignore = tuple(event_starts_to_ignore)
33
- self.event_ends_to_ignore = tuple(event_ends_to_ignore)
38
+
39
+ self.steps = {}
40
+
41
+ def _get_parent_id(self, event_parent_id: Optional[str] = None) -> Optional[str]:
42
+ if event_parent_id and event_parent_id in self.steps:
43
+ return event_parent_id
44
+ if root_message := self.context.session.root_message:
45
+ return root_message.id
46
+ return None
34
47
 
35
48
  def _restore_context(self) -> None:
36
49
  """Restore Chainlit context in the current thread
@@ -45,12 +58,6 @@ class LlamaIndexCallbackHandler(BaseCallbackHandler):
45
58
  """
46
59
  context_var.set(self.context)
47
60
 
48
- def _get_parent_id(self) -> Optional[str]:
49
- """Get the parent message id"""
50
- if root_message := self.context.session.root_message:
51
- return root_message.id
52
- return None
53
-
54
61
  def on_event_start(
55
62
  self,
56
63
  event_type: CBEventType,
@@ -61,14 +68,25 @@ class LlamaIndexCallbackHandler(BaseCallbackHandler):
61
68
  ) -> str:
62
69
  """Run when an event starts and return id of event."""
63
70
  self._restore_context()
64
- asyncio.run(
65
- Message(
66
- content="",
67
- author=event_type,
68
- parent_id=self._get_parent_id(),
69
- ).send()
71
+ step_type: StepType = "undefined"
72
+ if event_type == CBEventType.RETRIEVE:
73
+ step_type = "retrieval"
74
+ elif event_type == CBEventType.LLM:
75
+ step_type = "llm"
76
+ else:
77
+ return event_id
78
+
79
+ step = Step(
80
+ name=event_type.value,
81
+ type=step_type,
82
+ parent_id=self._get_parent_id(parent_id),
83
+ id=event_id,
84
+ disable_feedback=False,
70
85
  )
71
-
86
+ self.steps[event_id] = step
87
+ step.start = datetime.utcnow().isoformat()
88
+ step.input = payload or {}
89
+ self.context.loop.create_task(step.send())
72
90
  return event_id
73
91
 
74
92
  def on_event_end(
@@ -79,31 +97,27 @@ class LlamaIndexCallbackHandler(BaseCallbackHandler):
79
97
  **kwargs: Any,
80
98
  ) -> None:
81
99
  """Run when an event ends."""
82
- if payload is None:
100
+ step = self.steps.get(event_id, None)
101
+
102
+ if payload is None or step is None:
83
103
  return
84
104
 
85
105
  self._restore_context()
86
106
 
107
+ step.end = datetime.utcnow().isoformat()
108
+
87
109
  if event_type == CBEventType.RETRIEVE:
88
110
  sources = payload.get(EventPayload.NODES)
89
111
  if sources:
90
- elements = [
91
- Text(name=f"Source {idx}", content=source.node.get_text())
92
- for idx, source in enumerate(sources)
93
- ]
94
112
  source_refs = "\, ".join(
95
113
  [f"Source {idx}" for idx, _ in enumerate(sources)]
96
114
  )
97
- content = f"Retrieved the following sources: {source_refs}"
98
-
99
- asyncio.run(
100
- Message(
101
- content=content,
102
- author=event_type,
103
- elements=elements,
104
- parent_id=self._get_parent_id(),
105
- ).send()
106
- )
115
+ step.elements = [
116
+ Text(name=f"Source {idx}", content=source.node.get_text())
117
+ for idx, source in enumerate(sources)
118
+ ]
119
+ step.output = f"Retrieved the following sources: {source_refs}"
120
+ self.context.loop.create_task(step.update())
107
121
 
108
122
  if event_type == CBEventType.LLM:
109
123
  formatted_messages = payload.get(
@@ -114,7 +128,7 @@ class LlamaIndexCallbackHandler(BaseCallbackHandler):
114
128
 
115
129
  if formatted_messages:
116
130
  messages = [
117
- PromptMessage(role=m.role.value, formatted=m.content) # type: ignore[arg-type]
131
+ GenerationMessage(role=m.role.value, formatted=m.content) # type: ignore[arg-type]
118
132
  for m in formatted_messages
119
133
  ]
120
134
  else:
@@ -127,18 +141,24 @@ class LlamaIndexCallbackHandler(BaseCallbackHandler):
127
141
  else:
128
142
  content = ""
129
143
 
130
- asyncio.run(
131
- Message(
132
- content=content,
133
- author=event_type,
134
- parent_id=self._get_parent_id(),
135
- prompt=Prompt(
136
- formatted=formatted_prompt,
137
- messages=messages,
138
- completion=content,
139
- ),
140
- ).send()
141
- )
144
+ step.output = content
145
+
146
+ token_count = self.total_llm_token_count or None
147
+
148
+ if messages:
149
+ step.generation = ChatGeneration(
150
+ messages=messages, completion=content, token_count=token_count
151
+ )
152
+ elif formatted_prompt:
153
+ step.generation = CompletionGeneration(
154
+ formatted=formatted_prompt,
155
+ completion=content,
156
+ token_count=token_count,
157
+ )
158
+
159
+ self.context.loop.create_task(step.update())
160
+
161
+ self.steps.pop(event_id, None)
142
162
 
143
163
  def _noop(self, *args, **kwargs):
144
164
  pass