hyperforge 1.0.0.post19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. hyperforge/__init__.py +16 -0
  2. hyperforge/agent.py +81 -0
  3. hyperforge/api/__init__.py +20 -0
  4. hyperforge/api/app.py +155 -0
  5. hyperforge/api/authentication.py +271 -0
  6. hyperforge/api/commands.py +33 -0
  7. hyperforge/api/internal/__init__.py +4 -0
  8. hyperforge/api/internal/inspect.py +30 -0
  9. hyperforge/api/internal/router.py +3 -0
  10. hyperforge/api/logging.py +18 -0
  11. hyperforge/api/models.py +129 -0
  12. hyperforge/api/session.py +197 -0
  13. hyperforge/api/settings.py +38 -0
  14. hyperforge/api/utils.py +354 -0
  15. hyperforge/api/v1/__init__.py +23 -0
  16. hyperforge/api/v1/agents.py +531 -0
  17. hyperforge/api/v1/interaction.py +430 -0
  18. hyperforge/api/v1/mcp_content.py +311 -0
  19. hyperforge/api/v1/mcp_interaction.py +322 -0
  20. hyperforge/api/v1/oauth.py +60 -0
  21. hyperforge/api/v1/prompt.py +129 -0
  22. hyperforge/api/v1/router.py +3 -0
  23. hyperforge/api/v1/schema.py +56 -0
  24. hyperforge/api/v1/session.py +182 -0
  25. hyperforge/api/v1/utils.py +12 -0
  26. hyperforge/api/v1/workflows.py +643 -0
  27. hyperforge/arag.py +28 -0
  28. hyperforge/broker/__init__.py +52 -0
  29. hyperforge/broker/local.py +116 -0
  30. hyperforge/broker/redis.py +161 -0
  31. hyperforge/configure.py +571 -0
  32. hyperforge/context/__init__.py +0 -0
  33. hyperforge/context/agent.py +377 -0
  34. hyperforge/context/config.py +103 -0
  35. hyperforge/database.py +3 -0
  36. hyperforge/db/__init__.py +6 -0
  37. hyperforge/db/agents.py +1521 -0
  38. hyperforge/db/encryption.py +91 -0
  39. hyperforge/db/exceptions.py +26 -0
  40. hyperforge/db/settings.py +16 -0
  41. hyperforge/db/workflow_cleanup.py +69 -0
  42. hyperforge/definition.py +13 -0
  43. hyperforge/driver.py +31 -0
  44. hyperforge/dummy.py +28 -0
  45. hyperforge/engine.py +189 -0
  46. hyperforge/exceptions.py +14 -0
  47. hyperforge/feature_flag.py +105 -0
  48. hyperforge/fixtures.py +602 -0
  49. hyperforge/interaction.py +116 -0
  50. hyperforge/llm.py +75 -0
  51. hyperforge/manager.py +432 -0
  52. hyperforge/memory/__init__.py +5 -0
  53. hyperforge/memory/memory.py +974 -0
  54. hyperforge/minimal_fixtures.py +75 -0
  55. hyperforge/models.py +336 -0
  56. hyperforge/nua.py +336 -0
  57. hyperforge/openapi.py +63 -0
  58. hyperforge/prompts.py +188 -0
  59. hyperforge/pubsub.py +90 -0
  60. hyperforge/py.typed +0 -0
  61. hyperforge/redis_utils.py +82 -0
  62. hyperforge/retrieval/__init__.py +0 -0
  63. hyperforge/retrieval/agent.py +169 -0
  64. hyperforge/retrieval/config.py +94 -0
  65. hyperforge/server/__init__.py +5 -0
  66. hyperforge/server/cache.py +131 -0
  67. hyperforge/server/run.py +109 -0
  68. hyperforge/server/sandbox.py +60 -0
  69. hyperforge/server/session.py +421 -0
  70. hyperforge/server/settings.py +47 -0
  71. hyperforge/server/utils.py +57 -0
  72. hyperforge/server/web.py +31 -0
  73. hyperforge/settings.py +18 -0
  74. hyperforge/standalone/__init__.py +5 -0
  75. hyperforge/standalone/agent.py +189 -0
  76. hyperforge/standalone/app.py +264 -0
  77. hyperforge/standalone/config.py +137 -0
  78. hyperforge/standalone/const.py +1 -0
  79. hyperforge/standalone/run.py +60 -0
  80. hyperforge/standalone/settings.py +133 -0
  81. hyperforge/standalone/ui_router.py +241 -0
  82. hyperforge/trace.py +42 -0
  83. hyperforge/utils/__init__.py +112 -0
  84. hyperforge/utils/http.py +48 -0
  85. hyperforge/workflows.py +44 -0
  86. hyperforge-1.0.0.post19.dist-info/METADATA +95 -0
  87. hyperforge-1.0.0.post19.dist-info/RECORD +90 -0
  88. hyperforge-1.0.0.post19.dist-info/WHEEL +5 -0
  89. hyperforge-1.0.0.post19.dist-info/entry_points.txt +8 -0
  90. hyperforge-1.0.0.post19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1521 @@
