graphai-lib 0.0.1__tar.gz → 0.0.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {graphai_lib-0.0.1 → graphai_lib-0.0.2}/PKG-INFO +1 -1
- graphai_lib-0.0.2/graphai/callback.py +215 -0
- {graphai_lib-0.0.1 → graphai_lib-0.0.2}/pyproject.toml +1 -1
- graphai_lib-0.0.1/graphai/callback.py +0 -63
- {graphai_lib-0.0.1 → graphai_lib-0.0.2}/README.md +0 -0
- {graphai_lib-0.0.1 → graphai_lib-0.0.2}/graphai/__init__.py +0 -0
- {graphai_lib-0.0.1 → graphai_lib-0.0.2}/graphai/graph.py +0 -0
- {graphai_lib-0.0.1 → graphai_lib-0.0.2}/graphai/nodes/__init__.py +0 -0
- {graphai_lib-0.0.1 → graphai_lib-0.0.2}/graphai/nodes/base.py +0 -0
- {graphai_lib-0.0.1 → graphai_lib-0.0.2}/graphai/utils.py +0 -0
@@ -0,0 +1,215 @@
|
|
1
|
+
import asyncio
|
2
|
+
from pydantic import Field
|
3
|
+
from typing import Optional
|
4
|
+
from collections.abc import AsyncIterator
|
5
|
+
from semantic_router.utils.logger import logger
|
6
|
+
|
7
|
+
|
8
|
+
log_stream = True
|
9
|
+
|
10
|
+
class Callback:
|
11
|
+
identifier: str = Field(
|
12
|
+
default="graphai",
|
13
|
+
description=(
|
14
|
+
"The identifier for special tokens. This allows us to easily "
|
15
|
+
"identify special tokens in the stream so we can handle them "
|
16
|
+
"correctly in any downstream process."
|
17
|
+
)
|
18
|
+
)
|
19
|
+
special_token_format: str = Field(
|
20
|
+
default="<{identifier}:{token}:{params}>",
|
21
|
+
description=(
|
22
|
+
"The format for special tokens. This is used to format special "
|
23
|
+
"tokens so they can be easily identified in the stream. "
|
24
|
+
"The format is a string with three possible components:\n"
|
25
|
+
"- {identifier}: An identifier shared by all special tokens, "
|
26
|
+
"by default this is 'graphai'.\n"
|
27
|
+
"- {token}: The special token type to be streamed. This may "
|
28
|
+
"be a tool name, identifier for start/end nodes, etc.\n"
|
29
|
+
"- {params}: Any additional parameters to be streamed. The parameters "
|
30
|
+
"are formatted as a comma-separated list of key-value pairs."
|
31
|
+
),
|
32
|
+
examples=[
|
33
|
+
"<{identifier}:{token}:{params}>",
|
34
|
+
"<[{identifier} | {token} | {params}]>",
|
35
|
+
"<{token}:{params}>"
|
36
|
+
]
|
37
|
+
)
|
38
|
+
token_format: str = Field(
|
39
|
+
default="{token}",
|
40
|
+
description=(
|
41
|
+
"The format for streamed tokens. This is used to format the "
|
42
|
+
"tokens typically returned from LLMs. By default, no special "
|
43
|
+
"formatting is applied."
|
44
|
+
)
|
45
|
+
)
|
46
|
+
_first_token: bool = Field(
|
47
|
+
default=True,
|
48
|
+
description="Whether this is the first token in the stream.",
|
49
|
+
exclude=True
|
50
|
+
)
|
51
|
+
_current_node_name: Optional[str] = Field(
|
52
|
+
default=None,
|
53
|
+
description="The name of the current node.",
|
54
|
+
exclude=True
|
55
|
+
)
|
56
|
+
_active: bool = Field(
|
57
|
+
default=True,
|
58
|
+
description="Whether the callback is active.",
|
59
|
+
exclude=True
|
60
|
+
)
|
61
|
+
_done: bool = Field(
|
62
|
+
default=False,
|
63
|
+
description="Whether the stream is done and should be closed.",
|
64
|
+
exclude=True
|
65
|
+
)
|
66
|
+
queue: asyncio.Queue
|
67
|
+
|
68
|
+
def __init__(
|
69
|
+
self,
|
70
|
+
identifier: str = "graphai",
|
71
|
+
special_token_format: str = "<{identifier}:{token}:{params}>",
|
72
|
+
token_format: str = "{token}",
|
73
|
+
):
|
74
|
+
self.identifier = identifier
|
75
|
+
self.special_token_format = special_token_format
|
76
|
+
self.token_format = token_format
|
77
|
+
self.queue = asyncio.Queue()
|
78
|
+
self._done = False
|
79
|
+
self._first_token = True
|
80
|
+
self._current_node_name = None
|
81
|
+
self._active = True
|
82
|
+
|
83
|
+
@property
|
84
|
+
def first_token(self) -> bool:
|
85
|
+
return self._first_token
|
86
|
+
|
87
|
+
@first_token.setter
|
88
|
+
def first_token(self, value: bool):
|
89
|
+
self._first_token = value
|
90
|
+
|
91
|
+
@property
|
92
|
+
def current_node_name(self) -> Optional[str]:
|
93
|
+
return self._current_node_name
|
94
|
+
|
95
|
+
@current_node_name.setter
|
96
|
+
def current_node_name(self, value: Optional[str]):
|
97
|
+
self._current_node_name = value
|
98
|
+
|
99
|
+
@property
|
100
|
+
def active(self) -> bool:
|
101
|
+
return self._active
|
102
|
+
|
103
|
+
@active.setter
|
104
|
+
def active(self, value: bool):
|
105
|
+
self._active = value
|
106
|
+
|
107
|
+
def __call__(self, token: str, node_name: Optional[str] = None):
|
108
|
+
if self._done:
|
109
|
+
raise RuntimeError("Cannot add tokens to a closed stream")
|
110
|
+
self._check_node_name(node_name=node_name)
|
111
|
+
# otherwise we just assume node is correct and send token
|
112
|
+
self.queue.put_nowait(token)
|
113
|
+
|
114
|
+
async def acall(self, token: str, node_name: Optional[str] = None):
|
115
|
+
# TODO JB: do we need to have `node_name` param?
|
116
|
+
if self._done:
|
117
|
+
raise RuntimeError("Cannot add tokens to a closed stream")
|
118
|
+
self._check_node_name(node_name=node_name)
|
119
|
+
# otherwise we just assume node is correct and send token
|
120
|
+
self.queue.put_nowait(token)
|
121
|
+
|
122
|
+
async def aiter(self) -> AsyncIterator[str]:
|
123
|
+
"""Used by receiver to get the tokens from the stream queue. Creates
|
124
|
+
a generator that yields tokens from the queue until the END token is
|
125
|
+
received.
|
126
|
+
"""
|
127
|
+
end_token = await self._build_special_token(
|
128
|
+
name="END",
|
129
|
+
params=None
|
130
|
+
)
|
131
|
+
while True: # Keep going until we see the END token
|
132
|
+
try:
|
133
|
+
if self._done and self.queue.empty():
|
134
|
+
break
|
135
|
+
token = await self.queue.get()
|
136
|
+
yield token
|
137
|
+
self.queue.task_done()
|
138
|
+
if token == end_token:
|
139
|
+
break
|
140
|
+
except asyncio.CancelledError:
|
141
|
+
break
|
142
|
+
self._done = True # Mark as done after processing all tokens
|
143
|
+
|
144
|
+
async def start_node(self, node_name: str, active: bool = True):
|
145
|
+
"""Starts a new node and emits the start token.
|
146
|
+
"""
|
147
|
+
if self._done:
|
148
|
+
raise RuntimeError("Cannot start node on a closed stream")
|
149
|
+
self.current_node_name = node_name
|
150
|
+
if self.first_token:
|
151
|
+
self.first_token = False
|
152
|
+
self.active = active
|
153
|
+
if self.active:
|
154
|
+
token = await self._build_special_token(
|
155
|
+
name=f"{self.current_node_name}:start",
|
156
|
+
params=None
|
157
|
+
)
|
158
|
+
self.queue.put_nowait(token)
|
159
|
+
# TODO JB: should we use two tokens here?
|
160
|
+
node_token = await self._build_special_token(
|
161
|
+
name=self.current_node_name,
|
162
|
+
params=None
|
163
|
+
)
|
164
|
+
self.queue.put_nowait(node_token)
|
165
|
+
|
166
|
+
async def end_node(self, node_name: str):
|
167
|
+
"""Emits the end token for the current node.
|
168
|
+
"""
|
169
|
+
if self._done:
|
170
|
+
raise RuntimeError("Cannot end node on a closed stream")
|
171
|
+
#self.current_node_name = node_name
|
172
|
+
if self.active:
|
173
|
+
node_token = await self._build_special_token(
|
174
|
+
name=f"{self.current_node_name}:end",
|
175
|
+
params=None
|
176
|
+
)
|
177
|
+
self.queue.put_nowait(node_token)
|
178
|
+
|
179
|
+
async def close(self):
|
180
|
+
"""Close the stream and prevent further tokens from being added.
|
181
|
+
This will send an END token and set the done flag to True.
|
182
|
+
"""
|
183
|
+
if self._done:
|
184
|
+
return
|
185
|
+
end_token = await self._build_special_token(
|
186
|
+
name="END",
|
187
|
+
params=None
|
188
|
+
)
|
189
|
+
self._done = True # Set done before putting the end token
|
190
|
+
self.queue.put_nowait(end_token)
|
191
|
+
# Don't wait for queue.join() as it can cause deadlock
|
192
|
+
# The stream will close when aiter processes the END token
|
193
|
+
|
194
|
+
def _check_node_name(self, node_name: Optional[str] = None):
|
195
|
+
if node_name:
|
196
|
+
# we confirm this is the current node
|
197
|
+
if self.current_node_name != node_name:
|
198
|
+
raise ValueError(
|
199
|
+
f"Node name mismatch: {self.current_node_name} != {node_name}"
|
200
|
+
)
|
201
|
+
|
202
|
+
async def _build_special_token(self, name: str, params: dict[str, any] | None = None):
|
203
|
+
if params:
|
204
|
+
params_str = ",".join([f"{k}={v}" for k, v in params.items()])
|
205
|
+
else:
|
206
|
+
params_str = ""
|
207
|
+
if self.identifier:
|
208
|
+
identifier = self.identifier
|
209
|
+
else:
|
210
|
+
identifier = ""
|
211
|
+
return self.special_token_format.format(
|
212
|
+
identifier=identifier,
|
213
|
+
token=name,
|
214
|
+
params=params_str
|
215
|
+
)
|
@@ -1,63 +0,0 @@
|
|
1
|
-
import asyncio
|
2
|
-
from typing import Optional
|
3
|
-
from collections.abc import AsyncIterator
|
4
|
-
from semantic_router.utils.logger import logger
|
5
|
-
|
6
|
-
|
7
|
-
log_stream = True
|
8
|
-
|
9
|
-
class Callback:
|
10
|
-
first_token = True
|
11
|
-
current_node_name: Optional[str] = None
|
12
|
-
active: bool = True
|
13
|
-
queue: asyncio.Queue
|
14
|
-
|
15
|
-
def __init__(self):
|
16
|
-
self.queue = asyncio.Queue()
|
17
|
-
|
18
|
-
def __call__(self, token: str, node_name: Optional[str] = None):
|
19
|
-
self._check_node_name(node_name=node_name)
|
20
|
-
# otherwise we just assume node is correct and send token
|
21
|
-
self.queue.put_nowait(token)
|
22
|
-
|
23
|
-
async def acall(self, token: str, node_name: Optional[str] = None):
|
24
|
-
self._check_node_name(node_name=node_name)
|
25
|
-
# otherwise we just assume node is correct and send token
|
26
|
-
self.queue.put_nowait(token)
|
27
|
-
|
28
|
-
async def aiter(self) -> AsyncIterator[str]:
|
29
|
-
"""Used by receiver to get the tokens from the stream queue. Creates
|
30
|
-
a generator that yields tokens from the queue until the END token is
|
31
|
-
received.
|
32
|
-
"""
|
33
|
-
while True:
|
34
|
-
token = await self.queue.get()
|
35
|
-
yield token
|
36
|
-
self.queue.task_done()
|
37
|
-
if token == "<graphai:END>":
|
38
|
-
break
|
39
|
-
|
40
|
-
async def start_node(self, node_name: str, active: bool = True):
|
41
|
-
self.current_node_name = node_name
|
42
|
-
if self.first_token:
|
43
|
-
# TODO JB: not sure if we need self.first_token
|
44
|
-
self.first_token = False
|
45
|
-
self.active = active
|
46
|
-
if self.active:
|
47
|
-
self.queue.put_nowait(f"<graphai:start:{node_name}>")
|
48
|
-
|
49
|
-
async def end_node(self, node_name: str):
|
50
|
-
self.current_node_name = None
|
51
|
-
if self.active:
|
52
|
-
self.queue.put_nowait(f"<graphai:end:{node_name}>")
|
53
|
-
|
54
|
-
async def close(self):
|
55
|
-
self.queue.put_nowait("<graphai:END>")
|
56
|
-
|
57
|
-
def _check_node_name(self, node_name: Optional[str] = None):
|
58
|
-
if node_name:
|
59
|
-
# we confirm this is the current node
|
60
|
-
if self.current_node_name != node_name:
|
61
|
-
raise ValueError(
|
62
|
-
f"Node name mismatch: {self.current_node_name} != {node_name}"
|
63
|
-
)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|