quillsql 2.2.7__py3-none-any.whl → 2.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quillsql/core.py +271 -1
- {quillsql-2.2.7.dist-info → quillsql-2.2.9.dist-info}/METADATA +43 -2
- {quillsql-2.2.7.dist-info → quillsql-2.2.9.dist-info}/RECORD +5 -5
- {quillsql-2.2.7.dist-info → quillsql-2.2.9.dist-info}/WHEEL +1 -1
- {quillsql-2.2.7.dist-info → quillsql-2.2.9.dist-info}/top_level.txt +0 -0
quillsql/core.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
|
+
import codecs
|
|
2
3
|
from dotenv import load_dotenv
|
|
3
4
|
|
|
4
5
|
import requests
|
|
@@ -29,7 +30,7 @@ load_dotenv()
|
|
|
29
30
|
|
|
30
31
|
ENV = os.getenv("PYTHON_ENV")
|
|
31
32
|
DEV_HOST = "http://localhost:8080"
|
|
32
|
-
PROD_HOST = "https://quill
|
|
33
|
+
PROD_HOST = "https://api.quill.co"
|
|
33
34
|
HOST = DEV_HOST if ENV == "development" else PROD_HOST
|
|
34
35
|
|
|
35
36
|
SINGLE_TENANT = "QUILL_SINGLE_TENANT"
|
|
@@ -386,6 +387,275 @@ class Quill:
|
|
|
386
387
|
"status": "error",
|
|
387
388
|
"data": responseMetadata,
|
|
388
389
|
}
|
|
390
|
+
|
|
391
|
+
async def stream(
|
|
392
|
+
self,
|
|
393
|
+
tenants,
|
|
394
|
+
metadata,
|
|
395
|
+
flags=None,
|
|
396
|
+
filters=None,
|
|
397
|
+
admin_enabled=None,
|
|
398
|
+
):
|
|
399
|
+
if not tenants:
|
|
400
|
+
raise ValueError("You may not pass an empty tenants array.")
|
|
401
|
+
|
|
402
|
+
if not metadata:
|
|
403
|
+
yield {"type": "error", "errorText": "Missing metadata."}
|
|
404
|
+
return
|
|
405
|
+
|
|
406
|
+
task = metadata.get("task")
|
|
407
|
+
if not task:
|
|
408
|
+
yield {"type": "error", "errorText": "Missing task."}
|
|
409
|
+
return
|
|
410
|
+
|
|
411
|
+
try:
|
|
412
|
+
# Set tenant IDs in the connection
|
|
413
|
+
self.target_connection.tenant_ids = extract_tenant_ids(tenants)
|
|
414
|
+
if task in ("chat", "agent"):
|
|
415
|
+
for event in self._agentic_chat_loop(
|
|
416
|
+
tenants,
|
|
417
|
+
metadata,
|
|
418
|
+
flags,
|
|
419
|
+
filters,
|
|
420
|
+
admin_enabled,
|
|
421
|
+
):
|
|
422
|
+
yield event
|
|
423
|
+
return
|
|
424
|
+
|
|
425
|
+
for event in self._stream_sse(
|
|
426
|
+
task,
|
|
427
|
+
tenants,
|
|
428
|
+
metadata,
|
|
429
|
+
flags,
|
|
430
|
+
filters,
|
|
431
|
+
admin_enabled,
|
|
432
|
+
):
|
|
433
|
+
yield event
|
|
434
|
+
return
|
|
435
|
+
except Exception as err:
|
|
436
|
+
yield {
|
|
437
|
+
"type": "error",
|
|
438
|
+
"errorText": str(err).splitlines()[0],
|
|
439
|
+
}
|
|
440
|
+
return
|
|
441
|
+
|
|
442
|
+
def _normalize_tenant_flags(self, tenants, flags):
|
|
443
|
+
tenant_flags = None
|
|
444
|
+
if tenants and tenants[0] == SINGLE_TENANT and flags:
|
|
445
|
+
if flags and isinstance(flags[0], dict):
|
|
446
|
+
tenant_flags = [{"tenantField": SINGLE_TENANT, "flags": flags}]
|
|
447
|
+
else:
|
|
448
|
+
tenant_flags = flags
|
|
449
|
+
return tenant_flags
|
|
450
|
+
|
|
451
|
+
def _agentic_chat_loop(self, tenants, metadata, flags, filters, admin_enabled):
|
|
452
|
+
messages = list(metadata.get("messages") or [])
|
|
453
|
+
max_iterations = 10
|
|
454
|
+
|
|
455
|
+
for _ in range(max_iterations):
|
|
456
|
+
payload = {
|
|
457
|
+
**metadata,
|
|
458
|
+
"messages": messages,
|
|
459
|
+
}
|
|
460
|
+
|
|
461
|
+
has_tool_calls = False
|
|
462
|
+
assistant_text = ""
|
|
463
|
+
tool_results = []
|
|
464
|
+
tool_calls = []
|
|
465
|
+
|
|
466
|
+
for event in self._stream_sse(
|
|
467
|
+
"agent",
|
|
468
|
+
tenants,
|
|
469
|
+
payload,
|
|
470
|
+
flags,
|
|
471
|
+
filters,
|
|
472
|
+
admin_enabled,
|
|
473
|
+
):
|
|
474
|
+
yield event
|
|
475
|
+
|
|
476
|
+
if event.get("type") == "text-delta":
|
|
477
|
+
assistant_text += event.get("delta", "")
|
|
478
|
+
|
|
479
|
+
if event.get("type") == "tool-input-available":
|
|
480
|
+
tool_name = event.get("toolName")
|
|
481
|
+
tool_call_id = event.get("toolCallId")
|
|
482
|
+
tool_input = event.get("input") or {}
|
|
483
|
+
|
|
484
|
+
if tool_call_id is None:
|
|
485
|
+
yield {
|
|
486
|
+
"type": "error",
|
|
487
|
+
"errorText": "Missing toolCallId for tool-input-available event.",
|
|
488
|
+
}
|
|
489
|
+
continue
|
|
490
|
+
|
|
491
|
+
has_tool_calls = True
|
|
492
|
+
yield {
|
|
493
|
+
"type": "tool-executing",
|
|
494
|
+
"toolCallId": tool_call_id,
|
|
495
|
+
"toolName": tool_name,
|
|
496
|
+
}
|
|
497
|
+
|
|
498
|
+
result = self._execute_tool_locally(tool_name, tool_input)
|
|
499
|
+
tool_results.append(
|
|
500
|
+
{
|
|
501
|
+
"toolCallId": tool_call_id,
|
|
502
|
+
"toolName": tool_name,
|
|
503
|
+
"input": tool_input,
|
|
504
|
+
"result": result,
|
|
505
|
+
}
|
|
506
|
+
)
|
|
507
|
+
tool_calls.append(
|
|
508
|
+
{
|
|
509
|
+
"id": tool_call_id,
|
|
510
|
+
"type": "function",
|
|
511
|
+
"function": {
|
|
512
|
+
"name": tool_name,
|
|
513
|
+
"arguments": json.dumps(tool_input or {}),
|
|
514
|
+
},
|
|
515
|
+
}
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
yield {
|
|
519
|
+
"type": "tool-result",
|
|
520
|
+
"toolCallId": tool_call_id,
|
|
521
|
+
"result": result,
|
|
522
|
+
}
|
|
523
|
+
|
|
524
|
+
if event.get("type") in ("finish", "error"):
|
|
525
|
+
break
|
|
526
|
+
|
|
527
|
+
if not has_tool_calls:
|
|
528
|
+
break
|
|
529
|
+
|
|
530
|
+
def _build_tool_result_for_history(tool_result):
|
|
531
|
+
result = tool_result.get("result") or {}
|
|
532
|
+
has_rows = isinstance(result.get("rows"), list)
|
|
533
|
+
has_fields = isinstance(result.get("fields"), list)
|
|
534
|
+
is_query_result = has_rows or has_fields or result.get("dbMismatched")
|
|
535
|
+
if not is_query_result:
|
|
536
|
+
return result
|
|
537
|
+
tool_input = tool_result.get("input") or {}
|
|
538
|
+
error = result.get("error") or tool_input.get("error")
|
|
539
|
+
status = "error" if error or result.get("dbMismatched") else "success"
|
|
540
|
+
payload = {"status": status}
|
|
541
|
+
if tool_input.get("sql"):
|
|
542
|
+
payload["sql"] = tool_input.get("sql")
|
|
543
|
+
if error:
|
|
544
|
+
payload["error"] = error
|
|
545
|
+
if result.get("dbMismatched"):
|
|
546
|
+
payload["meta"] = {
|
|
547
|
+
"dbMismatched": True,
|
|
548
|
+
"backendDatabaseType": result.get("backendDatabaseType"),
|
|
549
|
+
}
|
|
550
|
+
elif not error:
|
|
551
|
+
payload["meta"] = {
|
|
552
|
+
"rowsFetchedSuccessfully": True,
|
|
553
|
+
"rowCount": len(result.get("rows") or []),
|
|
554
|
+
}
|
|
555
|
+
return payload
|
|
556
|
+
|
|
557
|
+
messages.append(
|
|
558
|
+
{
|
|
559
|
+
"role": "assistant",
|
|
560
|
+
"content": assistant_text or None,
|
|
561
|
+
"tool_calls": tool_calls,
|
|
562
|
+
}
|
|
563
|
+
)
|
|
564
|
+
for tool_result in tool_results:
|
|
565
|
+
messages.append(
|
|
566
|
+
{
|
|
567
|
+
"role": "tool",
|
|
568
|
+
"tool_call_id": tool_result["toolCallId"],
|
|
569
|
+
"content": json.dumps(_build_tool_result_for_history(tool_result)),
|
|
570
|
+
}
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
yield {"type": "done"}
|
|
574
|
+
|
|
575
|
+
def _execute_tool_locally(self, tool_name, tool_input):
|
|
576
|
+
if tool_name == "generateReport":
|
|
577
|
+
if tool_input.get("error"):
|
|
578
|
+
return {"error": tool_input.get("error")}
|
|
579
|
+
sql = tool_input.get("sql")
|
|
580
|
+
if not sql:
|
|
581
|
+
return {"error": "No SQL provided"}
|
|
582
|
+
results = self.run_queries(
|
|
583
|
+
[sql],
|
|
584
|
+
self.target_connection.database_type,
|
|
585
|
+
)
|
|
586
|
+
if results.get("dbMismatched"):
|
|
587
|
+
return results
|
|
588
|
+
query_results = results.get("queryResults") or []
|
|
589
|
+
if query_results and isinstance(query_results[0], dict):
|
|
590
|
+
if query_results[0].get("error"):
|
|
591
|
+
return query_results[0]
|
|
592
|
+
return {
|
|
593
|
+
"rows": query_results[0].get("rows", []),
|
|
594
|
+
"fields": query_results[0].get("fields", []),
|
|
595
|
+
}
|
|
596
|
+
return {"rows": [], "fields": []}
|
|
597
|
+
|
|
598
|
+
if tool_name == "createChart":
|
|
599
|
+
return {"chartConfig": tool_input}
|
|
600
|
+
|
|
601
|
+
return {"error": f"Unknown tool: {tool_name}"}
|
|
602
|
+
|
|
603
|
+
def _stream_sse(self, endpoint, tenants, payload, flags, filters, admin_enabled):
|
|
604
|
+
tenant_flags = self._normalize_tenant_flags(tenants, flags)
|
|
605
|
+
request_payload = {
|
|
606
|
+
**payload,
|
|
607
|
+
"tenants": tenants,
|
|
608
|
+
"flags": tenant_flags,
|
|
609
|
+
}
|
|
610
|
+
if filters:
|
|
611
|
+
request_payload["sdkFilters"] = [convert_custom_filter(f) for f in filters]
|
|
612
|
+
if admin_enabled is not None:
|
|
613
|
+
request_payload["adminEnabled"] = admin_enabled
|
|
614
|
+
|
|
615
|
+
# Custom JSON Encoder to handle Enums
|
|
616
|
+
class EnumEncoder(json.JSONEncoder):
|
|
617
|
+
def default(self, obj):
|
|
618
|
+
if isinstance(obj, Enum):
|
|
619
|
+
return obj.value # Convert enum to its value (string in this case)
|
|
620
|
+
return super().default(obj)
|
|
621
|
+
|
|
622
|
+
url = f"{self.baseUrl}/sdk/{endpoint}"
|
|
623
|
+
headers = {
|
|
624
|
+
"Authorization": f"Bearer {self.private_key}",
|
|
625
|
+
"Content-Type": "application/json",
|
|
626
|
+
"Accept": "text/event-stream",
|
|
627
|
+
}
|
|
628
|
+
encoded = json.dumps(request_payload, cls=EnumEncoder)
|
|
629
|
+
|
|
630
|
+
resp = requests.post(url, data=encoded, headers=headers, stream=True)
|
|
631
|
+
decoder = codecs.getincrementaldecoder("utf-8")()
|
|
632
|
+
buf = ""
|
|
633
|
+
for chunk in resp.iter_content(chunk_size=4096):
|
|
634
|
+
buf += decoder.decode(chunk)
|
|
635
|
+
while "\n\n" in buf:
|
|
636
|
+
raw_event, buf = buf.split("\n\n", 1)
|
|
637
|
+
data_lines = []
|
|
638
|
+
for line in raw_event.splitlines():
|
|
639
|
+
if line.startswith("data:"):
|
|
640
|
+
data_lines.append(line[len("data:"):].strip())
|
|
641
|
+
if not data_lines:
|
|
642
|
+
continue
|
|
643
|
+
payload = "\n".join(data_lines)
|
|
644
|
+
if payload == "[DONE]":
|
|
645
|
+
return
|
|
646
|
+
try:
|
|
647
|
+
parsed = json.loads(payload)
|
|
648
|
+
if isinstance(parsed, str):
|
|
649
|
+
yield {"type": "text-delta", "id": "0", "delta": parsed}
|
|
650
|
+
else:
|
|
651
|
+
yield parsed
|
|
652
|
+
except json.JSONDecodeError:
|
|
653
|
+
continue
|
|
654
|
+
|
|
655
|
+
# flush any partial code points at the end
|
|
656
|
+
buf += decoder.decode(b"", final=True)
|
|
657
|
+
yield buf
|
|
658
|
+
return
|
|
389
659
|
|
|
390
660
|
def apply_limit(self, query, limit):
|
|
391
661
|
# Simple logic: if query already has a limit, don't add another
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: quillsql
|
|
3
|
-
Version: 2.2.
|
|
3
|
+
Version: 2.2.9
|
|
4
4
|
Summary: Quill SDK for Python.
|
|
5
5
|
Home-page: https://github.com/quill-sql/quill-python
|
|
6
6
|
Author: Quill
|
|
@@ -67,3 +67,44 @@ async def quill_post(data: Request, user: dict = Depends(authenticate_jwt)):
|
|
|
67
67
|
|
|
68
68
|
Then you can run your app like normally. Pass in this route to our react library
|
|
69
69
|
on the frontend and you all set!
|
|
70
|
+
|
|
71
|
+
## Streaming
|
|
72
|
+
|
|
73
|
+
```python
|
|
74
|
+
from quillsql import Quill
|
|
75
|
+
from fastapi.responses import StreamingResponse
|
|
76
|
+
import asyncio
|
|
77
|
+
|
|
78
|
+
quill = Quill(
|
|
79
|
+
private_key=os.getenv("QULL_PRIVATE_KEY"),
|
|
80
|
+
database_connection_string=os.getenv("POSTGRES_READ"),
|
|
81
|
+
database_type="postgresql"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
@app.post("/quill-stream")
|
|
85
|
+
async def quill_post(data: Request, user: dict = Depends(authenticate_jwt)):
|
|
86
|
+
# assuming user fetched via auth middleware has an userId
|
|
87
|
+
user_id = user["user_id"]
|
|
88
|
+
body = await data.json()
|
|
89
|
+
metadata = body.get("metadata")
|
|
90
|
+
|
|
91
|
+
quill_stream = quill.stream(
|
|
92
|
+
tenants=[{"tenantField": "user_id", "tenantIds": [user_id]}],
|
|
93
|
+
metadata=metadata,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
async def event_generator():
|
|
97
|
+
# Full event types list: https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol#data-stream-protocol
|
|
98
|
+
async for event in quill_stream:
|
|
99
|
+
if event["type"] == "start":
|
|
100
|
+
pass
|
|
101
|
+
elif event["type"] == "text-delta":
|
|
102
|
+
yield event['delta']
|
|
103
|
+
elif event["type"] == "finish":
|
|
104
|
+
return
|
|
105
|
+
elif event["type"] == "error":
|
|
106
|
+
yield event['errorText']
|
|
107
|
+
await asyncio.sleep(0)
|
|
108
|
+
|
|
109
|
+
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
|
110
|
+
```
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
quillsql/__init__.py,sha256=wjJfszle5vheUbgUfJMHQqtqhx2W3UaDN4ndcRIfmkQ,236
|
|
2
|
-
quillsql/core.py,sha256=
|
|
2
|
+
quillsql/core.py,sha256=VzWtu0up8tcSxijHt_x3bSbZnPjbZFIyGKXCwerpPrg,33704
|
|
3
3
|
quillsql/error.py,sha256=n9VKHw4FAgg7ZEAz2YQ8L_8FdRG_1shwGngf2iWhUSM,175
|
|
4
4
|
quillsql/assets/__init__.py,sha256=oXQ2ZS5XDXkXTYjADxNfGt55cIn_rqfgWL2EDqjTyoI,45
|
|
5
5
|
quillsql/assets/pgtypes.py,sha256=-B_2wUaoAsdX7_HnJhUlx4ptZQ6x-cXwuST9ACgGFdE,33820
|
|
@@ -15,7 +15,7 @@ quillsql/utils/post_quill_executor.py,sha256=DB1RHNfqHPYarMM10vSv--UjpCZqe4qYTjq
|
|
|
15
15
|
quillsql/utils/run_query_processes.py,sha256=QwnMr5UwXdtO_W88lv5nBaf6pJ_h5oWQnYd8K9oHQ5s,1030
|
|
16
16
|
quillsql/utils/schema_conversion.py,sha256=TFfMibN9nOsxNRhHw5YIFl3jGTvipG81bxX4LFDulUY,314
|
|
17
17
|
quillsql/utils/tenants.py,sha256=ZD2FuKz0gjBVSsThHDv1P8PU6EL8E009NWihE5hAH-Q,2022
|
|
18
|
-
quillsql-2.2.
|
|
19
|
-
quillsql-2.2.
|
|
20
|
-
quillsql-2.2.
|
|
21
|
-
quillsql-2.2.
|
|
18
|
+
quillsql-2.2.9.dist-info/METADATA,sha256=zVlzKEUZAAIQR30Re6w6tdebu6hFjtTQnOeIXlOwZz4,3052
|
|
19
|
+
quillsql-2.2.9.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
20
|
+
quillsql-2.2.9.dist-info/top_level.txt,sha256=eU2vHnVqwpYQJ3ADl1Q-DIBzbYejZRUhcMdN_4zMCz8,9
|
|
21
|
+
quillsql-2.2.9.dist-info/RECORD,,
|
|
File without changes
|