solana-agent 20.1.2__py3-none-any.whl → 31.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- solana_agent/__init__.py +10 -5
- solana_agent/adapters/ffmpeg_transcoder.py +375 -0
- solana_agent/adapters/mongodb_adapter.py +15 -2
- solana_agent/adapters/openai_adapter.py +679 -0
- solana_agent/adapters/openai_realtime_ws.py +1813 -0
- solana_agent/adapters/pinecone_adapter.py +543 -0
- solana_agent/cli.py +128 -0
- solana_agent/client/solana_agent.py +180 -20
- solana_agent/domains/agent.py +13 -13
- solana_agent/domains/routing.py +18 -8
- solana_agent/factories/agent_factory.py +239 -38
- solana_agent/guardrails/pii.py +107 -0
- solana_agent/interfaces/client/client.py +95 -12
- solana_agent/interfaces/guardrails/guardrails.py +26 -0
- solana_agent/interfaces/plugins/plugins.py +2 -1
- solana_agent/interfaces/providers/__init__.py +0 -0
- solana_agent/interfaces/providers/audio.py +40 -0
- solana_agent/interfaces/providers/data_storage.py +9 -2
- solana_agent/interfaces/providers/llm.py +86 -9
- solana_agent/interfaces/providers/memory.py +13 -1
- solana_agent/interfaces/providers/realtime.py +212 -0
- solana_agent/interfaces/providers/vector_storage.py +53 -0
- solana_agent/interfaces/services/agent.py +27 -12
- solana_agent/interfaces/services/knowledge_base.py +59 -0
- solana_agent/interfaces/services/query.py +41 -8
- solana_agent/interfaces/services/routing.py +0 -1
- solana_agent/plugins/manager.py +37 -16
- solana_agent/plugins/registry.py +34 -19
- solana_agent/plugins/tools/__init__.py +0 -5
- solana_agent/plugins/tools/auto_tool.py +1 -0
- solana_agent/repositories/memory.py +332 -111
- solana_agent/services/__init__.py +1 -1
- solana_agent/services/agent.py +390 -241
- solana_agent/services/knowledge_base.py +768 -0
- solana_agent/services/query.py +1858 -153
- solana_agent/services/realtime.py +626 -0
- solana_agent/services/routing.py +104 -51
- solana_agent-31.4.0.dist-info/METADATA +1070 -0
- solana_agent-31.4.0.dist-info/RECORD +49 -0
- {solana_agent-20.1.2.dist-info → solana_agent-31.4.0.dist-info}/WHEEL +1 -1
- solana_agent-31.4.0.dist-info/entry_points.txt +3 -0
- solana_agent/adapters/llm_adapter.py +0 -160
- solana_agent-20.1.2.dist-info/METADATA +0 -464
- solana_agent-20.1.2.dist-info/RECORD +0 -35
- {solana_agent-20.1.2.dist-info → solana_agent-31.4.0.dist-info/licenses}/LICENSE +0 -0
|
@@ -4,30 +4,84 @@ Factory for creating and wiring components of the Solana Agent system.
|
|
|
4
4
|
This module handles the creation and dependency injection for all
|
|
5
5
|
services and components used in the system.
|
|
6
6
|
"""
|
|
7
|
-
|
|
7
|
+
|
|
8
|
+
import importlib
|
|
9
|
+
import logging
|
|
10
|
+
from typing import Dict, Any, List
|
|
8
11
|
|
|
9
12
|
# Service imports
|
|
13
|
+
from solana_agent.adapters.pinecone_adapter import PineconeAdapter
|
|
14
|
+
from solana_agent.interfaces.guardrails.guardrails import (
|
|
15
|
+
InputGuardrail,
|
|
16
|
+
OutputGuardrail,
|
|
17
|
+
)
|
|
10
18
|
from solana_agent.services.query import QueryService
|
|
11
19
|
from solana_agent.services.agent import AgentService
|
|
12
20
|
from solana_agent.services.routing import RoutingService
|
|
21
|
+
from solana_agent.services.knowledge_base import KnowledgeBaseService
|
|
22
|
+
# Realtime is now managed per-call in QueryService.process; no factory wiring
|
|
13
23
|
|
|
14
24
|
# Repository imports
|
|
15
25
|
from solana_agent.repositories.memory import MemoryRepository
|
|
16
26
|
|
|
17
27
|
# Adapter imports
|
|
18
|
-
from solana_agent.adapters.
|
|
28
|
+
from solana_agent.adapters.openai_adapter import OpenAIAdapter
|
|
19
29
|
from solana_agent.adapters.mongodb_adapter import MongoDBAdapter
|
|
20
30
|
|
|
21
31
|
# Domain and plugin imports
|
|
22
32
|
from solana_agent.domains.agent import BusinessMission
|
|
23
33
|
from solana_agent.plugins.manager import PluginManager
|
|
24
34
|
|
|
35
|
+
# Setup logger for this module
|
|
36
|
+
logger = logging.getLogger(__name__)
|
|
37
|
+
|
|
25
38
|
|
|
26
39
|
class SolanaAgentFactory:
|
|
27
40
|
"""Factory for creating and wiring components of the Solana Agent system."""
|
|
28
41
|
|
|
29
42
|
@staticmethod
|
|
30
|
-
def
|
|
43
|
+
def _create_guardrails(guardrail_configs: List[Dict[str, Any]]) -> List[Any]:
|
|
44
|
+
"""Instantiates guardrails from configuration."""
|
|
45
|
+
guardrails = []
|
|
46
|
+
if not guardrail_configs:
|
|
47
|
+
return guardrails
|
|
48
|
+
|
|
49
|
+
for config in guardrail_configs:
|
|
50
|
+
class_path = config.get("class")
|
|
51
|
+
guardrail_config = config.get("config", {})
|
|
52
|
+
if not class_path:
|
|
53
|
+
logger.warning(
|
|
54
|
+
f"Guardrail config missing 'class': {config}"
|
|
55
|
+
) # Use logger.warning
|
|
56
|
+
continue
|
|
57
|
+
try:
|
|
58
|
+
module_path, class_name = class_path.rsplit(".", 1)
|
|
59
|
+
module = importlib.import_module(module_path)
|
|
60
|
+
guardrail_class = getattr(module, class_name)
|
|
61
|
+
# Instantiate the guardrail, handling potential errors during init
|
|
62
|
+
try:
|
|
63
|
+
guardrails.append(guardrail_class(config=guardrail_config))
|
|
64
|
+
logger.info(
|
|
65
|
+
f"Successfully loaded guardrail: {class_path}"
|
|
66
|
+
) # Use logger.info
|
|
67
|
+
except Exception as init_e:
|
|
68
|
+
logger.error(
|
|
69
|
+
f"Error initializing guardrail '{class_path}': {init_e}"
|
|
70
|
+
) # Use logger.error
|
|
71
|
+
# Optionally re-raise or just skip this guardrail
|
|
72
|
+
|
|
73
|
+
except (ImportError, AttributeError, ValueError) as e:
|
|
74
|
+
logger.error(
|
|
75
|
+
f"Error loading guardrail class '{class_path}': {e}"
|
|
76
|
+
) # Use logger.error
|
|
77
|
+
except Exception as e: # Catch unexpected errors during import/getattr
|
|
78
|
+
logger.exception(
|
|
79
|
+
f"Unexpected error loading guardrail '{class_path}': {e}"
|
|
80
|
+
) # Use logger.exception
|
|
81
|
+
return guardrails
|
|
82
|
+
|
|
83
|
+
@staticmethod
|
|
84
|
+
def create_from_config(config: Dict[str, Any]) -> QueryService: # pragma: no cover
|
|
31
85
|
"""Create the agent system from configuration.
|
|
32
86
|
|
|
33
87
|
Args:
|
|
@@ -39,7 +93,6 @@ class SolanaAgentFactory:
|
|
|
39
93
|
# Create adapters
|
|
40
94
|
|
|
41
95
|
if "mongo" in config:
|
|
42
|
-
# MongoDB connection string and database name
|
|
43
96
|
if "connection_string" not in config["mongo"]:
|
|
44
97
|
raise ValueError("MongoDB connection string is required.")
|
|
45
98
|
if "database" not in config["mongo"]:
|
|
@@ -51,9 +104,40 @@ class SolanaAgentFactory:
|
|
|
51
104
|
else:
|
|
52
105
|
db_adapter = None
|
|
53
106
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
107
|
+
# Determine which LLM provider to use (Grok or OpenAI)
|
|
108
|
+
# Priority: grok > openai
|
|
109
|
+
llm_api_key = None
|
|
110
|
+
llm_base_url = None
|
|
111
|
+
llm_model = None
|
|
112
|
+
|
|
113
|
+
if "grok" in config and "api_key" in config["grok"]:
|
|
114
|
+
llm_api_key = config["grok"]["api_key"]
|
|
115
|
+
llm_base_url = config["grok"].get("base_url", "https://api.x.ai/v1")
|
|
116
|
+
llm_model = config["grok"].get("model", "grok-4-1-fast-non-reasoning")
|
|
117
|
+
logger.info(f"Using Grok as LLM provider with model: {llm_model}")
|
|
118
|
+
elif "openai" in config and "api_key" in config["openai"]:
|
|
119
|
+
llm_api_key = config["openai"]["api_key"]
|
|
120
|
+
llm_base_url = None # Use default OpenAI endpoint
|
|
121
|
+
llm_model = None # Will use OpenAI adapter defaults
|
|
122
|
+
logger.info("Using OpenAI as LLM provider")
|
|
123
|
+
else:
|
|
124
|
+
raise ValueError("Either OpenAI or Grok API key is required in config.")
|
|
125
|
+
|
|
126
|
+
if "logfire" in config:
|
|
127
|
+
if "api_key" not in config["logfire"]:
|
|
128
|
+
raise ValueError("Pydantic Logfire API key is required.")
|
|
129
|
+
llm_adapter = OpenAIAdapter(
|
|
130
|
+
api_key=llm_api_key,
|
|
131
|
+
base_url=llm_base_url,
|
|
132
|
+
model=llm_model,
|
|
133
|
+
logfire_api_key=config["logfire"].get("api_key"),
|
|
134
|
+
)
|
|
135
|
+
else:
|
|
136
|
+
llm_adapter = OpenAIAdapter(
|
|
137
|
+
api_key=llm_api_key,
|
|
138
|
+
base_url=llm_base_url,
|
|
139
|
+
model=llm_model,
|
|
140
|
+
)
|
|
57
141
|
|
|
58
142
|
# Create business mission if specified in config
|
|
59
143
|
business_mission = None
|
|
@@ -61,92 +145,209 @@ class SolanaAgentFactory:
|
|
|
61
145
|
org_config = config["business"]
|
|
62
146
|
business_mission = BusinessMission(
|
|
63
147
|
mission=org_config.get("mission", ""),
|
|
64
|
-
values=[
|
|
65
|
-
|
|
148
|
+
values=[
|
|
149
|
+
{"name": k, "description": v}
|
|
150
|
+
for k, v in org_config.get("values", {}).items()
|
|
151
|
+
],
|
|
66
152
|
goals=org_config.get("goals", []),
|
|
67
|
-
voice=org_config.get("voice", "")
|
|
153
|
+
voice=org_config.get("voice", ""),
|
|
68
154
|
)
|
|
69
155
|
|
|
156
|
+
# capture_mode removed: repository now always upserts/merges per capture
|
|
157
|
+
|
|
70
158
|
# Create repositories
|
|
71
159
|
memory_provider = None
|
|
72
160
|
|
|
73
161
|
if "zep" in config and "mongo" in config:
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
162
|
+
mem_kwargs: Dict[str, Any] = {
|
|
163
|
+
"mongo_adapter": db_adapter,
|
|
164
|
+
"zep_api_key": config["zep"].get("api_key"),
|
|
165
|
+
}
|
|
166
|
+
memory_provider = MemoryRepository(**mem_kwargs)
|
|
78
167
|
|
|
79
|
-
if "mongo" in config and
|
|
80
|
-
|
|
168
|
+
if "mongo" in config and "zep" not in config:
|
|
169
|
+
mem_kwargs = {"mongo_adapter": db_adapter}
|
|
170
|
+
memory_provider = MemoryRepository(**mem_kwargs)
|
|
81
171
|
|
|
82
|
-
if "zep" in config and
|
|
172
|
+
if "zep" in config and "mongo" not in config:
|
|
83
173
|
if "api_key" not in config["zep"]:
|
|
84
174
|
raise ValueError("Zep API key is required.")
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
175
|
+
mem_kwargs = {"zep_api_key": config["zep"].get("api_key")}
|
|
176
|
+
memory_provider = MemoryRepository(**mem_kwargs)
|
|
177
|
+
|
|
178
|
+
guardrail_config = config.get("guardrails", {})
|
|
179
|
+
input_guardrails: List[InputGuardrail] = SolanaAgentFactory._create_guardrails(
|
|
180
|
+
guardrail_config.get("input", [])
|
|
181
|
+
)
|
|
182
|
+
output_guardrails: List[OutputGuardrail] = (
|
|
183
|
+
SolanaAgentFactory._create_guardrails(guardrail_config.get("output", []))
|
|
184
|
+
)
|
|
185
|
+
logger.info( # Use logger.info
|
|
186
|
+
f"Loaded {len(input_guardrails)} input guardrails and {len(output_guardrails)} output guardrails."
|
|
187
|
+
)
|
|
89
188
|
|
|
90
189
|
# Create primary services
|
|
91
190
|
agent_service = AgentService(
|
|
92
191
|
llm_provider=llm_adapter,
|
|
93
192
|
business_mission=business_mission,
|
|
94
193
|
config=config,
|
|
194
|
+
api_key=llm_api_key,
|
|
195
|
+
base_url=llm_base_url,
|
|
196
|
+
model=llm_model,
|
|
197
|
+
output_guardrails=output_guardrails,
|
|
95
198
|
)
|
|
96
199
|
|
|
97
|
-
# Debug the agent service tool registry
|
|
98
|
-
print(
|
|
99
|
-
f"Agent service tools after initialization: {agent_service.tool_registry.list_all_tools()}")
|
|
100
|
-
|
|
101
200
|
# Create routing service
|
|
201
|
+
# Use Grok model if configured, otherwise check for OpenAI routing_model override
|
|
202
|
+
routing_model = llm_model # Use the same model as the main LLM by default
|
|
203
|
+
if not routing_model:
|
|
204
|
+
# Fall back to OpenAI routing_model config if no Grok model
|
|
205
|
+
routing_model = (
|
|
206
|
+
config.get("openai", {}).get("routing_model")
|
|
207
|
+
if isinstance(config.get("openai"), dict)
|
|
208
|
+
else None
|
|
209
|
+
)
|
|
102
210
|
routing_service = RoutingService(
|
|
103
211
|
llm_provider=llm_adapter,
|
|
104
212
|
agent_service=agent_service,
|
|
213
|
+
api_key=llm_api_key,
|
|
214
|
+
base_url=llm_base_url,
|
|
215
|
+
model=routing_model,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# Debug the agent service tool registry
|
|
219
|
+
logger.debug( # Use logger.debug
|
|
220
|
+
f"Agent service tools after initialization: {agent_service.tool_registry.list_all_tools()}"
|
|
105
221
|
)
|
|
106
222
|
|
|
107
223
|
# Initialize plugin system
|
|
108
224
|
agent_service.plugin_manager = PluginManager(
|
|
109
|
-
config=config,
|
|
110
|
-
tool_registry=agent_service.tool_registry
|
|
225
|
+
config=config, tool_registry=agent_service.tool_registry
|
|
111
226
|
)
|
|
112
227
|
try:
|
|
113
228
|
loaded_plugins = agent_service.plugin_manager.load_plugins()
|
|
114
|
-
|
|
229
|
+
logger.info(f"Loaded {loaded_plugins} plugins") # Use logger.info
|
|
115
230
|
except Exception as e:
|
|
116
|
-
|
|
231
|
+
logger.error(f"Error loading plugins: {e}") # Use logger.error
|
|
117
232
|
loaded_plugins = 0
|
|
118
233
|
|
|
119
234
|
# Register predefined agents
|
|
120
|
-
for agent_config in config.get("agents", []):
|
|
235
|
+
for agent_config in config.get("agents", []): # pragma: no cover
|
|
236
|
+
extra_kwargs = {}
|
|
237
|
+
if "capture_name" in agent_config:
|
|
238
|
+
extra_kwargs["capture_name"] = agent_config.get("capture_name")
|
|
239
|
+
if "capture_schema" in agent_config:
|
|
240
|
+
extra_kwargs["capture_schema"] = agent_config.get("capture_schema")
|
|
241
|
+
|
|
121
242
|
agent_service.register_ai_agent(
|
|
122
243
|
name=agent_config["name"],
|
|
123
244
|
instructions=agent_config["instructions"],
|
|
124
245
|
specialization=agent_config["specialization"],
|
|
246
|
+
**extra_kwargs,
|
|
125
247
|
)
|
|
126
248
|
|
|
127
249
|
# Register tools for this agent
|
|
128
250
|
if "tools" in agent_config:
|
|
129
251
|
for tool_name in agent_config["tools"]:
|
|
130
|
-
|
|
131
|
-
f"Available tools before registering {tool_name}: {agent_service.tool_registry.list_all_tools()}"
|
|
132
|
-
|
|
133
|
-
|
|
252
|
+
logger.debug( # Use logger.debug
|
|
253
|
+
f"Available tools before registering {tool_name}: {agent_service.tool_registry.list_all_tools()}"
|
|
254
|
+
)
|
|
255
|
+
agent_service.assign_tool_for_agent(agent_config["name"], tool_name)
|
|
256
|
+
logger.info( # Use logger.info
|
|
257
|
+
f"Successfully registered {tool_name} for agent {agent_config['name']}"
|
|
134
258
|
)
|
|
135
|
-
print(
|
|
136
|
-
f"Successfully registered {tool_name} for agent {agent_config['name']}")
|
|
137
259
|
|
|
138
260
|
# Global tool registrations
|
|
139
261
|
if "agent_tools" in config:
|
|
140
262
|
for agent_name, tools in config["agent_tools"].items():
|
|
141
263
|
for tool_name in tools:
|
|
142
|
-
agent_service.assign_tool_for_agent(
|
|
143
|
-
|
|
264
|
+
agent_service.assign_tool_for_agent(agent_name, tool_name)
|
|
265
|
+
|
|
266
|
+
# Initialize Knowledge Base if configured
|
|
267
|
+
knowledge_base = None
|
|
268
|
+
kb_config = config.get("knowledge_base")
|
|
269
|
+
# Requires both KB config section and MongoDB adapter
|
|
270
|
+
if kb_config and db_adapter:
|
|
271
|
+
try:
|
|
272
|
+
pinecone_config = kb_config.get("pinecone", {})
|
|
273
|
+
splitter_config = kb_config.get("splitter", {})
|
|
274
|
+
# Get OpenAI embedding config (used by KBService)
|
|
275
|
+
openai_embed_config = kb_config.get("openai_embeddings", {})
|
|
276
|
+
|
|
277
|
+
# Determine OpenAI model and dimensions for KBService
|
|
278
|
+
openai_model_name = openai_embed_config.get(
|
|
279
|
+
"model_name", "text-embedding-3-large"
|
|
280
|
+
)
|
|
281
|
+
if openai_model_name == "text-embedding-3-large":
|
|
282
|
+
openai_dimensions = 3072
|
|
283
|
+
elif openai_model_name == "text-embedding-3-small": # pragma: no cover
|
|
284
|
+
openai_dimensions = 1536 # pragma: no cover
|
|
285
|
+
else: # pragma: no cover
|
|
286
|
+
logger.warning(
|
|
287
|
+
f"Unknown OpenAI embedding model '{openai_model_name}' specified for KB. Defaulting dimensions to 3072."
|
|
288
|
+
) # pragma: no cover
|
|
289
|
+
openai_dimensions = 3072 # pragma: no cover
|
|
290
|
+
|
|
291
|
+
# Create Pinecone adapter for KB
|
|
292
|
+
# It now relies on external embeddings, so dimension MUST match OpenAI model
|
|
293
|
+
pinecone_adapter = PineconeAdapter(
|
|
294
|
+
api_key=pinecone_config.get("api_key"),
|
|
295
|
+
index_name=pinecone_config.get("index_name"),
|
|
296
|
+
# This dimension MUST match the OpenAI model used by KBService
|
|
297
|
+
embedding_dimensions=openai_dimensions,
|
|
298
|
+
cloud_provider=pinecone_config.get("cloud_provider", "aws"),
|
|
299
|
+
region=pinecone_config.get("region", "us-east-1"),
|
|
300
|
+
metric=pinecone_config.get("metric", "cosine"),
|
|
301
|
+
create_index_if_not_exists=pinecone_config.get(
|
|
302
|
+
"create_index", True
|
|
303
|
+
),
|
|
304
|
+
# Reranking config
|
|
305
|
+
use_reranking=pinecone_config.get("use_reranking", False),
|
|
306
|
+
rerank_model=pinecone_config.get("rerank_model"),
|
|
307
|
+
rerank_top_k=pinecone_config.get("rerank_top_k", 3),
|
|
308
|
+
initial_query_top_k_multiplier=pinecone_config.get(
|
|
309
|
+
"initial_query_top_k_multiplier", 5
|
|
310
|
+
),
|
|
311
|
+
rerank_text_field=pinecone_config.get("rerank_text_field", "text"),
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
# Create the KB service using OpenAI embeddings
|
|
315
|
+
knowledge_base = KnowledgeBaseService(
|
|
316
|
+
pinecone_adapter=pinecone_adapter,
|
|
317
|
+
mongodb_adapter=db_adapter,
|
|
318
|
+
# Pass OpenAI config directly
|
|
319
|
+
openai_api_key=openai_embed_config.get("api_key")
|
|
320
|
+
or config.get("openai", {}).get("api_key"),
|
|
321
|
+
openai_model_name=openai_model_name,
|
|
322
|
+
collection_name=kb_config.get(
|
|
323
|
+
"collection_name", "knowledge_documents"
|
|
324
|
+
),
|
|
325
|
+
# Pass rerank config (though PineconeAdapter handles the logic)
|
|
326
|
+
rerank_results=pinecone_config.get("use_reranking", False),
|
|
327
|
+
rerank_top_k=pinecone_config.get("rerank_top_k", 3),
|
|
328
|
+
# Pass splitter config
|
|
329
|
+
splitter_buffer_size=splitter_config.get("buffer_size", 1),
|
|
330
|
+
splitter_breakpoint_percentile=splitter_config.get(
|
|
331
|
+
"breakpoint_percentile", 95
|
|
332
|
+
),
|
|
333
|
+
)
|
|
334
|
+
logger.info(
|
|
335
|
+
"Knowledge Base Service initialized successfully."
|
|
336
|
+
) # Use logger.info
|
|
337
|
+
|
|
338
|
+
except Exception as e:
|
|
339
|
+
# Use logger.exception to include traceback automatically
|
|
340
|
+
logger.exception(f"Failed to initialize Knowledge Base: {e}")
|
|
341
|
+
knowledge_base = None # Ensure KB is None if init fails
|
|
144
342
|
|
|
145
343
|
# Create and return the query service
|
|
146
344
|
query_service = QueryService(
|
|
147
345
|
agent_service=agent_service,
|
|
148
346
|
routing_service=routing_service,
|
|
149
347
|
memory_provider=memory_provider,
|
|
348
|
+
knowledge_base=knowledge_base, # Pass the potentially created KB
|
|
349
|
+
kb_results_count=kb_config.get("results_count", 3) if kb_config else 3,
|
|
350
|
+
input_guardrails=input_guardrails,
|
|
150
351
|
)
|
|
151
352
|
|
|
152
353
|
return query_service
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Dict, Any, Optional, List
|
|
3
|
+
import scrubadub
|
|
4
|
+
from solana_agent.interfaces.guardrails.guardrails import (
|
|
5
|
+
InputGuardrail,
|
|
6
|
+
OutputGuardrail,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PII(InputGuardrail, OutputGuardrail):
|
|
13
|
+
"""
|
|
14
|
+
A guardrail using Scrubadub to detect and remove PII.
|
|
15
|
+
|
|
16
|
+
Requires 'scrubadub'. Install with: pip install solana-agent[guardrails]
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
DEFAULT_REPLACEMENT = "[REDACTED_{detector_name}]"
|
|
20
|
+
DEFAULT_LANG = "en_US" # Scrubadub uses locale format
|
|
21
|
+
|
|
22
|
+
def __init__(self, config: Dict[str, Any] = None):
|
|
23
|
+
super().__init__(config)
|
|
24
|
+
self.replacement_format = self.config.get(
|
|
25
|
+
"replacement", self.DEFAULT_REPLACEMENT
|
|
26
|
+
)
|
|
27
|
+
self.locale = self.config.get("locale", self.DEFAULT_LANG)
|
|
28
|
+
# Optional: Specify detectors to use, None uses defaults
|
|
29
|
+
self.detector_list: Optional[List[str]] = self.config.get("detectors")
|
|
30
|
+
# Optional: Add custom detectors if needed via config
|
|
31
|
+
self.extra_detector_list = self.config.get(
|
|
32
|
+
"extra_detectors", []
|
|
33
|
+
) # List of detector classes/instances
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
# Initialize Scrubber
|
|
37
|
+
# Note: detector_list expects instances, not names. Need mapping or direct instantiation if customizing.
|
|
38
|
+
# For simplicity, we'll use defaults or allow passing instances via config (advanced).
|
|
39
|
+
# Using default detectors if self.detector_list is None.
|
|
40
|
+
if self.detector_list is not None:
|
|
41
|
+
logger.warning(
|
|
42
|
+
"Customizing 'detectors' by name list is not directly supported here yet. Using defaults."
|
|
43
|
+
)
|
|
44
|
+
# TODO: Add logic to map names to detector classes if needed.
|
|
45
|
+
self.scrubber = scrubadub.Scrubber(locale=self.locale)
|
|
46
|
+
else:
|
|
47
|
+
self.scrubber = scrubadub.Scrubber(locale=self.locale)
|
|
48
|
+
|
|
49
|
+
# Add any extra detectors passed via config (e.g., custom regex detectors)
|
|
50
|
+
for detector in self.extra_detector_list:
|
|
51
|
+
# Assuming extra_detectors are already instantiated objects
|
|
52
|
+
# Or add logic here to instantiate them based on class paths/names
|
|
53
|
+
if isinstance(detector, scrubadub.detectors.Detector):
|
|
54
|
+
self.scrubber.add_detector(detector)
|
|
55
|
+
else:
|
|
56
|
+
logger.warning(f"Invalid item in extra_detectors: {detector}")
|
|
57
|
+
|
|
58
|
+
logger.info(f"ScrubadubPIIFilter initialized for locale '{self.locale}'")
|
|
59
|
+
|
|
60
|
+
except ImportError:
|
|
61
|
+
logger.error(
|
|
62
|
+
"Scrubadub not installed. Please install with 'pip install solana-agent[guardrails]'"
|
|
63
|
+
)
|
|
64
|
+
raise
|
|
65
|
+
except Exception as e:
|
|
66
|
+
logger.error(f"Failed to initialize Scrubadub: {e}", exc_info=True)
|
|
67
|
+
raise
|
|
68
|
+
|
|
69
|
+
async def process(self, text: str) -> str:
|
|
70
|
+
"""Clean text using Scrubadub."""
|
|
71
|
+
try:
|
|
72
|
+
# Scrubadub's clean method handles the replacement logic.
|
|
73
|
+
# We need to customize the replacement format per detector.
|
|
74
|
+
# This requires iterating through filth found first.
|
|
75
|
+
|
|
76
|
+
clean_text = text
|
|
77
|
+
filth_list = list(self.scrubber.iter_filth(text)) # Get all findings
|
|
78
|
+
|
|
79
|
+
if not filth_list:
|
|
80
|
+
return text
|
|
81
|
+
|
|
82
|
+
# Sort by start index to handle replacements correctly
|
|
83
|
+
filth_list.sort(key=lambda f: f.beg)
|
|
84
|
+
|
|
85
|
+
offset = 0
|
|
86
|
+
for filth in filth_list:
|
|
87
|
+
start = filth.beg + offset
|
|
88
|
+
end = filth.end + offset
|
|
89
|
+
replacement_text = self.replacement_format.format(
|
|
90
|
+
detector_name=filth.detector_name,
|
|
91
|
+
text=filth.text,
|
|
92
|
+
locale=filth.locale,
|
|
93
|
+
# Add other filth attributes if needed in format string
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
clean_text = clean_text[:start] + replacement_text + clean_text[end:]
|
|
97
|
+
offset += len(replacement_text) - (filth.end - filth.beg)
|
|
98
|
+
|
|
99
|
+
if clean_text != text:
|
|
100
|
+
logger.debug(
|
|
101
|
+
f"ScrubadubPIIFilter redacted {len(filth_list)} pieces of filth."
|
|
102
|
+
)
|
|
103
|
+
return clean_text
|
|
104
|
+
|
|
105
|
+
except Exception as e:
|
|
106
|
+
logger.error(f"Error during Scrubadub cleaning: {e}", exc_info=True)
|
|
107
|
+
return text # Return original text on error
|
|
@@ -1,43 +1,126 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import AsyncGenerator, Dict, Any, List, Literal, Optional, Type, Union
|
|
3
3
|
|
|
4
|
+
from pydantic import BaseModel
|
|
4
5
|
from solana_agent.interfaces.plugins.plugins import Tool
|
|
6
|
+
from solana_agent.interfaces.services.routing import RoutingService as RoutingInterface
|
|
7
|
+
from solana_agent.interfaces.providers.realtime import RealtimeChunk
|
|
5
8
|
|
|
6
9
|
|
|
7
10
|
class SolanaAgent(ABC):
|
|
8
|
-
"""Interface for the Solana
|
|
11
|
+
"""Interface for the Solana Agent client."""
|
|
9
12
|
|
|
10
13
|
@abstractmethod
|
|
11
14
|
async def process(
|
|
12
15
|
self,
|
|
13
16
|
user_id: str,
|
|
14
17
|
message: Union[str, bytes],
|
|
18
|
+
prompt: Optional[str] = None,
|
|
15
19
|
output_format: Literal["text", "audio"] = "text",
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
20
|
+
capture_schema: Optional[Dict[str, Any]] = None,
|
|
21
|
+
capture_name: Optional[str] = None,
|
|
22
|
+
realtime: bool = False,
|
|
23
|
+
vad: bool = False,
|
|
24
|
+
rt_encode_input: bool = False,
|
|
25
|
+
rt_encode_output: bool = False,
|
|
26
|
+
rt_output_modalities: Optional[List[Literal["audio", "text"]]] = None,
|
|
27
|
+
rt_voice: Literal[
|
|
28
|
+
"alloy",
|
|
29
|
+
"ash",
|
|
30
|
+
"ballad",
|
|
31
|
+
"cedar",
|
|
32
|
+
"coral",
|
|
33
|
+
"echo",
|
|
34
|
+
"marin",
|
|
35
|
+
"sage",
|
|
36
|
+
"shimmer",
|
|
37
|
+
"verse",
|
|
38
|
+
] = "marin",
|
|
39
|
+
audio_voice: Literal[
|
|
40
|
+
"alloy",
|
|
41
|
+
"ash",
|
|
42
|
+
"ballad",
|
|
43
|
+
"coral",
|
|
44
|
+
"echo",
|
|
45
|
+
"fable",
|
|
46
|
+
"onyx",
|
|
47
|
+
"nova",
|
|
48
|
+
"sage",
|
|
49
|
+
"shimmer",
|
|
50
|
+
] = "nova",
|
|
51
|
+
audio_output_format: Literal[
|
|
52
|
+
"mp3", "opus", "aac", "flac", "wav", "pcm"
|
|
53
|
+
] = "aac",
|
|
21
54
|
audio_input_format: Literal[
|
|
22
55
|
"flac", "mp3", "mp4", "mpeg", "mpga", "m4a", "ogg", "wav", "webm"
|
|
23
56
|
] = "mp4",
|
|
24
|
-
|
|
25
|
-
|
|
57
|
+
router: Optional[RoutingInterface] = None,
|
|
58
|
+
images: Optional[List[Union[str, bytes]]] = None,
|
|
59
|
+
output_model: Optional[Type[BaseModel]] = None,
|
|
60
|
+
) -> AsyncGenerator[Union[str, bytes, BaseModel, RealtimeChunk], None]:
|
|
26
61
|
"""Process a user message and return the response stream."""
|
|
27
62
|
pass
|
|
28
63
|
|
|
64
|
+
@abstractmethod
|
|
65
|
+
async def delete_user_history(self, user_id: str) -> None:
|
|
66
|
+
"""Delete the conversation history for a user."""
|
|
67
|
+
pass
|
|
68
|
+
|
|
29
69
|
@abstractmethod
|
|
30
70
|
async def get_user_history(
|
|
31
71
|
self,
|
|
32
72
|
user_id: str,
|
|
33
73
|
page_num: int = 1,
|
|
34
74
|
page_size: int = 20,
|
|
35
|
-
sort_order: str = "desc"
|
|
75
|
+
sort_order: str = "desc",
|
|
36
76
|
) -> Dict[str, Any]:
|
|
37
77
|
"""Get paginated message history for a user."""
|
|
38
78
|
pass
|
|
39
79
|
|
|
40
80
|
@abstractmethod
|
|
41
|
-
def register_tool(self, tool: Tool) -> bool:
|
|
42
|
-
"""Register a tool
|
|
81
|
+
def register_tool(self, agent_name: str, tool: Tool) -> bool:
|
|
82
|
+
"""Register a tool with the agent system."""
|
|
83
|
+
pass
|
|
84
|
+
|
|
85
|
+
@abstractmethod
|
|
86
|
+
async def kb_add_document(
|
|
87
|
+
self,
|
|
88
|
+
text: str,
|
|
89
|
+
metadata: Dict[str, Any],
|
|
90
|
+
document_id: Optional[str] = None,
|
|
91
|
+
namespace: Optional[str] = None,
|
|
92
|
+
) -> str:
|
|
93
|
+
"""Add a document to the knowledge base."""
|
|
94
|
+
pass
|
|
95
|
+
|
|
96
|
+
@abstractmethod
|
|
97
|
+
async def kb_query(
|
|
98
|
+
self,
|
|
99
|
+
query_text: str,
|
|
100
|
+
filter: Optional[Dict[str, Any]] = None,
|
|
101
|
+
top_k: int = 5,
|
|
102
|
+
namespace: Optional[str] = None,
|
|
103
|
+
include_content: bool = True,
|
|
104
|
+
include_metadata: bool = True,
|
|
105
|
+
) -> List[Dict[str, Any]]:
|
|
106
|
+
"""Query the knowledge base."""
|
|
107
|
+
pass
|
|
108
|
+
|
|
109
|
+
@abstractmethod
|
|
110
|
+
async def kb_delete_document(
|
|
111
|
+
self, document_id: str, namespace: Optional[str] = None
|
|
112
|
+
) -> bool:
|
|
113
|
+
"""Delete a document from the knowledge base."""
|
|
114
|
+
pass
|
|
115
|
+
|
|
116
|
+
@abstractmethod
|
|
117
|
+
async def kb_add_pdf_document(
|
|
118
|
+
self,
|
|
119
|
+
pdf_data: Union[bytes, str],
|
|
120
|
+
metadata: Dict[str, Any],
|
|
121
|
+
document_id: Optional[str] = None,
|
|
122
|
+
namespace: Optional[str] = None,
|
|
123
|
+
chunk_batch_size: int = 50,
|
|
124
|
+
) -> str:
|
|
125
|
+
"""Add a PDF document to the knowledge base."""
|
|
43
126
|
pass
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any, Dict
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Guardrail(ABC):
|
|
6
|
+
"""Base class for all guardrails."""
|
|
7
|
+
|
|
8
|
+
def __init__(self, config: Dict[str, Any] = None):
|
|
9
|
+
self.config = config or {}
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
async def process(self, text: str) -> str:
|
|
13
|
+
"""Process the text and return the modified text."""
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class InputGuardrail(Guardrail):
|
|
18
|
+
"""Interface for guardrails applied to user input."""
|
|
19
|
+
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class OutputGuardrail(Guardrail):
|
|
24
|
+
"""Interface for guardrails applied to agent output."""
|
|
25
|
+
|
|
26
|
+
pass
|
|
@@ -4,8 +4,9 @@ Plugin system interfaces.
|
|
|
4
4
|
These interfaces define the contracts for the plugin system,
|
|
5
5
|
enabling extensibility through tools and plugins.
|
|
6
6
|
"""
|
|
7
|
+
|
|
7
8
|
from abc import ABC, abstractmethod
|
|
8
|
-
from typing import Dict, List, Any, Optional
|
|
9
|
+
from typing import Dict, List, Any, Optional
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
class Tool(ABC):
|
|
File without changes
|