stackraise 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stackraise/__init__.py +6 -0
- stackraise/ai/__init__.py +2 -0
- stackraise/ai/rpa.py +380 -0
- stackraise/ai/toolset.py +227 -0
- stackraise/app.py +23 -0
- stackraise/auth/__init__.py +2 -0
- stackraise/auth/model.py +24 -0
- stackraise/auth/service.py +240 -0
- stackraise/ctrl/__init__.py +4 -0
- stackraise/ctrl/change_stream.py +40 -0
- stackraise/ctrl/crud_controller.py +63 -0
- stackraise/ctrl/file_storage.py +68 -0
- stackraise/db/__init__.py +11 -0
- stackraise/db/adapter.py +60 -0
- stackraise/db/collection.py +292 -0
- stackraise/db/cursor.py +229 -0
- stackraise/db/document.py +282 -0
- stackraise/db/exceptions.py +9 -0
- stackraise/db/id.py +79 -0
- stackraise/db/index.py +84 -0
- stackraise/db/persistence.py +238 -0
- stackraise/db/pipeline.py +245 -0
- stackraise/db/protocols.py +141 -0
- stackraise/di.py +36 -0
- stackraise/event.py +150 -0
- stackraise/inflection.py +28 -0
- stackraise/io/__init__.py +3 -0
- stackraise/io/imap_client.py +400 -0
- stackraise/io/smtp_client.py +102 -0
- stackraise/logging.py +22 -0
- stackraise/model/__init__.py +11 -0
- stackraise/model/core.py +16 -0
- stackraise/model/dto.py +12 -0
- stackraise/model/email_message.py +88 -0
- stackraise/model/file.py +154 -0
- stackraise/model/name_email.py +45 -0
- stackraise/model/query_filters.py +231 -0
- stackraise/model/time_range.py +285 -0
- stackraise/model/validation.py +8 -0
- stackraise/templating/__init__.py +4 -0
- stackraise/templating/exceptions.py +23 -0
- stackraise/templating/image/__init__.py +2 -0
- stackraise/templating/image/model.py +51 -0
- stackraise/templating/image/processor.py +154 -0
- stackraise/templating/parser.py +156 -0
- stackraise/templating/pptx/__init__.py +3 -0
- stackraise/templating/pptx/pptx_engine.py +204 -0
- stackraise/templating/pptx/slide_renderer.py +181 -0
- stackraise/templating/tracer.py +57 -0
- stackraise-0.1.0.dist-info/METADATA +37 -0
- stackraise-0.1.0.dist-info/RECORD +52 -0
- stackraise-0.1.0.dist-info/WHEEL +4 -0
stackraise/__init__.py
ADDED
stackraise/ai/rpa.py
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
1
|
+
# %%
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import asyncio
|
|
5
|
+
from contextlib import asynccontextmanager
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
from abc import abstractmethod
|
|
9
|
+
from dataclasses import Field, dataclass
|
|
10
|
+
from functools import cached_property
|
|
11
|
+
from inspect import isawaitable
|
|
12
|
+
from textwrap import dedent
|
|
13
|
+
from typing import Annotated, Awaitable, Optional, Any
|
|
14
|
+
|
|
15
|
+
from openai import AsyncOpenAI
|
|
16
|
+
from stackraise import model
|
|
17
|
+
|
|
18
|
+
from openai.types.responses import (
|
|
19
|
+
ResponseInputFileParam,
|
|
20
|
+
ResponseInputImageParam,
|
|
21
|
+
ResponseInputParam,
|
|
22
|
+
ResponseFunctionToolCall,
|
|
23
|
+
ResponseOutputMessage,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
EMBEDDED_IMAGE_URL_CONTENT_TYPE = {
|
|
30
|
+
"image/png",
|
|
31
|
+
"image/jpeg",
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# # mime types that can be embedded in the message body via URL encoding
|
|
36
|
+
# EMBEDDED_IMAGE_URL_CONTENT_TYPE = {
|
|
37
|
+
# "image/png",
|
|
38
|
+
# "image/jpeg",
|
|
39
|
+
# }
|
|
40
|
+
|
|
41
|
+
# CONTENT_TYPE_TO_TOOL_MAPPING = {
|
|
42
|
+
# "image/jpeg": "code_interpreter",
|
|
43
|
+
# "image/png": "code_interpreter",
|
|
44
|
+
# "application/pdf": "file_search",
|
|
45
|
+
# "text/plain": "file_search",
|
|
46
|
+
# "application/vnd.openxmlformats-officedocument.wordprocessingml.document": "file_search",
|
|
47
|
+
# "application/msword": "file_search",
|
|
48
|
+
# "application/vnd.ms-excel": "code_interpreter",
|
|
49
|
+
# "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": "code_interpreter",
|
|
50
|
+
# "application/vnd.openxmlformats-officedocument.spreadsheetml.presentation": "code_interpreter",
|
|
51
|
+
# }
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class StopRPA[R](Exception):
|
|
55
|
+
"""Exception to signal that the RPA process should finish with a result."""
|
|
56
|
+
|
|
57
|
+
def __init__(self, status: str, result: R):
|
|
58
|
+
self._status = status
|
|
59
|
+
self._result = result
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def result(self) -> R:
|
|
63
|
+
return self._result
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def status(self) -> str:
|
|
67
|
+
return self._status
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class Rpa[C, R]: # TODO: RpaBase
|
|
71
|
+
"""
|
|
72
|
+
Assistant that processes emails using OpenAI's API to extract intentions and relevant data
|
|
73
|
+
C: Context
|
|
74
|
+
R: Result
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
class _ToolBase[R](model.Base):
|
|
78
|
+
@abstractmethod
|
|
79
|
+
def run(self, rpa: Rpa.Context) -> Awaitable[R] | R:
|
|
80
|
+
raise NotImplementedError()
|
|
81
|
+
|
|
82
|
+
class Query(_ToolBase[str]): ...
|
|
83
|
+
|
|
84
|
+
class Action(_ToolBase[str]):
|
|
85
|
+
def run(self, rpa: Rpa.Context) -> str:
|
|
86
|
+
return "Successfully executed action: " + self.__class__.__name__
|
|
87
|
+
|
|
88
|
+
class Result[R](_ToolBase[R]): ...
|
|
89
|
+
|
|
90
|
+
class Context(model.Base):
|
|
91
|
+
# rpa: Rpa
|
|
92
|
+
performed_actions: Annotated[
|
|
93
|
+
list[Rpa.Action], model.Field(default_factory=list)
|
|
94
|
+
]
|
|
95
|
+
|
|
96
|
+
def make_prompt(self) -> str:
|
|
97
|
+
raise NotImplementedError(
|
|
98
|
+
"get_prompt() must be implemented in context subclasses"
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
def get_files(self) -> list[model.File]:
|
|
102
|
+
"""Return the list of files to be processed by the RPA."""
|
|
103
|
+
return []
|
|
104
|
+
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
*,
|
|
108
|
+
model: str,
|
|
109
|
+
name: Optional[str] = None,
|
|
110
|
+
instructions: Optional[str] = None,
|
|
111
|
+
# toolsets: dict[str, Any] = {},
|
|
112
|
+
queries: list[Query] = [],
|
|
113
|
+
actions: list[Action] = [],
|
|
114
|
+
results: list[Result[R]],
|
|
115
|
+
):
|
|
116
|
+
"""Initialize the RPA assistant with OpenAI credentials"""
|
|
117
|
+
# super().__init__(toolsets=toolsets)
|
|
118
|
+
self.model = model
|
|
119
|
+
self.instructions = dedent(instructions or self.__doc__)
|
|
120
|
+
self.client = AsyncOpenAI()
|
|
121
|
+
self.queries = queries
|
|
122
|
+
self.actions = actions
|
|
123
|
+
self.results = results
|
|
124
|
+
self.name = name or self.__class__.__name__
|
|
125
|
+
# self.conversation = None
|
|
126
|
+
|
|
127
|
+
@cached_property
|
|
128
|
+
def _oai_tool_classes(self) -> list[type[_ToolBase]]:
|
|
129
|
+
"""Return all tools defined in the assistant"""
|
|
130
|
+
return [
|
|
131
|
+
*self.queries,
|
|
132
|
+
*self.actions,
|
|
133
|
+
*self.results,
|
|
134
|
+
]
|
|
135
|
+
|
|
136
|
+
@cached_property
|
|
137
|
+
def _oai_tool_by_name(self) -> dict[str, type[_ToolBase]]:
|
|
138
|
+
return {tool.__name__: tool for tool in self._oai_tool_classes}
|
|
139
|
+
|
|
140
|
+
async def _oai_create_file_content(
|
|
141
|
+
self, file: model.File
|
|
142
|
+
) -> ResponseInputFileParam | ResponseInputImageParam:
|
|
143
|
+
file_content = await file.content()
|
|
144
|
+
|
|
145
|
+
uploaded_file = await self.client.files.create(
|
|
146
|
+
file=(file.filename, file_content, file.content_type),
|
|
147
|
+
purpose="user_data",
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
if file.content_type in EMBEDDED_IMAGE_URL_CONTENT_TYPE:
|
|
151
|
+
return {
|
|
152
|
+
"type": "input_image",
|
|
153
|
+
"file_id": uploaded_file.id,
|
|
154
|
+
"detail": "high",
|
|
155
|
+
}
|
|
156
|
+
else:
|
|
157
|
+
return {
|
|
158
|
+
"type": "input_file",
|
|
159
|
+
"file_id": uploaded_file.id,
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
async def run(self, ctx: Context) -> R:
|
|
163
|
+
"""
|
|
164
|
+
Process an email to determine its type and extract relevant information
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
email: The email message to process
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
InferResult with the detected label and associated action
|
|
171
|
+
"""
|
|
172
|
+
|
|
173
|
+
response = None
|
|
174
|
+
|
|
175
|
+
tools = [
|
|
176
|
+
# {"type": "file_search"},
|
|
177
|
+
# {"type": "code_interpreter"},
|
|
178
|
+
*[get_oai_tool_schema(tool) for tool in self._oai_tool_classes],
|
|
179
|
+
]
|
|
180
|
+
|
|
181
|
+
input: list[ResponseInputParam] = [
|
|
182
|
+
{
|
|
183
|
+
"role": "user",
|
|
184
|
+
"content": [
|
|
185
|
+
await self._oai_create_file_content(file)
|
|
186
|
+
for file in ctx.get_files()
|
|
187
|
+
],
|
|
188
|
+
},
|
|
189
|
+
]
|
|
190
|
+
|
|
191
|
+
try:
|
|
192
|
+
while True:
|
|
193
|
+
response = await self.client.responses.create(
|
|
194
|
+
model=self.model,
|
|
195
|
+
instructions=self.instructions,
|
|
196
|
+
tools=tools,
|
|
197
|
+
input=input,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
input.extend(response.output)
|
|
201
|
+
|
|
202
|
+
match response.status:
|
|
203
|
+
# case 'completed':
|
|
204
|
+
# input.append({
|
|
205
|
+
# 'role': "system",
|
|
206
|
+
# 'content': "IMPORTANT: You must complete the task by calling a Result tool.",
|
|
207
|
+
# })
|
|
208
|
+
# continue
|
|
209
|
+
case "failed" | "cancelled":
|
|
210
|
+
raise StopRPA(
|
|
211
|
+
response.status, response.output.content[0].text.value
|
|
212
|
+
)
|
|
213
|
+
case _:
|
|
214
|
+
pass
|
|
215
|
+
|
|
216
|
+
for output_item in response.output:
|
|
217
|
+
match output_item:
|
|
218
|
+
case ResponseFunctionToolCall(
|
|
219
|
+
name=name, arguments=arguments, call_id=call_id
|
|
220
|
+
):
|
|
221
|
+
logger.info("Dispatching tool call %s with arguments %s", name, arguments)
|
|
222
|
+
|
|
223
|
+
result = await self._dispatch_tool_call(
|
|
224
|
+
ctx,
|
|
225
|
+
name,
|
|
226
|
+
arguments,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
input.append(
|
|
230
|
+
{
|
|
231
|
+
"type": "function_call_output",
|
|
232
|
+
"call_id": call_id,
|
|
233
|
+
"output": str(result) if result is not None else "OK",
|
|
234
|
+
}
|
|
235
|
+
)
|
|
236
|
+
case ResponseOutputMessage(content=content):
|
|
237
|
+
print(f"RPA >>> {content}")
|
|
238
|
+
case _:
|
|
239
|
+
logger.info(f"Unhandled response output {type(output_item)}")
|
|
240
|
+
pass
|
|
241
|
+
|
|
242
|
+
except StopRPA as e:
|
|
243
|
+
if e.status == "completed":
|
|
244
|
+
return e.result
|
|
245
|
+
raise e
|
|
246
|
+
|
|
247
|
+
async def _dispatch_tool_call(self, ctx: Context, function_name, arguments):
|
|
248
|
+
try:
|
|
249
|
+
arguments = json.loads(arguments)
|
|
250
|
+
except json.JSONDecodeError as e:
|
|
251
|
+
return f"Invalid JSON in function arguments: {e}"
|
|
252
|
+
|
|
253
|
+
try:
|
|
254
|
+
tool_cls = self._oai_tool_by_name.get(function_name)
|
|
255
|
+
except KeyError:
|
|
256
|
+
return f"Tool {function_name} not found"
|
|
257
|
+
|
|
258
|
+
try:
|
|
259
|
+
tool = tool_cls.model_validate(arguments)
|
|
260
|
+
except ValueError as e:
|
|
261
|
+
return str(e)
|
|
262
|
+
|
|
263
|
+
logger.info("Running tool %s", tool)
|
|
264
|
+
|
|
265
|
+
try:
|
|
266
|
+
result = tool.run(ctx)
|
|
267
|
+
|
|
268
|
+
if isawaitable(result):
|
|
269
|
+
result = await result
|
|
270
|
+
|
|
271
|
+
if issubclass(tool_cls, self.Action):
|
|
272
|
+
ctx.performed_actions.append(tool)
|
|
273
|
+
|
|
274
|
+
if issubclass(tool_cls, self.Result):
|
|
275
|
+
raise StopRPA("completed", result)
|
|
276
|
+
except StopRPA as e:
|
|
277
|
+
raise e
|
|
278
|
+
except Exception as e:
|
|
279
|
+
logger.exception("Error running tool %s: %s", tool_cls.__name__, e)
|
|
280
|
+
return f"Error running tool {tool_cls.__name__}: {e}"
|
|
281
|
+
|
|
282
|
+
return result
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def get_oai_tool_schema(model_class: type[model.Base], flatten_schema: bool = False):
|
|
286
|
+
schema = model_class.model_json_schema()
|
|
287
|
+
|
|
288
|
+
if flatten_schema:
|
|
289
|
+
# Post-procesamiento para eliminar $ref:
|
|
290
|
+
defs = schema.pop("$defs", {}) # Extraer definiciones
|
|
291
|
+
|
|
292
|
+
def resolve_refs(obj: dict):
|
|
293
|
+
"""Función recursiva para reemplazar $ref por la definición correspondiente."""
|
|
294
|
+
for key, value in list(obj.items()):
|
|
295
|
+
if isinstance(value, dict):
|
|
296
|
+
if "$ref" in value: # encontramos una referencia
|
|
297
|
+
ref_path: str = value.pop("$ref")
|
|
298
|
+
ref_name = ref_path.split("/")[
|
|
299
|
+
-1
|
|
300
|
+
] # nombre de la definición referenciada
|
|
301
|
+
if ref_name in defs:
|
|
302
|
+
# Tomar una copia de la definición e insertarla aquí
|
|
303
|
+
sub_schema = defs[ref_name]
|
|
304
|
+
# Antes de insertar, opcionalmente remover metadata no deseada
|
|
305
|
+
sub_schema.pop(
|
|
306
|
+
"title", None
|
|
307
|
+
) # (ejemplo: quitar títulos automáticos)
|
|
308
|
+
value.update(sub_schema) # insertar keys del sub-esquema
|
|
309
|
+
# Llamada recursiva para anidar más profundo
|
|
310
|
+
resolve_refs(value)
|
|
311
|
+
elif isinstance(value, list):
|
|
312
|
+
for item in value:
|
|
313
|
+
if isinstance(item, dict):
|
|
314
|
+
resolve_refs(item)
|
|
315
|
+
|
|
316
|
+
resolve_refs(schema)
|
|
317
|
+
|
|
318
|
+
# return {
|
|
319
|
+
# "type": "function",
|
|
320
|
+
# "function": {
|
|
321
|
+
# "name": model_class.__name__,
|
|
322
|
+
# "description": dedent(model_class.__doc__),
|
|
323
|
+
# "parameters": schema,
|
|
324
|
+
# # {
|
|
325
|
+
# # "type": "object",
|
|
326
|
+
# # "properties": schema.get("properties"),
|
|
327
|
+
# # "required": schema.get("required", []),
|
|
328
|
+
# # },
|
|
329
|
+
# },
|
|
330
|
+
# }
|
|
331
|
+
|
|
332
|
+
return {
|
|
333
|
+
"type": "function",
|
|
334
|
+
"name": model_class.__name__,
|
|
335
|
+
"description": dedent(model_class.__doc__),
|
|
336
|
+
"parameters": schema,
|
|
337
|
+
}
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
if __name__ == "__main__":
|
|
341
|
+
from asyncio import run
|
|
342
|
+
|
|
343
|
+
class WeatherAgent(Rpa[str]):
|
|
344
|
+
"You are a weather agent that provides current weather information."
|
|
345
|
+
|
|
346
|
+
class QueryCurrentWeather(Rpa.Query):
|
|
347
|
+
"get the current weather in a given city"
|
|
348
|
+
|
|
349
|
+
city: str
|
|
350
|
+
|
|
351
|
+
async def run(self, rpa: WeatherAgent):
|
|
352
|
+
return (
|
|
353
|
+
f"current weather in {self.city} are sunny and 25 degrees Celsius"
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
class CurrentWeatherResult(Rpa.Result[str]):
|
|
357
|
+
"""
|
|
358
|
+
Result of the current weather query
|
|
359
|
+
|
|
360
|
+
You must call this tool when you have the current weather information to finalize the RPA process.
|
|
361
|
+
"""
|
|
362
|
+
|
|
363
|
+
city: str
|
|
364
|
+
temperature: float
|
|
365
|
+
# condition: str
|
|
366
|
+
|
|
367
|
+
async def run(self, rpa: WeatherAgent):
|
|
368
|
+
return f"The current weather in {self.city} is {self.temperature}°C."
|
|
369
|
+
|
|
370
|
+
def __init__(self):
|
|
371
|
+
super().__init__(
|
|
372
|
+
model="gpt-4o",
|
|
373
|
+
queries=[self.QueryCurrentWeather],
|
|
374
|
+
actions=[],
|
|
375
|
+
results=[self.CurrentWeatherResult],
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
rpa = WeatherAgent()
|
|
379
|
+
response = run(rpa.run("Get the current weather in New York City"))
|
|
380
|
+
print(response)
|
stackraise/ai/toolset.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
# %%
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from venv import logger
|
|
4
|
+
from pydantic import Field, create_model, BaseModel
|
|
5
|
+
|
|
6
|
+
from functools import cache, cached_property, update_wrapper
|
|
7
|
+
from inspect import isfunction, get_annotations, isawaitable, signature, Parameter
|
|
8
|
+
from typing import Annotated, Any, Callable, Optional
|
|
9
|
+
from textwrap import dedent
|
|
10
|
+
|
|
11
|
+
from types import MethodType
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ArgsModelBase(BaseModel):
|
|
15
|
+
"Base class for tool arguments. All tools should inherit from this class."
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ToolDescriptor[*A, R]:
|
|
19
|
+
# __slots__ = (
|
|
20
|
+
# # "__module__",
|
|
21
|
+
# "__name__",
|
|
22
|
+
# "__qualname__",
|
|
23
|
+
# "__doc__",
|
|
24
|
+
# "__annotations__",
|
|
25
|
+
# # "__type_params__",
|
|
26
|
+
# "__func__",
|
|
27
|
+
# "__owner__",
|
|
28
|
+
# )
|
|
29
|
+
|
|
30
|
+
def __init__(self, func: Callable[[*A], R]):
|
|
31
|
+
self.__func__ = func
|
|
32
|
+
# update_wrapper(self, func)
|
|
33
|
+
|
|
34
|
+
def __set_name__(self, owner, name):
|
|
35
|
+
assert (
|
|
36
|
+
self.__func__.__name__ == name
|
|
37
|
+
), f"Tool name mismatch: {self.__name__} != {name}"
|
|
38
|
+
self.__owner__ = owner
|
|
39
|
+
|
|
40
|
+
def __get__(self, instance, owner):
|
|
41
|
+
if instance is None:
|
|
42
|
+
return self
|
|
43
|
+
return MethodType(self.__func__, instance)
|
|
44
|
+
|
|
45
|
+
#def __call__(self, instance, *args, **kwargs): ...
|
|
46
|
+
|
|
47
|
+
@cached_property
|
|
48
|
+
def ArgsModel(self) -> type[BaseModel]:
|
|
49
|
+
assert hasattr(self, "__owner__"), "ToolDescriptor must be set on a class"
|
|
50
|
+
|
|
51
|
+
sign = signature(self.__func__)
|
|
52
|
+
|
|
53
|
+
field_definitions = {
|
|
54
|
+
f.name: (
|
|
55
|
+
(f.annotation, f.default)
|
|
56
|
+
if f.default is not Parameter.empty
|
|
57
|
+
else f.annotation
|
|
58
|
+
)
|
|
59
|
+
for f in sign.parameters.values()
|
|
60
|
+
if f.annotation is not Parameter.empty
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
arg_model = create_model(
|
|
64
|
+
"ArgsModel",
|
|
65
|
+
__base__=ArgsModelBase,
|
|
66
|
+
__module__=self.__func__.__module__,
|
|
67
|
+
__doc__=self.__func__.__doc__,
|
|
68
|
+
**field_definitions,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
setattr(
|
|
72
|
+
arg_model,
|
|
73
|
+
"__qualname__",
|
|
74
|
+
f"{self.__owner__.__name__}.{self.__func__.__name__}.ArgsModel",
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
return arg_model
|
|
78
|
+
|
|
79
|
+
# @cache
|
|
80
|
+
def generate_schema(self, namespace: Optional[str] = None):
|
|
81
|
+
return {
|
|
82
|
+
"name": (
|
|
83
|
+
f"{namespace}-{self.__func__.__name__}"
|
|
84
|
+
if namespace
|
|
85
|
+
else self.__func__.__name__
|
|
86
|
+
),
|
|
87
|
+
"type": "function",
|
|
88
|
+
"description": dedent(self.__func__.__doc__.strip()) if self.__func__.__doc__ else "",
|
|
89
|
+
"parameters": self.ArgsModel.model_json_schema(),
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def tool(fn):
|
|
94
|
+
"""
|
|
95
|
+
Decorator to mark a function as a tool in the Toolset.
|
|
96
|
+
"""
|
|
97
|
+
if not isfunction(fn):
|
|
98
|
+
raise TypeError("tool decorator can only be applied to functions")
|
|
99
|
+
|
|
100
|
+
return ToolDescriptor(fn)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _generate_tools_schema(toolsets: dict[str, type[Any]]):
|
|
104
|
+
"""
|
|
105
|
+
Create tool descriptors for all tools in the toolsets.
|
|
106
|
+
"""
|
|
107
|
+
for namespace, toolset in toolsets.items():
|
|
108
|
+
for name, tool in vars(toolset).items():
|
|
109
|
+
if not isinstance(tool, ToolDescriptor):
|
|
110
|
+
continue
|
|
111
|
+
yield tool.generate_schema(namespace=namespace)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _generate_tools_schema(toolsets: dict[str, type[Any]]):
|
|
115
|
+
"""
|
|
116
|
+
Create tool descriptors for all tools in the toolsets.
|
|
117
|
+
"""
|
|
118
|
+
for namespace, toolset in toolsets.items():
|
|
119
|
+
for name, tool in vars(toolset).items():
|
|
120
|
+
if not isinstance(tool, ToolDescriptor):
|
|
121
|
+
continue
|
|
122
|
+
yield tool.generate_schema(namespace=namespace)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class ToolDispatcher:
|
|
126
|
+
|
|
127
|
+
def __init__(self, toolsets: dict[str, Any]):
|
|
128
|
+
"""
|
|
129
|
+
Initialize the ToolDispatcher with a dictionary of toolsets.
|
|
130
|
+
Each toolset should be a class with methods decorated with @tool.
|
|
131
|
+
"""
|
|
132
|
+
self._toolsets = toolsets
|
|
133
|
+
|
|
134
|
+
@cached_property
|
|
135
|
+
def tools_schema(self) -> list[dict]:
|
|
136
|
+
return list(
|
|
137
|
+
_generate_tools_schema({nm: type(v) for nm, v in self._toolsets.items()})
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
@cached_property
|
|
141
|
+
def tool_mapping(self) -> dict[str, tuple[Any, ToolDescriptor]]:
|
|
142
|
+
result = {}
|
|
143
|
+
for item in self.tools_schema:
|
|
144
|
+
fullname = item["name"]
|
|
145
|
+
namespace, fn_name = fullname.split("-")
|
|
146
|
+
toolset = self._toolsets.get(namespace, None)
|
|
147
|
+
assert (
|
|
148
|
+
toolset is not None
|
|
149
|
+
), f"Toolset '{namespace}' not found in {self._toolsets.keys()}"
|
|
150
|
+
tool_descriptor = getattr(type(toolset), fn_name, None)
|
|
151
|
+
assert (
|
|
152
|
+
tool_descriptor is not None
|
|
153
|
+
), f"Function '{fn_name}' not found in toolset '{namespace}'"
|
|
154
|
+
result[fullname] = toolset, tool_descriptor
|
|
155
|
+
return result
|
|
156
|
+
|
|
157
|
+
async def _dispatch_tool_call(self, name: str, raw_args: Any) -> str:
|
|
158
|
+
|
|
159
|
+
try:
|
|
160
|
+
toolset, tool = self.tool_mapping.get(name, None)
|
|
161
|
+
except Exception as e:
|
|
162
|
+
return f"ERROR: Tool '{name}' not found."
|
|
163
|
+
|
|
164
|
+
args_model_class = tool.ArgsModel
|
|
165
|
+
|
|
166
|
+
# validate the model
|
|
167
|
+
try:
|
|
168
|
+
if isinstance(raw_args, (str, bytes, bytearray)):
|
|
169
|
+
args_model = args_model_class.model_validate_json(raw_args)
|
|
170
|
+
else:
|
|
171
|
+
args_model = args_model_class.model_validate(raw_args)
|
|
172
|
+
except Exception as e:
|
|
173
|
+
logger.debug(f"Error validating args for tool '{name}': {str(e)}")
|
|
174
|
+
return f"Error: Invalid arguments for tool '{name}': {str(e)}"
|
|
175
|
+
|
|
176
|
+
# convert model fields to function arguments
|
|
177
|
+
args = {
|
|
178
|
+
nm: getattr(args_model, nm) for nm in args_model_class.model_fields.keys()
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
try:
|
|
182
|
+
result = tool.__func__(toolset, **args)
|
|
183
|
+
|
|
184
|
+
if isawaitable(result):
|
|
185
|
+
result = await result
|
|
186
|
+
|
|
187
|
+
except Exception as e:
|
|
188
|
+
logger.debug(f"Error executing tool '{name}': {str(e)}")
|
|
189
|
+
return f"Error executing tool '{name}': {str(e)}"
|
|
190
|
+
else:
|
|
191
|
+
logger.debug(f"Tool '{name}' executed successfully: {result}")
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
if isinstance(result, BaseModel):
|
|
196
|
+
# Convert BaseModel to JSON string
|
|
197
|
+
result = result.model_dump_json()
|
|
198
|
+
|
|
199
|
+
return result
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
if __name__ == "__main__":
|
|
203
|
+
|
|
204
|
+
class MyToolset:
|
|
205
|
+
|
|
206
|
+
# class Response(BaseModel):
|
|
207
|
+
# message: str
|
|
208
|
+
|
|
209
|
+
#TODO: custom responses
|
|
210
|
+
@tool
|
|
211
|
+
def query_weather(
|
|
212
|
+
self,
|
|
213
|
+
city: Annotated[str, Field(description="nombre de la ciudad")],
|
|
214
|
+
time: Annotated[datetime, Field(description="hora de la consulta")],
|
|
215
|
+
) -> str:
|
|
216
|
+
"""
|
|
217
|
+
Example tool that takes two arguments and returns a string.
|
|
218
|
+
"""
|
|
219
|
+
|
|
220
|
+
#ts = td.tools_schema
|
|
221
|
+
|
|
222
|
+
# ts = MyToolset()
|
|
223
|
+
# ts("query_weather", '{"city": "Madrid", "time": "2023-03-15T12:00:00Z"}')
|
|
224
|
+
# print(MyToolset.tool_schemas)
|
|
225
|
+
# %%
|
|
226
|
+
#f = await td.dispatch_tool_call('mytools-query_weather', '{"city": "Madrid", "time": "2023-03-15T12:00:00Z"}')
|
|
227
|
+
|
stackraise/app.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import stackraise.db as db
|
|
2
|
+
|
|
3
|
+
class Middleware:
|
|
4
|
+
"""
|
|
5
|
+
A middleware class for FastAPI applications that provides persistence layer context management
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
```
|
|
9
|
+
from fastapi import FastAPI
|
|
10
|
+
from backframe import Backframe, Persistence
|
|
11
|
+
|
|
12
|
+
app = FastAPI()
|
|
13
|
+
app.add_middleware(Backframe, persistence=Persistence())
|
|
14
|
+
```
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, app, persistence: db.Persistence):
|
|
18
|
+
self.app = app
|
|
19
|
+
self.persistence = persistence
|
|
20
|
+
|
|
21
|
+
async def __call__(self, scope, receive, send):
|
|
22
|
+
async with self.persistence.session():
|
|
23
|
+
return await self.app(scope, receive, send)
|
stackraise/auth/model.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from enum import Enum
|
|
3
|
+
|
|
4
|
+
from typing import ClassVar, Annotated, Self
|
|
5
|
+
from pydantic import Field, EmailStr
|
|
6
|
+
|
|
7
|
+
import stackraise.db as db
|
|
8
|
+
|
|
9
|
+
class BaseUserAccount(db.Document, ABC, abstract=True):
|
|
10
|
+
class Scope(str, Enum): ...
|
|
11
|
+
|
|
12
|
+
SCOPES: ClassVar[dict[Scope, str]]
|
|
13
|
+
LOGIN_URL: ClassVar[str] = "/auth"
|
|
14
|
+
|
|
15
|
+
email: Annotated[EmailStr, Field()]
|
|
16
|
+
scopes: Annotated[list[Scope], Field()]
|
|
17
|
+
|
|
18
|
+
password_salt: Annotated[str, Field()]
|
|
19
|
+
password_hash: Annotated[str, Field()]
|
|
20
|
+
|
|
21
|
+
@classmethod
|
|
22
|
+
async def fetch_by_email(cls, email:str)-> Self | None:
|
|
23
|
+
"""Fetch a user account by email."""
|
|
24
|
+
return await cls.collection._find_one({"email": email})
|