1
+ import datetime
2
+ from typing import Any, List
3
+ from uuid import UUID
4
+
5
+ import databases
6
+ import sqlalchemy as sa
7
+ from cryptography.fernet import Fernet
8
+ from fastapi import UploadFile
9
+ from nucliadb_telemetry.utils import get_telemetry, init_telemetry
10
+ from pydantic import BaseModel, ValidationError
11
+ from sqlalchemy.dialects import postgresql as pg_dialect
12
+ from sqlalchemy.dialects.postgresql import JSONB
13
+
14
+ from hyperforge.agent import AgentConfig
15
+ from hyperforge.configure import (
16
+ get_agent_config_instance,
17
+ get_driver_config_instance,
18
+ get_driver_config_klass,
19
+ )
20
+ from hyperforge.database import metadata
21
+ from hyperforge.db import exceptions, logger
22
+ from hyperforge.db.encryption import (
23
+ decrypt_fields,
24
+ encrypt_fields,
25
+ fernet_key_from_passphrase,
26
+ )
27
+ from hyperforge.db.settings import DataManagerSettings
28
+ from hyperforge.driver import DriverConfig
29
+ from hyperforge.models import MemoryConfig, NucliaDBMemoryConfig, Rules
30
+ from hyperforge.prompts import PromptArgument, PromptConfig
31
+ from hyperforge.retrieval.config import (
32
+ RetrievalAgentConfig,
33
+ RetrievalAgentExportV1,
34
+ retrievalAgentAdapter,
35
+ )
36
+ from hyperforge.workflows import (
37
+ RetrievalAgent,
38
+ WorkflowData,
39
+ WorkflowInput,
40
+ WorkflowUpdate,
41
+ )
42
+
43
+ SERVICE_NAME = "TASK_MANAGER"
44
+ EXPIRATION = 7 * 24
45
+ WORKFLOW_PURGE_RETENTION = datetime.timedelta(days=15)
46
+
47
+
48
+ def utc_now() -> datetime.datetime:
49
+ return datetime.datetime.now(datetime.timezone.utc)
50
+
51
+
52
+ retrieval_agent_workflow = sa.Table(
53
+ "retrieval_agent_workflow",
54
+ metadata,
55
+ sa.Column("account", sa.String, primary_key=True, nullable=False, index=True),
56
+ sa.Column(
57
+ "agent_id", sa.String, primary_key=True, nullable=False, index=True
58
+ ), # Agent ID
59
+ sa.Column(
60
+ "workflow_id", sa.String, primary_key=True, nullable=False, index=True
61
+ ), # Agent ID
62
+ sa.Column("name", sa.String, nullable=False),
63
+ sa.Column("description", sa.String, nullable=True),
64
+ sa.Column("parameters", JSONB, nullable=True),
65
+ sa.Column("required", JSONB, nullable=True),
66
+ sa.Column("rules", JSONB, nullable=True),
67
+ sa.Column("created", sa.DateTime, default=sa.func.now()),
68
+ sa.Column("modified", sa.DateTime, onupdate=sa.func.now()),
69
+ sa.Column("is_deleted", sa.Boolean, nullable=False, default=False, index=True),
70
+ sa.Column("deleted_by", sa.String, nullable=True),
71
+ sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True, index=True),
72
+ sa.ForeignKeyConstraint(
73
+ ["account", "agent_id"],
74
+ ["retrieval_agent_config.account", "retrieval_agent_config.agent_id"],
75
+ onupdate="CASCADE",
76
+ ondelete="CASCADE",
77
+ ),
78
+ )
79
+
80
+ retrieval_agent_prompts = sa.Table(
81
+ "retrieval_agent_prompts",
82
+ metadata,
83
+ sa.Column(
84
+ "id",
85
+ pg_dialect.UUID,
86
+ primary_key=True,
87
+ server_default=sa.func.uuid_generate_v4(),
88
+ ),
89
+ sa.Column(
90
+ "account",
91
+ sa.String,
92
+ nullable=False,
93
+ index=True,
94
+ ),
95
+ sa.Column(
96
+ "agent_id",
97
+ sa.String,
98
+ nullable=False,
99
+ index=True,
100
+ ),
101
+ sa.Column("name", sa.String, nullable=False),
102
+ sa.Column("description", sa.String, nullable=False),
103
+ sa.Column("prompt", sa.String, nullable=False),
104
+ sa.Column("arguments", JSONB, nullable=True),
105
+ sa.Column("icons", JSONB, nullable=True),
106
+ sa.Column("meta", JSONB, nullable=True),
107
+ sa.Column("created", sa.DateTime, default=sa.func.now()),
108
+ sa.Column("modified", sa.DateTime, onupdate=sa.func.now()),
109
+ sa.ForeignKeyConstraint(
110
+ ["account", "agent_id"],
111
+ ["retrieval_agent_config.account", "retrieval_agent_config.agent_id"],
112
+ onupdate="CASCADE",
113
+ ondelete="CASCADE",
114
+ ),
115
+ )
116
+
117
+ retrieval_agent_config = sa.Table(
118
+ "retrieval_agent_config",
119
+ metadata,
120
+ sa.Column("account", sa.String, primary_key=True, nullable=False, index=True),
121
+ sa.Column(
122
+ "agent_id", sa.String, primary_key=True, nullable=False, index=True
123
+ ), # Agent ID
124
+ sa.Column("rules", JSONB, nullable=False),
125
+ sa.Column("created", sa.DateTime, default=sa.func.now()),
126
+ sa.Column("modified", sa.DateTime, onupdate=sa.func.now()),
127
+ sa.Column("memory", JSONB, nullable=False),
128
+ sa.Column("description", sa.String, nullable=True),
129
+ sa.Column("title", sa.String, nullable=True),
130
+ sa.Column("instructions", sa.String, nullable=True),
131
+ )
132
+
133
+
134
+ retrieval_agent_preprocess = sa.Table(
135
+ "retrieval_agent_preprocess",
136
+ metadata,
137
+ sa.Column(
138
+ "id",
139
+ pg_dialect.UUID,
140
+ primary_key=True,
141
+ server_default=sa.func.uuid_generate_v4(),
142
+ ),
143
+ sa.Column(
144
+ "account",
145
+ sa.String,
146
+ nullable=False,
147
+ index=True,
148
+ ),
149
+ sa.Column(
150
+ "agent_id",
151
+ sa.String,
152
+ nullable=False,
153
+ index=True,
154
+ ),
155
+ sa.Column(
156
+ "workflow_id",
157
+ sa.String,
158
+ nullable=False,
159
+ index=True,
160
+ ),
161
+ sa.Column("preprocess", JSONB, nullable=False),
162
+ sa.Column("created", sa.DateTime, default=sa.func.now()),
163
+ sa.Column("modified", sa.DateTime, onupdate=sa.func.now()),
164
+ sa.ForeignKeyConstraint(
165
+ ["account", "agent_id", "workflow_id"],
166
+ [
167
+ "retrieval_agent_workflow.account",
168
+ "retrieval_agent_workflow.agent_id",
169
+ "retrieval_agent_workflow.workflow_id",
170
+ ],
171
+ onupdate="CASCADE",
172
+ ondelete="CASCADE",
173
+ ),
174
+ )
175
+
176
+
177
+ retrieval_agent_postprocess = sa.Table(
178
+ "retrieval_agent_postprocess",
179
+ metadata,
180
+ sa.Column(
181
+ "id",
182
+ pg_dialect.UUID,
183
+ primary_key=True,
184
+ server_default=sa.func.uuid_generate_v4(),
185
+ ),
186
+ sa.Column(
187
+ "account",
188
+ sa.String,
189
+ nullable=False,
190
+ index=True,
191
+ ),
192
+ sa.Column(
193
+ "agent_id",
194
+ sa.String,
195
+ nullable=False,
196
+ index=True,
197
+ ),
198
+ sa.Column(
199
+ "workflow_id",
200
+ sa.String,
201
+ nullable=False,
202
+ index=True,
203
+ ),
204
+ sa.Column("postprocess", JSONB, nullable=False),
205
+ sa.Column("created", sa.DateTime, default=sa.func.now()),
206
+ sa.Column("modified", sa.DateTime, onupdate=sa.func.now()),
207
+ sa.ForeignKeyConstraint(
208
+ ["account", "agent_id", "workflow_id"],
209
+ [
210
+ "retrieval_agent_workflow.account",
211
+ "retrieval_agent_workflow.agent_id",
212
+ "retrieval_agent_workflow.workflow_id",
213
+ ],
214
+ onupdate="CASCADE",
215
+ ondelete="CASCADE",
216
+ ),
217
+ )
218
+
219
+ retrieval_agent_context = sa.Table(
220
+ "retrieval_agent_context",
221
+ metadata,
222
+ sa.Column(
223
+ "id",
224
+ pg_dialect.UUID,
225
+ primary_key=True,
226
+ server_default=sa.func.uuid_generate_v4(),
227
+ ),
228
+ sa.Column(
229
+ "account",
230
+ sa.String,
231
+ nullable=False,
232
+ index=True,
233
+ ),
234
+ sa.Column(
235
+ "agent_id",
236
+ sa.String,
237
+ nullable=False,
238
+ index=True,
239
+ ),
240
+ sa.Column(
241
+ "workflow_id",
242
+ sa.String,
243
+ nullable=False,
244
+ index=True,
245
+ ),
246
+ sa.Column("context", JSONB, nullable=False),
247
+ sa.Column("created", sa.DateTime, default=sa.func.now()),
248
+ sa.Column("modified", sa.DateTime, onupdate=sa.func.now()),
249
+ sa.ForeignKeyConstraint(
250
+ ["account", "agent_id", "workflow_id"],
251
+ [
252
+ "retrieval_agent_workflow.account",
253
+ "retrieval_agent_workflow.agent_id",
254
+ "retrieval_agent_workflow.workflow_id",
255
+ ],
256
+ onupdate="CASCADE",
257
+ ondelete="CASCADE",
258
+ ),
259
+ )
260
+
261
+
262
+ retrieval_agent_generation = sa.Table(
263
+ "retrieval_agent_generation",
264
+ metadata,
265
+ sa.Column(
266
+ "id",
267
+ pg_dialect.UUID,
268
+ primary_key=True,
269
+ server_default=sa.func.uuid_generate_v4(),
270
+ ),
271
+ sa.Column(
272
+ "account",
273
+ sa.String,
274
+ nullable=False,
275
+ index=True,
276
+ ),
277
+ sa.Column(
278
+ "agent_id",
279
+ sa.String,
280
+ nullable=False,
281
+ index=True,
282
+ ),
283
+ sa.Column(
284
+ "workflow_id",
285
+ sa.String,
286
+ nullable=False,
287
+ index=True,
288
+ ),
289
+ sa.Column("generation", JSONB, nullable=False),
290
+ sa.Column("created", sa.DateTime, default=sa.func.now()),
291
+ sa.Column("modified", sa.DateTime, onupdate=sa.func.now()),
292
+ sa.ForeignKeyConstraint(
293
+ ["account", "agent_id", "workflow_id"],
294
+ [
295
+ "retrieval_agent_workflow.account",
296
+ "retrieval_agent_workflow.agent_id",
297
+ "retrieval_agent_workflow.workflow_id",
298
+ ],
299
+ onupdate="CASCADE",
300
+ ondelete="CASCADE",
301
+ ),
302
+ )
303
+
304
+ retrieval_agents_drivers = sa.Table(
305
+ "retrieval_agents_drivers",
306
+ metadata,
307
+ sa.Column(
308
+ "id",
309
+ pg_dialect.UUID,
310
+ primary_key=True,
311
+ server_default=sa.func.uuid_generate_v4(),
312
+ ),
313
+ sa.Column("account", sa.String, nullable=False),
314
+ sa.Column("agent_id", sa.String, nullable=False),
315
+ sa.Column("driver", sa.String, nullable=False),
316
+ sa.Column("provider", sa.String, nullable=False),
317
+ sa.Column("identifier", sa.String, nullable=False),
318
+ sa.Column("config", JSONB, nullable=False),
319
+ sa.Column("created", sa.DateTime, default=sa.func.now()),
320
+ sa.Column("modified", sa.DateTime, onupdate=sa.func.now()),
321
+ sa.ForeignKeyConstraint(
322
+ ["account", "agent_id"],
323
+ ["retrieval_agent_config.account", "retrieval_agent_config.agent_id"],
324
+ onupdate="CASCADE",
325
+ ondelete="CASCADE",
326
+ ),
327
+ sa.UniqueConstraint("account", "agent_id", "identifier"),
328
+ )
329
+
330
+
331
+ class AgentManager:
332
+ settings: DataManagerSettings
333
+
334
+ def __init__(
335
+ self,
336
+ database: databases.Database,
337
+ settings: DataManagerSettings,
338
+ ):
339
+ self.database = database
340
+ self.settings = settings
341
+
342
+ @classmethod
343
+ async def from_settings(
344
+ cls,
345
+ settings: DataManagerSettings,
346
+ ):
347
+ tracer_provider = get_telemetry(SERVICE_NAME)
348
+ if tracer_provider:
349
+ await init_telemetry(tracer_provider)
350
+
351
+ database = databases.Database(settings.postgresql_dsn)
352
+
353
+ return cls(database=database, settings=settings)
354
+
355
+ async def initialize(self):
356
+ await self.database.connect()
357
+
358
+ async def finalize(self):
359
+ await self.database.disconnect()
360
+
361
+ async def patch_driver(
362
+ self,
363
+ account: str,
364
+ agent_id: str,
365
+ driver: str,
366
+ config: DriverConfig,
367
+ ):
368
+ try:
369
+ previous_config = await self.get_driver(account, agent_id, driver)
370
+ except exceptions.DriverNotFoundError:
371
+ # No previous config for this driver found, nothing to update
372
+ return
373
+ updated_config: DriverConfig = update_driver_config(config, previous_config)
374
+ statement = (
375
+ retrieval_agents_drivers.update()
376
+ .where(
377
+ retrieval_agents_drivers.c.account == account,
378
+ )
379
+ .where(retrieval_agents_drivers.c.agent_id == agent_id)
380
+ .where(retrieval_agents_drivers.c.id == driver)
381
+ .values(
382
+ driver=updated_config.name,
383
+ config=encrypt_fields(updated_config.config),
384
+ )
385
+ )
386
+
387
+ await self.database.execute(statement)
388
+
389
+ async def add_driver(
390
+ self,
391
+ agent_id: str,
392
+ account: str,
393
+ config: DriverConfig,
394
+ ):
395
+ statement = retrieval_agents_drivers.insert().values(
396
+ account=account,
397
+ agent_id=agent_id,
398
+ driver=config.name,
399
+ provider=config.provider,
400
+ identifier=config.identifier,
401
+ config=encrypt_fields(config.config),
402
+ )
403
+
404
+ result = await self.database.execute(statement)
405
+
406
+ return str(result)
407
+
408
+ async def delete_driver(self, account: str, agent_id: str, driver: str):
409
+ statement = (
410
+ retrieval_agents_drivers.delete()
411
+ .where(retrieval_agents_drivers.c.account == account)
412
+ .where(retrieval_agents_drivers.c.agent_id == agent_id)
413
+ .where(retrieval_agents_drivers.c.id == driver)
414
+ )
415
+ await self.database.execute(statement)
416
+
417
+ async def get_driver(
418
+ self, account: str, agent_id: str, driver: str
419
+ ) -> DriverConfig:
420
+ statement = (
421
+ retrieval_agents_drivers.select()
422
+ .where(retrieval_agents_drivers.c.account == account)
423
+ .where(retrieval_agents_drivers.c.agent_id == agent_id)
424
+ .where(retrieval_agents_drivers.c.id == driver)
425
+ )
426
+ result = await self.database.fetch_one(statement)
427
+ if result is None:
428
+ raise exceptions.DriverNotFoundError()
429
+
430
+ config_class = get_driver_config_klass(result["provider"])
431
+ driver_config = config_class.model_validate(
432
+ {
433
+ "id": str(result["id"]),
434
+ "name": result["driver"],
435
+ "identifier": result["identifier"],
436
+ "provider": result["provider"],
437
+ "config": result["config"],
438
+ }
439
+ )
440
+ decrypt_fields(driver_config.config)
441
+ return driver_config
442
+
443
+ async def get_drivers(self, account: str, agent_id: str) -> List[DriverConfig]:
444
+ statement = (
445
+ retrieval_agents_drivers.select()
446
+ .where(retrieval_agents_drivers.c.account == account)
447
+ .where(retrieval_agents_drivers.c.agent_id == agent_id)
448
+ )
449
+ results = await self.database.fetch_all(statement)
450
+ drivers = []
451
+ for result in results:
452
+ config_class = get_driver_config_klass(result["provider"])
453
+ driver = config_class.model_validate(
454
+ {
455
+ "id": str(result["id"]),
456
+ "name": result["driver"],
457
+ "identifier": result["identifier"],
458
+ "provider": result["provider"],
459
+ "config": result["config"],
460
+ }
461
+ )
462
+ decrypt_fields(driver.config)
463
+ drivers.append(driver)
464
+ return drivers
465
+
466
+ async def add_agent(
467
+ self, account: str, agent_id: str, memory: MemoryConfig, rules: Rules
468
+ ):
469
+ statement = retrieval_agent_config.insert().values(
470
+ account=account,
471
+ agent_id=agent_id,
472
+ rules=rules.model_dump(),
473
+ memory=memory.model_dump(),
474
+ )
475
+
476
+ await self.database.execute(statement)
477
+
478
+ statement = retrieval_agent_workflow.insert().values(
479
+ account=account,
480
+ agent_id=agent_id,
481
+ workflow_id="default",
482
+ name="default",
483
+ description="Default workflow",
484
+ parameters={},
485
+ rules=Rules(rules=[]).model_dump(),
486
+ required=[],
487
+ is_deleted=False,
488
+ )
489
+
490
+ await self.database.execute(statement)
491
+
492
+ async def delete_agent(self, account: str, agent_id: str):
493
+ statement = (
494
+ retrieval_agent_config.delete()
495
+ .where(retrieval_agent_config.c.agent_id == agent_id)
496
+ .where(retrieval_agent_config.c.account == account)
497
+ )
498
+
499
+ await self.database.execute(statement)
500
+
501
+ async def add_workflow(self, account: str, agent_id: str, item: WorkflowInput):
502
+ statement = retrieval_agent_workflow.insert().values(
503
+ account=account,
504
+ agent_id=agent_id,
505
+ workflow_id=item.id,
506
+ name=item.name,
507
+ description=item.description,
508
+ parameters=item.parameters,
509
+ required=item.required,
510
+ rules=item.rules.model_dump()
511
+ if item.rules
512
+ else Rules(rules=[]).model_dump(),
513
+ is_deleted=False,
514
+ )
515
+ await self.database.execute(statement)
516
+
517
+ def _active_workflow_condition(self):
518
+ return retrieval_agent_workflow.c.is_deleted.is_(False)
519
+
520
+ async def ensure_workflow_active(
521
+ self, account: str, agent_id: str, workflow_id: str
522
+ ):
523
+ statement = (
524
+ retrieval_agent_workflow.select()
525
+ .where(retrieval_agent_workflow.c.account == account)
526
+ .where(retrieval_agent_workflow.c.agent_id == agent_id)
527
+ .where(retrieval_agent_workflow.c.workflow_id == workflow_id)
528
+ .where(self._active_workflow_condition())
529
+ )
530
+ result = await self.database.fetch_one(statement)
531
+ if result is None:
532
+ raise exceptions.NotFoundError("Workflow not found")
533
+
534
+ async def set_workflow(
535
+ self, account: str, agent_id: str, workflow_id: str, item: WorkflowUpdate
536
+ ):
537
+ await self.ensure_workflow_active(account, agent_id, workflow_id)
538
+ statement = (
539
+ retrieval_agent_workflow.update()
540
+ .where(retrieval_agent_workflow.c.account == account)
541
+ .where(retrieval_agent_workflow.c.agent_id == agent_id)
542
+ .where(retrieval_agent_workflow.c.workflow_id == workflow_id)
543
+ .where(self._active_workflow_condition())
544
+ .values(
545
+ name=item.name,
546
+ description=item.description,
547
+ parameters=item.parameters,
548
+ required=item.required,
549
+ rules=item.rules.model_dump()
550
+ if item.rules
551
+ else Rules(rules=[]).model_dump(),
552
+ )
553
+ )
554
+ await self.database.execute(statement)
555
+
556
+ async def delete_workflow(
557
+ self, account: str, agent_id: str, workflow_id: str, deleted_by: str
558
+ ):
559
+ if workflow_id == "default":
560
+ raise exceptions.ProtectedWorkflowError(
561
+ "Default workflow cannot be deleted"
562
+ )
563
+
564
+ await self.ensure_workflow_active(account, agent_id, workflow_id)
565
+
566
+ statement = (
567
+ retrieval_agent_workflow.update()
568
+ .where(retrieval_agent_workflow.c.account == account)
569
+ .where(retrieval_agent_workflow.c.agent_id == agent_id)
570
+ .where(retrieval_agent_workflow.c.workflow_id == workflow_id)
571
+ .where(self._active_workflow_condition())
572
+ .values(is_deleted=True, deleted_by=deleted_by, deleted_at=utc_now())
573
+ )
574
+ await self.database.execute(statement)
575
+
576
+ async def get_expired_deleted_workflows(
577
+ self, older_than: datetime.timedelta = WORKFLOW_PURGE_RETENTION
578
+ ):
579
+ threshold = utc_now() - older_than
580
+ statement = (
581
+ retrieval_agent_workflow.select()
582
+ .where(retrieval_agent_workflow.c.is_deleted.is_(True))
583
+ .where(retrieval_agent_workflow.c.deleted_at < threshold)
584
+ )
585
+ return await self.database.fetch_all(statement)
586
+
587
+ async def purge_deleted_workflow(
588
+ self, account: str, agent_id: str, workflow_id: str
589
+ ):
590
+ statement = (
591
+ retrieval_agent_workflow.delete()
592
+ .where(retrieval_agent_workflow.c.account == account)
593
+ .where(retrieval_agent_workflow.c.agent_id == agent_id)
594
+ .where(retrieval_agent_workflow.c.workflow_id == workflow_id)
595
+ .where(retrieval_agent_workflow.c.is_deleted.is_(True))
596
+ )
597
+ await self.database.execute(statement)
598
+
599
+ async def set_workflow_rules(
600
+ self, agent_id: str, workflow_id: str, account: str, rules: Rules
601
+ ):
602
+ await self.ensure_workflow_active(account, agent_id, workflow_id)
603
+ statement = (
604
+ retrieval_agent_workflow.update()
605
+ .values(rules=rules.model_dump())
606
+ .where(retrieval_agent_workflow.c.account == account)
607
+ .where(retrieval_agent_workflow.c.agent_id == agent_id)
608
+ .where(retrieval_agent_workflow.c.workflow_id == workflow_id)
609
+ .where(self._active_workflow_condition())
610
+ )
611
+
612
+ await self.database.execute(statement)
613
+
614
+ async def workflows_list(self, account: str, agent_id: str) -> List[WorkflowData]:
615
+ statement = (
616
+ retrieval_agent_workflow.select()
617
+ .where(retrieval_agent_workflow.c.account == account)
618
+ .where(retrieval_agent_workflow.c.agent_id == agent_id)
619
+ .where(self._active_workflow_condition())
620
+ )
621
+ results = await self.database.fetch_all(statement)
622
+ workflows = []
623
+ for result in results:
624
+ workflows.append(WorkflowData(id=result["workflow_id"], **result)) # type: ignore
625
+ return workflows
626
+
627
+ async def get_agent_config_basic(
628
+ self, account: str, agent_id: str
629
+ ) -> RetrievalAgent:
630
+ """Loads the basic configuration without preprocess, context, generation or postprocess
631
+
632
+ Args:
633
+ account (str): Account ID
634
+ agent_id (str): Agent ID
635
+ """
636
+
637
+ statement = (
638
+ retrieval_agent_config.select()
639
+ .where(retrieval_agent_config.c.account == account)
640
+ .where(retrieval_agent_config.c.agent_id == agent_id)
641
+ )
642
+
643
+ result = await self.database.fetch_one(statement)
644
+
645
+ if result is None:
646
+ raise exceptions.NotFoundError("Agent config not found")
647
+
648
+ return RetrievalAgent(
649
+ account=result["account"],
650
+ agent_id=result["agent_id"],
651
+ description=result["description"],
652
+ memory=result["memory"],
653
+ title=result["title"],
654
+ instructions=result["instructions"],
655
+ created=result["created"],
656
+ modified=result["modified"],
657
+ )
658
+
659
+ async def get_agent_config(
660
+ self,
661
+ account: str,
662
+ agent_id: str,
663
+ internal_nucliadb_url: str | None = None,
664
+ default_memory: bool = False,
665
+ workflow_id: str = "default",
666
+ ) -> RetrievalAgentConfig:
667
+ """Loads the configuration in a single query to ensure we get a consistent view
668
+
669
+ Args:
670
+ account (str): Account ID
671
+ agent_id (str): Agent ID
672
+ internal_nucliadb_url (str | None): Internal NucliaDB URL to use if no memory is set
673
+ default_memory (bool): Whether to ignore the stored memory config and use a default one
674
+ workflow (str): Workflow name to load the agent for
675
+ """
676
+
677
+ # Queries for each agent type
678
+ queries = [
679
+ sa.select(
680
+ sa.literal_column(f"'{table.name}'").label("kind"), # type:ignore
681
+ table.c.id,
682
+ sa.null().label("identifier"), # type:ignore
683
+ sa.null().label("name"),
684
+ sa.null().label("provider"),
685
+ column.label("config"),
686
+ )
687
+ .where(table.c.account == account)
688
+ .where(table.c.agent_id == agent_id)
689
+ .where(table.c.workflow_id == workflow_id)
690
+ for (table, column) in [
691
+ (retrieval_agent_preprocess, retrieval_agent_preprocess.c.preprocess),
692
+ (retrieval_agent_context, retrieval_agent_context.c.context),
693
+ (retrieval_agent_generation, retrieval_agent_generation.c.generation),
694
+ (
695
+ retrieval_agent_postprocess,
696
+ retrieval_agent_postprocess.c.postprocess,
697
+ ),
698
+ ]
699
+ ]
700
+ # Query for drivers
701
+ queries.append(
702
+ sa.select(
703
+ sa.literal_column("'driver'").label("kind"), # type:ignore
704
+ retrieval_agents_drivers.c.id,
705
+ retrieval_agents_drivers.c.identifier, # type:ignore
706
+ retrieval_agents_drivers.c.driver.label("name"),
707
+ retrieval_agents_drivers.c.provider,
708
+ retrieval_agents_drivers.c.config,
709
+ )
710
+ .where(retrieval_agents_drivers.c.account == account)
711
+ .where(retrieval_agents_drivers.c.agent_id == agent_id)
712
+ )
713
+ # Query for rules
714
+ queries.append(
715
+ sa.select(
716
+ sa.literal_column("'rules'").label("kind"), # type:ignore
717
+ sa.null().label("id"),
718
+ sa.null().label("identifier"), # type:ignore
719
+ sa.null().label("name"),
720
+ sa.null().label("provider"),
721
+ retrieval_agent_config.c.rules.label("config"),
722
+ )
723
+ .where(retrieval_agent_config.c.account == account)
724
+ .where(retrieval_agent_config.c.agent_id == agent_id)
725
+ )
726
+ queries.append(
727
+ sa.select(
728
+ sa.literal_column("'memory'").label("kind"), # type:ignore
729
+ sa.null().label("id"),
730
+ sa.null().label("identifier"), # type:ignore
731
+ sa.null().label("name"),
732
+ sa.null().label("provider"),
733
+ retrieval_agent_config.c.memory.label("config"),
734
+ )
735
+ .where(retrieval_agent_config.c.account == account)
736
+ .where(retrieval_agent_config.c.agent_id == agent_id)
737
+ )
738
+ workflow_query = (
739
+ sa.select(
740
+ retrieval_agent_workflow.c.name, # type: ignore
741
+ retrieval_agent_workflow.c.description,
742
+ retrieval_agent_workflow.c.parameters, # type: ignore
743
+ retrieval_agent_workflow.c.rules,
744
+ )
745
+ .where(retrieval_agent_workflow.c.account == account)
746
+ .where(retrieval_agent_workflow.c.agent_id == agent_id)
747
+ .where(retrieval_agent_workflow.c.workflow_id == workflow_id)
748
+ .where(self._active_workflow_condition())
749
+ )
750
+
751
+ preprocess: list[Any] = []
752
+ context: list[Any] = []
753
+ generation: list[Any] = []
754
+ postprocess: list[Any] = []
755
+ drivers: list[Any] = []
756
+ rules: Rules | None = None
757
+ memory: MemoryConfig | None = None if not default_memory else MemoryConfig()
758
+ rows = await self.database.fetch_all(sa.union(*queries))
759
+ workflow_data = await self.database.fetch_one(workflow_query)
760
+ for row in rows:
761
+ match row["kind"]:
762
+ case retrieval_agent_preprocess.name:
763
+ preprocess.append(
764
+ get_agent_config_instance(
765
+ agent_config={"id": row["id"], **row["config"]},
766
+ agent_type="preprocess",
767
+ )
768
+ ) # Validate config
769
+ case retrieval_agent_context.name:
770
+ context.append(
771
+ get_agent_config_instance(
772
+ agent_config={"id": row["id"], **row["config"]},
773
+ agent_type="context",
774
+ )
775
+ ) # Validate config
776
+ case retrieval_agent_generation.name:
777
+ generation.append(
778
+ get_agent_config_instance(
779
+ {"id": row["id"], **row["config"]}, agent_type="generation"
780
+ )
781
+ ) # Validate config
782
+ case retrieval_agent_postprocess.name:
783
+ postprocess.append(
784
+ get_agent_config_instance(
785
+ agent_config={"id": row["id"], **row["config"]},
786
+ agent_type="postprocess",
787
+ )
788
+ ) # Validate config
789
+ case "driver":
790
+ driver: DriverConfig = get_driver_config_instance(
791
+ {
792
+ "id": str(row["id"]),
793
+ "identifier": row["identifier"],
794
+ "name": row["name"],
795
+ "provider": row["provider"],
796
+ "config": row["config"],
797
+ }
798
+ ) # Validate config
799
+ decrypt_fields(driver.config)
800
+ drivers.append(driver)
801
+
802
+ case "rules":
803
+ rules = Rules.model_validate(row["config"])
804
+ case "memory":
805
+ if not default_memory and row["config"] is not None:
806
+ memory = MemoryConfig.model_validate(row["config"])
807
+
808
+ if workflow_data is None:
809
+ raise exceptions.NotFoundError("Workflow not found")
810
+
811
+ workflow = WorkflowData(
812
+ id=workflow_id,
813
+ name=workflow_data["name"],
814
+ description=workflow_data["description"],
815
+ parameters=workflow_data["parameters"],
816
+ rules=Rules.model_validate(workflow_data["rules"]),
817
+ required=workflow_data["required"] if "required" in workflow_data else [], # noqa
818
+ )
819
+
820
+ if rules is None:
821
+ raise exceptions.NotFoundError("Agent config not found")
822
+
823
+ if memory is None and internal_nucliadb_url is not None and not default_memory:
824
+ memory = MemoryConfig(
825
+ nucliadb=NucliaDBMemoryConfig(
826
+ url=internal_nucliadb_url, kbid=agent_id, internal=True
827
+ )
828
+ )
829
+
830
+ if memory is None:
831
+ raise Exception("Agent memory config not found")
832
+
833
+ return RetrievalAgentConfig(
834
+ preprocess=preprocess,
835
+ context=context,
836
+ generation=generation,
837
+ postprocess=postprocess,
838
+ drivers=drivers,
839
+ rules=rules,
840
+ memory=memory,
841
+ workflow=workflow,
842
+ )
843
+
844
+ async def set_rules(self, agent_id: str, account: str, rules: Rules):
845
+ statement = (
846
+ retrieval_agent_config.update()
847
+ .values(rules=rules.model_dump())
848
+ .where(retrieval_agent_config.c.account == account)
849
+ .where(retrieval_agent_config.c.agent_id == agent_id)
850
+ )
851
+
852
+ await self.database.execute(statement)
853
+
854
+ async def add_prompt(
855
+ self, agent_id: str, account: str, prompt: PromptConfig
856
+ ) -> str:
857
+ statement = retrieval_agent_prompts.insert().values(
858
+ account=account,
859
+ agent_id=agent_id,
860
+ name=prompt.name,
861
+ prompt=prompt.prompt,
862
+ description=prompt.description,
863
+ arguments=[x.model_dump() for x in prompt.arguments]
864
+ if prompt.arguments is not None
865
+ else None,
866
+ icons=prompt.icons,
867
+ meta=prompt.meta,
868
+ )
869
+ result = await self.database.execute(statement)
870
+ return str(result)
871
+
872
+ async def set_prompt(
873
+ self, agent_id: str, account: str, prompt_id: str, prompt: PromptConfig
874
+ ):
875
+ statement = (
876
+ retrieval_agent_prompts.update()
877
+ .values(
878
+ name=prompt.name,
879
+ description=prompt.description,
880
+ prompt=prompt.prompt,
881
+ arguments=[x.model_dump() for x in prompt.arguments]
882
+ if prompt.arguments is not None
883
+ else None,
884
+ icons=prompt.icons,
885
+ meta=prompt.meta,
886
+ )
887
+ .where(retrieval_agent_prompts.c.account == account)
888
+ .where(retrieval_agent_prompts.c.agent_id == agent_id)
889
+ .where(retrieval_agent_prompts.c.id == prompt_id)
890
+ )
891
+
892
+ await self.database.execute(statement)
893
+
894
+ async def delete_prompt(self, agent_id: str, account: str, prompt_id: str):
895
+ statement = (
896
+ retrieval_agent_prompts.delete()
897
+ .where(retrieval_agent_prompts.c.account == account)
898
+ .where(retrieval_agent_prompts.c.agent_id == agent_id)
899
+ .where(retrieval_agent_prompts.c.id == prompt_id)
900
+ )
901
+
902
+ await self.database.execute(statement)
903
+
904
+ async def get_prompt(
905
+ self, agent_id: str, account: str, prompt_id: str
906
+ ) -> PromptConfig:
907
+ statement = (
908
+ retrieval_agent_prompts.select()
909
+ .where(retrieval_agent_prompts.c.account == account)
910
+ .where(retrieval_agent_prompts.c.agent_id == agent_id)
911
+ .where(retrieval_agent_prompts.c.id == prompt_id)
912
+ )
913
+ result = await self.database.fetch_one(statement)
914
+ if result is None:
915
+ raise exceptions.NotFoundError("Prompt not found")
916
+
917
+ prompt = PromptConfig(
918
+ name=result["name"],
919
+ description=result["description"],
920
+ prompt=result["prompt"],
921
+ arguments=[
922
+ PromptArgument.model_validate(arg) for arg in result["arguments"]
923
+ ]
924
+ if result["arguments"] is not None
925
+ else None,
926
+ icons=result["icons"],
927
+ meta=result["meta"],
928
+ prompt_id=str(result["id"]),
929
+ )
930
+
931
+ return prompt
932
+
933
+ async def get_prompts(self, agent_id: str, account: str) -> List[PromptConfig]:
934
+ statement = (
935
+ retrieval_agent_prompts.select()
936
+ .where(retrieval_agent_config.c.account == account)
937
+ .where(retrieval_agent_config.c.agent_id == agent_id)
938
+ )
939
+ result = await self.database.fetch_all(statement)
940
+ prompts = []
941
+ for row in result:
942
+ prompt = PromptConfig(
943
+ name=row["name"],
944
+ description=row["description"],
945
+ prompt=row["prompt"],
946
+ arguments=[
947
+ PromptArgument.model_validate(arg) for arg in row["arguments"]
948
+ ]
949
+ if row["arguments"] is not None
950
+ else None,
951
+ icons=row["icons"],
952
+ meta=row["meta"],
953
+ prompt_id=str(row["id"]),
954
+ )
955
+ prompts.append(prompt)
956
+
957
+ return prompts
958
+
959
+ async def set_memory(self, agent_id: str, account: str, memory: MemoryConfig):
960
+ statement = (
961
+ retrieval_agent_config.update()
962
+ .values(memory=memory.model_dump())
963
+ .where(retrieval_agent_config.c.account == account)
964
+ .where(retrieval_agent_config.c.agent_id == agent_id)
965
+ )
966
+
967
+ await self.database.execute(statement)
968
+
969
+ async def get_rules(self, account: str, agent_id: str) -> Rules:
970
+ statement = (
971
+ retrieval_agent_config.select()
972
+ .where(retrieval_agent_config.c.account == account)
973
+ .where(retrieval_agent_config.c.agent_id == agent_id)
974
+ )
975
+ result = await self.database.fetch_one(statement)
976
+ rules = Rules(rules=[])
977
+
978
+ if result is not None:
979
+ rules = Rules.model_validate(result["rules"])
980
+
981
+ return rules
982
+
983
+ async def get_workflow_rules(
984
+ self, account: str, agent_id: str, workflow_id: str
985
+ ) -> Rules:
986
+ statement = (
987
+ retrieval_agent_workflow.select()
988
+ .where(retrieval_agent_workflow.c.account == account)
989
+ .where(retrieval_agent_workflow.c.agent_id == agent_id)
990
+ .where(retrieval_agent_workflow.c.workflow_id == workflow_id)
991
+ .where(self._active_workflow_condition())
992
+ )
993
+ result = await self.database.fetch_one(statement)
994
+ if result is None:
995
+ raise exceptions.NotFoundError("Workflow not found")
996
+
997
+ return Rules.model_validate(result["rules"])
998
+
999
+ async def add_preprocess(
1000
+ self,
1001
+ agent_id: str,
1002
+ account: str,
1003
+ agent: BaseModel,
1004
+ workflow_id: str = "default",
1005
+ ):
1006
+ await self.ensure_workflow_active(account, agent_id, workflow_id)
1007
+ statement = retrieval_agent_preprocess.insert().values(
1008
+ account=account,
1009
+ agent_id=agent_id,
1010
+ workflow_id=workflow_id,
1011
+ preprocess=agent.model_dump(),
1012
+ )
1013
+ result = await self.database.execute(statement)
1014
+ return str(result)
1015
+
1016
+ async def patch_preprocess(
1017
+ self,
1018
+ agent_id: str,
1019
+ account: str,
1020
+ preprocess: str,
1021
+ agent: BaseModel,
1022
+ workflow_id: str = "default",
1023
+ ):
1024
+ await self.ensure_workflow_active(account, agent_id, workflow_id)
1025
+ statement = (
1026
+ retrieval_agent_preprocess.update()
1027
+ .where(retrieval_agent_preprocess.c.id == preprocess)
1028
+ .where(retrieval_agent_preprocess.c.agent_id == agent_id)
1029
+ .where(retrieval_agent_preprocess.c.account == account)
1030
+ .where(retrieval_agent_preprocess.c.workflow_id == workflow_id)
1031
+ .values(preprocess=agent.model_dump())
1032
+ )
1033
+ await self.database.execute(statement)
1034
+
1035
+ async def delete_preprocess(
1036
+ self, account: str, agent_id: str, preprocess: str, workflow_id: str = "default"
1037
+ ):
1038
+ await self.ensure_workflow_active(account, agent_id, workflow_id)
1039
+ statement = (
1040
+ retrieval_agent_preprocess.delete()
1041
+ .where(retrieval_agent_preprocess.c.account == account)
1042
+ .where(retrieval_agent_preprocess.c.agent_id == agent_id)
1043
+ .where(retrieval_agent_preprocess.c.workflow_id == workflow_id)
1044
+ .where(retrieval_agent_preprocess.c.id == preprocess)
1045
+ )
1046
+ await self.database.execute(statement)
1047
+
1048
+ async def get_preprocess(
1049
+ self, account: str, agent_id: str, workflow_id: str = "default"
1050
+ ) -> List[AgentConfig]:
1051
+ await self.ensure_workflow_active(account, agent_id, workflow_id)
1052
+ statement = (
1053
+ retrieval_agent_preprocess.select()
1054
+ .where(retrieval_agent_preprocess.c.account == account)
1055
+ .where(retrieval_agent_preprocess.c.agent_id == agent_id)
1056
+ .where(retrieval_agent_preprocess.c.workflow_id == workflow_id)
1057
+ )
1058
+ results = await self.database.fetch_all(statement)
1059
+ preprocess = []
1060
+ for result in results:
1061
+ base_config = get_agent_config_instance(
1062
+ agent_config=result["preprocess"], agent_type="preprocess"
1063
+ )
1064
+ base_config.id = str(result["id"])
1065
+ preprocess.append(base_config)
1066
+
1067
+ return preprocess
1068
+
1069
+ async def delete_postprocess(
1070
+ self,
1071
+ account: str,
1072
+ agent_id: str,
1073
+ postprocess: str,
1074
+ workflow_id: str = "default",
1075
+ ):
1076
+ await self.ensure_workflow_active(account, agent_id, workflow_id)
1077
+ statement = (
1078
+ retrieval_agent_postprocess.delete()
1079
+ .where(retrieval_agent_postprocess.c.account == account)
1080
+ .where(retrieval_agent_postprocess.c.agent_id == agent_id)
1081
+ .where(retrieval_agent_postprocess.c.workflow_id == workflow_id)
1082
+ .where(retrieval_agent_postprocess.c.id == postprocess)
1083
+ )
1084
+ await self.database.execute(statement)
1085
+
1086
+ async def add_postprocess(
1087
+ self,
1088
+ agent_id: str,
1089
+ account: str,
1090
+ agent: BaseModel,
1091
+ workflow_id: str = "default",
1092
+ ):
1093
+ await self.ensure_workflow_active(account, agent_id, workflow_id)
1094
+ statement = retrieval_agent_postprocess.insert().values(
1095
+ account=account,
1096
+ agent_id=agent_id,
1097
+ postprocess=agent.model_dump(),
1098
+ workflow_id=workflow_id,
1099
+ )
1100
+ result = await self.database.execute(statement)
1101
+ return str(result)
1102
+
1103
+ async def patch_postprocess(
1104
+ self,
1105
+ agent_id: str,
1106
+ account: str,
1107
+ postprocess: str,
1108
+ agent: BaseModel,
1109
+ workflow_id: str = "default",
1110
+ ):
1111
+ await self.ensure_workflow_active(account, agent_id, workflow_id)
1112
+ statement = (
1113
+ retrieval_agent_postprocess.update()
1114
+ .where(retrieval_agent_postprocess.c.id == postprocess)
1115
+ .where(retrieval_agent_postprocess.c.agent_id == agent_id)
1116
+ .where(retrieval_agent_postprocess.c.account == account)
1117
+ .where(retrieval_agent_postprocess.c.workflow_id == workflow_id)
1118
+ .values(postprocess=agent.model_dump())
1119
+ )
1120
+ await self.database.execute(statement)
1121
+
1122
+ async def get_postprocess(
1123
+ self, account: str, agent_id: str, workflow_id: str = "default"
1124
+ ) -> List[AgentConfig]:
1125
+ await self.ensure_workflow_active(account, agent_id, workflow_id)
1126
+ statement = (
1127
+ retrieval_agent_postprocess.select()
1128
+ .where(retrieval_agent_postprocess.c.account == account)
1129
+ .where(retrieval_agent_postprocess.c.agent_id == agent_id)
1130
+ .where(retrieval_agent_postprocess.c.workflow_id == workflow_id)
1131
+ )
1132
+ results = await self.database.fetch_all(statement)
1133
+ postprocess = []
1134
+ for result in results:
1135
+ base_config = get_agent_config_instance(
1136
+ agent_config=result["postprocess"], agent_type="postprocess"
1137
+ )
1138
+ base_config.id = str(result["id"])
1139
+ postprocess.append(base_config)
1140
+
1141
+ return postprocess
1142
+
1143
+ async def add_context(
1144
+ self,
1145
+ agent_id: str,
1146
+ account: str,
1147
+ agent: BaseModel,
1148
+ workflow_id: str = "default",
1149
+ ):
1150
+ await self.ensure_workflow_active(account, agent_id, workflow_id)
1151
+ statement = retrieval_agent_context.insert().values(
1152
+ account=account,
1153
+ agent_id=agent_id,
1154
+ context=agent.model_dump(),
1155
+ workflow_id=workflow_id,
1156
+ )
1157
+ result = await self.database.execute(statement)
1158
+
1159
+ return str(result)
1160
+
1161
+ async def delete_context(
1162
+ self, agent_id: str, account: str, context: UUID, workflow_id: str = "default"
1163
+ ):
1164
+ await self.ensure_workflow_active(account, agent_id, workflow_id)
1165
+ statement = (
1166
+ retrieval_agent_context.delete()
1167
+ .where(retrieval_agent_context.c.id == context)
1168
+ .where(retrieval_agent_context.c.agent_id == agent_id)
1169
+ .where(retrieval_agent_context.c.account == account)
1170
+ .where(retrieval_agent_context.c.workflow_id == workflow_id)
1171
+ )
1172
+ await self.database.execute(statement)
1173
+
1174
+ async def patch_context(
1175
+ self,
1176
+ agent_id: str,
1177
+ account: str,
1178
+ context: UUID,
1179
+ agent: BaseModel,
1180
+ workflow_id: str = "default",
1181
+ ):
1182
+ await self.ensure_workflow_active(account, agent_id, workflow_id)
1183
+ statement = (
1184
+ retrieval_agent_context.update()
1185
+ .where(retrieval_agent_context.c.id == context)
1186
+ .where(retrieval_agent_context.c.agent_id == agent_id)
1187
+ .where(retrieval_agent_context.c.account == account)
1188
+ .where(retrieval_agent_context.c.workflow_id == workflow_id)
1189
+ .values(context=agent.model_dump())
1190
+ )
1191
+ await self.database.execute(statement)
1192
+
1193
+ async def get_context(
1194
+ self, account: str, agent_id: str, workflow_id: str = "default"
1195
+ ) -> List[AgentConfig]:
1196
+ await self.ensure_workflow_active(account, agent_id, workflow_id)
1197
+ statement = (
1198
+ retrieval_agent_context.select()
1199
+ .where(retrieval_agent_context.c.account == account)
1200
+ .where(retrieval_agent_context.c.agent_id == agent_id)
1201
+ .where(retrieval_agent_context.c.workflow_id == workflow_id)
1202
+ )
1203
+ results = await self.database.fetch_all(statement)
1204
+ context = []
1205
+ for result in results:
1206
+ base_config = get_agent_config_instance(
1207
+ agent_config=result["context"], agent_type="context"
1208
+ )
1209
+ base_config.id = str(result["id"])
1210
+ context.append(base_config)
1211
+
1212
+ return context
1213
+
1214
+ async def add_generation(
1215
+ self,
1216
+ agent_id: str,
1217
+ account: str,
1218
+ agent: BaseModel,
1219
+ workflow_id: str = "default",
1220
+ ):
1221
+ await self.ensure_workflow_active(account, agent_id, workflow_id)
1222
+ statement = retrieval_agent_generation.insert().values(
1223
+ account=account,
1224
+ agent_id=agent_id,
1225
+ generation=agent.model_dump(),
1226
+ workflow_id=workflow_id,
1227
+ )
1228
+ result = await self.database.execute(statement)
1229
+
1230
+ return str(result)
1231
+
1232
+ async def delete_generation(
1233
+ self,
1234
+ agent_id: str,
1235
+ account: str,
1236
+ generation: str,
1237
+ workflow_id: str = "default",
1238
+ ):
1239
+ await self.ensure_workflow_active(account, agent_id, workflow_id)
1240
+ statement = (
1241
+ retrieval_agent_generation.delete()
1242
+ .where(retrieval_agent_generation.c.id == generation)
1243
+ .where(retrieval_agent_generation.c.agent_id == agent_id)
1244
+ .where(retrieval_agent_generation.c.account == account)
1245
+ .where(retrieval_agent_generation.c.workflow_id == workflow_id)
1246
+ )
1247
+ await self.database.execute(statement)
1248
+
1249
+ async def patch_generation(
1250
+ self,
1251
+ agent_id: str,
1252
+ account: str,
1253
+ generation: str,
1254
+ agent: BaseModel,
1255
+ workflow_id: str = "default",
1256
+ ):
1257
+ await self.ensure_workflow_active(account, agent_id, workflow_id)
1258
+ statement = (
1259
+ retrieval_agent_generation.update()
1260
+ .where(retrieval_agent_generation.c.id == generation)
1261
+ .where(retrieval_agent_generation.c.agent_id == agent_id)
1262
+ .where(retrieval_agent_generation.c.account == account)
1263
+ .where(retrieval_agent_generation.c.workflow_id == workflow_id)
1264
+ .values(generation=agent.model_dump())
1265
+ )
1266
+ await self.database.execute(statement)
1267
+
1268
+ async def get_generation(
1269
+ self, account: str, agent_id: str, workflow_id: str = "default"
1270
+ ) -> List[AgentConfig]:
1271
+ await self.ensure_workflow_active(account, agent_id, workflow_id)
1272
+ statement = (
1273
+ retrieval_agent_generation.select()
1274
+ .where(retrieval_agent_generation.c.account == account)
1275
+ .where(retrieval_agent_generation.c.agent_id == agent_id)
1276
+ .where(retrieval_agent_generation.c.workflow_id == workflow_id)
1277
+ )
1278
+ results = await self.database.fetch_all(statement)
1279
+ generation = []
1280
+ for result in results:
1281
+ base_config = get_agent_config_instance(
1282
+ agent_config=result["generation"], agent_type="generation"
1283
+ )
1284
+ base_config.id = str(result["id"])
1285
+ generation.append(base_config)
1286
+
1287
+ return generation
1288
+
1289
+ async def export(
1290
+ self, account: str, agent_id: str, passphrase: str
1291
+ ) -> tuple[bytes, str | None]:
1292
+ # Passphrase validation
1293
+ if len(passphrase) < 16:
1294
+ raise Exception("Passphrase too short, minimum 16 characters required")
1295
+ workflow_config = {}
1296
+ default_config = None
1297
+ workflows = [wf.id for wf in await self.workflows_list(account, agent_id)]
1298
+ for workflow in workflows:
1299
+ try:
1300
+ agent_config = await self.get_agent_config(
1301
+ account, agent_id, default_memory=True, workflow_id=workflow
1302
+ )
1303
+ if workflow == "default":
1304
+ default_config = agent_config
1305
+ workflow_config[workflow] = agent_config
1306
+ except Exception:
1307
+ logger.exception("Retrieval agent not found for export")
1308
+ raise Exception("Retrieval agent not found for export")
1309
+
1310
+ # XXX: ID sanitization not required since current implementation ignores IDs on import
1311
+
1312
+ export_model = RetrievalAgentExportV1(
1313
+ agent_config=default_config,
1314
+ agent_config_workflows=workflow_config,
1315
+ prompts=await self.get_prompts(agent_id, account),
1316
+ )
1317
+ export_bytes = export_model.model_dump_json().encode("utf-8")
1318
+ key, salt = fernet_key_from_passphrase(passphrase, None)
1319
+ fernet = Fernet(key)
1320
+ try:
1321
+ encrypted_bytes = salt + fernet.encrypt(export_bytes)
1322
+ except Exception:
1323
+ logger.exception("Error encrypting retrieval agent export")
1324
+ raise Exception("Error encrypting export")
1325
+
1326
+ return encrypted_bytes, None
1327
+
1328
+ async def import_config(
1329
+ self,
1330
+ account: str,
1331
+ agent_id: str,
1332
+ import_file: UploadFile,
1333
+ passphrase: str,
1334
+ overwrite: bool,
1335
+ ):
1336
+ # Get the current agent config
1337
+ try:
1338
+ destination_agent_config = await self.get_agent_config(
1339
+ account, agent_id, default_memory=True, workflow_id="default"
1340
+ )
1341
+ except exceptions.NotFoundError:
1342
+ raise
1343
+ # If not overwriting and the agent config is not empty, raise error
1344
+ if not destination_agent_config.is_empty():
1345
+ if not overwrite:
1346
+ raise exceptions.InvalidTargetAgentError()
1347
+ else:
1348
+ # Delete current configuration
1349
+ await self.delete_agent(account=account, agent_id=agent_id)
1350
+ # Recreate empty configuration
1351
+ await self.add_agent(
1352
+ account=account,
1353
+ agent_id=agent_id,
1354
+ memory=MemoryConfig(),
1355
+ rules=Rules(rules=[]),
1356
+ )
1357
+
1358
+ # Read the salt (first 16 bytes)
1359
+ salt = await import_file.read(16)
1360
+ key, _ = fernet_key_from_passphrase(passphrase, salt)
1361
+ fernet = Fernet(key)
1362
+ # Read the rest of the file in chunks
1363
+
1364
+ import_bytes = bytearray()
1365
+ while True:
1366
+ chunk = await import_file.read(self.settings.export_read_chunk_size)
1367
+ if not chunk:
1368
+ break
1369
+ import_bytes.extend(chunk)
1370
+ if len(import_bytes) > self.settings.export_read_max_size:
1371
+ await import_file.close()
1372
+ raise exceptions.ParseExportError("Import file too large")
1373
+ await import_file.close()
1374
+ # Decrypt the bytes
1375
+ try:
1376
+ encrypted_bytes = fernet.decrypt(bytes(import_bytes))
1377
+ except Exception as e:
1378
+ raise exceptions.ExportEncryptionError from e
1379
+ # Load the model
1380
+ try:
1381
+ parsed_export = retrievalAgentAdapter.validate_json(
1382
+ encrypted_bytes.decode("utf-8")
1383
+ )
1384
+ except ValidationError as e:
1385
+ raise exceptions.ParseExportError from e
1386
+
1387
+ if isinstance(parsed_export, RetrievalAgentExportV1):
1388
+ workflow_configs: dict[str, RetrievalAgentConfig] = {}
1389
+ agent_config = parsed_export.agent_config
1390
+ if agent_config is not None:
1391
+ workflow_id = "default"
1392
+ workflow_configs[workflow_id] = agent_config
1393
+ for wf_id, wf_config in parsed_export.agent_config_workflows.items():
1394
+ workflow_configs[wf_id] = wf_config
1395
+
1396
+ drivers = {}
1397
+ agent_rules = None
1398
+ try:
1399
+ for workflow_id, agent_config_workflow in workflow_configs.items():
1400
+ for driver in agent_config_workflow.drivers:
1401
+ drivers[driver.name] = driver
1402
+ agent_rules = agent_config_workflow.rules
1403
+ wf = agent_config_workflow.workflow
1404
+ if workflow_id == "default":
1405
+ await self.set_workflow(
1406
+ account=account,
1407
+ agent_id=agent_id,
1408
+ workflow_id=workflow_id,
1409
+ item=WorkflowUpdate(
1410
+ name=wf.name,
1411
+ description=wf.description or "",
1412
+ parameters=wf.parameters or {},
1413
+ rules=wf.rules,
1414
+ ),
1415
+ )
1416
+ else:
1417
+ await self.add_workflow(
1418
+ account=account,
1419
+ agent_id=agent_id,
1420
+ item=WorkflowInput(
1421
+ id=workflow_id,
1422
+ name=wf.name,
1423
+ description=wf.description,
1424
+ parameters=wf.parameters,
1425
+ rules=wf.rules or Rules(rules=[]),
1426
+ ),
1427
+ )
1428
+
1429
+ # Store preprocess
1430
+ for preprocess in agent_config_workflow.preprocess:
1431
+ await self.add_preprocess(
1432
+ agent_id=agent_id,
1433
+ account=account,
1434
+ agent=preprocess,
1435
+ workflow_id=workflow_id,
1436
+ )
1437
+ # Store context
1438
+ for context in agent_config_workflow.context:
1439
+ await self.add_context(
1440
+ agent_id=agent_id,
1441
+ account=account,
1442
+ agent=context,
1443
+ workflow_id=workflow_id,
1444
+ )
1445
+ # Store generation
1446
+ for generation in agent_config_workflow.generation:
1447
+ await self.add_generation(
1448
+ agent_id=agent_id,
1449
+ account=account,
1450
+ agent=generation,
1451
+ workflow_id=workflow_id,
1452
+ )
1453
+ # Store postprocess
1454
+ for postprocess in agent_config_workflow.postprocess:
1455
+ await self.add_postprocess(
1456
+ agent_id=agent_id,
1457
+ account=account,
1458
+ agent=postprocess,
1459
+ workflow_id=workflow_id,
1460
+ )
1461
+
1462
+ if agent_rules is not None:
1463
+ # Store rules
1464
+ await self.set_rules(
1465
+ agent_id=agent_id,
1466
+ account=account,
1467
+ rules=agent_rules,
1468
+ )
1469
+ # XXX: We are pruposely not importing memory configuration and will use the one already set
1470
+ # As KBs are created with the agent, if when importing we overwrite, we leave a dangling KB
1471
+ # And cross region imports with KB memory would break as well
1472
+ # This can be revisited in future versions if needed
1473
+
1474
+ # Store drivers — done once after all workflows are processed to avoid
1475
+ # duplicate insertions (drivers are global per-agent but appear in each
1476
+ # workflow's exported config).
1477
+ for driver in drivers.values():
1478
+ await self.add_driver(
1479
+ agent_id=agent_id, account=account, config=driver
1480
+ )
1481
+ for prompt in parsed_export.prompts:
1482
+ await self.add_prompt(
1483
+ agent_id=agent_id,
1484
+ account=account,
1485
+ prompt=prompt,
1486
+ )
1487
+ except Exception as e:
1488
+ raise exceptions.ParseExportError(
1489
+ "Failed to import retrieval agent configuration"
1490
+ ) from e
1491
+ else:
1492
+ raise exceptions.ParseExportError("Unsupported export version")
1493
+
1494
+
1495
+ def update_driver_config(
1496
+ config: DriverConfig, previous_config: DriverConfig
1497
+ ) -> DriverConfig:
1498
+ # Generate configuration for values set in the request. All values present on
1499
+ # the request will override values of the current configuration. If there was
1500
+ # any value on the stored configuration NOT SET yet, and it's not on the update
1501
+ # the default will be added.
1502
+ desired_configuration = config.model_dump(exclude_unset=True)
1503
+ updated_configuration = previous_config.model_dump()
1504
+
1505
+ return config.__class__(**deep_update(updated_configuration, desired_configuration))
1506
+
1507
+
1508
+ def deep_update(original: dict, updates: dict) -> dict:
1509
+ """
1510
+ Recursively update a nested dictionary.
1511
+ """
1512
+ for key, value in updates.items():
1513
+ if (
1514
+ key in original
1515
+ and isinstance(original[key], dict)
1516
+ and isinstance(value, dict)
1517
+ ):
1518
+ deep_update(original[key], value)
1519
+ else:
1520
+ original[key] = value
1521
+ return original