praisonaiagents 0.0.109__py3-none-any.whl → 0.0.111__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- praisonaiagents/__init__.py +6 -0
- praisonaiagents/agent/__init__.py +2 -1
- praisonaiagents/agent/agent.py +37 -0
- praisonaiagents/agent/handoff.py +317 -0
- praisonaiagents/llm/llm.py +217 -68
- praisonaiagents/mcp/mcp.py +27 -7
- praisonaiagents/memory/memory.py +48 -0
- praisonaiagents/telemetry/integration.py +1 -1
- praisonaiagents/tools/duckdb_tools.py +47 -16
- praisonaiagents/tools/file_tools.py +52 -10
- praisonaiagents/tools/python_tools.py +84 -4
- praisonaiagents/tools/shell_tools.py +18 -8
- praisonaiagents/tools/spider_tools.py +55 -0
- {praisonaiagents-0.0.109.dist-info → praisonaiagents-0.0.111.dist-info}/METADATA +1 -1
- {praisonaiagents-0.0.109.dist-info → praisonaiagents-0.0.111.dist-info}/RECORD +17 -16
- {praisonaiagents-0.0.109.dist-info → praisonaiagents-0.0.111.dist-info}/WHEEL +0 -0
- {praisonaiagents-0.0.109.dist-info → praisonaiagents-0.0.111.dist-info}/top_level.txt +0 -0
praisonaiagents/llm/llm.py
CHANGED
@@ -276,6 +276,41 @@ class LLM:
|
|
276
276
|
]
|
277
277
|
|
278
278
|
return self.model in legacy_o1_models
|
279
|
+
|
280
|
+
def _supports_streaming_tools(self) -> bool:
|
281
|
+
"""
|
282
|
+
Check if the current provider supports streaming with tools.
|
283
|
+
|
284
|
+
Most providers that support tool calling also support streaming with tools,
|
285
|
+
but some providers (like Ollama and certain local models) require non-streaming
|
286
|
+
calls when tools are involved.
|
287
|
+
|
288
|
+
Returns:
|
289
|
+
bool: True if provider supports streaming with tools, False otherwise
|
290
|
+
"""
|
291
|
+
if not self.model:
|
292
|
+
return False
|
293
|
+
|
294
|
+
# Ollama doesn't reliably support streaming with tools
|
295
|
+
if self._is_ollama_provider():
|
296
|
+
return False
|
297
|
+
|
298
|
+
# OpenAI models support streaming with tools
|
299
|
+
if any(self.model.startswith(prefix) for prefix in ["gpt-", "o1-", "o3-"]):
|
300
|
+
return True
|
301
|
+
|
302
|
+
# Anthropic Claude models support streaming with tools
|
303
|
+
if self.model.startswith("claude-"):
|
304
|
+
return True
|
305
|
+
|
306
|
+
# Google Gemini models support streaming with tools
|
307
|
+
if any(self.model.startswith(prefix) for prefix in ["gemini-", "gemini/"]):
|
308
|
+
return True
|
309
|
+
|
310
|
+
# For other providers, default to False to be safe
|
311
|
+
# This ensures we make a single non-streaming call rather than risk
|
312
|
+
# missing tool calls or making duplicate calls
|
313
|
+
return False
|
279
314
|
|
280
315
|
def get_response(
|
281
316
|
self,
|
@@ -480,49 +515,110 @@ class LLM:
|
|
480
515
|
|
481
516
|
# Otherwise do the existing streaming approach
|
482
517
|
else:
|
483
|
-
if
|
484
|
-
|
485
|
-
|
518
|
+
# Determine if we should use streaming based on tool support
|
519
|
+
use_streaming = stream
|
520
|
+
if formatted_tools and not self._supports_streaming_tools():
|
521
|
+
# Provider doesn't support streaming with tools, use non-streaming
|
522
|
+
use_streaming = False
|
523
|
+
|
524
|
+
if use_streaming:
|
525
|
+
# Streaming approach (with or without tools)
|
526
|
+
tool_calls = []
|
527
|
+
response_text = ""
|
528
|
+
|
529
|
+
if verbose:
|
530
|
+
with Live(display_generating("", current_time), console=console, refresh_per_second=4) as live:
|
531
|
+
for chunk in litellm.completion(
|
532
|
+
**self._build_completion_params(
|
533
|
+
messages=messages,
|
534
|
+
tools=formatted_tools,
|
535
|
+
temperature=temperature,
|
536
|
+
stream=True,
|
537
|
+
**kwargs
|
538
|
+
)
|
539
|
+
):
|
540
|
+
if chunk and chunk.choices and chunk.choices[0].delta:
|
541
|
+
delta = chunk.choices[0].delta
|
542
|
+
if delta.content:
|
543
|
+
response_text += delta.content
|
544
|
+
live.update(display_generating(response_text, current_time))
|
545
|
+
|
546
|
+
# Capture tool calls from streaming chunks if provider supports it
|
547
|
+
if formatted_tools and self._supports_streaming_tools() and hasattr(delta, 'tool_calls') and delta.tool_calls:
|
548
|
+
for tc in delta.tool_calls:
|
549
|
+
if tc.index >= len(tool_calls):
|
550
|
+
tool_calls.append({
|
551
|
+
"id": tc.id,
|
552
|
+
"type": "function",
|
553
|
+
"function": {"name": "", "arguments": ""}
|
554
|
+
})
|
555
|
+
if tc.function.name:
|
556
|
+
tool_calls[tc.index]["function"]["name"] = tc.function.name
|
557
|
+
if tc.function.arguments:
|
558
|
+
tool_calls[tc.index]["function"]["arguments"] += tc.function.arguments
|
559
|
+
else:
|
560
|
+
# Non-verbose streaming
|
486
561
|
for chunk in litellm.completion(
|
487
562
|
**self._build_completion_params(
|
488
563
|
messages=messages,
|
489
564
|
tools=formatted_tools,
|
490
565
|
temperature=temperature,
|
491
|
-
stream=
|
566
|
+
stream=True,
|
492
567
|
**kwargs
|
493
568
|
)
|
494
569
|
):
|
495
|
-
if chunk and chunk.choices and chunk.choices[0].delta
|
496
|
-
|
497
|
-
|
498
|
-
|
570
|
+
if chunk and chunk.choices and chunk.choices[0].delta:
|
571
|
+
delta = chunk.choices[0].delta
|
572
|
+
if delta.content:
|
573
|
+
response_text += delta.content
|
574
|
+
|
575
|
+
# Capture tool calls from streaming chunks if provider supports it
|
576
|
+
if formatted_tools and self._supports_streaming_tools() and hasattr(delta, 'tool_calls') and delta.tool_calls:
|
577
|
+
for tc in delta.tool_calls:
|
578
|
+
if tc.index >= len(tool_calls):
|
579
|
+
tool_calls.append({
|
580
|
+
"id": tc.id,
|
581
|
+
"type": "function",
|
582
|
+
"function": {"name": "", "arguments": ""}
|
583
|
+
})
|
584
|
+
if tc.function.name:
|
585
|
+
tool_calls[tc.index]["function"]["name"] = tc.function.name
|
586
|
+
if tc.function.arguments:
|
587
|
+
tool_calls[tc.index]["function"]["arguments"] += tc.function.arguments
|
588
|
+
|
589
|
+
response_text = response_text.strip()
|
590
|
+
|
591
|
+
# Create a mock final_response with the captured data
|
592
|
+
final_response = {
|
593
|
+
"choices": [{
|
594
|
+
"message": {
|
595
|
+
"content": response_text,
|
596
|
+
"tool_calls": tool_calls if tool_calls else None
|
597
|
+
}
|
598
|
+
}]
|
599
|
+
}
|
499
600
|
else:
|
500
|
-
# Non-
|
501
|
-
|
502
|
-
for chunk in litellm.completion(
|
601
|
+
# Non-streaming approach (when tools require it or streaming is disabled)
|
602
|
+
final_response = litellm.completion(
|
503
603
|
**self._build_completion_params(
|
504
604
|
messages=messages,
|
505
605
|
tools=formatted_tools,
|
506
606
|
temperature=temperature,
|
507
|
-
stream=
|
607
|
+
stream=False,
|
508
608
|
**kwargs
|
509
609
|
)
|
510
|
-
):
|
511
|
-
if chunk and chunk.choices and chunk.choices[0].delta.content:
|
512
|
-
response_text += chunk.choices[0].delta.content
|
513
|
-
|
514
|
-
response_text = response_text.strip()
|
515
|
-
|
516
|
-
# Get final completion to check for tool calls
|
517
|
-
final_response = litellm.completion(
|
518
|
-
**self._build_completion_params(
|
519
|
-
messages=messages,
|
520
|
-
tools=formatted_tools,
|
521
|
-
temperature=temperature,
|
522
|
-
stream=False, # No streaming for tool call check
|
523
|
-
**kwargs
|
524
610
|
)
|
525
|
-
|
611
|
+
response_text = final_response["choices"][0]["message"]["content"]
|
612
|
+
|
613
|
+
if verbose:
|
614
|
+
# Display the complete response at once
|
615
|
+
display_interaction(
|
616
|
+
original_prompt,
|
617
|
+
response_text,
|
618
|
+
markdown=markdown,
|
619
|
+
generation_time=time.time() - current_time,
|
620
|
+
console=console
|
621
|
+
)
|
526
622
|
|
527
623
|
tool_calls = final_response["choices"][0]["message"].get("tool_calls")
|
528
624
|
|
@@ -1198,53 +1294,106 @@ Output MUST be JSON with 'reflection' and 'satisfactory'.
|
|
1198
1294
|
console=console
|
1199
1295
|
)
|
1200
1296
|
else:
|
1201
|
-
if
|
1202
|
-
|
1203
|
-
|
1204
|
-
#
|
1205
|
-
|
1206
|
-
|
1207
|
-
|
1208
|
-
|
1209
|
-
|
1210
|
-
|
1211
|
-
|
1212
|
-
|
1213
|
-
|
1214
|
-
|
1215
|
-
|
1216
|
-
|
1297
|
+
# Determine if we should use streaming based on tool support
|
1298
|
+
use_streaming = stream
|
1299
|
+
if formatted_tools and not self._supports_streaming_tools():
|
1300
|
+
# Provider doesn't support streaming with tools, use non-streaming
|
1301
|
+
use_streaming = False
|
1302
|
+
|
1303
|
+
if use_streaming:
|
1304
|
+
# Streaming approach (with or without tools)
|
1305
|
+
tool_calls = []
|
1306
|
+
|
1307
|
+
if verbose:
|
1308
|
+
async for chunk in await litellm.acompletion(
|
1309
|
+
**self._build_completion_params(
|
1310
|
+
messages=messages,
|
1311
|
+
temperature=temperature,
|
1312
|
+
stream=True,
|
1313
|
+
tools=formatted_tools,
|
1314
|
+
**kwargs
|
1315
|
+
)
|
1316
|
+
):
|
1317
|
+
if chunk and chunk.choices and chunk.choices[0].delta:
|
1318
|
+
delta = chunk.choices[0].delta
|
1319
|
+
if delta.content:
|
1320
|
+
response_text += delta.content
|
1321
|
+
print("\033[K", end="\r")
|
1322
|
+
print(f"Generating... {time.time() - start_time:.1f}s", end="\r")
|
1323
|
+
|
1324
|
+
# Capture tool calls from streaming chunks if provider supports it
|
1325
|
+
if formatted_tools and self._supports_streaming_tools() and hasattr(delta, 'tool_calls') and delta.tool_calls:
|
1326
|
+
for tc in delta.tool_calls:
|
1327
|
+
if tc.index >= len(tool_calls):
|
1328
|
+
tool_calls.append({
|
1329
|
+
"id": tc.id,
|
1330
|
+
"type": "function",
|
1331
|
+
"function": {"name": "", "arguments": ""}
|
1332
|
+
})
|
1333
|
+
if tc.function.name:
|
1334
|
+
tool_calls[tc.index]["function"]["name"] = tc.function.name
|
1335
|
+
if tc.function.arguments:
|
1336
|
+
tool_calls[tc.index]["function"]["arguments"] += tc.function.arguments
|
1337
|
+
else:
|
1338
|
+
# Non-verbose streaming
|
1339
|
+
async for chunk in await litellm.acompletion(
|
1340
|
+
**self._build_completion_params(
|
1341
|
+
messages=messages,
|
1342
|
+
temperature=temperature,
|
1343
|
+
stream=True,
|
1344
|
+
tools=formatted_tools,
|
1345
|
+
**kwargs
|
1346
|
+
)
|
1347
|
+
):
|
1348
|
+
if chunk and chunk.choices and chunk.choices[0].delta:
|
1349
|
+
delta = chunk.choices[0].delta
|
1350
|
+
if delta.content:
|
1351
|
+
response_text += delta.content
|
1352
|
+
|
1353
|
+
# Capture tool calls from streaming chunks if provider supports it
|
1354
|
+
if formatted_tools and self._supports_streaming_tools() and hasattr(delta, 'tool_calls') and delta.tool_calls:
|
1355
|
+
for tc in delta.tool_calls:
|
1356
|
+
if tc.index >= len(tool_calls):
|
1357
|
+
tool_calls.append({
|
1358
|
+
"id": tc.id,
|
1359
|
+
"type": "function",
|
1360
|
+
"function": {"name": "", "arguments": ""}
|
1361
|
+
})
|
1362
|
+
if tc.function.name:
|
1363
|
+
tool_calls[tc.index]["function"]["name"] = tc.function.name
|
1364
|
+
if tc.function.arguments:
|
1365
|
+
tool_calls[tc.index]["function"]["arguments"] += tc.function.arguments
|
1366
|
+
|
1367
|
+
response_text = response_text.strip()
|
1368
|
+
|
1369
|
+
# We already have tool_calls from streaming if supported
|
1370
|
+
# No need for a second API call!
|
1217
1371
|
else:
|
1218
|
-
# Non-
|
1219
|
-
|
1372
|
+
# Non-streaming approach (when tools require it or streaming is disabled)
|
1373
|
+
tool_response = await litellm.acompletion(
|
1220
1374
|
**self._build_completion_params(
|
1221
1375
|
messages=messages,
|
1222
1376
|
temperature=temperature,
|
1223
|
-
stream=
|
1224
|
-
|
1377
|
+
stream=False,
|
1378
|
+
tools=formatted_tools,
|
1379
|
+
**{k:v for k,v in kwargs.items() if k != 'reasoning_steps'}
|
1225
1380
|
)
|
1226
|
-
):
|
1227
|
-
if chunk and chunk.choices and chunk.choices[0].delta.content:
|
1228
|
-
response_text += chunk.choices[0].delta.content
|
1229
|
-
|
1230
|
-
response_text = response_text.strip()
|
1231
|
-
|
1232
|
-
# ----------------------------------------------------
|
1233
|
-
# 2) If tool calls are needed, do a non-streaming call
|
1234
|
-
# ----------------------------------------------------
|
1235
|
-
if tools and execute_tool_fn:
|
1236
|
-
# Next call with tools if needed
|
1237
|
-
tool_response = await litellm.acompletion(
|
1238
|
-
**self._build_completion_params(
|
1239
|
-
messages=messages,
|
1240
|
-
temperature=temperature,
|
1241
|
-
stream=False,
|
1242
|
-
tools=formatted_tools, # We safely pass tools here
|
1243
|
-
**{k:v for k,v in kwargs.items() if k != 'reasoning_steps'}
|
1244
1381
|
)
|
1245
|
-
|
1246
|
-
|
1247
|
-
|
1382
|
+
response_text = tool_response.choices[0].message.get("content", "")
|
1383
|
+
tool_calls = tool_response.choices[0].message.get("tool_calls", [])
|
1384
|
+
|
1385
|
+
if verbose:
|
1386
|
+
# Display the complete response at once
|
1387
|
+
display_interaction(
|
1388
|
+
original_prompt,
|
1389
|
+
response_text,
|
1390
|
+
markdown=markdown,
|
1391
|
+
generation_time=time.time() - start_time,
|
1392
|
+
console=console
|
1393
|
+
)
|
1394
|
+
|
1395
|
+
# Now handle tools if we have them (either from streaming or non-streaming)
|
1396
|
+
if tools and execute_tool_fn and tool_calls:
|
1248
1397
|
|
1249
1398
|
if tool_calls:
|
1250
1399
|
# Convert tool_calls to a serializable format for all providers
|
praisonaiagents/mcp/mcp.py
CHANGED
@@ -7,6 +7,7 @@ import shlex
|
|
7
7
|
import logging
|
8
8
|
import os
|
9
9
|
import re
|
10
|
+
import platform
|
10
11
|
from typing import Any, List, Optional, Callable, Iterable, Union
|
11
12
|
from functools import wraps, partial
|
12
13
|
|
@@ -199,7 +200,13 @@ class MCP:
|
|
199
200
|
# Handle the single string format for stdio client
|
200
201
|
if isinstance(command_or_string, str) and args is None:
|
201
202
|
# Split the string into command and args using shell-like parsing
|
202
|
-
|
203
|
+
if platform.system() == 'Windows':
|
204
|
+
# Use shlex with posix=False for Windows to handle quotes and paths with spaces
|
205
|
+
parts = shlex.split(command_or_string, posix=False)
|
206
|
+
# Remove quotes from parts if present (Windows shlex keeps them)
|
207
|
+
parts = [part.strip('"') for part in parts]
|
208
|
+
else:
|
209
|
+
parts = shlex.split(command_or_string)
|
203
210
|
if not parts:
|
204
211
|
raise ValueError("Empty command string")
|
205
212
|
|
@@ -217,11 +224,17 @@ class MCP:
|
|
217
224
|
env = kwargs.get('env', {})
|
218
225
|
if not env:
|
219
226
|
env = os.environ.copy()
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
227
|
+
|
228
|
+
# Always set Python encoding
|
229
|
+
env['PYTHONIOENCODING'] = 'utf-8'
|
230
|
+
|
231
|
+
# Only set locale variables on Unix systems
|
232
|
+
if platform.system() != 'Windows':
|
233
|
+
env.update({
|
234
|
+
'LC_ALL': 'C.UTF-8',
|
235
|
+
'LANG': 'C.UTF-8'
|
236
|
+
})
|
237
|
+
|
225
238
|
kwargs['env'] = env
|
226
239
|
|
227
240
|
self.server_params = StdioServerParameters(
|
@@ -236,7 +249,14 @@ class MCP:
|
|
236
249
|
print(f"Warning: MCP initialization timed out after {self.timeout} seconds")
|
237
250
|
|
238
251
|
# Automatically detect if this is an NPX command
|
239
|
-
|
252
|
+
base_cmd = os.path.basename(cmd) if isinstance(cmd, str) else cmd
|
253
|
+
# Check for npx with or without Windows extensions
|
254
|
+
npx_variants = ['npx', 'npx.cmd', 'npx.exe']
|
255
|
+
if platform.system() == 'Windows' and isinstance(base_cmd, str):
|
256
|
+
# Case-insensitive comparison on Windows
|
257
|
+
self.is_npx = base_cmd.lower() in [v.lower() for v in npx_variants]
|
258
|
+
else:
|
259
|
+
self.is_npx = base_cmd in npx_variants
|
240
260
|
|
241
261
|
# For NPX-based MCP servers, use a different approach
|
242
262
|
if self.is_npx:
|
praisonaiagents/memory/memory.py
CHANGED
@@ -741,6 +741,54 @@ class Memory:
|
|
741
741
|
filtered.append(h)
|
742
742
|
return filtered[:limit]
|
743
743
|
|
744
|
+
def search(self, query: str, user_id: Optional[str] = None, agent_id: Optional[str] = None,
|
745
|
+
run_id: Optional[str] = None, limit: int = 5, rerank: bool = False, **kwargs) -> List[Dict[str, Any]]:
|
746
|
+
"""
|
747
|
+
Generic search method that delegates to appropriate specific search methods.
|
748
|
+
Provides compatibility with mem0.Memory interface.
|
749
|
+
|
750
|
+
Args:
|
751
|
+
query: The search query string
|
752
|
+
user_id: Optional user ID for user-specific search
|
753
|
+
agent_id: Optional agent ID for agent-specific search
|
754
|
+
run_id: Optional run ID for run-specific search
|
755
|
+
limit: Maximum number of results to return
|
756
|
+
rerank: Whether to use advanced reranking
|
757
|
+
**kwargs: Additional search parameters
|
758
|
+
|
759
|
+
Returns:
|
760
|
+
List of search results
|
761
|
+
"""
|
762
|
+
# If using mem0, pass all parameters directly
|
763
|
+
if self.use_mem0 and hasattr(self, "mem0_client"):
|
764
|
+
search_params = {
|
765
|
+
"query": query,
|
766
|
+
"limit": limit,
|
767
|
+
"rerank": rerank
|
768
|
+
}
|
769
|
+
|
770
|
+
# Add optional parameters if provided
|
771
|
+
if user_id is not None:
|
772
|
+
search_params["user_id"] = user_id
|
773
|
+
if agent_id is not None:
|
774
|
+
search_params["agent_id"] = agent_id
|
775
|
+
if run_id is not None:
|
776
|
+
search_params["run_id"] = run_id
|
777
|
+
|
778
|
+
# Include any additional kwargs
|
779
|
+
search_params.update(kwargs)
|
780
|
+
|
781
|
+
return self.mem0_client.search(**search_params)
|
782
|
+
|
783
|
+
# For local memory, use specific search methods
|
784
|
+
if user_id:
|
785
|
+
# Use user-specific search
|
786
|
+
return self.search_user_memory(user_id, query, limit=limit, rerank=rerank, **kwargs)
|
787
|
+
else:
|
788
|
+
# Default to long-term memory search
|
789
|
+
# Note: agent_id and run_id filtering could be added to metadata filtering in the future
|
790
|
+
return self.search_long_term(query, limit=limit, rerank=rerank, **kwargs)
|
791
|
+
|
744
792
|
def reset_user_memory(self):
|
745
793
|
"""
|
746
794
|
Clear all user-based info. For simplicity, we do a full LTM reset.
|
@@ -140,7 +140,7 @@ def instrument_workflow(workflow: 'PraisonAIAgents', telemetry: Optional['Minima
|
|
140
140
|
task = None
|
141
141
|
try:
|
142
142
|
# Get task info
|
143
|
-
if hasattr(workflow, 'tasks') and task_id < len(workflow.tasks):
|
143
|
+
if hasattr(workflow, 'tasks') and isinstance(task_id, int) and task_id < len(workflow.tasks):
|
144
144
|
task = workflow.tasks[task_id]
|
145
145
|
|
146
146
|
result = original_execute_task(task_id, *args, **kwargs)
|
@@ -13,6 +13,7 @@ import logging
|
|
13
13
|
from typing import List, Dict, Any, Optional, Union, TYPE_CHECKING
|
14
14
|
from importlib import util
|
15
15
|
import json
|
16
|
+
import re
|
16
17
|
|
17
18
|
if TYPE_CHECKING:
|
18
19
|
import duckdb
|
@@ -29,6 +30,25 @@ class DuckDBTools:
|
|
29
30
|
"""
|
30
31
|
self.database = database
|
31
32
|
self._conn = None
|
33
|
+
|
34
|
+
def _validate_identifier(self, identifier: str) -> str:
|
35
|
+
"""Validate and quote a SQL identifier to prevent injection.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
identifier: Table or column name to validate
|
39
|
+
|
40
|
+
Returns:
|
41
|
+
Quoted identifier safe for SQL
|
42
|
+
|
43
|
+
Raises:
|
44
|
+
ValueError: If identifier contains invalid characters
|
45
|
+
"""
|
46
|
+
# Only allow alphanumeric characters, underscores, and dots
|
47
|
+
if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)?$', identifier):
|
48
|
+
raise ValueError(f"Invalid identifier: {identifier}")
|
49
|
+
|
50
|
+
# Quote the identifier to handle reserved words
|
51
|
+
return f'"{identifier}"'
|
32
52
|
|
33
53
|
def _get_duckdb(self) -> Optional['duckdb']:
|
34
54
|
"""Get duckdb module, installing if needed"""
|
@@ -125,39 +145,50 @@ class DuckDBTools:
|
|
125
145
|
if conn is None:
|
126
146
|
return False
|
127
147
|
|
128
|
-
# Check if table exists
|
129
|
-
exists = conn.execute(
|
148
|
+
# Check if table exists using parameterized query
|
149
|
+
exists = conn.execute("""
|
130
150
|
SELECT name FROM sqlite_master
|
131
|
-
WHERE type='table' AND name
|
132
|
-
""").fetchone() is not None
|
151
|
+
WHERE type='table' AND name=?
|
152
|
+
""", [table_name]).fetchone() is not None
|
133
153
|
|
134
154
|
if exists:
|
135
155
|
if if_exists == 'fail':
|
136
156
|
raise ValueError(f"Table {table_name} already exists")
|
137
157
|
elif if_exists == 'replace':
|
138
|
-
|
158
|
+
# Validate and quote table name to prevent injection
|
159
|
+
safe_table = self._validate_identifier(table_name)
|
160
|
+
conn.execute(f"DROP TABLE IF EXISTS {safe_table}")
|
139
161
|
elif if_exists != 'append':
|
140
162
|
raise ValueError("if_exists must be 'fail', 'replace', or 'append'")
|
141
163
|
|
142
164
|
# Create table if needed
|
143
165
|
if not exists or if_exists == 'replace':
|
166
|
+
safe_table = self._validate_identifier(table_name)
|
144
167
|
if schema:
|
145
|
-
# Create table with schema
|
146
|
-
|
147
|
-
|
168
|
+
# Create table with schema - validate column names
|
169
|
+
column_defs = []
|
170
|
+
for col_name, col_type in schema.items():
|
171
|
+
safe_col = self._validate_identifier(col_name)
|
172
|
+
# Validate column type to prevent injection
|
173
|
+
if not re.match(r'^[A-Z][A-Z0-9_]*(\([0-9,]+\))?$', col_type.upper()):
|
174
|
+
raise ValueError(f"Invalid column type: {col_type}")
|
175
|
+
column_defs.append(f"{safe_col} {col_type}")
|
176
|
+
columns = ', '.join(column_defs)
|
177
|
+
conn.execute(f"CREATE TABLE {safe_table} ({columns})")
|
148
178
|
else:
|
149
|
-
# Infer schema from CSV
|
179
|
+
# Infer schema from CSV - use parameterized query for filepath
|
150
180
|
conn.execute(f"""
|
151
|
-
CREATE TABLE {
|
152
|
-
SELECT * FROM read_csv_auto(
|
181
|
+
CREATE TABLE {safe_table} AS
|
182
|
+
SELECT * FROM read_csv_auto(?)
|
153
183
|
WHERE 1=0
|
154
|
-
""")
|
184
|
+
""", [filepath])
|
155
185
|
|
156
|
-
# Load data
|
186
|
+
# Load data - use validated table name and parameterized filepath
|
187
|
+
safe_table = self._validate_identifier(table_name)
|
157
188
|
conn.execute(f"""
|
158
|
-
INSERT INTO {
|
159
|
-
SELECT * FROM read_csv_auto(
|
160
|
-
""")
|
189
|
+
INSERT INTO {safe_table}
|
190
|
+
SELECT * FROM read_csv_auto(?)
|
191
|
+
""", [filepath])
|
161
192
|
|
162
193
|
return True
|
163
194
|
|