onegenie 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onegenie-0.1.0/.gitignore +57 -0
- onegenie-0.1.0/PKG-INFO +15 -0
- onegenie-0.1.0/README.md +1 -0
- onegenie-0.1.0/onegenie/__init__.py +10 -0
- onegenie-0.1.0/onegenie/agent.py +51 -0
- onegenie-0.1.0/onegenie/context.py +22 -0
- onegenie-0.1.0/onegenie/daemon.py +270 -0
- onegenie-0.1.0/onegenie/mcp_agent.py +235 -0
- onegenie-0.1.0/onegenie/util.py +120 -0
- onegenie-0.1.0/pyproject.toml +38 -0
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
coverage
|
|
2
|
+
dist
|
|
3
|
+
node_modules
|
|
4
|
+
*.log
|
|
5
|
+
.direnv
|
|
6
|
+
.envrc.custom
|
|
7
|
+
versions.json
|
|
8
|
+
versions.html
|
|
9
|
+
.DS_Store
|
|
10
|
+
user-upload/
|
|
11
|
+
.babel-error
|
|
12
|
+
src/uid/global.css
|
|
13
|
+
secrets/
|
|
14
|
+
env
|
|
15
|
+
.eslintcache
|
|
16
|
+
*.sql
|
|
17
|
+
!db/seeds/cmdb-data.sql
|
|
18
|
+
newapp-bin
|
|
19
|
+
.dbchecksum
|
|
20
|
+
.skip-starting-bg-jobs
|
|
21
|
+
.mixin-outsync
|
|
22
|
+
|
|
23
|
+
.yarn/*
|
|
24
|
+
!.yarn/cache
|
|
25
|
+
!.yarn/patches
|
|
26
|
+
!.yarn/plugins
|
|
27
|
+
!.yarn/releases
|
|
28
|
+
!.yarn/sdks
|
|
29
|
+
!.yarn/versions
|
|
30
|
+
|
|
31
|
+
*.blend1
|
|
32
|
+
of.csv
|
|
33
|
+
|
|
34
|
+
*.dxf
|
|
35
|
+
uploads/
|
|
36
|
+
minio-data/
|
|
37
|
+
|
|
38
|
+
build/qdrant_data
|
|
39
|
+
*.gz
|
|
40
|
+
downloads/
|
|
41
|
+
thei/
|
|
42
|
+
.restart-express
|
|
43
|
+
|
|
44
|
+
sqlite.db
|
|
45
|
+
oasis.tar
|
|
46
|
+
test/aisc-support/
|
|
47
|
+
|
|
48
|
+
generated.jpg
|
|
49
|
+
|
|
50
|
+
test/ntpo/*.csv
|
|
51
|
+
__pycache__/
|
|
52
|
+
*.db
|
|
53
|
+
*.db-shm
|
|
54
|
+
*.db-wal
|
|
55
|
+
*.db-journal
|
|
56
|
+
|
|
57
|
+
employees.sql.gz
|
onegenie-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: onegenie
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: The official Python SDK for AskGenie AI Platform
|
|
5
|
+
Requires-Python: >=3.12
|
|
6
|
+
Requires-Dist: json-repair>=0.57.0
|
|
7
|
+
Requires-Dist: mcp>=1.26.0
|
|
8
|
+
Requires-Dist: openai>=2.0.0
|
|
9
|
+
Requires-Dist: orjson>=3.11.0
|
|
10
|
+
Requires-Dist: pydantic>=2.0.0
|
|
11
|
+
Requires-Dist: redis>=7.0.0
|
|
12
|
+
Requires-Dist: requests>=2.31.0
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
|
|
15
|
+
# OneGenie Official SDK
|
onegenie-0.1.0/README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# OneGenie Official SDK
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from typing import List, Callable, Optional, Dict, Any
|
|
2
|
+
from pydantic import BaseModel
|
|
3
|
+
from .context import Context
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ModelConfig(BaseModel):
|
|
7
|
+
id: str
|
|
8
|
+
features: Optional[Dict[str, Any]] = None
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Agent:
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
id: str,
|
|
15
|
+
name: str,
|
|
16
|
+
desc: str,
|
|
17
|
+
models: List[ModelConfig],
|
|
18
|
+
extra_params: Any = None,
|
|
19
|
+
survey: Any = None,
|
|
20
|
+
):
|
|
21
|
+
self.id = id
|
|
22
|
+
self.name = name
|
|
23
|
+
self.desc = desc
|
|
24
|
+
self.models = models
|
|
25
|
+
self.extra_params = extra_params
|
|
26
|
+
self.survey = survey
|
|
27
|
+
self._run_handler: Optional[Callable[[Context], Any]] = None
|
|
28
|
+
|
|
29
|
+
def on_run(self, func: Callable[[Context], Any]):
|
|
30
|
+
self._run_handler = func
|
|
31
|
+
return func
|
|
32
|
+
|
|
33
|
+
def serve(self):
|
|
34
|
+
if not self._run_handler:
|
|
35
|
+
raise ValueError("Missing @agent.on_run")
|
|
36
|
+
|
|
37
|
+
agent_card = {
|
|
38
|
+
"id": self.id,
|
|
39
|
+
"name": self.name,
|
|
40
|
+
"desc": self.desc,
|
|
41
|
+
"models": [m.model_dump(exclude_none=True) for m in self.models],
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
agent_card["extra_params"] = self.extra_params or {}
|
|
45
|
+
|
|
46
|
+
if self.survey is not None:
|
|
47
|
+
agent_card["extra_params"]["survey"] = self.survey
|
|
48
|
+
|
|
49
|
+
from .daemon import start_daemon
|
|
50
|
+
|
|
51
|
+
start_daemon(agent_card, self._run_handler)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import List, Dict, Any, Callable
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class LLMTools:
|
|
6
|
+
def __init__(self, gen_stream: Callable, gen: Callable):
|
|
7
|
+
self.gen_stream = gen_stream
|
|
8
|
+
self.gen = gen
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class FSTools:
|
|
12
|
+
def __init__(self, upload_file: Callable):
|
|
13
|
+
self.upload_file = upload_file
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class Context:
|
|
18
|
+
messages: List[Dict[str, Any]]
|
|
19
|
+
extra_params: Dict[str, Any]
|
|
20
|
+
llm: LLMTools
|
|
21
|
+
fs: FSTools
|
|
22
|
+
mcp: Any
|
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import uuid
|
|
3
|
+
import asyncio
|
|
4
|
+
import requests
|
|
5
|
+
import redis.asyncio as redis
|
|
6
|
+
import redis as sync_redis
|
|
7
|
+
|
|
8
|
+
from .util import dumps, loads, gen, gen_stream
|
|
9
|
+
from .mcp_agent import MCPAgent
|
|
10
|
+
from .context import Context, LLMTools, FSTools
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def register_agent(card):
|
|
14
|
+
ASKGENIE_ENDPOINT = os.getenv("ASKGENIE_ENDPOINT")
|
|
15
|
+
if ASKGENIE_ENDPOINT is None:
|
|
16
|
+
raise Exception("Expected ASKGENIE_ENDPOINT but missing")
|
|
17
|
+
|
|
18
|
+
ASKGENIE_DEVELOPER_KEY = os.getenv("ASKGENIE_DEVELOPER_KEY")
|
|
19
|
+
if ASKGENIE_DEVELOPER_KEY is None:
|
|
20
|
+
raise Exception("Expected ASKGENIE_DEVELOPER_KEY but missing")
|
|
21
|
+
|
|
22
|
+
REGISTER_URL = f"{ASKGENIE_ENDPOINT}/agent/register"
|
|
23
|
+
|
|
24
|
+
print("Register agent to system...")
|
|
25
|
+
resp = requests.post(
|
|
26
|
+
REGISTER_URL,
|
|
27
|
+
json={"agent": card},
|
|
28
|
+
headers={"X-API-KEY": ASKGENIE_DEVELOPER_KEY},
|
|
29
|
+
timeout=10,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
if resp.status_code != 200:
|
|
33
|
+
print(resp.text)
|
|
34
|
+
resp.raise_for_status()
|
|
35
|
+
return resp.json()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# Configuration
|
|
39
|
+
GROUP_NAME = "default"
|
|
40
|
+
CONSUMER_NAME = "consumer-1"
|
|
41
|
+
DGG_SSE_URL = "http://localhost:8081/sse"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class TaskSession:
|
|
45
|
+
"""管理單一任務生命週期與平台 API 工具的會話物件"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, r, job_id, api_token, s3_ticket, response_channel):
|
|
48
|
+
self.r = r
|
|
49
|
+
self.job_id = job_id
|
|
50
|
+
self.api_token = api_token
|
|
51
|
+
self.s3_ticket = s3_ticket
|
|
52
|
+
self.response_channel = response_channel
|
|
53
|
+
|
|
54
|
+
# upload a file
|
|
55
|
+
async def upload_file(self, *, filename, srcpath=None, content=None, type="file"):
|
|
56
|
+
if type not in ("file", "streamlit_app"):
|
|
57
|
+
raise ValueError(
|
|
58
|
+
f"Upload Error: unsupported type '{type}'. Allowed types are 'file' and 'streamlit_app'."
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
if (srcpath is None and content is None) or (
|
|
62
|
+
srcpath is not None and content is not None
|
|
63
|
+
):
|
|
64
|
+
raise ValueError(
|
|
65
|
+
"Upload Error: Must provide exactly one of 'srcpath' or 'content'"
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
_, file_ext = os.path.splitext(filename)
|
|
69
|
+
file_ext = file_ext.lower()
|
|
70
|
+
|
|
71
|
+
if not file_ext:
|
|
72
|
+
raise ValueError(
|
|
73
|
+
"Upload Error: extension is expected but missing in filename"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
# filename security check
|
|
77
|
+
unsafe_chars = ["/", "\\", "..", "\0"]
|
|
78
|
+
if any(char in filename for char in unsafe_chars) or filename.strip() == "":
|
|
79
|
+
raise ValueError(f"Upload Error: illegal filename ('{filename}')")
|
|
80
|
+
|
|
81
|
+
# generate uuid for storage
|
|
82
|
+
uuid_filename = f"{uuid.uuid4()}{file_ext}"
|
|
83
|
+
url = self.s3_ticket["url"]
|
|
84
|
+
fields = self.s3_ticket["fields"].copy()
|
|
85
|
+
fields["key"] = fields["key"].replace("${filename}", uuid_filename)
|
|
86
|
+
s3_filepath = fields["key"]
|
|
87
|
+
print(f"Ready to upload: File [{filename}] -> S3 path [{s3_filepath}]")
|
|
88
|
+
|
|
89
|
+
# stream to s3
|
|
90
|
+
def _perform_upload():
|
|
91
|
+
if srcpath is not None:
|
|
92
|
+
with open(srcpath, "rb") as f:
|
|
93
|
+
return requests.post(
|
|
94
|
+
url, data=fields, files={"file": f}, allow_redirects=False
|
|
95
|
+
)
|
|
96
|
+
else:
|
|
97
|
+
# Handle in-memory content (convert str to bytes if necessary)
|
|
98
|
+
file_data = (
|
|
99
|
+
content.encode("utf-8") if isinstance(content, str) else content
|
|
100
|
+
)
|
|
101
|
+
# requests allows a tuple for files: (filename, data)
|
|
102
|
+
return requests.post(
|
|
103
|
+
url,
|
|
104
|
+
data=fields,
|
|
105
|
+
files={"file": (uuid_filename, file_data)},
|
|
106
|
+
allow_redirects=False,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
response = await asyncio.to_thread(_perform_upload)
|
|
110
|
+
print(f"Upload {uuid_filename} Status code: {response.status_code}")
|
|
111
|
+
|
|
112
|
+
if response.status_code == 204:
|
|
113
|
+
payload = {
|
|
114
|
+
"jobId": self.job_id,
|
|
115
|
+
"status": "processing",
|
|
116
|
+
"filename": filename, # for display and download
|
|
117
|
+
"s3_filepath": s3_filepath, # for storage
|
|
118
|
+
"type": type,
|
|
119
|
+
}
|
|
120
|
+
msg = dumps(payload)
|
|
121
|
+
|
|
122
|
+
print(f"job_id: {self.job_id} | Ready to Publish: {payload}")
|
|
123
|
+
await self.r.publish(self.response_channel, msg)
|
|
124
|
+
else:
|
|
125
|
+
print(f"Upload Failed: {response.text}")
|
|
126
|
+
|
|
127
|
+
# streaming output
|
|
128
|
+
async def gen_stream(self, *, messages=[], model, tools=[]):
|
|
129
|
+
async for chunk in gen_stream(
|
|
130
|
+
messages=messages, model=model, api_token=self.api_token, tools=tools
|
|
131
|
+
):
|
|
132
|
+
yield chunk
|
|
133
|
+
|
|
134
|
+
# blocking one time generation
|
|
135
|
+
async def gen(self, *, messages=[], model, tools=[]):
|
|
136
|
+
return await gen(
|
|
137
|
+
messages=messages, model=model, api_token=self.api_token, tools=tools
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
async def process_task(r, task_data, handler_func, response_channel):
|
|
142
|
+
job_id = task_data.get("id")
|
|
143
|
+
agent = task_data.get("agent")
|
|
144
|
+
messages = task_data.get("messages")
|
|
145
|
+
api_token = task_data.get("token")
|
|
146
|
+
extra_params = task_data.get("extra_params")
|
|
147
|
+
s3_ticket = task_data.get("s3_ticket")
|
|
148
|
+
|
|
149
|
+
print(f"Processing job agent[{agent}] job_id[{job_id}]...")
|
|
150
|
+
|
|
151
|
+
session = TaskSession(r, job_id, api_token, s3_ticket, response_channel)
|
|
152
|
+
mcp_agent = MCPAgent(DGG_SSE_URL, api_token, session.gen_stream)
|
|
153
|
+
|
|
154
|
+
ctx = Context(
|
|
155
|
+
messages=messages,
|
|
156
|
+
extra_params=extra_params,
|
|
157
|
+
llm=LLMTools(gen_stream=session.gen_stream, gen=session.gen),
|
|
158
|
+
fs=FSTools(upload_file=session.upload_file),
|
|
159
|
+
mcp=mcp_agent,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
try:
|
|
163
|
+
output = ""
|
|
164
|
+
async for chunk in handler_func(ctx):
|
|
165
|
+
msg = dumps({"jobId": job_id, "status": "processing", "chunk": chunk})
|
|
166
|
+
output += chunk
|
|
167
|
+
await r.publish(response_channel, msg)
|
|
168
|
+
|
|
169
|
+
await r.publish(
|
|
170
|
+
response_channel,
|
|
171
|
+
dumps({"jobId": job_id, "status": "done", "output": output}),
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
except Exception as e:
|
|
175
|
+
print(f"Error processing {job_id}: {e}")
|
|
176
|
+
err_msg = dumps({"status": "error", "error": str(e)})
|
|
177
|
+
await r.publish(response_channel, err_msg)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
async def start_heartbeat(r, hb_key):
|
|
181
|
+
while True:
|
|
182
|
+
try:
|
|
183
|
+
await r.set(hb_key, "alive", ex=45)
|
|
184
|
+
await asyncio.sleep(30)
|
|
185
|
+
except Exception as e:
|
|
186
|
+
print(f"Heartbeat failed: {e}")
|
|
187
|
+
await asyncio.sleep(30)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
async def main(agent_card, handler_func, ctx):
|
|
191
|
+
info = register_agent(agent_card)
|
|
192
|
+
print("Registration completed.", info)
|
|
193
|
+
|
|
194
|
+
ctx["redis_url"] = info["redisUrl"]
|
|
195
|
+
r = redis.from_url(info["redisUrl"], decode_responses=True)
|
|
196
|
+
print(f"Python Agent Daemon connected to {info['redisUrl']}")
|
|
197
|
+
|
|
198
|
+
worker_id = uuid.uuid4().hex
|
|
199
|
+
ctx["hb_key"] = info["heartbeatPattern"].replace(":*", f":{worker_id}")
|
|
200
|
+
print(f"heartbeat_key: {ctx['hb_key']}")
|
|
201
|
+
|
|
202
|
+
MAX_CONCURRENCY = 10
|
|
203
|
+
sem = asyncio.Semaphore(MAX_CONCURRENCY)
|
|
204
|
+
|
|
205
|
+
# We need to keep track of running tasks to cancel them on shutdown if needed
|
|
206
|
+
background_tasks = set()
|
|
207
|
+
|
|
208
|
+
heartbeat_task = asyncio.create_task(start_heartbeat(r, ctx["hb_key"]))
|
|
209
|
+
|
|
210
|
+
async def task_wrapper(mid, task_data):
|
|
211
|
+
try:
|
|
212
|
+
await process_task(r, task_data, handler_func, info["outputChannel"])
|
|
213
|
+
await r.xack(info["streamKey"], GROUP_NAME, mid)
|
|
214
|
+
except Exception as e:
|
|
215
|
+
print(f"Task processing error: {e}")
|
|
216
|
+
finally:
|
|
217
|
+
sem.release()
|
|
218
|
+
|
|
219
|
+
while True:
|
|
220
|
+
try:
|
|
221
|
+
await sem.acquire()
|
|
222
|
+
|
|
223
|
+
response = await r.xreadgroup(
|
|
224
|
+
GROUP_NAME,
|
|
225
|
+
CONSUMER_NAME,
|
|
226
|
+
streams={info["streamKey"]: ">"},
|
|
227
|
+
count=1,
|
|
228
|
+
block=0,
|
|
229
|
+
)
|
|
230
|
+
if response:
|
|
231
|
+
stream_name, messages = response[0]
|
|
232
|
+
message_id, data = messages[0]
|
|
233
|
+
|
|
234
|
+
task_payload = loads(data.get("payload"))
|
|
235
|
+
|
|
236
|
+
# Create background task and add to set
|
|
237
|
+
t = asyncio.create_task(task_wrapper(message_id, task_payload))
|
|
238
|
+
background_tasks.add(t)
|
|
239
|
+
|
|
240
|
+
# Remove task from set when done
|
|
241
|
+
t.add_done_callback(background_tasks.discard)
|
|
242
|
+
else:
|
|
243
|
+
# If no message, release the semaphore immediately so we can loop again
|
|
244
|
+
sem.release()
|
|
245
|
+
|
|
246
|
+
except Exception as e:
|
|
247
|
+
print(f"Critical error in loop: {e}")
|
|
248
|
+
heartbeat_task.cancel()
|
|
249
|
+
await heartbeat_task
|
|
250
|
+
if background_tasks:
|
|
251
|
+
await asyncio.wait(background_tasks)
|
|
252
|
+
await asyncio.sleep(1) # Prevent tight loop on Redis connection fail
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def start_daemon(agent_card, handler_func):
|
|
256
|
+
ctx = {}
|
|
257
|
+
try:
|
|
258
|
+
asyncio.run(main(agent_card, handler_func, ctx))
|
|
259
|
+
except KeyboardInterrupt:
|
|
260
|
+
print("Daemon stopped.")
|
|
261
|
+
finally:
|
|
262
|
+
if "hb_key" in ctx:
|
|
263
|
+
print(f"Cleaning up heartbeat: {ctx['hb_key']}...")
|
|
264
|
+
try:
|
|
265
|
+
# Use standard 'redis' (not asyncio) to guarantee execution
|
|
266
|
+
r = sync_redis.from_url(ctx["redis_url"])
|
|
267
|
+
r.delete(ctx["hb_key"])
|
|
268
|
+
print("Cleanup successful.")
|
|
269
|
+
except Exception as e:
|
|
270
|
+
print(f"Cleanup failed: {e}")
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
from json_repair import repair_json
|
|
2
|
+
|
|
3
|
+
from contextlib import AsyncExitStack
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from mcp import ClientSession
|
|
7
|
+
from mcp.client.sse import sse_client
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MCPAgent:
|
|
11
|
+
def __init__(self, dgg_url: str, api_token: str, default_gen_stream=None):
|
|
12
|
+
self.dgg_url = dgg_url
|
|
13
|
+
self.api_token = api_token
|
|
14
|
+
self.exit_stack = AsyncExitStack()
|
|
15
|
+
self.dgg: Optional[ClientSession] = None
|
|
16
|
+
self.messages = []
|
|
17
|
+
self.default_gen_stream = default_gen_stream
|
|
18
|
+
|
|
19
|
+
async def __aenter__(self):
|
|
20
|
+
"""進入 async with 區塊時自動觸發"""
|
|
21
|
+
await self.start()
|
|
22
|
+
return self
|
|
23
|
+
|
|
24
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
25
|
+
"""離開 async with 區塊時自動觸發(即使發生 Exception 也會執行)"""
|
|
26
|
+
# exc_type, exc_val, exc_tb 是用來捕捉區塊內發生錯誤時的資訊
|
|
27
|
+
# 在這裡我們只需要確保資源被正確關閉即可
|
|
28
|
+
await self.stop()
|
|
29
|
+
|
|
30
|
+
async def start(self):
|
|
31
|
+
"""建立與 DGG 的連線,並傳入 User Context Headers"""
|
|
32
|
+
print("Connecting to DGG...")
|
|
33
|
+
|
|
34
|
+
# [關鍵] 我們把 User Context 放在 Header 傳給 DGG
|
|
35
|
+
# DGG 的 handle_sse 需要修改以讀取這些 Headers 來做 OPA Filtering
|
|
36
|
+
headers = {"X-JWT-TOKEN": self.api_token}
|
|
37
|
+
|
|
38
|
+
# 使用 AsyncExitStack 管理 Context Manager
|
|
39
|
+
sse_transport = await self.exit_stack.enter_async_context(
|
|
40
|
+
sse_client(self.dgg_url, headers=headers)
|
|
41
|
+
)
|
|
42
|
+
self.dgg = await self.exit_stack.enter_async_context(
|
|
43
|
+
ClientSession(sse_transport[0], sse_transport[1])
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
await self.dgg.initialize()
|
|
47
|
+
print("✅ Connected & Initialized!")
|
|
48
|
+
|
|
49
|
+
def _convert_mcp_to_openai_tools(self, mcp_tools):
|
|
50
|
+
"""[關鍵] 將 MCP 工具格式轉換為 OpenAI/vLLM 格式"""
|
|
51
|
+
openai_tools = []
|
|
52
|
+
for tool in mcp_tools.tools:
|
|
53
|
+
openai_tools.append(
|
|
54
|
+
{
|
|
55
|
+
"type": "function",
|
|
56
|
+
"function": {
|
|
57
|
+
"name": tool.name,
|
|
58
|
+
"description": tool.description,
|
|
59
|
+
"parameters": tool.inputSchema, # MCP Schema 相容 JSON Schema
|
|
60
|
+
},
|
|
61
|
+
}
|
|
62
|
+
)
|
|
63
|
+
return openai_tools
|
|
64
|
+
|
|
65
|
+
def _robust_parse_arguments(self, args_str: str) -> dict:
|
|
66
|
+
if not args_str:
|
|
67
|
+
return {}
|
|
68
|
+
|
|
69
|
+
args_str = args_str.strip()
|
|
70
|
+
|
|
71
|
+
# --- 第一道防線:手動補全結構 ---
|
|
72
|
+
# 如果不像是一個物件的開頭,我們強制幫它補上 '{'
|
|
73
|
+
if not args_str.startswith("{"):
|
|
74
|
+
# 你的案例會在這裡被修復
|
|
75
|
+
args_str = "{" + args_str
|
|
76
|
+
|
|
77
|
+
# 同理,如果沒結尾,補上 '}'
|
|
78
|
+
if not args_str.endswith("}"):
|
|
79
|
+
args_str = args_str + "}"
|
|
80
|
+
|
|
81
|
+
# --- 第二道防線:json_repair 修復語法 ---
|
|
82
|
+
# 這裡會處理單引號、trailing comma、未轉義字符等細節
|
|
83
|
+
try:
|
|
84
|
+
decoded = repair_json(args_str, return_objects=True)
|
|
85
|
+
|
|
86
|
+
# [重要] 確保回傳的一定是 dict
|
|
87
|
+
# 有時候 json_repair 修復失敗會回傳原始字串,這會導致後續報錯
|
|
88
|
+
if isinstance(decoded, dict):
|
|
89
|
+
return decoded
|
|
90
|
+
elif isinstance(decoded, list):
|
|
91
|
+
# 極少見情況,但以防萬一
|
|
92
|
+
return {"items": decoded}
|
|
93
|
+
else:
|
|
94
|
+
# 如果回傳的是字串,代表修復失敗或內容根本不是 JSON
|
|
95
|
+
print(f"⚠️ json_repair returned non-dict: {type(decoded)}")
|
|
96
|
+
# 嘗試最後一次暴力解析 (針對 key: value 這種格式)
|
|
97
|
+
return self._fallback_parse(args_str)
|
|
98
|
+
|
|
99
|
+
except Exception as e:
|
|
100
|
+
print(f"❌ JSON Parsing Failed: {e}")
|
|
101
|
+
return {}
|
|
102
|
+
|
|
103
|
+
async def chat(self, history: list):
|
|
104
|
+
if not self.dgg:
|
|
105
|
+
raise RuntimeError("Agent not started")
|
|
106
|
+
|
|
107
|
+
self.messages = history
|
|
108
|
+
|
|
109
|
+
# 1. 向 DGG 拿工具 (DGG 會根據 Header 過濾工具)
|
|
110
|
+
mcp_tools = await self.dgg.list_tools()
|
|
111
|
+
openai_tools = self._convert_mcp_to_openai_tools(mcp_tools)
|
|
112
|
+
|
|
113
|
+
# --- ReAct Loop ---
|
|
114
|
+
while True:
|
|
115
|
+
current_tool_calls = {} # 用來暫存工具呼叫的片段
|
|
116
|
+
is_collecting_tool = False
|
|
117
|
+
full_response_content = ""
|
|
118
|
+
|
|
119
|
+
# A. 呼叫 LLM
|
|
120
|
+
async for delta in self.default_gen_stream(
|
|
121
|
+
messages=self.messages,
|
|
122
|
+
model="llm",
|
|
123
|
+
tools=openai_tools if openai_tools else None,
|
|
124
|
+
):
|
|
125
|
+
# delta.reasoning_content
|
|
126
|
+
# print(delta)
|
|
127
|
+
# --- 情況 A: LLM 決定呼叫工具 (攔截模式) ---
|
|
128
|
+
if delta.tool_calls:
|
|
129
|
+
is_collecting_tool = True
|
|
130
|
+
for tc in delta.tool_calls:
|
|
131
|
+
idx = tc.index
|
|
132
|
+
if idx not in current_tool_calls:
|
|
133
|
+
current_tool_calls[idx] = {
|
|
134
|
+
"id": tc.id,
|
|
135
|
+
"name": tc.function.name,
|
|
136
|
+
"arguments": "",
|
|
137
|
+
}
|
|
138
|
+
# 拼湊 JSON 片段
|
|
139
|
+
if tc.function.arguments:
|
|
140
|
+
current_tool_calls[idx]["arguments"] += (
|
|
141
|
+
tc.function.arguments
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
# 選用:你可以 yield 一個 "正在分析工具..." 的狀態給 UI
|
|
145
|
+
# yield "[STATUS: Analyzing...]"
|
|
146
|
+
|
|
147
|
+
# --- 情況 B: LLM 正在說話 (直通模式) ---
|
|
148
|
+
elif delta.content:
|
|
149
|
+
# 如果我們確認不是在組裝工具,就直接串流給用戶
|
|
150
|
+
if not is_collecting_tool:
|
|
151
|
+
full_response_content += delta.content
|
|
152
|
+
yield delta.content
|
|
153
|
+
|
|
154
|
+
if is_collecting_tool:
|
|
155
|
+
# 這裡不需要 yield content,因為上面攔截了
|
|
156
|
+
|
|
157
|
+
# 建構完整的 Assistant Message (為了塞回歷史紀錄)
|
|
158
|
+
tool_calls_data = []
|
|
159
|
+
for idx, tc in current_tool_calls.items():
|
|
160
|
+
tool_calls_data.append(
|
|
161
|
+
{
|
|
162
|
+
"id": tc["id"],
|
|
163
|
+
"type": "function",
|
|
164
|
+
"function": {
|
|
165
|
+
"name": tc["name"],
|
|
166
|
+
"arguments": tc["arguments"],
|
|
167
|
+
},
|
|
168
|
+
}
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
# 把「意圖」存入歷史 (即使沒顯示給用戶,LLM 需要知道它呼叫了工具)
|
|
172
|
+
self.messages.append(
|
|
173
|
+
{
|
|
174
|
+
"role": "assistant",
|
|
175
|
+
"content": None,
|
|
176
|
+
"tool_calls": tool_calls_data,
|
|
177
|
+
}
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# 通知 UI 我們正在執行 (UX 優化)
|
|
181
|
+
for tc in tool_calls_data:
|
|
182
|
+
func_name = tc["function"]["name"]
|
|
183
|
+
# yield f"\n\n> ⚙️ *Calling {func_name}...*\n\n"
|
|
184
|
+
|
|
185
|
+
# ... (執行 Tool 的邏輯與之前相同) ...
|
|
186
|
+
# result = await self.dgg.call_tool(...)
|
|
187
|
+
# self.messages.append(tool_result_message)
|
|
188
|
+
func_args_str = tc["function"]["arguments"]
|
|
189
|
+
call_id = tc["id"]
|
|
190
|
+
|
|
191
|
+
print(f"⚡ Executing: {func_name} with {func_args_str}")
|
|
192
|
+
|
|
193
|
+
try:
|
|
194
|
+
# 解析參數 JSON
|
|
195
|
+
func_args = self._robust_parse_arguments(func_args_str)
|
|
196
|
+
|
|
197
|
+
# 呼叫 DGG (MCP Protocol)
|
|
198
|
+
result = await self.dgg.call_tool(func_name, func_args)
|
|
199
|
+
|
|
200
|
+
# 提取文字結果
|
|
201
|
+
tool_output = ""
|
|
202
|
+
for content in result.content:
|
|
203
|
+
if content.type == "text":
|
|
204
|
+
tool_output += content.text
|
|
205
|
+
|
|
206
|
+
except Exception as e:
|
|
207
|
+
tool_output = f"Error: {str(e)}"
|
|
208
|
+
|
|
209
|
+
print(f" -> Result: {tool_output[:50]}...")
|
|
210
|
+
|
|
211
|
+
# E. 將工具執行結果回傳給 LLM (Role: tool)
|
|
212
|
+
# OpenAI 格式要求必須帶上 tool_call_id
|
|
213
|
+
self.messages.append(
|
|
214
|
+
{
|
|
215
|
+
"role": "tool",
|
|
216
|
+
"tool_call_id": call_id,
|
|
217
|
+
"name": func_name,
|
|
218
|
+
"content": tool_output,
|
|
219
|
+
}
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
# [關鍵] 執行完工具後,continue 回到 while 開頭
|
|
223
|
+
# 下一輪 LLM 會看到 Tool Result,然後開始解釋 (這時就會進入情況 B 串流文字)
|
|
224
|
+
continue
|
|
225
|
+
|
|
226
|
+
# 2. 如果沒有工具呼叫 -> 結束對話
|
|
227
|
+
else:
|
|
228
|
+
# 把完整的文字回應存入歷史
|
|
229
|
+
self.messages.append(
|
|
230
|
+
{"role": "assistant", "content": full_response_content}
|
|
231
|
+
)
|
|
232
|
+
return # 結束 Loop
|
|
233
|
+
|
|
234
|
+
async def stop(self):
|
|
235
|
+
await self.exit_stack.aclose()
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import os
|
|
3
|
+
import json
|
|
4
|
+
import orjson
|
|
5
|
+
from openai import AsyncOpenAI, APIError
|
|
6
|
+
|
|
7
|
+
base_url = os.environ["ASKGENIE_ENDPOINT"].replace("/api/v1", "/mpg/v1")
|
|
8
|
+
|
|
9
|
+
chunksize = 30000
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def dumps(obj) -> str:
|
|
13
|
+
return orjson.dumps(obj).decode("utf-8")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def loads(json_str: str | bytes):
|
|
17
|
+
return orjson.loads(json_str)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def read_file(path, default=None):
|
|
21
|
+
print(f"Reading File content from {path}...")
|
|
22
|
+
if os.path.exists(path):
|
|
23
|
+
with open(path, "r") as file:
|
|
24
|
+
return file.read()
|
|
25
|
+
else:
|
|
26
|
+
if default is None:
|
|
27
|
+
raise Exception(f"File do not exists: {path}")
|
|
28
|
+
else:
|
|
29
|
+
return default
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
async def gen(messages, model, api_token, tools=None):
|
|
33
|
+
client = AsyncOpenAI(base_url=base_url, api_key=api_token)
|
|
34
|
+
|
|
35
|
+
params = {
|
|
36
|
+
"model": model,
|
|
37
|
+
"messages": messages,
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
if tools:
|
|
41
|
+
params["tools"] = tools
|
|
42
|
+
params["tool_choice"] = "auto"
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
response = await client.chat.completions.create(**params)
|
|
46
|
+
return response.choices[0].message
|
|
47
|
+
except Exception as e:
|
|
48
|
+
print(f"LLM Error: {e}")
|
|
49
|
+
raise e
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# def genWithlogprobs(prompt, top_logprobs=10):
|
|
53
|
+
# response = llm.chat.completions.create(
|
|
54
|
+
# model=model_id,
|
|
55
|
+
# messages=[{"role": "user", "content": prompt}],
|
|
56
|
+
# logprobs=True,
|
|
57
|
+
# top_logprobs=top_logprobs,
|
|
58
|
+
# )
|
|
59
|
+
# return response
|
|
60
|
+
#
|
|
61
|
+
#
|
|
62
|
+
# async def agen(prompt):
|
|
63
|
+
# response = await llm_stream.chat.completions.create(
|
|
64
|
+
# model=model_id, messages=[{"role": "user", "content": prompt}]
|
|
65
|
+
# )
|
|
66
|
+
# return response.choices[0].message.content
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# history = [
|
|
70
|
+
# {"role": "system", "content": "You are a helpful assistant."},
|
|
71
|
+
# {"role": "user", "content": "What's the capital of France?"},
|
|
72
|
+
# {"role": "assistant", "content": "Paris."},
|
|
73
|
+
# ]
|
|
74
|
+
async def gen_stream(messages, model, api_token, tools=None):
|
|
75
|
+
# messages = history + [{"role": "user", "content": prompt}]
|
|
76
|
+
mpg = AsyncOpenAI(base_url=base_url, api_key=api_token)
|
|
77
|
+
|
|
78
|
+
params = {
|
|
79
|
+
"model": model,
|
|
80
|
+
"messages": messages,
|
|
81
|
+
"stream": True,
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
if tools:
|
|
85
|
+
params["tools"] = tools
|
|
86
|
+
params["tool_choice"] = "auto"
|
|
87
|
+
|
|
88
|
+
try:
|
|
89
|
+
stream = await mpg.chat.completions.create(**params)
|
|
90
|
+
async for chunk in stream:
|
|
91
|
+
# Extract the content delta
|
|
92
|
+
if chunk.choices:
|
|
93
|
+
delta = chunk.choices[0].delta
|
|
94
|
+
if delta:
|
|
95
|
+
yield delta
|
|
96
|
+
|
|
97
|
+
elif hasattr(chunk, "usage") and chunk.usage:
|
|
98
|
+
continue
|
|
99
|
+
|
|
100
|
+
except APIError as e:
|
|
101
|
+
yield f"Failed: {e.message}"
|
|
102
|
+
|
|
103
|
+
except Exception as e:
|
|
104
|
+
yield f"Failed: {str(e)}"
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def trim_and_load_json(input):
|
|
108
|
+
start = input.find("{")
|
|
109
|
+
end = input.rfind("}") + 1
|
|
110
|
+
if end == 0 and start != -1:
|
|
111
|
+
input = input + "}"
|
|
112
|
+
end = len(input)
|
|
113
|
+
jsonStr = input[start:end] if start != -1 and end != 0 else ""
|
|
114
|
+
jsonStr = re.sub(r",\s*([\]}])", r"\1", jsonStr)
|
|
115
|
+
try:
|
|
116
|
+
return json.loads(jsonStr)
|
|
117
|
+
except json.JSONDecodeError:
|
|
118
|
+
return []
|
|
119
|
+
except Exception:
|
|
120
|
+
return []
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "onegenie"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "The official Python SDK for AskGenie AI Platform"
|
|
9
|
+
|
|
10
|
+
readme = "README.md"
|
|
11
|
+
|
|
12
|
+
requires-python = ">=3.12"
|
|
13
|
+
dependencies = [
|
|
14
|
+
"pydantic>=2.0.0",
|
|
15
|
+
"redis>=7.0.0",
|
|
16
|
+
"requests>=2.31.0",
|
|
17
|
+
"orjson>=3.11.0",
|
|
18
|
+
"mcp>=1.26.0",
|
|
19
|
+
"openai>=2.0.0",
|
|
20
|
+
"json-repair>=0.57.0",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
[dependency-groups]
|
|
24
|
+
dev = [
|
|
25
|
+
"honcho>=2.0.0",
|
|
26
|
+
"watchdog>=6.0.0",
|
|
27
|
+
"pandas==2.3.3",
|
|
28
|
+
"boto3>=1.42.82",
|
|
29
|
+
"arize-phoenix-client==1.24.0",
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
[tool.hatch.build.targets.wheel]
|
|
33
|
+
packages = ["onegenie"]
|
|
34
|
+
|
|
35
|
+
[tool.hatch.build.targets.sdist]
|
|
36
|
+
include = [
|
|
37
|
+
"/onegenie",
|
|
38
|
+
]
|