praisonaiagents 0.0.108__py3-none-any.whl → 0.0.110__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- praisonaiagents/__init__.py +6 -0
- praisonaiagents/agent/__init__.py +2 -1
- praisonaiagents/agent/agent.py +79 -3
- praisonaiagents/agent/handoff.py +317 -0
- praisonaiagents/llm/llm.py +217 -68
- praisonaiagents/mcp/mcp.py +27 -7
- praisonaiagents/memory/memory.py +48 -0
- praisonaiagents/tools/duckdb_tools.py +47 -16
- praisonaiagents/tools/file_tools.py +52 -10
- praisonaiagents/tools/python_tools.py +84 -4
- praisonaiagents/tools/shell_tools.py +18 -8
- praisonaiagents/tools/spider_tools.py +55 -0
- {praisonaiagents-0.0.108.dist-info → praisonaiagents-0.0.110.dist-info}/METADATA +1 -1
- {praisonaiagents-0.0.108.dist-info → praisonaiagents-0.0.110.dist-info}/RECORD +16 -15
- {praisonaiagents-0.0.108.dist-info → praisonaiagents-0.0.110.dist-info}/WHEEL +0 -0
- {praisonaiagents-0.0.108.dist-info → praisonaiagents-0.0.110.dist-info}/top_level.txt +0 -0
praisonaiagents/llm/llm.py
CHANGED
@@ -276,6 +276,41 @@ class LLM:
|
|
276
276
|
]
|
277
277
|
|
278
278
|
return self.model in legacy_o1_models
|
279
|
+
|
280
|
+
def _supports_streaming_tools(self) -> bool:
|
281
|
+
"""
|
282
|
+
Check if the current provider supports streaming with tools.
|
283
|
+
|
284
|
+
Most providers that support tool calling also support streaming with tools,
|
285
|
+
but some providers (like Ollama and certain local models) require non-streaming
|
286
|
+
calls when tools are involved.
|
287
|
+
|
288
|
+
Returns:
|
289
|
+
bool: True if provider supports streaming with tools, False otherwise
|
290
|
+
"""
|
291
|
+
if not self.model:
|
292
|
+
return False
|
293
|
+
|
294
|
+
# Ollama doesn't reliably support streaming with tools
|
295
|
+
if self._is_ollama_provider():
|
296
|
+
return False
|
297
|
+
|
298
|
+
# OpenAI models support streaming with tools
|
299
|
+
if any(self.model.startswith(prefix) for prefix in ["gpt-", "o1-", "o3-"]):
|
300
|
+
return True
|
301
|
+
|
302
|
+
# Anthropic Claude models support streaming with tools
|
303
|
+
if self.model.startswith("claude-"):
|
304
|
+
return True
|
305
|
+
|
306
|
+
# Google Gemini models support streaming with tools
|
307
|
+
if any(self.model.startswith(prefix) for prefix in ["gemini-", "gemini/"]):
|
308
|
+
return True
|
309
|
+
|
310
|
+
# For other providers, default to False to be safe
|
311
|
+
# This ensures we make a single non-streaming call rather than risk
|
312
|
+
# missing tool calls or making duplicate calls
|
313
|
+
return False
|
279
314
|
|
280
315
|
def get_response(
|
281
316
|
self,
|
@@ -480,49 +515,110 @@ class LLM:
|
|
480
515
|
|
481
516
|
# Otherwise do the existing streaming approach
|
482
517
|
else:
|
483
|
-
if
|
484
|
-
|
485
|
-
|
518
|
+
# Determine if we should use streaming based on tool support
|
519
|
+
use_streaming = stream
|
520
|
+
if formatted_tools and not self._supports_streaming_tools():
|
521
|
+
# Provider doesn't support streaming with tools, use non-streaming
|
522
|
+
use_streaming = False
|
523
|
+
|
524
|
+
if use_streaming:
|
525
|
+
# Streaming approach (with or without tools)
|
526
|
+
tool_calls = []
|
527
|
+
response_text = ""
|
528
|
+
|
529
|
+
if verbose:
|
530
|
+
with Live(display_generating("", current_time), console=console, refresh_per_second=4) as live:
|
531
|
+
for chunk in litellm.completion(
|
532
|
+
**self._build_completion_params(
|
533
|
+
messages=messages,
|
534
|
+
tools=formatted_tools,
|
535
|
+
temperature=temperature,
|
536
|
+
stream=True,
|
537
|
+
**kwargs
|
538
|
+
)
|
539
|
+
):
|
540
|
+
if chunk and chunk.choices and chunk.choices[0].delta:
|
541
|
+
delta = chunk.choices[0].delta
|
542
|
+
if delta.content:
|
543
|
+
response_text += delta.content
|
544
|
+
live.update(display_generating(response_text, current_time))
|
545
|
+
|
546
|
+
# Capture tool calls from streaming chunks if provider supports it
|
547
|
+
if formatted_tools and self._supports_streaming_tools() and hasattr(delta, 'tool_calls') and delta.tool_calls:
|
548
|
+
for tc in delta.tool_calls:
|
549
|
+
if tc.index >= len(tool_calls):
|
550
|
+
tool_calls.append({
|
551
|
+
"id": tc.id,
|
552
|
+
"type": "function",
|
553
|
+
"function": {"name": "", "arguments": ""}
|
554
|
+
})
|
555
|
+
if tc.function.name:
|
556
|
+
tool_calls[tc.index]["function"]["name"] = tc.function.name
|
557
|
+
if tc.function.arguments:
|
558
|
+
tool_calls[tc.index]["function"]["arguments"] += tc.function.arguments
|
559
|
+
else:
|
560
|
+
# Non-verbose streaming
|
486
561
|
for chunk in litellm.completion(
|
487
562
|
**self._build_completion_params(
|
488
563
|
messages=messages,
|
489
564
|
tools=formatted_tools,
|
490
565
|
temperature=temperature,
|
491
|
-
stream=
|
566
|
+
stream=True,
|
492
567
|
**kwargs
|
493
568
|
)
|
494
569
|
):
|
495
|
-
if chunk and chunk.choices and chunk.choices[0].delta
|
496
|
-
|
497
|
-
|
498
|
-
|
570
|
+
if chunk and chunk.choices and chunk.choices[0].delta:
|
571
|
+
delta = chunk.choices[0].delta
|
572
|
+
if delta.content:
|
573
|
+
response_text += delta.content
|
574
|
+
|
575
|
+
# Capture tool calls from streaming chunks if provider supports it
|
576
|
+
if formatted_tools and self._supports_streaming_tools() and hasattr(delta, 'tool_calls') and delta.tool_calls:
|
577
|
+
for tc in delta.tool_calls:
|
578
|
+
if tc.index >= len(tool_calls):
|
579
|
+
tool_calls.append({
|
580
|
+
"id": tc.id,
|
581
|
+
"type": "function",
|
582
|
+
"function": {"name": "", "arguments": ""}
|
583
|
+
})
|
584
|
+
if tc.function.name:
|
585
|
+
tool_calls[tc.index]["function"]["name"] = tc.function.name
|
586
|
+
if tc.function.arguments:
|
587
|
+
tool_calls[tc.index]["function"]["arguments"] += tc.function.arguments
|
588
|
+
|
589
|
+
response_text = response_text.strip()
|
590
|
+
|
591
|
+
# Create a mock final_response with the captured data
|
592
|
+
final_response = {
|
593
|
+
"choices": [{
|
594
|
+
"message": {
|
595
|
+
"content": response_text,
|
596
|
+
"tool_calls": tool_calls if tool_calls else None
|
597
|
+
}
|
598
|
+
}]
|
599
|
+
}
|
499
600
|
else:
|
500
|
-
# Non-
|
501
|
-
|
502
|
-
for chunk in litellm.completion(
|
601
|
+
# Non-streaming approach (when tools require it or streaming is disabled)
|
602
|
+
final_response = litellm.completion(
|
503
603
|
**self._build_completion_params(
|
504
604
|
messages=messages,
|
505
605
|
tools=formatted_tools,
|
506
606
|
temperature=temperature,
|
507
|
-
stream=
|
607
|
+
stream=False,
|
508
608
|
**kwargs
|
509
609
|
)
|
510
|
-
):
|
511
|
-
if chunk and chunk.choices and chunk.choices[0].delta.content:
|
512
|
-
response_text += chunk.choices[0].delta.content
|
513
|
-
|
514
|
-
response_text = response_text.strip()
|
515
|
-
|
516
|
-
# Get final completion to check for tool calls
|
517
|
-
final_response = litellm.completion(
|
518
|
-
**self._build_completion_params(
|
519
|
-
messages=messages,
|
520
|
-
tools=formatted_tools,
|
521
|
-
temperature=temperature,
|
522
|
-
stream=False, # No streaming for tool call check
|
523
|
-
**kwargs
|
524
610
|
)
|
525
|
-
|
611
|
+
response_text = final_response["choices"][0]["message"]["content"]
|
612
|
+
|
613
|
+
if verbose:
|
614
|
+
# Display the complete response at once
|
615
|
+
display_interaction(
|
616
|
+
original_prompt,
|
617
|
+
response_text,
|
618
|
+
markdown=markdown,
|
619
|
+
generation_time=time.time() - current_time,
|
620
|
+
console=console
|
621
|
+
)
|
526
622
|
|
527
623
|
tool_calls = final_response["choices"][0]["message"].get("tool_calls")
|
528
624
|
|
@@ -1198,53 +1294,106 @@ Output MUST be JSON with 'reflection' and 'satisfactory'.
|
|
1198
1294
|
console=console
|
1199
1295
|
)
|
1200
1296
|
else:
|
1201
|
-
if
|
1202
|
-
|
1203
|
-
|
1204
|
-
#
|
1205
|
-
|
1206
|
-
|
1207
|
-
|
1208
|
-
|
1209
|
-
|
1210
|
-
|
1211
|
-
|
1212
|
-
|
1213
|
-
|
1214
|
-
|
1215
|
-
|
1216
|
-
|
1297
|
+
# Determine if we should use streaming based on tool support
|
1298
|
+
use_streaming = stream
|
1299
|
+
if formatted_tools and not self._supports_streaming_tools():
|
1300
|
+
# Provider doesn't support streaming with tools, use non-streaming
|
1301
|
+
use_streaming = False
|
1302
|
+
|
1303
|
+
if use_streaming:
|
1304
|
+
# Streaming approach (with or without tools)
|
1305
|
+
tool_calls = []
|
1306
|
+
|
1307
|
+
if verbose:
|
1308
|
+
async for chunk in await litellm.acompletion(
|
1309
|
+
**self._build_completion_params(
|
1310
|
+
messages=messages,
|
1311
|
+
temperature=temperature,
|
1312
|
+
stream=True,
|
1313
|
+
tools=formatted_tools,
|
1314
|
+
**kwargs
|
1315
|
+
)
|
1316
|
+
):
|
1317
|
+
if chunk and chunk.choices and chunk.choices[0].delta:
|
1318
|
+
delta = chunk.choices[0].delta
|
1319
|
+
if delta.content:
|
1320
|
+
response_text += delta.content
|
1321
|
+
print("\033[K", end="\r")
|
1322
|
+
print(f"Generating... {time.time() - start_time:.1f}s", end="\r")
|
1323
|
+
|
1324
|
+
# Capture tool calls from streaming chunks if provider supports it
|
1325
|
+
if formatted_tools and self._supports_streaming_tools() and hasattr(delta, 'tool_calls') and delta.tool_calls:
|
1326
|
+
for tc in delta.tool_calls:
|
1327
|
+
if tc.index >= len(tool_calls):
|
1328
|
+
tool_calls.append({
|
1329
|
+
"id": tc.id,
|
1330
|
+
"type": "function",
|
1331
|
+
"function": {"name": "", "arguments": ""}
|
1332
|
+
})
|
1333
|
+
if tc.function.name:
|
1334
|
+
tool_calls[tc.index]["function"]["name"] = tc.function.name
|
1335
|
+
if tc.function.arguments:
|
1336
|
+
tool_calls[tc.index]["function"]["arguments"] += tc.function.arguments
|
1337
|
+
else:
|
1338
|
+
# Non-verbose streaming
|
1339
|
+
async for chunk in await litellm.acompletion(
|
1340
|
+
**self._build_completion_params(
|
1341
|
+
messages=messages,
|
1342
|
+
temperature=temperature,
|
1343
|
+
stream=True,
|
1344
|
+
tools=formatted_tools,
|
1345
|
+
**kwargs
|
1346
|
+
)
|
1347
|
+
):
|
1348
|
+
if chunk and chunk.choices and chunk.choices[0].delta:
|
1349
|
+
delta = chunk.choices[0].delta
|
1350
|
+
if delta.content:
|
1351
|
+
response_text += delta.content
|
1352
|
+
|
1353
|
+
# Capture tool calls from streaming chunks if provider supports it
|
1354
|
+
if formatted_tools and self._supports_streaming_tools() and hasattr(delta, 'tool_calls') and delta.tool_calls:
|
1355
|
+
for tc in delta.tool_calls:
|
1356
|
+
if tc.index >= len(tool_calls):
|
1357
|
+
tool_calls.append({
|
1358
|
+
"id": tc.id,
|
1359
|
+
"type": "function",
|
1360
|
+
"function": {"name": "", "arguments": ""}
|
1361
|
+
})
|
1362
|
+
if tc.function.name:
|
1363
|
+
tool_calls[tc.index]["function"]["name"] = tc.function.name
|
1364
|
+
if tc.function.arguments:
|
1365
|
+
tool_calls[tc.index]["function"]["arguments"] += tc.function.arguments
|
1366
|
+
|
1367
|
+
response_text = response_text.strip()
|
1368
|
+
|
1369
|
+
# We already have tool_calls from streaming if supported
|
1370
|
+
# No need for a second API call!
|
1217
1371
|
else:
|
1218
|
-
# Non-
|
1219
|
-
|
1372
|
+
# Non-streaming approach (when tools require it or streaming is disabled)
|
1373
|
+
tool_response = await litellm.acompletion(
|
1220
1374
|
**self._build_completion_params(
|
1221
1375
|
messages=messages,
|
1222
1376
|
temperature=temperature,
|
1223
|
-
stream=
|
1224
|
-
|
1377
|
+
stream=False,
|
1378
|
+
tools=formatted_tools,
|
1379
|
+
**{k:v for k,v in kwargs.items() if k != 'reasoning_steps'}
|
1225
1380
|
)
|
1226
|
-
):
|
1227
|
-
if chunk and chunk.choices and chunk.choices[0].delta.content:
|
1228
|
-
response_text += chunk.choices[0].delta.content
|
1229
|
-
|
1230
|
-
response_text = response_text.strip()
|
1231
|
-
|
1232
|
-
# ----------------------------------------------------
|
1233
|
-
# 2) If tool calls are needed, do a non-streaming call
|
1234
|
-
# ----------------------------------------------------
|
1235
|
-
if tools and execute_tool_fn:
|
1236
|
-
# Next call with tools if needed
|
1237
|
-
tool_response = await litellm.acompletion(
|
1238
|
-
**self._build_completion_params(
|
1239
|
-
messages=messages,
|
1240
|
-
temperature=temperature,
|
1241
|
-
stream=False,
|
1242
|
-
tools=formatted_tools, # We safely pass tools here
|
1243
|
-
**{k:v for k,v in kwargs.items() if k != 'reasoning_steps'}
|
1244
1381
|
)
|
1245
|
-
|
1246
|
-
|
1247
|
-
|
1382
|
+
response_text = tool_response.choices[0].message.get("content", "")
|
1383
|
+
tool_calls = tool_response.choices[0].message.get("tool_calls", [])
|
1384
|
+
|
1385
|
+
if verbose:
|
1386
|
+
# Display the complete response at once
|
1387
|
+
display_interaction(
|
1388
|
+
original_prompt,
|
1389
|
+
response_text,
|
1390
|
+
markdown=markdown,
|
1391
|
+
generation_time=time.time() - start_time,
|
1392
|
+
console=console
|
1393
|
+
)
|
1394
|
+
|
1395
|
+
# Now handle tools if we have them (either from streaming or non-streaming)
|
1396
|
+
if tools and execute_tool_fn and tool_calls:
|
1248
1397
|
|
1249
1398
|
if tool_calls:
|
1250
1399
|
# Convert tool_calls to a serializable format for all providers
|
praisonaiagents/mcp/mcp.py
CHANGED
@@ -7,6 +7,7 @@ import shlex
|
|
7
7
|
import logging
|
8
8
|
import os
|
9
9
|
import re
|
10
|
+
import platform
|
10
11
|
from typing import Any, List, Optional, Callable, Iterable, Union
|
11
12
|
from functools import wraps, partial
|
12
13
|
|
@@ -199,7 +200,13 @@ class MCP:
|
|
199
200
|
# Handle the single string format for stdio client
|
200
201
|
if isinstance(command_or_string, str) and args is None:
|
201
202
|
# Split the string into command and args using shell-like parsing
|
202
|
-
|
203
|
+
if platform.system() == 'Windows':
|
204
|
+
# Use shlex with posix=False for Windows to handle quotes and paths with spaces
|
205
|
+
parts = shlex.split(command_or_string, posix=False)
|
206
|
+
# Remove quotes from parts if present (Windows shlex keeps them)
|
207
|
+
parts = [part.strip('"') for part in parts]
|
208
|
+
else:
|
209
|
+
parts = shlex.split(command_or_string)
|
203
210
|
if not parts:
|
204
211
|
raise ValueError("Empty command string")
|
205
212
|
|
@@ -217,11 +224,17 @@ class MCP:
|
|
217
224
|
env = kwargs.get('env', {})
|
218
225
|
if not env:
|
219
226
|
env = os.environ.copy()
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
227
|
+
|
228
|
+
# Always set Python encoding
|
229
|
+
env['PYTHONIOENCODING'] = 'utf-8'
|
230
|
+
|
231
|
+
# Only set locale variables on Unix systems
|
232
|
+
if platform.system() != 'Windows':
|
233
|
+
env.update({
|
234
|
+
'LC_ALL': 'C.UTF-8',
|
235
|
+
'LANG': 'C.UTF-8'
|
236
|
+
})
|
237
|
+
|
225
238
|
kwargs['env'] = env
|
226
239
|
|
227
240
|
self.server_params = StdioServerParameters(
|
@@ -236,7 +249,14 @@ class MCP:
|
|
236
249
|
print(f"Warning: MCP initialization timed out after {self.timeout} seconds")
|
237
250
|
|
238
251
|
# Automatically detect if this is an NPX command
|
239
|
-
|
252
|
+
base_cmd = os.path.basename(cmd) if isinstance(cmd, str) else cmd
|
253
|
+
# Check for npx with or without Windows extensions
|
254
|
+
npx_variants = ['npx', 'npx.cmd', 'npx.exe']
|
255
|
+
if platform.system() == 'Windows' and isinstance(base_cmd, str):
|
256
|
+
# Case-insensitive comparison on Windows
|
257
|
+
self.is_npx = base_cmd.lower() in [v.lower() for v in npx_variants]
|
258
|
+
else:
|
259
|
+
self.is_npx = base_cmd in npx_variants
|
240
260
|
|
241
261
|
# For NPX-based MCP servers, use a different approach
|
242
262
|
if self.is_npx:
|
praisonaiagents/memory/memory.py
CHANGED
@@ -741,6 +741,54 @@ class Memory:
|
|
741
741
|
filtered.append(h)
|
742
742
|
return filtered[:limit]
|
743
743
|
|
744
|
+
def search(self, query: str, user_id: Optional[str] = None, agent_id: Optional[str] = None,
|
745
|
+
run_id: Optional[str] = None, limit: int = 5, rerank: bool = False, **kwargs) -> List[Dict[str, Any]]:
|
746
|
+
"""
|
747
|
+
Generic search method that delegates to appropriate specific search methods.
|
748
|
+
Provides compatibility with mem0.Memory interface.
|
749
|
+
|
750
|
+
Args:
|
751
|
+
query: The search query string
|
752
|
+
user_id: Optional user ID for user-specific search
|
753
|
+
agent_id: Optional agent ID for agent-specific search
|
754
|
+
run_id: Optional run ID for run-specific search
|
755
|
+
limit: Maximum number of results to return
|
756
|
+
rerank: Whether to use advanced reranking
|
757
|
+
**kwargs: Additional search parameters
|
758
|
+
|
759
|
+
Returns:
|
760
|
+
List of search results
|
761
|
+
"""
|
762
|
+
# If using mem0, pass all parameters directly
|
763
|
+
if self.use_mem0 and hasattr(self, "mem0_client"):
|
764
|
+
search_params = {
|
765
|
+
"query": query,
|
766
|
+
"limit": limit,
|
767
|
+
"rerank": rerank
|
768
|
+
}
|
769
|
+
|
770
|
+
# Add optional parameters if provided
|
771
|
+
if user_id is not None:
|
772
|
+
search_params["user_id"] = user_id
|
773
|
+
if agent_id is not None:
|
774
|
+
search_params["agent_id"] = agent_id
|
775
|
+
if run_id is not None:
|
776
|
+
search_params["run_id"] = run_id
|
777
|
+
|
778
|
+
# Include any additional kwargs
|
779
|
+
search_params.update(kwargs)
|
780
|
+
|
781
|
+
return self.mem0_client.search(**search_params)
|
782
|
+
|
783
|
+
# For local memory, use specific search methods
|
784
|
+
if user_id:
|
785
|
+
# Use user-specific search
|
786
|
+
return self.search_user_memory(user_id, query, limit=limit, rerank=rerank, **kwargs)
|
787
|
+
else:
|
788
|
+
# Default to long-term memory search
|
789
|
+
# Note: agent_id and run_id filtering could be added to metadata filtering in the future
|
790
|
+
return self.search_long_term(query, limit=limit, rerank=rerank, **kwargs)
|
791
|
+
|
744
792
|
def reset_user_memory(self):
|
745
793
|
"""
|
746
794
|
Clear all user-based info. For simplicity, we do a full LTM reset.
|
@@ -13,6 +13,7 @@ import logging
|
|
13
13
|
from typing import List, Dict, Any, Optional, Union, TYPE_CHECKING
|
14
14
|
from importlib import util
|
15
15
|
import json
|
16
|
+
import re
|
16
17
|
|
17
18
|
if TYPE_CHECKING:
|
18
19
|
import duckdb
|
@@ -29,6 +30,25 @@ class DuckDBTools:
|
|
29
30
|
"""
|
30
31
|
self.database = database
|
31
32
|
self._conn = None
|
33
|
+
|
34
|
+
def _validate_identifier(self, identifier: str) -> str:
|
35
|
+
"""Validate and quote a SQL identifier to prevent injection.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
identifier: Table or column name to validate
|
39
|
+
|
40
|
+
Returns:
|
41
|
+
Quoted identifier safe for SQL
|
42
|
+
|
43
|
+
Raises:
|
44
|
+
ValueError: If identifier contains invalid characters
|
45
|
+
"""
|
46
|
+
# Only allow alphanumeric characters, underscores, and dots
|
47
|
+
if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)?$', identifier):
|
48
|
+
raise ValueError(f"Invalid identifier: {identifier}")
|
49
|
+
|
50
|
+
# Quote the identifier to handle reserved words
|
51
|
+
return f'"{identifier}"'
|
32
52
|
|
33
53
|
def _get_duckdb(self) -> Optional['duckdb']:
|
34
54
|
"""Get duckdb module, installing if needed"""
|
@@ -125,39 +145,50 @@ class DuckDBTools:
|
|
125
145
|
if conn is None:
|
126
146
|
return False
|
127
147
|
|
128
|
-
# Check if table exists
|
129
|
-
exists = conn.execute(
|
148
|
+
# Check if table exists using parameterized query
|
149
|
+
exists = conn.execute("""
|
130
150
|
SELECT name FROM sqlite_master
|
131
|
-
WHERE type='table' AND name
|
132
|
-
""").fetchone() is not None
|
151
|
+
WHERE type='table' AND name=?
|
152
|
+
""", [table_name]).fetchone() is not None
|
133
153
|
|
134
154
|
if exists:
|
135
155
|
if if_exists == 'fail':
|
136
156
|
raise ValueError(f"Table {table_name} already exists")
|
137
157
|
elif if_exists == 'replace':
|
138
|
-
|
158
|
+
# Validate and quote table name to prevent injection
|
159
|
+
safe_table = self._validate_identifier(table_name)
|
160
|
+
conn.execute(f"DROP TABLE IF EXISTS {safe_table}")
|
139
161
|
elif if_exists != 'append':
|
140
162
|
raise ValueError("if_exists must be 'fail', 'replace', or 'append'")
|
141
163
|
|
142
164
|
# Create table if needed
|
143
165
|
if not exists or if_exists == 'replace':
|
166
|
+
safe_table = self._validate_identifier(table_name)
|
144
167
|
if schema:
|
145
|
-
# Create table with schema
|
146
|
-
|
147
|
-
|
168
|
+
# Create table with schema - validate column names
|
169
|
+
column_defs = []
|
170
|
+
for col_name, col_type in schema.items():
|
171
|
+
safe_col = self._validate_identifier(col_name)
|
172
|
+
# Validate column type to prevent injection
|
173
|
+
if not re.match(r'^[A-Z][A-Z0-9_]*(\([0-9,]+\))?$', col_type.upper()):
|
174
|
+
raise ValueError(f"Invalid column type: {col_type}")
|
175
|
+
column_defs.append(f"{safe_col} {col_type}")
|
176
|
+
columns = ', '.join(column_defs)
|
177
|
+
conn.execute(f"CREATE TABLE {safe_table} ({columns})")
|
148
178
|
else:
|
149
|
-
# Infer schema from CSV
|
179
|
+
# Infer schema from CSV - use parameterized query for filepath
|
150
180
|
conn.execute(f"""
|
151
|
-
CREATE TABLE {
|
152
|
-
SELECT * FROM read_csv_auto(
|
181
|
+
CREATE TABLE {safe_table} AS
|
182
|
+
SELECT * FROM read_csv_auto(?)
|
153
183
|
WHERE 1=0
|
154
|
-
""")
|
184
|
+
""", [filepath])
|
155
185
|
|
156
|
-
# Load data
|
186
|
+
# Load data - use validated table name and parameterized filepath
|
187
|
+
safe_table = self._validate_identifier(table_name)
|
157
188
|
conn.execute(f"""
|
158
|
-
INSERT INTO {
|
159
|
-
SELECT * FROM read_csv_auto(
|
160
|
-
""")
|
189
|
+
INSERT INTO {safe_table}
|
190
|
+
SELECT * FROM read_csv_auto(?)
|
191
|
+
""", [filepath])
|
161
192
|
|
162
193
|
return True
|
163
194
|
|
@@ -21,6 +21,32 @@ from ..approval import require_approval
|
|
21
21
|
class FileTools:
|
22
22
|
"""Tools for file operations including read, write, list, and information."""
|
23
23
|
|
24
|
+
@staticmethod
|
25
|
+
def _validate_path(filepath: str) -> str:
|
26
|
+
"""
|
27
|
+
Validate and normalize a file path to prevent path traversal attacks.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
filepath: Path to validate
|
31
|
+
|
32
|
+
Returns:
|
33
|
+
str: Normalized absolute path
|
34
|
+
|
35
|
+
Raises:
|
36
|
+
ValueError: If path contains suspicious patterns
|
37
|
+
"""
|
38
|
+
# Normalize the path
|
39
|
+
normalized = os.path.normpath(filepath)
|
40
|
+
absolute = os.path.abspath(normalized)
|
41
|
+
|
42
|
+
# Check for suspicious patterns
|
43
|
+
if '..' in filepath or filepath.startswith('~'):
|
44
|
+
raise ValueError(f"Suspicious path pattern detected: {filepath}")
|
45
|
+
|
46
|
+
# Additional check: ensure the resolved path doesn't escape expected boundaries
|
47
|
+
# This is a basic check - in production, you'd want to define allowed directories
|
48
|
+
return absolute
|
49
|
+
|
24
50
|
@staticmethod
|
25
51
|
def read_file(filepath: str, encoding: str = 'utf-8') -> str:
|
26
52
|
"""
|
@@ -34,7 +60,9 @@ class FileTools:
|
|
34
60
|
str: Content of the file
|
35
61
|
"""
|
36
62
|
try:
|
37
|
-
|
63
|
+
# Validate path to prevent traversal attacks
|
64
|
+
safe_path = FileTools._validate_path(filepath)
|
65
|
+
with open(safe_path, 'r', encoding=encoding) as f:
|
38
66
|
return f.read()
|
39
67
|
except Exception as e:
|
40
68
|
error_msg = f"Error reading file {filepath}: {str(e)}"
|
@@ -56,9 +84,11 @@ class FileTools:
|
|
56
84
|
bool: True if successful, False otherwise
|
57
85
|
"""
|
58
86
|
try:
|
87
|
+
# Validate path to prevent traversal attacks
|
88
|
+
safe_path = FileTools._validate_path(filepath)
|
59
89
|
# Create directory if it doesn't exist
|
60
|
-
os.makedirs(os.path.dirname(
|
61
|
-
with open(
|
90
|
+
os.makedirs(os.path.dirname(safe_path), exist_ok=True)
|
91
|
+
with open(safe_path, 'w', encoding=encoding) as f:
|
62
92
|
f.write(content)
|
63
93
|
return True
|
64
94
|
except Exception as e:
|
@@ -79,7 +109,9 @@ class FileTools:
|
|
79
109
|
List[Dict]: List of file information dictionaries
|
80
110
|
"""
|
81
111
|
try:
|
82
|
-
|
112
|
+
# Validate directory path
|
113
|
+
safe_dir = FileTools._validate_path(directory)
|
114
|
+
path = Path(safe_dir)
|
83
115
|
if pattern:
|
84
116
|
files = path.glob(pattern)
|
85
117
|
else:
|
@@ -114,7 +146,9 @@ class FileTools:
|
|
114
146
|
Dict: File information including size, dates, etc.
|
115
147
|
"""
|
116
148
|
try:
|
117
|
-
|
149
|
+
# Validate file path
|
150
|
+
safe_path = FileTools._validate_path(filepath)
|
151
|
+
path = Path(safe_path)
|
118
152
|
if not path.exists():
|
119
153
|
return {'error': f'File not found: {filepath}'}
|
120
154
|
|
@@ -149,9 +183,12 @@ class FileTools:
|
|
149
183
|
bool: True if successful, False otherwise
|
150
184
|
"""
|
151
185
|
try:
|
186
|
+
# Validate paths to prevent traversal attacks
|
187
|
+
safe_src = FileTools._validate_path(src)
|
188
|
+
safe_dst = FileTools._validate_path(dst)
|
152
189
|
# Create destination directory if it doesn't exist
|
153
|
-
os.makedirs(os.path.dirname(
|
154
|
-
shutil.copy2(
|
190
|
+
os.makedirs(os.path.dirname(safe_dst), exist_ok=True)
|
191
|
+
shutil.copy2(safe_src, safe_dst)
|
155
192
|
return True
|
156
193
|
except Exception as e:
|
157
194
|
error_msg = f"Error copying file from {src} to {dst}: {str(e)}"
|
@@ -172,9 +209,12 @@ class FileTools:
|
|
172
209
|
bool: True if successful, False otherwise
|
173
210
|
"""
|
174
211
|
try:
|
212
|
+
# Validate paths to prevent traversal attacks
|
213
|
+
safe_src = FileTools._validate_path(src)
|
214
|
+
safe_dst = FileTools._validate_path(dst)
|
175
215
|
# Create destination directory if it doesn't exist
|
176
|
-
os.makedirs(os.path.dirname(
|
177
|
-
shutil.move(
|
216
|
+
os.makedirs(os.path.dirname(safe_dst), exist_ok=True)
|
217
|
+
shutil.move(safe_src, safe_dst)
|
178
218
|
return True
|
179
219
|
except Exception as e:
|
180
220
|
error_msg = f"Error moving file from {src} to {dst}: {str(e)}"
|
@@ -194,7 +234,9 @@ class FileTools:
|
|
194
234
|
bool: True if successful, False otherwise
|
195
235
|
"""
|
196
236
|
try:
|
197
|
-
|
237
|
+
# Validate path to prevent traversal attacks
|
238
|
+
safe_path = FileTools._validate_path(filepath)
|
239
|
+
os.remove(safe_path)
|
198
240
|
return True
|
199
241
|
except Exception as e:
|
200
242
|
error_msg = f"Error deleting file {filepath}: {str(e)}"
|