hyperforge 1.0.0.post19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyperforge/__init__.py +16 -0
- hyperforge/agent.py +81 -0
- hyperforge/api/__init__.py +20 -0
- hyperforge/api/app.py +155 -0
- hyperforge/api/authentication.py +271 -0
- hyperforge/api/commands.py +33 -0
- hyperforge/api/internal/__init__.py +4 -0
- hyperforge/api/internal/inspect.py +30 -0
- hyperforge/api/internal/router.py +3 -0
- hyperforge/api/logging.py +18 -0
- hyperforge/api/models.py +129 -0
- hyperforge/api/session.py +197 -0
- hyperforge/api/settings.py +38 -0
- hyperforge/api/utils.py +354 -0
- hyperforge/api/v1/__init__.py +23 -0
- hyperforge/api/v1/agents.py +531 -0
- hyperforge/api/v1/interaction.py +430 -0
- hyperforge/api/v1/mcp_content.py +311 -0
- hyperforge/api/v1/mcp_interaction.py +322 -0
- hyperforge/api/v1/oauth.py +60 -0
- hyperforge/api/v1/prompt.py +129 -0
- hyperforge/api/v1/router.py +3 -0
- hyperforge/api/v1/schema.py +56 -0
- hyperforge/api/v1/session.py +182 -0
- hyperforge/api/v1/utils.py +12 -0
- hyperforge/api/v1/workflows.py +643 -0
- hyperforge/arag.py +28 -0
- hyperforge/broker/__init__.py +52 -0
- hyperforge/broker/local.py +116 -0
- hyperforge/broker/redis.py +161 -0
- hyperforge/configure.py +571 -0
- hyperforge/context/__init__.py +0 -0
- hyperforge/context/agent.py +377 -0
- hyperforge/context/config.py +103 -0
- hyperforge/database.py +3 -0
- hyperforge/db/__init__.py +6 -0
- hyperforge/db/agents.py +1521 -0
- hyperforge/db/encryption.py +91 -0
- hyperforge/db/exceptions.py +26 -0
- hyperforge/db/settings.py +16 -0
- hyperforge/db/workflow_cleanup.py +69 -0
- hyperforge/definition.py +13 -0
- hyperforge/driver.py +31 -0
- hyperforge/dummy.py +28 -0
- hyperforge/engine.py +189 -0
- hyperforge/exceptions.py +14 -0
- hyperforge/feature_flag.py +105 -0
- hyperforge/fixtures.py +602 -0
- hyperforge/interaction.py +116 -0
- hyperforge/llm.py +75 -0
- hyperforge/manager.py +432 -0
- hyperforge/memory/__init__.py +5 -0
- hyperforge/memory/memory.py +974 -0
- hyperforge/minimal_fixtures.py +75 -0
- hyperforge/models.py +336 -0
- hyperforge/nua.py +336 -0
- hyperforge/openapi.py +63 -0
- hyperforge/prompts.py +188 -0
- hyperforge/pubsub.py +90 -0
- hyperforge/py.typed +0 -0
- hyperforge/redis_utils.py +82 -0
- hyperforge/retrieval/__init__.py +0 -0
- hyperforge/retrieval/agent.py +169 -0
- hyperforge/retrieval/config.py +94 -0
- hyperforge/server/__init__.py +5 -0
- hyperforge/server/cache.py +131 -0
- hyperforge/server/run.py +109 -0
- hyperforge/server/sandbox.py +60 -0
- hyperforge/server/session.py +421 -0
- hyperforge/server/settings.py +47 -0
- hyperforge/server/utils.py +57 -0
- hyperforge/server/web.py +31 -0
- hyperforge/settings.py +18 -0
- hyperforge/standalone/__init__.py +5 -0
- hyperforge/standalone/agent.py +189 -0
- hyperforge/standalone/app.py +264 -0
- hyperforge/standalone/config.py +137 -0
- hyperforge/standalone/const.py +1 -0
- hyperforge/standalone/run.py +60 -0
- hyperforge/standalone/settings.py +133 -0
- hyperforge/standalone/ui_router.py +241 -0
- hyperforge/trace.py +42 -0
- hyperforge/utils/__init__.py +112 -0
- hyperforge/utils/http.py +48 -0
- hyperforge/workflows.py +44 -0
- hyperforge-1.0.0.post19.dist-info/METADATA +95 -0
- hyperforge-1.0.0.post19.dist-info/RECORD +90 -0
- hyperforge-1.0.0.post19.dist-info/WHEEL +5 -0
- hyperforge-1.0.0.post19.dist-info/entry_points.txt +8 -0
- hyperforge-1.0.0.post19.dist-info/top_level.txt +1 -0
hyperforge/db/agents.py
ADDED
|
@@ -0,0 +1,1521 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
from typing import Any, List
|
|
3
|
+
from uuid import UUID
|
|
4
|
+
|
|
5
|
+
import databases
|
|
6
|
+
import sqlalchemy as sa
|
|
7
|
+
from cryptography.fernet import Fernet
|
|
8
|
+
from fastapi import UploadFile
|
|
9
|
+
from nucliadb_telemetry.utils import get_telemetry, init_telemetry
|
|
10
|
+
from pydantic import BaseModel, ValidationError
|
|
11
|
+
from sqlalchemy.dialects import postgresql as pg_dialect
|
|
12
|
+
from sqlalchemy.dialects.postgresql import JSONB
|
|
13
|
+
|
|
14
|
+
from hyperforge.agent import AgentConfig
|
|
15
|
+
from hyperforge.configure import (
|
|
16
|
+
get_agent_config_instance,
|
|
17
|
+
get_driver_config_instance,
|
|
18
|
+
get_driver_config_klass,
|
|
19
|
+
)
|
|
20
|
+
from hyperforge.database import metadata
|
|
21
|
+
from hyperforge.db import exceptions, logger
|
|
22
|
+
from hyperforge.db.encryption import (
|
|
23
|
+
decrypt_fields,
|
|
24
|
+
encrypt_fields,
|
|
25
|
+
fernet_key_from_passphrase,
|
|
26
|
+
)
|
|
27
|
+
from hyperforge.db.settings import DataManagerSettings
|
|
28
|
+
from hyperforge.driver import DriverConfig
|
|
29
|
+
from hyperforge.models import MemoryConfig, NucliaDBMemoryConfig, Rules
|
|
30
|
+
from hyperforge.prompts import PromptArgument, PromptConfig
|
|
31
|
+
from hyperforge.retrieval.config import (
|
|
32
|
+
RetrievalAgentConfig,
|
|
33
|
+
RetrievalAgentExportV1,
|
|
34
|
+
retrievalAgentAdapter,
|
|
35
|
+
)
|
|
36
|
+
from hyperforge.workflows import (
|
|
37
|
+
RetrievalAgent,
|
|
38
|
+
WorkflowData,
|
|
39
|
+
WorkflowInput,
|
|
40
|
+
WorkflowUpdate,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
SERVICE_NAME = "TASK_MANAGER"
|
|
44
|
+
EXPIRATION = 7 * 24
|
|
45
|
+
WORKFLOW_PURGE_RETENTION = datetime.timedelta(days=15)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def utc_now() -> datetime.datetime:
|
|
49
|
+
return datetime.datetime.now(datetime.timezone.utc)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
retrieval_agent_workflow = sa.Table(
|
|
53
|
+
"retrieval_agent_workflow",
|
|
54
|
+
metadata,
|
|
55
|
+
sa.Column("account", sa.String, primary_key=True, nullable=False, index=True),
|
|
56
|
+
sa.Column(
|
|
57
|
+
"agent_id", sa.String, primary_key=True, nullable=False, index=True
|
|
58
|
+
), # Agent ID
|
|
59
|
+
sa.Column(
|
|
60
|
+
"workflow_id", sa.String, primary_key=True, nullable=False, index=True
|
|
61
|
+
), # Agent ID
|
|
62
|
+
sa.Column("name", sa.String, nullable=False),
|
|
63
|
+
sa.Column("description", sa.String, nullable=True),
|
|
64
|
+
sa.Column("parameters", JSONB, nullable=True),
|
|
65
|
+
sa.Column("required", JSONB, nullable=True),
|
|
66
|
+
sa.Column("rules", JSONB, nullable=True),
|
|
67
|
+
sa.Column("created", sa.DateTime, default=sa.func.now()),
|
|
68
|
+
sa.Column("modified", sa.DateTime, onupdate=sa.func.now()),
|
|
69
|
+
sa.Column("is_deleted", sa.Boolean, nullable=False, default=False, index=True),
|
|
70
|
+
sa.Column("deleted_by", sa.String, nullable=True),
|
|
71
|
+
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True, index=True),
|
|
72
|
+
sa.ForeignKeyConstraint(
|
|
73
|
+
["account", "agent_id"],
|
|
74
|
+
["retrieval_agent_config.account", "retrieval_agent_config.agent_id"],
|
|
75
|
+
onupdate="CASCADE",
|
|
76
|
+
ondelete="CASCADE",
|
|
77
|
+
),
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
retrieval_agent_prompts = sa.Table(
|
|
81
|
+
"retrieval_agent_prompts",
|
|
82
|
+
metadata,
|
|
83
|
+
sa.Column(
|
|
84
|
+
"id",
|
|
85
|
+
pg_dialect.UUID,
|
|
86
|
+
primary_key=True,
|
|
87
|
+
server_default=sa.func.uuid_generate_v4(),
|
|
88
|
+
),
|
|
89
|
+
sa.Column(
|
|
90
|
+
"account",
|
|
91
|
+
sa.String,
|
|
92
|
+
nullable=False,
|
|
93
|
+
index=True,
|
|
94
|
+
),
|
|
95
|
+
sa.Column(
|
|
96
|
+
"agent_id",
|
|
97
|
+
sa.String,
|
|
98
|
+
nullable=False,
|
|
99
|
+
index=True,
|
|
100
|
+
),
|
|
101
|
+
sa.Column("name", sa.String, nullable=False),
|
|
102
|
+
sa.Column("description", sa.String, nullable=False),
|
|
103
|
+
sa.Column("prompt", sa.String, nullable=False),
|
|
104
|
+
sa.Column("arguments", JSONB, nullable=True),
|
|
105
|
+
sa.Column("icons", JSONB, nullable=True),
|
|
106
|
+
sa.Column("meta", JSONB, nullable=True),
|
|
107
|
+
sa.Column("created", sa.DateTime, default=sa.func.now()),
|
|
108
|
+
sa.Column("modified", sa.DateTime, onupdate=sa.func.now()),
|
|
109
|
+
sa.ForeignKeyConstraint(
|
|
110
|
+
["account", "agent_id"],
|
|
111
|
+
["retrieval_agent_config.account", "retrieval_agent_config.agent_id"],
|
|
112
|
+
onupdate="CASCADE",
|
|
113
|
+
ondelete="CASCADE",
|
|
114
|
+
),
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
retrieval_agent_config = sa.Table(
|
|
118
|
+
"retrieval_agent_config",
|
|
119
|
+
metadata,
|
|
120
|
+
sa.Column("account", sa.String, primary_key=True, nullable=False, index=True),
|
|
121
|
+
sa.Column(
|
|
122
|
+
"agent_id", sa.String, primary_key=True, nullable=False, index=True
|
|
123
|
+
), # Agent ID
|
|
124
|
+
sa.Column("rules", JSONB, nullable=False),
|
|
125
|
+
sa.Column("created", sa.DateTime, default=sa.func.now()),
|
|
126
|
+
sa.Column("modified", sa.DateTime, onupdate=sa.func.now()),
|
|
127
|
+
sa.Column("memory", JSONB, nullable=False),
|
|
128
|
+
sa.Column("description", sa.String, nullable=True),
|
|
129
|
+
sa.Column("title", sa.String, nullable=True),
|
|
130
|
+
sa.Column("instructions", sa.String, nullable=True),
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
retrieval_agent_preprocess = sa.Table(
|
|
135
|
+
"retrieval_agent_preprocess",
|
|
136
|
+
metadata,
|
|
137
|
+
sa.Column(
|
|
138
|
+
"id",
|
|
139
|
+
pg_dialect.UUID,
|
|
140
|
+
primary_key=True,
|
|
141
|
+
server_default=sa.func.uuid_generate_v4(),
|
|
142
|
+
),
|
|
143
|
+
sa.Column(
|
|
144
|
+
"account",
|
|
145
|
+
sa.String,
|
|
146
|
+
nullable=False,
|
|
147
|
+
index=True,
|
|
148
|
+
),
|
|
149
|
+
sa.Column(
|
|
150
|
+
"agent_id",
|
|
151
|
+
sa.String,
|
|
152
|
+
nullable=False,
|
|
153
|
+
index=True,
|
|
154
|
+
),
|
|
155
|
+
sa.Column(
|
|
156
|
+
"workflow_id",
|
|
157
|
+
sa.String,
|
|
158
|
+
nullable=False,
|
|
159
|
+
index=True,
|
|
160
|
+
),
|
|
161
|
+
sa.Column("preprocess", JSONB, nullable=False),
|
|
162
|
+
sa.Column("created", sa.DateTime, default=sa.func.now()),
|
|
163
|
+
sa.Column("modified", sa.DateTime, onupdate=sa.func.now()),
|
|
164
|
+
sa.ForeignKeyConstraint(
|
|
165
|
+
["account", "agent_id", "workflow_id"],
|
|
166
|
+
[
|
|
167
|
+
"retrieval_agent_workflow.account",
|
|
168
|
+
"retrieval_agent_workflow.agent_id",
|
|
169
|
+
"retrieval_agent_workflow.workflow_id",
|
|
170
|
+
],
|
|
171
|
+
onupdate="CASCADE",
|
|
172
|
+
ondelete="CASCADE",
|
|
173
|
+
),
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
retrieval_agent_postprocess = sa.Table(
|
|
178
|
+
"retrieval_agent_postprocess",
|
|
179
|
+
metadata,
|
|
180
|
+
sa.Column(
|
|
181
|
+
"id",
|
|
182
|
+
pg_dialect.UUID,
|
|
183
|
+
primary_key=True,
|
|
184
|
+
server_default=sa.func.uuid_generate_v4(),
|
|
185
|
+
),
|
|
186
|
+
sa.Column(
|
|
187
|
+
"account",
|
|
188
|
+
sa.String,
|
|
189
|
+
nullable=False,
|
|
190
|
+
index=True,
|
|
191
|
+
),
|
|
192
|
+
sa.Column(
|
|
193
|
+
"agent_id",
|
|
194
|
+
sa.String,
|
|
195
|
+
nullable=False,
|
|
196
|
+
index=True,
|
|
197
|
+
),
|
|
198
|
+
sa.Column(
|
|
199
|
+
"workflow_id",
|
|
200
|
+
sa.String,
|
|
201
|
+
nullable=False,
|
|
202
|
+
index=True,
|
|
203
|
+
),
|
|
204
|
+
sa.Column("postprocess", JSONB, nullable=False),
|
|
205
|
+
sa.Column("created", sa.DateTime, default=sa.func.now()),
|
|
206
|
+
sa.Column("modified", sa.DateTime, onupdate=sa.func.now()),
|
|
207
|
+
sa.ForeignKeyConstraint(
|
|
208
|
+
["account", "agent_id", "workflow_id"],
|
|
209
|
+
[
|
|
210
|
+
"retrieval_agent_workflow.account",
|
|
211
|
+
"retrieval_agent_workflow.agent_id",
|
|
212
|
+
"retrieval_agent_workflow.workflow_id",
|
|
213
|
+
],
|
|
214
|
+
onupdate="CASCADE",
|
|
215
|
+
ondelete="CASCADE",
|
|
216
|
+
),
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
retrieval_agent_context = sa.Table(
|
|
220
|
+
"retrieval_agent_context",
|
|
221
|
+
metadata,
|
|
222
|
+
sa.Column(
|
|
223
|
+
"id",
|
|
224
|
+
pg_dialect.UUID,
|
|
225
|
+
primary_key=True,
|
|
226
|
+
server_default=sa.func.uuid_generate_v4(),
|
|
227
|
+
),
|
|
228
|
+
sa.Column(
|
|
229
|
+
"account",
|
|
230
|
+
sa.String,
|
|
231
|
+
nullable=False,
|
|
232
|
+
index=True,
|
|
233
|
+
),
|
|
234
|
+
sa.Column(
|
|
235
|
+
"agent_id",
|
|
236
|
+
sa.String,
|
|
237
|
+
nullable=False,
|
|
238
|
+
index=True,
|
|
239
|
+
),
|
|
240
|
+
sa.Column(
|
|
241
|
+
"workflow_id",
|
|
242
|
+
sa.String,
|
|
243
|
+
nullable=False,
|
|
244
|
+
index=True,
|
|
245
|
+
),
|
|
246
|
+
sa.Column("context", JSONB, nullable=False),
|
|
247
|
+
sa.Column("created", sa.DateTime, default=sa.func.now()),
|
|
248
|
+
sa.Column("modified", sa.DateTime, onupdate=sa.func.now()),
|
|
249
|
+
sa.ForeignKeyConstraint(
|
|
250
|
+
["account", "agent_id", "workflow_id"],
|
|
251
|
+
[
|
|
252
|
+
"retrieval_agent_workflow.account",
|
|
253
|
+
"retrieval_agent_workflow.agent_id",
|
|
254
|
+
"retrieval_agent_workflow.workflow_id",
|
|
255
|
+
],
|
|
256
|
+
onupdate="CASCADE",
|
|
257
|
+
ondelete="CASCADE",
|
|
258
|
+
),
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
retrieval_agent_generation = sa.Table(
|
|
263
|
+
"retrieval_agent_generation",
|
|
264
|
+
metadata,
|
|
265
|
+
sa.Column(
|
|
266
|
+
"id",
|
|
267
|
+
pg_dialect.UUID,
|
|
268
|
+
primary_key=True,
|
|
269
|
+
server_default=sa.func.uuid_generate_v4(),
|
|
270
|
+
),
|
|
271
|
+
sa.Column(
|
|
272
|
+
"account",
|
|
273
|
+
sa.String,
|
|
274
|
+
nullable=False,
|
|
275
|
+
index=True,
|
|
276
|
+
),
|
|
277
|
+
sa.Column(
|
|
278
|
+
"agent_id",
|
|
279
|
+
sa.String,
|
|
280
|
+
nullable=False,
|
|
281
|
+
index=True,
|
|
282
|
+
),
|
|
283
|
+
sa.Column(
|
|
284
|
+
"workflow_id",
|
|
285
|
+
sa.String,
|
|
286
|
+
nullable=False,
|
|
287
|
+
index=True,
|
|
288
|
+
),
|
|
289
|
+
sa.Column("generation", JSONB, nullable=False),
|
|
290
|
+
sa.Column("created", sa.DateTime, default=sa.func.now()),
|
|
291
|
+
sa.Column("modified", sa.DateTime, onupdate=sa.func.now()),
|
|
292
|
+
sa.ForeignKeyConstraint(
|
|
293
|
+
["account", "agent_id", "workflow_id"],
|
|
294
|
+
[
|
|
295
|
+
"retrieval_agent_workflow.account",
|
|
296
|
+
"retrieval_agent_workflow.agent_id",
|
|
297
|
+
"retrieval_agent_workflow.workflow_id",
|
|
298
|
+
],
|
|
299
|
+
onupdate="CASCADE",
|
|
300
|
+
ondelete="CASCADE",
|
|
301
|
+
),
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
retrieval_agents_drivers = sa.Table(
|
|
305
|
+
"retrieval_agents_drivers",
|
|
306
|
+
metadata,
|
|
307
|
+
sa.Column(
|
|
308
|
+
"id",
|
|
309
|
+
pg_dialect.UUID,
|
|
310
|
+
primary_key=True,
|
|
311
|
+
server_default=sa.func.uuid_generate_v4(),
|
|
312
|
+
),
|
|
313
|
+
sa.Column("account", sa.String, nullable=False),
|
|
314
|
+
sa.Column("agent_id", sa.String, nullable=False),
|
|
315
|
+
sa.Column("driver", sa.String, nullable=False),
|
|
316
|
+
sa.Column("provider", sa.String, nullable=False),
|
|
317
|
+
sa.Column("identifier", sa.String, nullable=False),
|
|
318
|
+
sa.Column("config", JSONB, nullable=False),
|
|
319
|
+
sa.Column("created", sa.DateTime, default=sa.func.now()),
|
|
320
|
+
sa.Column("modified", sa.DateTime, onupdate=sa.func.now()),
|
|
321
|
+
sa.ForeignKeyConstraint(
|
|
322
|
+
["account", "agent_id"],
|
|
323
|
+
["retrieval_agent_config.account", "retrieval_agent_config.agent_id"],
|
|
324
|
+
onupdate="CASCADE",
|
|
325
|
+
ondelete="CASCADE",
|
|
326
|
+
),
|
|
327
|
+
sa.UniqueConstraint("account", "agent_id", "identifier"),
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
class AgentManager:
|
|
332
|
+
settings: DataManagerSettings
|
|
333
|
+
|
|
334
|
+
def __init__(
|
|
335
|
+
self,
|
|
336
|
+
database: databases.Database,
|
|
337
|
+
settings: DataManagerSettings,
|
|
338
|
+
):
|
|
339
|
+
self.database = database
|
|
340
|
+
self.settings = settings
|
|
341
|
+
|
|
342
|
+
@classmethod
|
|
343
|
+
async def from_settings(
|
|
344
|
+
cls,
|
|
345
|
+
settings: DataManagerSettings,
|
|
346
|
+
):
|
|
347
|
+
tracer_provider = get_telemetry(SERVICE_NAME)
|
|
348
|
+
if tracer_provider:
|
|
349
|
+
await init_telemetry(tracer_provider)
|
|
350
|
+
|
|
351
|
+
database = databases.Database(settings.postgresql_dsn)
|
|
352
|
+
|
|
353
|
+
return cls(database=database, settings=settings)
|
|
354
|
+
|
|
355
|
+
async def initialize(self):
|
|
356
|
+
await self.database.connect()
|
|
357
|
+
|
|
358
|
+
async def finalize(self):
|
|
359
|
+
await self.database.disconnect()
|
|
360
|
+
|
|
361
|
+
async def patch_driver(
|
|
362
|
+
self,
|
|
363
|
+
account: str,
|
|
364
|
+
agent_id: str,
|
|
365
|
+
driver: str,
|
|
366
|
+
config: DriverConfig,
|
|
367
|
+
):
|
|
368
|
+
try:
|
|
369
|
+
previous_config = await self.get_driver(account, agent_id, driver)
|
|
370
|
+
except exceptions.DriverNotFoundError:
|
|
371
|
+
# No previous config for this driver found, nothing to update
|
|
372
|
+
return
|
|
373
|
+
updated_config: DriverConfig = update_driver_config(config, previous_config)
|
|
374
|
+
statement = (
|
|
375
|
+
retrieval_agents_drivers.update()
|
|
376
|
+
.where(
|
|
377
|
+
retrieval_agents_drivers.c.account == account,
|
|
378
|
+
)
|
|
379
|
+
.where(retrieval_agents_drivers.c.agent_id == agent_id)
|
|
380
|
+
.where(retrieval_agents_drivers.c.id == driver)
|
|
381
|
+
.values(
|
|
382
|
+
driver=updated_config.name,
|
|
383
|
+
config=encrypt_fields(updated_config.config),
|
|
384
|
+
)
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
await self.database.execute(statement)
|
|
388
|
+
|
|
389
|
+
async def add_driver(
|
|
390
|
+
self,
|
|
391
|
+
agent_id: str,
|
|
392
|
+
account: str,
|
|
393
|
+
config: DriverConfig,
|
|
394
|
+
):
|
|
395
|
+
statement = retrieval_agents_drivers.insert().values(
|
|
396
|
+
account=account,
|
|
397
|
+
agent_id=agent_id,
|
|
398
|
+
driver=config.name,
|
|
399
|
+
provider=config.provider,
|
|
400
|
+
identifier=config.identifier,
|
|
401
|
+
config=encrypt_fields(config.config),
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
result = await self.database.execute(statement)
|
|
405
|
+
|
|
406
|
+
return str(result)
|
|
407
|
+
|
|
408
|
+
async def delete_driver(self, account: str, agent_id: str, driver: str):
|
|
409
|
+
statement = (
|
|
410
|
+
retrieval_agents_drivers.delete()
|
|
411
|
+
.where(retrieval_agents_drivers.c.account == account)
|
|
412
|
+
.where(retrieval_agents_drivers.c.agent_id == agent_id)
|
|
413
|
+
.where(retrieval_agents_drivers.c.id == driver)
|
|
414
|
+
)
|
|
415
|
+
await self.database.execute(statement)
|
|
416
|
+
|
|
417
|
+
async def get_driver(
|
|
418
|
+
self, account: str, agent_id: str, driver: str
|
|
419
|
+
) -> DriverConfig:
|
|
420
|
+
statement = (
|
|
421
|
+
retrieval_agents_drivers.select()
|
|
422
|
+
.where(retrieval_agents_drivers.c.account == account)
|
|
423
|
+
.where(retrieval_agents_drivers.c.agent_id == agent_id)
|
|
424
|
+
.where(retrieval_agents_drivers.c.id == driver)
|
|
425
|
+
)
|
|
426
|
+
result = await self.database.fetch_one(statement)
|
|
427
|
+
if result is None:
|
|
428
|
+
raise exceptions.DriverNotFoundError()
|
|
429
|
+
|
|
430
|
+
config_class = get_driver_config_klass(result["provider"])
|
|
431
|
+
driver_config = config_class.model_validate(
|
|
432
|
+
{
|
|
433
|
+
"id": str(result["id"]),
|
|
434
|
+
"name": result["driver"],
|
|
435
|
+
"identifier": result["identifier"],
|
|
436
|
+
"provider": result["provider"],
|
|
437
|
+
"config": result["config"],
|
|
438
|
+
}
|
|
439
|
+
)
|
|
440
|
+
decrypt_fields(driver_config.config)
|
|
441
|
+
return driver_config
|
|
442
|
+
|
|
443
|
+
async def get_drivers(self, account: str, agent_id: str) -> List[DriverConfig]:
|
|
444
|
+
statement = (
|
|
445
|
+
retrieval_agents_drivers.select()
|
|
446
|
+
.where(retrieval_agents_drivers.c.account == account)
|
|
447
|
+
.where(retrieval_agents_drivers.c.agent_id == agent_id)
|
|
448
|
+
)
|
|
449
|
+
results = await self.database.fetch_all(statement)
|
|
450
|
+
drivers = []
|
|
451
|
+
for result in results:
|
|
452
|
+
config_class = get_driver_config_klass(result["provider"])
|
|
453
|
+
driver = config_class.model_validate(
|
|
454
|
+
{
|
|
455
|
+
"id": str(result["id"]),
|
|
456
|
+
"name": result["driver"],
|
|
457
|
+
"identifier": result["identifier"],
|
|
458
|
+
"provider": result["provider"],
|
|
459
|
+
"config": result["config"],
|
|
460
|
+
}
|
|
461
|
+
)
|
|
462
|
+
decrypt_fields(driver.config)
|
|
463
|
+
drivers.append(driver)
|
|
464
|
+
return drivers
|
|
465
|
+
|
|
466
|
+
async def add_agent(
|
|
467
|
+
self, account: str, agent_id: str, memory: MemoryConfig, rules: Rules
|
|
468
|
+
):
|
|
469
|
+
statement = retrieval_agent_config.insert().values(
|
|
470
|
+
account=account,
|
|
471
|
+
agent_id=agent_id,
|
|
472
|
+
rules=rules.model_dump(),
|
|
473
|
+
memory=memory.model_dump(),
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
await self.database.execute(statement)
|
|
477
|
+
|
|
478
|
+
statement = retrieval_agent_workflow.insert().values(
|
|
479
|
+
account=account,
|
|
480
|
+
agent_id=agent_id,
|
|
481
|
+
workflow_id="default",
|
|
482
|
+
name="default",
|
|
483
|
+
description="Default workflow",
|
|
484
|
+
parameters={},
|
|
485
|
+
rules=Rules(rules=[]).model_dump(),
|
|
486
|
+
required=[],
|
|
487
|
+
is_deleted=False,
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
await self.database.execute(statement)
|
|
491
|
+
|
|
492
|
+
async def delete_agent(self, account: str, agent_id: str):
|
|
493
|
+
statement = (
|
|
494
|
+
retrieval_agent_config.delete()
|
|
495
|
+
.where(retrieval_agent_config.c.agent_id == agent_id)
|
|
496
|
+
.where(retrieval_agent_config.c.account == account)
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
await self.database.execute(statement)
|
|
500
|
+
|
|
501
|
+
async def add_workflow(self, account: str, agent_id: str, item: WorkflowInput):
|
|
502
|
+
statement = retrieval_agent_workflow.insert().values(
|
|
503
|
+
account=account,
|
|
504
|
+
agent_id=agent_id,
|
|
505
|
+
workflow_id=item.id,
|
|
506
|
+
name=item.name,
|
|
507
|
+
description=item.description,
|
|
508
|
+
parameters=item.parameters,
|
|
509
|
+
required=item.required,
|
|
510
|
+
rules=item.rules.model_dump()
|
|
511
|
+
if item.rules
|
|
512
|
+
else Rules(rules=[]).model_dump(),
|
|
513
|
+
is_deleted=False,
|
|
514
|
+
)
|
|
515
|
+
await self.database.execute(statement)
|
|
516
|
+
|
|
517
|
+
def _active_workflow_condition(self):
|
|
518
|
+
return retrieval_agent_workflow.c.is_deleted.is_(False)
|
|
519
|
+
|
|
520
|
+
async def ensure_workflow_active(
|
|
521
|
+
self, account: str, agent_id: str, workflow_id: str
|
|
522
|
+
):
|
|
523
|
+
statement = (
|
|
524
|
+
retrieval_agent_workflow.select()
|
|
525
|
+
.where(retrieval_agent_workflow.c.account == account)
|
|
526
|
+
.where(retrieval_agent_workflow.c.agent_id == agent_id)
|
|
527
|
+
.where(retrieval_agent_workflow.c.workflow_id == workflow_id)
|
|
528
|
+
.where(self._active_workflow_condition())
|
|
529
|
+
)
|
|
530
|
+
result = await self.database.fetch_one(statement)
|
|
531
|
+
if result is None:
|
|
532
|
+
raise exceptions.NotFoundError("Workflow not found")
|
|
533
|
+
|
|
534
|
+
async def set_workflow(
|
|
535
|
+
self, account: str, agent_id: str, workflow_id: str, item: WorkflowUpdate
|
|
536
|
+
):
|
|
537
|
+
await self.ensure_workflow_active(account, agent_id, workflow_id)
|
|
538
|
+
statement = (
|
|
539
|
+
retrieval_agent_workflow.update()
|
|
540
|
+
.where(retrieval_agent_workflow.c.account == account)
|
|
541
|
+
.where(retrieval_agent_workflow.c.agent_id == agent_id)
|
|
542
|
+
.where(retrieval_agent_workflow.c.workflow_id == workflow_id)
|
|
543
|
+
.where(self._active_workflow_condition())
|
|
544
|
+
.values(
|
|
545
|
+
name=item.name,
|
|
546
|
+
description=item.description,
|
|
547
|
+
parameters=item.parameters,
|
|
548
|
+
required=item.required,
|
|
549
|
+
rules=item.rules.model_dump()
|
|
550
|
+
if item.rules
|
|
551
|
+
else Rules(rules=[]).model_dump(),
|
|
552
|
+
)
|
|
553
|
+
)
|
|
554
|
+
await self.database.execute(statement)
|
|
555
|
+
|
|
556
|
+
async def delete_workflow(
|
|
557
|
+
self, account: str, agent_id: str, workflow_id: str, deleted_by: str
|
|
558
|
+
):
|
|
559
|
+
if workflow_id == "default":
|
|
560
|
+
raise exceptions.ProtectedWorkflowError(
|
|
561
|
+
"Default workflow cannot be deleted"
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
await self.ensure_workflow_active(account, agent_id, workflow_id)
|
|
565
|
+
|
|
566
|
+
statement = (
|
|
567
|
+
retrieval_agent_workflow.update()
|
|
568
|
+
.where(retrieval_agent_workflow.c.account == account)
|
|
569
|
+
.where(retrieval_agent_workflow.c.agent_id == agent_id)
|
|
570
|
+
.where(retrieval_agent_workflow.c.workflow_id == workflow_id)
|
|
571
|
+
.where(self._active_workflow_condition())
|
|
572
|
+
.values(is_deleted=True, deleted_by=deleted_by, deleted_at=utc_now())
|
|
573
|
+
)
|
|
574
|
+
await self.database.execute(statement)
|
|
575
|
+
|
|
576
|
+
async def get_expired_deleted_workflows(
|
|
577
|
+
self, older_than: datetime.timedelta = WORKFLOW_PURGE_RETENTION
|
|
578
|
+
):
|
|
579
|
+
threshold = utc_now() - older_than
|
|
580
|
+
statement = (
|
|
581
|
+
retrieval_agent_workflow.select()
|
|
582
|
+
.where(retrieval_agent_workflow.c.is_deleted.is_(True))
|
|
583
|
+
.where(retrieval_agent_workflow.c.deleted_at < threshold)
|
|
584
|
+
)
|
|
585
|
+
return await self.database.fetch_all(statement)
|
|
586
|
+
|
|
587
|
+
async def purge_deleted_workflow(
|
|
588
|
+
self, account: str, agent_id: str, workflow_id: str
|
|
589
|
+
):
|
|
590
|
+
statement = (
|
|
591
|
+
retrieval_agent_workflow.delete()
|
|
592
|
+
.where(retrieval_agent_workflow.c.account == account)
|
|
593
|
+
.where(retrieval_agent_workflow.c.agent_id == agent_id)
|
|
594
|
+
.where(retrieval_agent_workflow.c.workflow_id == workflow_id)
|
|
595
|
+
.where(retrieval_agent_workflow.c.is_deleted.is_(True))
|
|
596
|
+
)
|
|
597
|
+
await self.database.execute(statement)
|
|
598
|
+
|
|
599
|
+
async def set_workflow_rules(
|
|
600
|
+
self, agent_id: str, workflow_id: str, account: str, rules: Rules
|
|
601
|
+
):
|
|
602
|
+
await self.ensure_workflow_active(account, agent_id, workflow_id)
|
|
603
|
+
statement = (
|
|
604
|
+
retrieval_agent_workflow.update()
|
|
605
|
+
.values(rules=rules.model_dump())
|
|
606
|
+
.where(retrieval_agent_workflow.c.account == account)
|
|
607
|
+
.where(retrieval_agent_workflow.c.agent_id == agent_id)
|
|
608
|
+
.where(retrieval_agent_workflow.c.workflow_id == workflow_id)
|
|
609
|
+
.where(self._active_workflow_condition())
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
await self.database.execute(statement)
|
|
613
|
+
|
|
614
|
+
async def workflows_list(self, account: str, agent_id: str) -> List[WorkflowData]:
|
|
615
|
+
statement = (
|
|
616
|
+
retrieval_agent_workflow.select()
|
|
617
|
+
.where(retrieval_agent_workflow.c.account == account)
|
|
618
|
+
.where(retrieval_agent_workflow.c.agent_id == agent_id)
|
|
619
|
+
.where(self._active_workflow_condition())
|
|
620
|
+
)
|
|
621
|
+
results = await self.database.fetch_all(statement)
|
|
622
|
+
workflows = []
|
|
623
|
+
for result in results:
|
|
624
|
+
workflows.append(WorkflowData(id=result["workflow_id"], **result)) # type: ignore
|
|
625
|
+
return workflows
|
|
626
|
+
|
|
627
|
+
async def get_agent_config_basic(
|
|
628
|
+
self, account: str, agent_id: str
|
|
629
|
+
) -> RetrievalAgent:
|
|
630
|
+
"""Loads the basic configuration without preprocess, context, generation or postprocess
|
|
631
|
+
|
|
632
|
+
Args:
|
|
633
|
+
account (str): Account ID
|
|
634
|
+
agent_id (str): Agent ID
|
|
635
|
+
"""
|
|
636
|
+
|
|
637
|
+
statement = (
|
|
638
|
+
retrieval_agent_config.select()
|
|
639
|
+
.where(retrieval_agent_config.c.account == account)
|
|
640
|
+
.where(retrieval_agent_config.c.agent_id == agent_id)
|
|
641
|
+
)
|
|
642
|
+
|
|
643
|
+
result = await self.database.fetch_one(statement)
|
|
644
|
+
|
|
645
|
+
if result is None:
|
|
646
|
+
raise exceptions.NotFoundError("Agent config not found")
|
|
647
|
+
|
|
648
|
+
return RetrievalAgent(
|
|
649
|
+
account=result["account"],
|
|
650
|
+
agent_id=result["agent_id"],
|
|
651
|
+
description=result["description"],
|
|
652
|
+
memory=result["memory"],
|
|
653
|
+
title=result["title"],
|
|
654
|
+
instructions=result["instructions"],
|
|
655
|
+
created=result["created"],
|
|
656
|
+
modified=result["modified"],
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
async def get_agent_config(
|
|
660
|
+
self,
|
|
661
|
+
account: str,
|
|
662
|
+
agent_id: str,
|
|
663
|
+
internal_nucliadb_url: str | None = None,
|
|
664
|
+
default_memory: bool = False,
|
|
665
|
+
workflow_id: str = "default",
|
|
666
|
+
) -> RetrievalAgentConfig:
|
|
667
|
+
"""Loads the configuration in a single query to ensure we get a consistent view
|
|
668
|
+
|
|
669
|
+
Args:
|
|
670
|
+
account (str): Account ID
|
|
671
|
+
agent_id (str): Agent ID
|
|
672
|
+
internal_nucliadb_url (str | None): Internal NucliaDB URL to use if no memory is set
|
|
673
|
+
default_memory (bool): Whether to ignore the stored memory config and use a default one
|
|
674
|
+
workflow (str): Workflow name to load the agent for
|
|
675
|
+
"""
|
|
676
|
+
|
|
677
|
+
# Queries for each agent type
|
|
678
|
+
queries = [
|
|
679
|
+
sa.select(
|
|
680
|
+
sa.literal_column(f"'{table.name}'").label("kind"), # type:ignore
|
|
681
|
+
table.c.id,
|
|
682
|
+
sa.null().label("identifier"), # type:ignore
|
|
683
|
+
sa.null().label("name"),
|
|
684
|
+
sa.null().label("provider"),
|
|
685
|
+
column.label("config"),
|
|
686
|
+
)
|
|
687
|
+
.where(table.c.account == account)
|
|
688
|
+
.where(table.c.agent_id == agent_id)
|
|
689
|
+
.where(table.c.workflow_id == workflow_id)
|
|
690
|
+
for (table, column) in [
|
|
691
|
+
(retrieval_agent_preprocess, retrieval_agent_preprocess.c.preprocess),
|
|
692
|
+
(retrieval_agent_context, retrieval_agent_context.c.context),
|
|
693
|
+
(retrieval_agent_generation, retrieval_agent_generation.c.generation),
|
|
694
|
+
(
|
|
695
|
+
retrieval_agent_postprocess,
|
|
696
|
+
retrieval_agent_postprocess.c.postprocess,
|
|
697
|
+
),
|
|
698
|
+
]
|
|
699
|
+
]
|
|
700
|
+
# Query for drivers
|
|
701
|
+
queries.append(
|
|
702
|
+
sa.select(
|
|
703
|
+
sa.literal_column("'driver'").label("kind"), # type:ignore
|
|
704
|
+
retrieval_agents_drivers.c.id,
|
|
705
|
+
retrieval_agents_drivers.c.identifier, # type:ignore
|
|
706
|
+
retrieval_agents_drivers.c.driver.label("name"),
|
|
707
|
+
retrieval_agents_drivers.c.provider,
|
|
708
|
+
retrieval_agents_drivers.c.config,
|
|
709
|
+
)
|
|
710
|
+
.where(retrieval_agents_drivers.c.account == account)
|
|
711
|
+
.where(retrieval_agents_drivers.c.agent_id == agent_id)
|
|
712
|
+
)
|
|
713
|
+
# Query for rules
|
|
714
|
+
queries.append(
|
|
715
|
+
sa.select(
|
|
716
|
+
sa.literal_column("'rules'").label("kind"), # type:ignore
|
|
717
|
+
sa.null().label("id"),
|
|
718
|
+
sa.null().label("identifier"), # type:ignore
|
|
719
|
+
sa.null().label("name"),
|
|
720
|
+
sa.null().label("provider"),
|
|
721
|
+
retrieval_agent_config.c.rules.label("config"),
|
|
722
|
+
)
|
|
723
|
+
.where(retrieval_agent_config.c.account == account)
|
|
724
|
+
.where(retrieval_agent_config.c.agent_id == agent_id)
|
|
725
|
+
)
|
|
726
|
+
queries.append(
|
|
727
|
+
sa.select(
|
|
728
|
+
sa.literal_column("'memory'").label("kind"), # type:ignore
|
|
729
|
+
sa.null().label("id"),
|
|
730
|
+
sa.null().label("identifier"), # type:ignore
|
|
731
|
+
sa.null().label("name"),
|
|
732
|
+
sa.null().label("provider"),
|
|
733
|
+
retrieval_agent_config.c.memory.label("config"),
|
|
734
|
+
)
|
|
735
|
+
.where(retrieval_agent_config.c.account == account)
|
|
736
|
+
.where(retrieval_agent_config.c.agent_id == agent_id)
|
|
737
|
+
)
|
|
738
|
+
workflow_query = (
|
|
739
|
+
sa.select(
|
|
740
|
+
retrieval_agent_workflow.c.name, # type: ignore
|
|
741
|
+
retrieval_agent_workflow.c.description,
|
|
742
|
+
retrieval_agent_workflow.c.parameters, # type: ignore
|
|
743
|
+
retrieval_agent_workflow.c.rules,
|
|
744
|
+
)
|
|
745
|
+
.where(retrieval_agent_workflow.c.account == account)
|
|
746
|
+
.where(retrieval_agent_workflow.c.agent_id == agent_id)
|
|
747
|
+
.where(retrieval_agent_workflow.c.workflow_id == workflow_id)
|
|
748
|
+
.where(self._active_workflow_condition())
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
preprocess: list[Any] = []
|
|
752
|
+
context: list[Any] = []
|
|
753
|
+
generation: list[Any] = []
|
|
754
|
+
postprocess: list[Any] = []
|
|
755
|
+
drivers: list[Any] = []
|
|
756
|
+
rules: Rules | None = None
|
|
757
|
+
memory: MemoryConfig | None = None if not default_memory else MemoryConfig()
|
|
758
|
+
rows = await self.database.fetch_all(sa.union(*queries))
|
|
759
|
+
workflow_data = await self.database.fetch_one(workflow_query)
|
|
760
|
+
for row in rows:
|
|
761
|
+
match row["kind"]:
|
|
762
|
+
case retrieval_agent_preprocess.name:
|
|
763
|
+
preprocess.append(
|
|
764
|
+
get_agent_config_instance(
|
|
765
|
+
agent_config={"id": row["id"], **row["config"]},
|
|
766
|
+
agent_type="preprocess",
|
|
767
|
+
)
|
|
768
|
+
) # Validate config
|
|
769
|
+
case retrieval_agent_context.name:
|
|
770
|
+
context.append(
|
|
771
|
+
get_agent_config_instance(
|
|
772
|
+
agent_config={"id": row["id"], **row["config"]},
|
|
773
|
+
agent_type="context",
|
|
774
|
+
)
|
|
775
|
+
) # Validate config
|
|
776
|
+
case retrieval_agent_generation.name:
|
|
777
|
+
generation.append(
|
|
778
|
+
get_agent_config_instance(
|
|
779
|
+
{"id": row["id"], **row["config"]}, agent_type="generation"
|
|
780
|
+
)
|
|
781
|
+
) # Validate config
|
|
782
|
+
case retrieval_agent_postprocess.name:
|
|
783
|
+
postprocess.append(
|
|
784
|
+
get_agent_config_instance(
|
|
785
|
+
agent_config={"id": row["id"], **row["config"]},
|
|
786
|
+
agent_type="postprocess",
|
|
787
|
+
)
|
|
788
|
+
) # Validate config
|
|
789
|
+
case "driver":
|
|
790
|
+
driver: DriverConfig = get_driver_config_instance(
|
|
791
|
+
{
|
|
792
|
+
"id": str(row["id"]),
|
|
793
|
+
"identifier": row["identifier"],
|
|
794
|
+
"name": row["name"],
|
|
795
|
+
"provider": row["provider"],
|
|
796
|
+
"config": row["config"],
|
|
797
|
+
}
|
|
798
|
+
) # Validate config
|
|
799
|
+
decrypt_fields(driver.config)
|
|
800
|
+
drivers.append(driver)
|
|
801
|
+
|
|
802
|
+
case "rules":
|
|
803
|
+
rules = Rules.model_validate(row["config"])
|
|
804
|
+
case "memory":
|
|
805
|
+
if not default_memory and row["config"] is not None:
|
|
806
|
+
memory = MemoryConfig.model_validate(row["config"])
|
|
807
|
+
|
|
808
|
+
if workflow_data is None:
|
|
809
|
+
raise exceptions.NotFoundError("Workflow not found")
|
|
810
|
+
|
|
811
|
+
workflow = WorkflowData(
|
|
812
|
+
id=workflow_id,
|
|
813
|
+
name=workflow_data["name"],
|
|
814
|
+
description=workflow_data["description"],
|
|
815
|
+
parameters=workflow_data["parameters"],
|
|
816
|
+
rules=Rules.model_validate(workflow_data["rules"]),
|
|
817
|
+
required=workflow_data["required"] if "required" in workflow_data else [], # noqa
|
|
818
|
+
)
|
|
819
|
+
|
|
820
|
+
if rules is None:
|
|
821
|
+
raise exceptions.NotFoundError("Agent config not found")
|
|
822
|
+
|
|
823
|
+
if memory is None and internal_nucliadb_url is not None and not default_memory:
|
|
824
|
+
memory = MemoryConfig(
|
|
825
|
+
nucliadb=NucliaDBMemoryConfig(
|
|
826
|
+
url=internal_nucliadb_url, kbid=agent_id, internal=True
|
|
827
|
+
)
|
|
828
|
+
)
|
|
829
|
+
|
|
830
|
+
if memory is None:
|
|
831
|
+
raise Exception("Agent memory config not found")
|
|
832
|
+
|
|
833
|
+
return RetrievalAgentConfig(
|
|
834
|
+
preprocess=preprocess,
|
|
835
|
+
context=context,
|
|
836
|
+
generation=generation,
|
|
837
|
+
postprocess=postprocess,
|
|
838
|
+
drivers=drivers,
|
|
839
|
+
rules=rules,
|
|
840
|
+
memory=memory,
|
|
841
|
+
workflow=workflow,
|
|
842
|
+
)
|
|
843
|
+
|
|
844
|
+
async def set_rules(self, agent_id: str, account: str, rules: Rules):
|
|
845
|
+
statement = (
|
|
846
|
+
retrieval_agent_config.update()
|
|
847
|
+
.values(rules=rules.model_dump())
|
|
848
|
+
.where(retrieval_agent_config.c.account == account)
|
|
849
|
+
.where(retrieval_agent_config.c.agent_id == agent_id)
|
|
850
|
+
)
|
|
851
|
+
|
|
852
|
+
await self.database.execute(statement)
|
|
853
|
+
|
|
854
|
+
async def add_prompt(
|
|
855
|
+
self, agent_id: str, account: str, prompt: PromptConfig
|
|
856
|
+
) -> str:
|
|
857
|
+
statement = retrieval_agent_prompts.insert().values(
|
|
858
|
+
account=account,
|
|
859
|
+
agent_id=agent_id,
|
|
860
|
+
name=prompt.name,
|
|
861
|
+
prompt=prompt.prompt,
|
|
862
|
+
description=prompt.description,
|
|
863
|
+
arguments=[x.model_dump() for x in prompt.arguments]
|
|
864
|
+
if prompt.arguments is not None
|
|
865
|
+
else None,
|
|
866
|
+
icons=prompt.icons,
|
|
867
|
+
meta=prompt.meta,
|
|
868
|
+
)
|
|
869
|
+
result = await self.database.execute(statement)
|
|
870
|
+
return str(result)
|
|
871
|
+
|
|
872
|
+
async def set_prompt(
|
|
873
|
+
self, agent_id: str, account: str, prompt_id: str, prompt: PromptConfig
|
|
874
|
+
):
|
|
875
|
+
statement = (
|
|
876
|
+
retrieval_agent_prompts.update()
|
|
877
|
+
.values(
|
|
878
|
+
name=prompt.name,
|
|
879
|
+
description=prompt.description,
|
|
880
|
+
prompt=prompt.prompt,
|
|
881
|
+
arguments=[x.model_dump() for x in prompt.arguments]
|
|
882
|
+
if prompt.arguments is not None
|
|
883
|
+
else None,
|
|
884
|
+
icons=prompt.icons,
|
|
885
|
+
meta=prompt.meta,
|
|
886
|
+
)
|
|
887
|
+
.where(retrieval_agent_prompts.c.account == account)
|
|
888
|
+
.where(retrieval_agent_prompts.c.agent_id == agent_id)
|
|
889
|
+
.where(retrieval_agent_prompts.c.id == prompt_id)
|
|
890
|
+
)
|
|
891
|
+
|
|
892
|
+
await self.database.execute(statement)
|
|
893
|
+
|
|
894
|
+
async def delete_prompt(self, agent_id: str, account: str, prompt_id: str):
|
|
895
|
+
statement = (
|
|
896
|
+
retrieval_agent_prompts.delete()
|
|
897
|
+
.where(retrieval_agent_prompts.c.account == account)
|
|
898
|
+
.where(retrieval_agent_prompts.c.agent_id == agent_id)
|
|
899
|
+
.where(retrieval_agent_prompts.c.id == prompt_id)
|
|
900
|
+
)
|
|
901
|
+
|
|
902
|
+
await self.database.execute(statement)
|
|
903
|
+
|
|
904
|
+
async def get_prompt(
|
|
905
|
+
self, agent_id: str, account: str, prompt_id: str
|
|
906
|
+
) -> PromptConfig:
|
|
907
|
+
statement = (
|
|
908
|
+
retrieval_agent_prompts.select()
|
|
909
|
+
.where(retrieval_agent_prompts.c.account == account)
|
|
910
|
+
.where(retrieval_agent_prompts.c.agent_id == agent_id)
|
|
911
|
+
.where(retrieval_agent_prompts.c.id == prompt_id)
|
|
912
|
+
)
|
|
913
|
+
result = await self.database.fetch_one(statement)
|
|
914
|
+
if result is None:
|
|
915
|
+
raise exceptions.NotFoundError("Prompt not found")
|
|
916
|
+
|
|
917
|
+
prompt = PromptConfig(
|
|
918
|
+
name=result["name"],
|
|
919
|
+
description=result["description"],
|
|
920
|
+
prompt=result["prompt"],
|
|
921
|
+
arguments=[
|
|
922
|
+
PromptArgument.model_validate(arg) for arg in result["arguments"]
|
|
923
|
+
]
|
|
924
|
+
if result["arguments"] is not None
|
|
925
|
+
else None,
|
|
926
|
+
icons=result["icons"],
|
|
927
|
+
meta=result["meta"],
|
|
928
|
+
prompt_id=str(result["id"]),
|
|
929
|
+
)
|
|
930
|
+
|
|
931
|
+
return prompt
|
|
932
|
+
|
|
933
|
+
async def get_prompts(self, agent_id: str, account: str) -> List[PromptConfig]:
|
|
934
|
+
statement = (
|
|
935
|
+
retrieval_agent_prompts.select()
|
|
936
|
+
.where(retrieval_agent_config.c.account == account)
|
|
937
|
+
.where(retrieval_agent_config.c.agent_id == agent_id)
|
|
938
|
+
)
|
|
939
|
+
result = await self.database.fetch_all(statement)
|
|
940
|
+
prompts = []
|
|
941
|
+
for row in result:
|
|
942
|
+
prompt = PromptConfig(
|
|
943
|
+
name=row["name"],
|
|
944
|
+
description=row["description"],
|
|
945
|
+
prompt=row["prompt"],
|
|
946
|
+
arguments=[
|
|
947
|
+
PromptArgument.model_validate(arg) for arg in row["arguments"]
|
|
948
|
+
]
|
|
949
|
+
if row["arguments"] is not None
|
|
950
|
+
else None,
|
|
951
|
+
icons=row["icons"],
|
|
952
|
+
meta=row["meta"],
|
|
953
|
+
prompt_id=str(row["id"]),
|
|
954
|
+
)
|
|
955
|
+
prompts.append(prompt)
|
|
956
|
+
|
|
957
|
+
return prompts
|
|
958
|
+
|
|
959
|
+
async def set_memory(self, agent_id: str, account: str, memory: MemoryConfig):
|
|
960
|
+
statement = (
|
|
961
|
+
retrieval_agent_config.update()
|
|
962
|
+
.values(memory=memory.model_dump())
|
|
963
|
+
.where(retrieval_agent_config.c.account == account)
|
|
964
|
+
.where(retrieval_agent_config.c.agent_id == agent_id)
|
|
965
|
+
)
|
|
966
|
+
|
|
967
|
+
await self.database.execute(statement)
|
|
968
|
+
|
|
969
|
+
async def get_rules(self, account: str, agent_id: str) -> Rules:
|
|
970
|
+
statement = (
|
|
971
|
+
retrieval_agent_config.select()
|
|
972
|
+
.where(retrieval_agent_config.c.account == account)
|
|
973
|
+
.where(retrieval_agent_config.c.agent_id == agent_id)
|
|
974
|
+
)
|
|
975
|
+
result = await self.database.fetch_one(statement)
|
|
976
|
+
rules = Rules(rules=[])
|
|
977
|
+
|
|
978
|
+
if result is not None:
|
|
979
|
+
rules = Rules.model_validate(result["rules"])
|
|
980
|
+
|
|
981
|
+
return rules
|
|
982
|
+
|
|
983
|
+
async def get_workflow_rules(
|
|
984
|
+
self, account: str, agent_id: str, workflow_id: str
|
|
985
|
+
) -> Rules:
|
|
986
|
+
statement = (
|
|
987
|
+
retrieval_agent_workflow.select()
|
|
988
|
+
.where(retrieval_agent_workflow.c.account == account)
|
|
989
|
+
.where(retrieval_agent_workflow.c.agent_id == agent_id)
|
|
990
|
+
.where(retrieval_agent_workflow.c.workflow_id == workflow_id)
|
|
991
|
+
.where(self._active_workflow_condition())
|
|
992
|
+
)
|
|
993
|
+
result = await self.database.fetch_one(statement)
|
|
994
|
+
if result is None:
|
|
995
|
+
raise exceptions.NotFoundError("Workflow not found")
|
|
996
|
+
|
|
997
|
+
return Rules.model_validate(result["rules"])
|
|
998
|
+
|
|
999
|
+
async def add_preprocess(
|
|
1000
|
+
self,
|
|
1001
|
+
agent_id: str,
|
|
1002
|
+
account: str,
|
|
1003
|
+
agent: BaseModel,
|
|
1004
|
+
workflow_id: str = "default",
|
|
1005
|
+
):
|
|
1006
|
+
await self.ensure_workflow_active(account, agent_id, workflow_id)
|
|
1007
|
+
statement = retrieval_agent_preprocess.insert().values(
|
|
1008
|
+
account=account,
|
|
1009
|
+
agent_id=agent_id,
|
|
1010
|
+
workflow_id=workflow_id,
|
|
1011
|
+
preprocess=agent.model_dump(),
|
|
1012
|
+
)
|
|
1013
|
+
result = await self.database.execute(statement)
|
|
1014
|
+
return str(result)
|
|
1015
|
+
|
|
1016
|
+
async def patch_preprocess(
|
|
1017
|
+
self,
|
|
1018
|
+
agent_id: str,
|
|
1019
|
+
account: str,
|
|
1020
|
+
preprocess: str,
|
|
1021
|
+
agent: BaseModel,
|
|
1022
|
+
workflow_id: str = "default",
|
|
1023
|
+
):
|
|
1024
|
+
await self.ensure_workflow_active(account, agent_id, workflow_id)
|
|
1025
|
+
statement = (
|
|
1026
|
+
retrieval_agent_preprocess.update()
|
|
1027
|
+
.where(retrieval_agent_preprocess.c.id == preprocess)
|
|
1028
|
+
.where(retrieval_agent_preprocess.c.agent_id == agent_id)
|
|
1029
|
+
.where(retrieval_agent_preprocess.c.account == account)
|
|
1030
|
+
.where(retrieval_agent_preprocess.c.workflow_id == workflow_id)
|
|
1031
|
+
.values(preprocess=agent.model_dump())
|
|
1032
|
+
)
|
|
1033
|
+
await self.database.execute(statement)
|
|
1034
|
+
|
|
1035
|
+
async def delete_preprocess(
|
|
1036
|
+
self, account: str, agent_id: str, preprocess: str, workflow_id: str = "default"
|
|
1037
|
+
):
|
|
1038
|
+
await self.ensure_workflow_active(account, agent_id, workflow_id)
|
|
1039
|
+
statement = (
|
|
1040
|
+
retrieval_agent_preprocess.delete()
|
|
1041
|
+
.where(retrieval_agent_preprocess.c.account == account)
|
|
1042
|
+
.where(retrieval_agent_preprocess.c.agent_id == agent_id)
|
|
1043
|
+
.where(retrieval_agent_preprocess.c.workflow_id == workflow_id)
|
|
1044
|
+
.where(retrieval_agent_preprocess.c.id == preprocess)
|
|
1045
|
+
)
|
|
1046
|
+
await self.database.execute(statement)
|
|
1047
|
+
|
|
1048
|
+
async def get_preprocess(
|
|
1049
|
+
self, account: str, agent_id: str, workflow_id: str = "default"
|
|
1050
|
+
) -> List[AgentConfig]:
|
|
1051
|
+
await self.ensure_workflow_active(account, agent_id, workflow_id)
|
|
1052
|
+
statement = (
|
|
1053
|
+
retrieval_agent_preprocess.select()
|
|
1054
|
+
.where(retrieval_agent_preprocess.c.account == account)
|
|
1055
|
+
.where(retrieval_agent_preprocess.c.agent_id == agent_id)
|
|
1056
|
+
.where(retrieval_agent_preprocess.c.workflow_id == workflow_id)
|
|
1057
|
+
)
|
|
1058
|
+
results = await self.database.fetch_all(statement)
|
|
1059
|
+
preprocess = []
|
|
1060
|
+
for result in results:
|
|
1061
|
+
base_config = get_agent_config_instance(
|
|
1062
|
+
agent_config=result["preprocess"], agent_type="preprocess"
|
|
1063
|
+
)
|
|
1064
|
+
base_config.id = str(result["id"])
|
|
1065
|
+
preprocess.append(base_config)
|
|
1066
|
+
|
|
1067
|
+
return preprocess
|
|
1068
|
+
|
|
1069
|
+
async def delete_postprocess(
|
|
1070
|
+
self,
|
|
1071
|
+
account: str,
|
|
1072
|
+
agent_id: str,
|
|
1073
|
+
postprocess: str,
|
|
1074
|
+
workflow_id: str = "default",
|
|
1075
|
+
):
|
|
1076
|
+
await self.ensure_workflow_active(account, agent_id, workflow_id)
|
|
1077
|
+
statement = (
|
|
1078
|
+
retrieval_agent_postprocess.delete()
|
|
1079
|
+
.where(retrieval_agent_postprocess.c.account == account)
|
|
1080
|
+
.where(retrieval_agent_postprocess.c.agent_id == agent_id)
|
|
1081
|
+
.where(retrieval_agent_postprocess.c.workflow_id == workflow_id)
|
|
1082
|
+
.where(retrieval_agent_postprocess.c.id == postprocess)
|
|
1083
|
+
)
|
|
1084
|
+
await self.database.execute(statement)
|
|
1085
|
+
|
|
1086
|
+
async def add_postprocess(
|
|
1087
|
+
self,
|
|
1088
|
+
agent_id: str,
|
|
1089
|
+
account: str,
|
|
1090
|
+
agent: BaseModel,
|
|
1091
|
+
workflow_id: str = "default",
|
|
1092
|
+
):
|
|
1093
|
+
await self.ensure_workflow_active(account, agent_id, workflow_id)
|
|
1094
|
+
statement = retrieval_agent_postprocess.insert().values(
|
|
1095
|
+
account=account,
|
|
1096
|
+
agent_id=agent_id,
|
|
1097
|
+
postprocess=agent.model_dump(),
|
|
1098
|
+
workflow_id=workflow_id,
|
|
1099
|
+
)
|
|
1100
|
+
result = await self.database.execute(statement)
|
|
1101
|
+
return str(result)
|
|
1102
|
+
|
|
1103
|
+
async def patch_postprocess(
|
|
1104
|
+
self,
|
|
1105
|
+
agent_id: str,
|
|
1106
|
+
account: str,
|
|
1107
|
+
postprocess: str,
|
|
1108
|
+
agent: BaseModel,
|
|
1109
|
+
workflow_id: str = "default",
|
|
1110
|
+
):
|
|
1111
|
+
await self.ensure_workflow_active(account, agent_id, workflow_id)
|
|
1112
|
+
statement = (
|
|
1113
|
+
retrieval_agent_postprocess.update()
|
|
1114
|
+
.where(retrieval_agent_postprocess.c.id == postprocess)
|
|
1115
|
+
.where(retrieval_agent_postprocess.c.agent_id == agent_id)
|
|
1116
|
+
.where(retrieval_agent_postprocess.c.account == account)
|
|
1117
|
+
.where(retrieval_agent_postprocess.c.workflow_id == workflow_id)
|
|
1118
|
+
.values(postprocess=agent.model_dump())
|
|
1119
|
+
)
|
|
1120
|
+
await self.database.execute(statement)
|
|
1121
|
+
|
|
1122
|
+
async def get_postprocess(
|
|
1123
|
+
self, account: str, agent_id: str, workflow_id: str = "default"
|
|
1124
|
+
) -> List[AgentConfig]:
|
|
1125
|
+
await self.ensure_workflow_active(account, agent_id, workflow_id)
|
|
1126
|
+
statement = (
|
|
1127
|
+
retrieval_agent_postprocess.select()
|
|
1128
|
+
.where(retrieval_agent_postprocess.c.account == account)
|
|
1129
|
+
.where(retrieval_agent_postprocess.c.agent_id == agent_id)
|
|
1130
|
+
.where(retrieval_agent_postprocess.c.workflow_id == workflow_id)
|
|
1131
|
+
)
|
|
1132
|
+
results = await self.database.fetch_all(statement)
|
|
1133
|
+
postprocess = []
|
|
1134
|
+
for result in results:
|
|
1135
|
+
base_config = get_agent_config_instance(
|
|
1136
|
+
agent_config=result["postprocess"], agent_type="postprocess"
|
|
1137
|
+
)
|
|
1138
|
+
base_config.id = str(result["id"])
|
|
1139
|
+
postprocess.append(base_config)
|
|
1140
|
+
|
|
1141
|
+
return postprocess
|
|
1142
|
+
|
|
1143
|
+
async def add_context(
|
|
1144
|
+
self,
|
|
1145
|
+
agent_id: str,
|
|
1146
|
+
account: str,
|
|
1147
|
+
agent: BaseModel,
|
|
1148
|
+
workflow_id: str = "default",
|
|
1149
|
+
):
|
|
1150
|
+
await self.ensure_workflow_active(account, agent_id, workflow_id)
|
|
1151
|
+
statement = retrieval_agent_context.insert().values(
|
|
1152
|
+
account=account,
|
|
1153
|
+
agent_id=agent_id,
|
|
1154
|
+
context=agent.model_dump(),
|
|
1155
|
+
workflow_id=workflow_id,
|
|
1156
|
+
)
|
|
1157
|
+
result = await self.database.execute(statement)
|
|
1158
|
+
|
|
1159
|
+
return str(result)
|
|
1160
|
+
|
|
1161
|
+
async def delete_context(
|
|
1162
|
+
self, agent_id: str, account: str, context: UUID, workflow_id: str = "default"
|
|
1163
|
+
):
|
|
1164
|
+
await self.ensure_workflow_active(account, agent_id, workflow_id)
|
|
1165
|
+
statement = (
|
|
1166
|
+
retrieval_agent_context.delete()
|
|
1167
|
+
.where(retrieval_agent_context.c.id == context)
|
|
1168
|
+
.where(retrieval_agent_context.c.agent_id == agent_id)
|
|
1169
|
+
.where(retrieval_agent_context.c.account == account)
|
|
1170
|
+
.where(retrieval_agent_context.c.workflow_id == workflow_id)
|
|
1171
|
+
)
|
|
1172
|
+
await self.database.execute(statement)
|
|
1173
|
+
|
|
1174
|
+
async def patch_context(
|
|
1175
|
+
self,
|
|
1176
|
+
agent_id: str,
|
|
1177
|
+
account: str,
|
|
1178
|
+
context: UUID,
|
|
1179
|
+
agent: BaseModel,
|
|
1180
|
+
workflow_id: str = "default",
|
|
1181
|
+
):
|
|
1182
|
+
await self.ensure_workflow_active(account, agent_id, workflow_id)
|
|
1183
|
+
statement = (
|
|
1184
|
+
retrieval_agent_context.update()
|
|
1185
|
+
.where(retrieval_agent_context.c.id == context)
|
|
1186
|
+
.where(retrieval_agent_context.c.agent_id == agent_id)
|
|
1187
|
+
.where(retrieval_agent_context.c.account == account)
|
|
1188
|
+
.where(retrieval_agent_context.c.workflow_id == workflow_id)
|
|
1189
|
+
.values(context=agent.model_dump())
|
|
1190
|
+
)
|
|
1191
|
+
await self.database.execute(statement)
|
|
1192
|
+
|
|
1193
|
+
async def get_context(
|
|
1194
|
+
self, account: str, agent_id: str, workflow_id: str = "default"
|
|
1195
|
+
) -> List[AgentConfig]:
|
|
1196
|
+
await self.ensure_workflow_active(account, agent_id, workflow_id)
|
|
1197
|
+
statement = (
|
|
1198
|
+
retrieval_agent_context.select()
|
|
1199
|
+
.where(retrieval_agent_context.c.account == account)
|
|
1200
|
+
.where(retrieval_agent_context.c.agent_id == agent_id)
|
|
1201
|
+
.where(retrieval_agent_context.c.workflow_id == workflow_id)
|
|
1202
|
+
)
|
|
1203
|
+
results = await self.database.fetch_all(statement)
|
|
1204
|
+
context = []
|
|
1205
|
+
for result in results:
|
|
1206
|
+
base_config = get_agent_config_instance(
|
|
1207
|
+
agent_config=result["context"], agent_type="context"
|
|
1208
|
+
)
|
|
1209
|
+
base_config.id = str(result["id"])
|
|
1210
|
+
context.append(base_config)
|
|
1211
|
+
|
|
1212
|
+
return context
|
|
1213
|
+
|
|
1214
|
+
async def add_generation(
|
|
1215
|
+
self,
|
|
1216
|
+
agent_id: str,
|
|
1217
|
+
account: str,
|
|
1218
|
+
agent: BaseModel,
|
|
1219
|
+
workflow_id: str = "default",
|
|
1220
|
+
):
|
|
1221
|
+
await self.ensure_workflow_active(account, agent_id, workflow_id)
|
|
1222
|
+
statement = retrieval_agent_generation.insert().values(
|
|
1223
|
+
account=account,
|
|
1224
|
+
agent_id=agent_id,
|
|
1225
|
+
generation=agent.model_dump(),
|
|
1226
|
+
workflow_id=workflow_id,
|
|
1227
|
+
)
|
|
1228
|
+
result = await self.database.execute(statement)
|
|
1229
|
+
|
|
1230
|
+
return str(result)
|
|
1231
|
+
|
|
1232
|
+
async def delete_generation(
|
|
1233
|
+
self,
|
|
1234
|
+
agent_id: str,
|
|
1235
|
+
account: str,
|
|
1236
|
+
generation: str,
|
|
1237
|
+
workflow_id: str = "default",
|
|
1238
|
+
):
|
|
1239
|
+
await self.ensure_workflow_active(account, agent_id, workflow_id)
|
|
1240
|
+
statement = (
|
|
1241
|
+
retrieval_agent_generation.delete()
|
|
1242
|
+
.where(retrieval_agent_generation.c.id == generation)
|
|
1243
|
+
.where(retrieval_agent_generation.c.agent_id == agent_id)
|
|
1244
|
+
.where(retrieval_agent_generation.c.account == account)
|
|
1245
|
+
.where(retrieval_agent_generation.c.workflow_id == workflow_id)
|
|
1246
|
+
)
|
|
1247
|
+
await self.database.execute(statement)
|
|
1248
|
+
|
|
1249
|
+
async def patch_generation(
|
|
1250
|
+
self,
|
|
1251
|
+
agent_id: str,
|
|
1252
|
+
account: str,
|
|
1253
|
+
generation: str,
|
|
1254
|
+
agent: BaseModel,
|
|
1255
|
+
workflow_id: str = "default",
|
|
1256
|
+
):
|
|
1257
|
+
await self.ensure_workflow_active(account, agent_id, workflow_id)
|
|
1258
|
+
statement = (
|
|
1259
|
+
retrieval_agent_generation.update()
|
|
1260
|
+
.where(retrieval_agent_generation.c.id == generation)
|
|
1261
|
+
.where(retrieval_agent_generation.c.agent_id == agent_id)
|
|
1262
|
+
.where(retrieval_agent_generation.c.account == account)
|
|
1263
|
+
.where(retrieval_agent_generation.c.workflow_id == workflow_id)
|
|
1264
|
+
.values(generation=agent.model_dump())
|
|
1265
|
+
)
|
|
1266
|
+
await self.database.execute(statement)
|
|
1267
|
+
|
|
1268
|
+
async def get_generation(
|
|
1269
|
+
self, account: str, agent_id: str, workflow_id: str = "default"
|
|
1270
|
+
) -> List[AgentConfig]:
|
|
1271
|
+
await self.ensure_workflow_active(account, agent_id, workflow_id)
|
|
1272
|
+
statement = (
|
|
1273
|
+
retrieval_agent_generation.select()
|
|
1274
|
+
.where(retrieval_agent_generation.c.account == account)
|
|
1275
|
+
.where(retrieval_agent_generation.c.agent_id == agent_id)
|
|
1276
|
+
.where(retrieval_agent_generation.c.workflow_id == workflow_id)
|
|
1277
|
+
)
|
|
1278
|
+
results = await self.database.fetch_all(statement)
|
|
1279
|
+
generation = []
|
|
1280
|
+
for result in results:
|
|
1281
|
+
base_config = get_agent_config_instance(
|
|
1282
|
+
agent_config=result["generation"], agent_type="generation"
|
|
1283
|
+
)
|
|
1284
|
+
base_config.id = str(result["id"])
|
|
1285
|
+
generation.append(base_config)
|
|
1286
|
+
|
|
1287
|
+
return generation
|
|
1288
|
+
|
|
1289
|
+
async def export(
|
|
1290
|
+
self, account: str, agent_id: str, passphrase: str
|
|
1291
|
+
) -> tuple[bytes, str | None]:
|
|
1292
|
+
# Passphrase validation
|
|
1293
|
+
if len(passphrase) < 16:
|
|
1294
|
+
raise Exception("Passphrase too short, minimum 16 characters required")
|
|
1295
|
+
workflow_config = {}
|
|
1296
|
+
default_config = None
|
|
1297
|
+
workflows = [wf.id for wf in await self.workflows_list(account, agent_id)]
|
|
1298
|
+
for workflow in workflows:
|
|
1299
|
+
try:
|
|
1300
|
+
agent_config = await self.get_agent_config(
|
|
1301
|
+
account, agent_id, default_memory=True, workflow_id=workflow
|
|
1302
|
+
)
|
|
1303
|
+
if workflow == "default":
|
|
1304
|
+
default_config = agent_config
|
|
1305
|
+
workflow_config[workflow] = agent_config
|
|
1306
|
+
except Exception:
|
|
1307
|
+
logger.exception("Retrieval agent not found for export")
|
|
1308
|
+
raise Exception("Retrieval agent not found for export")
|
|
1309
|
+
|
|
1310
|
+
# XXX: ID sanitization not required since current implementation ignores IDs on import
|
|
1311
|
+
|
|
1312
|
+
export_model = RetrievalAgentExportV1(
|
|
1313
|
+
agent_config=default_config,
|
|
1314
|
+
agent_config_workflows=workflow_config,
|
|
1315
|
+
prompts=await self.get_prompts(agent_id, account),
|
|
1316
|
+
)
|
|
1317
|
+
export_bytes = export_model.model_dump_json().encode("utf-8")
|
|
1318
|
+
key, salt = fernet_key_from_passphrase(passphrase, None)
|
|
1319
|
+
fernet = Fernet(key)
|
|
1320
|
+
try:
|
|
1321
|
+
encrypted_bytes = salt + fernet.encrypt(export_bytes)
|
|
1322
|
+
except Exception:
|
|
1323
|
+
logger.exception("Error encrypting retrieval agent export")
|
|
1324
|
+
raise Exception("Error encrypting export")
|
|
1325
|
+
|
|
1326
|
+
return encrypted_bytes, None
|
|
1327
|
+
|
|
1328
|
+
async def import_config(
|
|
1329
|
+
self,
|
|
1330
|
+
account: str,
|
|
1331
|
+
agent_id: str,
|
|
1332
|
+
import_file: UploadFile,
|
|
1333
|
+
passphrase: str,
|
|
1334
|
+
overwrite: bool,
|
|
1335
|
+
):
|
|
1336
|
+
# Get the current agent config
|
|
1337
|
+
try:
|
|
1338
|
+
destination_agent_config = await self.get_agent_config(
|
|
1339
|
+
account, agent_id, default_memory=True, workflow_id="default"
|
|
1340
|
+
)
|
|
1341
|
+
except exceptions.NotFoundError:
|
|
1342
|
+
raise
|
|
1343
|
+
# If not overwriting and the agent config is not empty, raise error
|
|
1344
|
+
if not destination_agent_config.is_empty():
|
|
1345
|
+
if not overwrite:
|
|
1346
|
+
raise exceptions.InvalidTargetAgentError()
|
|
1347
|
+
else:
|
|
1348
|
+
# Delete current configuration
|
|
1349
|
+
await self.delete_agent(account=account, agent_id=agent_id)
|
|
1350
|
+
# Recreate empty configuration
|
|
1351
|
+
await self.add_agent(
|
|
1352
|
+
account=account,
|
|
1353
|
+
agent_id=agent_id,
|
|
1354
|
+
memory=MemoryConfig(),
|
|
1355
|
+
rules=Rules(rules=[]),
|
|
1356
|
+
)
|
|
1357
|
+
|
|
1358
|
+
# Read the salt (first 16 bytes)
|
|
1359
|
+
salt = await import_file.read(16)
|
|
1360
|
+
key, _ = fernet_key_from_passphrase(passphrase, salt)
|
|
1361
|
+
fernet = Fernet(key)
|
|
1362
|
+
# Read the rest of the file in chunks
|
|
1363
|
+
|
|
1364
|
+
import_bytes = bytearray()
|
|
1365
|
+
while True:
|
|
1366
|
+
chunk = await import_file.read(self.settings.export_read_chunk_size)
|
|
1367
|
+
if not chunk:
|
|
1368
|
+
break
|
|
1369
|
+
import_bytes.extend(chunk)
|
|
1370
|
+
if len(import_bytes) > self.settings.export_read_max_size:
|
|
1371
|
+
await import_file.close()
|
|
1372
|
+
raise exceptions.ParseExportError("Import file too large")
|
|
1373
|
+
await import_file.close()
|
|
1374
|
+
# Decrypt the bytes
|
|
1375
|
+
try:
|
|
1376
|
+
encrypted_bytes = fernet.decrypt(bytes(import_bytes))
|
|
1377
|
+
except Exception as e:
|
|
1378
|
+
raise exceptions.ExportEncryptionError from e
|
|
1379
|
+
# Load the model
|
|
1380
|
+
try:
|
|
1381
|
+
parsed_export = retrievalAgentAdapter.validate_json(
|
|
1382
|
+
encrypted_bytes.decode("utf-8")
|
|
1383
|
+
)
|
|
1384
|
+
except ValidationError as e:
|
|
1385
|
+
raise exceptions.ParseExportError from e
|
|
1386
|
+
|
|
1387
|
+
if isinstance(parsed_export, RetrievalAgentExportV1):
|
|
1388
|
+
workflow_configs: dict[str, RetrievalAgentConfig] = {}
|
|
1389
|
+
agent_config = parsed_export.agent_config
|
|
1390
|
+
if agent_config is not None:
|
|
1391
|
+
workflow_id = "default"
|
|
1392
|
+
workflow_configs[workflow_id] = agent_config
|
|
1393
|
+
for wf_id, wf_config in parsed_export.agent_config_workflows.items():
|
|
1394
|
+
workflow_configs[wf_id] = wf_config
|
|
1395
|
+
|
|
1396
|
+
drivers = {}
|
|
1397
|
+
agent_rules = None
|
|
1398
|
+
try:
|
|
1399
|
+
for workflow_id, agent_config_workflow in workflow_configs.items():
|
|
1400
|
+
for driver in agent_config_workflow.drivers:
|
|
1401
|
+
drivers[driver.name] = driver
|
|
1402
|
+
agent_rules = agent_config_workflow.rules
|
|
1403
|
+
wf = agent_config_workflow.workflow
|
|
1404
|
+
if workflow_id == "default":
|
|
1405
|
+
await self.set_workflow(
|
|
1406
|
+
account=account,
|
|
1407
|
+
agent_id=agent_id,
|
|
1408
|
+
workflow_id=workflow_id,
|
|
1409
|
+
item=WorkflowUpdate(
|
|
1410
|
+
name=wf.name,
|
|
1411
|
+
description=wf.description or "",
|
|
1412
|
+
parameters=wf.parameters or {},
|
|
1413
|
+
rules=wf.rules,
|
|
1414
|
+
),
|
|
1415
|
+
)
|
|
1416
|
+
else:
|
|
1417
|
+
await self.add_workflow(
|
|
1418
|
+
account=account,
|
|
1419
|
+
agent_id=agent_id,
|
|
1420
|
+
item=WorkflowInput(
|
|
1421
|
+
id=workflow_id,
|
|
1422
|
+
name=wf.name,
|
|
1423
|
+
description=wf.description,
|
|
1424
|
+
parameters=wf.parameters,
|
|
1425
|
+
rules=wf.rules or Rules(rules=[]),
|
|
1426
|
+
),
|
|
1427
|
+
)
|
|
1428
|
+
|
|
1429
|
+
# Store preprocess
|
|
1430
|
+
for preprocess in agent_config_workflow.preprocess:
|
|
1431
|
+
await self.add_preprocess(
|
|
1432
|
+
agent_id=agent_id,
|
|
1433
|
+
account=account,
|
|
1434
|
+
agent=preprocess,
|
|
1435
|
+
workflow_id=workflow_id,
|
|
1436
|
+
)
|
|
1437
|
+
# Store context
|
|
1438
|
+
for context in agent_config_workflow.context:
|
|
1439
|
+
await self.add_context(
|
|
1440
|
+
agent_id=agent_id,
|
|
1441
|
+
account=account,
|
|
1442
|
+
agent=context,
|
|
1443
|
+
workflow_id=workflow_id,
|
|
1444
|
+
)
|
|
1445
|
+
# Store generation
|
|
1446
|
+
for generation in agent_config_workflow.generation:
|
|
1447
|
+
await self.add_generation(
|
|
1448
|
+
agent_id=agent_id,
|
|
1449
|
+
account=account,
|
|
1450
|
+
agent=generation,
|
|
1451
|
+
workflow_id=workflow_id,
|
|
1452
|
+
)
|
|
1453
|
+
# Store postprocess
|
|
1454
|
+
for postprocess in agent_config_workflow.postprocess:
|
|
1455
|
+
await self.add_postprocess(
|
|
1456
|
+
agent_id=agent_id,
|
|
1457
|
+
account=account,
|
|
1458
|
+
agent=postprocess,
|
|
1459
|
+
workflow_id=workflow_id,
|
|
1460
|
+
)
|
|
1461
|
+
|
|
1462
|
+
if agent_rules is not None:
|
|
1463
|
+
# Store rules
|
|
1464
|
+
await self.set_rules(
|
|
1465
|
+
agent_id=agent_id,
|
|
1466
|
+
account=account,
|
|
1467
|
+
rules=agent_rules,
|
|
1468
|
+
)
|
|
1469
|
+
# XXX: We are pruposely not importing memory configuration and will use the one already set
|
|
1470
|
+
# As KBs are created with the agent, if when importing we overwrite, we leave a dangling KB
|
|
1471
|
+
# And cross region imports with KB memory would break as well
|
|
1472
|
+
# This can be revisited in future versions if needed
|
|
1473
|
+
|
|
1474
|
+
# Store drivers — done once after all workflows are processed to avoid
|
|
1475
|
+
# duplicate insertions (drivers are global per-agent but appear in each
|
|
1476
|
+
# workflow's exported config).
|
|
1477
|
+
for driver in drivers.values():
|
|
1478
|
+
await self.add_driver(
|
|
1479
|
+
agent_id=agent_id, account=account, config=driver
|
|
1480
|
+
)
|
|
1481
|
+
for prompt in parsed_export.prompts:
|
|
1482
|
+
await self.add_prompt(
|
|
1483
|
+
agent_id=agent_id,
|
|
1484
|
+
account=account,
|
|
1485
|
+
prompt=prompt,
|
|
1486
|
+
)
|
|
1487
|
+
except Exception as e:
|
|
1488
|
+
raise exceptions.ParseExportError(
|
|
1489
|
+
"Failed to import retrieval agent configuration"
|
|
1490
|
+
) from e
|
|
1491
|
+
else:
|
|
1492
|
+
raise exceptions.ParseExportError("Unsupported export version")
|
|
1493
|
+
|
|
1494
|
+
|
|
1495
|
+
def update_driver_config(
|
|
1496
|
+
config: DriverConfig, previous_config: DriverConfig
|
|
1497
|
+
) -> DriverConfig:
|
|
1498
|
+
# Generate configuration for values set in the request. All values present on
|
|
1499
|
+
# the request will override values of the current configuration. If there was
|
|
1500
|
+
# any value on the stored configuration NOT SET yet, and it's not on the update
|
|
1501
|
+
# the default will be added.
|
|
1502
|
+
desired_configuration = config.model_dump(exclude_unset=True)
|
|
1503
|
+
updated_configuration = previous_config.model_dump()
|
|
1504
|
+
|
|
1505
|
+
return config.__class__(**deep_update(updated_configuration, desired_configuration))
|
|
1506
|
+
|
|
1507
|
+
|
|
1508
|
+
def deep_update(original: dict, updates: dict) -> dict:
|
|
1509
|
+
"""
|
|
1510
|
+
Recursively update a nested dictionary.
|
|
1511
|
+
"""
|
|
1512
|
+
for key, value in updates.items():
|
|
1513
|
+
if (
|
|
1514
|
+
key in original
|
|
1515
|
+
and isinstance(original[key], dict)
|
|
1516
|
+
and isinstance(value, dict)
|
|
1517
|
+
):
|
|
1518
|
+
deep_update(original[key], value)
|
|
1519
|
+
else:
|
|
1520
|
+
original[key] = value
|
|
1521
|
+
return original
|