agno 2.3.5__py3-none-any.whl → 2.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/agent/agent.py +200 -30
- agno/db/postgres/async_postgres.py +37 -11
- agno/db/postgres/postgres.py +9 -3
- agno/db/sqlite/async_sqlite.py +1 -1
- agno/db/sqlite/sqlite.py +3 -4
- agno/db/utils.py +2 -0
- agno/eval/accuracy.py +8 -4
- agno/integrations/discord/client.py +1 -1
- agno/models/base.py +34 -4
- agno/models/cerebras/cerebras.py +11 -12
- agno/models/response.py +1 -1
- agno/os/interfaces/whatsapp/security.py +3 -1
- agno/os/routers/evals/utils.py +13 -3
- agno/os/schema.py +2 -1
- agno/run/agent.py +17 -0
- agno/run/requirement.py +98 -0
- agno/run/team.py +10 -0
- agno/session/team.py +0 -1
- agno/table.py +1 -1
- agno/team/team.py +98 -14
- agno/tools/google_drive.py +4 -3
- agno/tools/postgres.py +76 -36
- agno/tools/redshift.py +406 -0
- agno/tools/spotify.py +922 -0
- agno/tools/toolkit.py +25 -0
- agno/utils/agent.py +2 -2
- agno/utils/events.py +5 -1
- agno/utils/mcp.py +1 -1
- agno/workflow/workflow.py +5 -2
- {agno-2.3.5.dist-info → agno-2.3.7.dist-info}/METADATA +40 -32
- {agno-2.3.5.dist-info → agno-2.3.7.dist-info}/RECORD +34 -31
- {agno-2.3.5.dist-info → agno-2.3.7.dist-info}/WHEEL +0 -0
- {agno-2.3.5.dist-info → agno-2.3.7.dist-info}/licenses/LICENSE +0 -0
- {agno-2.3.5.dist-info → agno-2.3.7.dist-info}/top_level.txt +0 -0
agno/team/team.py
CHANGED
|
@@ -723,6 +723,8 @@ class Team:
|
|
|
723
723
|
|
|
724
724
|
# List of MCP tools that were initialized on the last run
|
|
725
725
|
self._mcp_tools_initialized_on_run: List[Any] = []
|
|
726
|
+
# List of connectable tools that were initialized on the last run
|
|
727
|
+
self._connectable_tools_initialized_on_run: List[Any] = []
|
|
726
728
|
|
|
727
729
|
# Lazy-initialized shared thread pool executor for background tasks (memory, cultural knowledge, etc.)
|
|
728
730
|
self._background_executor: Optional[Any] = None
|
|
@@ -1046,16 +1048,48 @@ class Team:
|
|
|
1046
1048
|
and any(c.__name__ in ["MCPTools", "MultiMCPTools"] for c in type(tool).__mro__)
|
|
1047
1049
|
and not tool.initialized # type: ignore
|
|
1048
1050
|
):
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1051
|
+
try:
|
|
1052
|
+
# Connect the MCP server
|
|
1053
|
+
await tool.connect() # type: ignore
|
|
1054
|
+
self._mcp_tools_initialized_on_run.append(tool)
|
|
1055
|
+
except Exception as e:
|
|
1056
|
+
log_warning(f"Error connecting tool: {str(e)}")
|
|
1052
1057
|
|
|
1053
1058
|
async def _disconnect_mcp_tools(self) -> None:
|
|
1054
1059
|
"""Disconnect the MCP tools from the agent."""
|
|
1055
1060
|
for tool in self._mcp_tools_initialized_on_run:
|
|
1056
|
-
|
|
1061
|
+
try:
|
|
1062
|
+
await tool.close()
|
|
1063
|
+
except Exception as e:
|
|
1064
|
+
log_warning(f"Error disconnecting tool: {str(e)}")
|
|
1057
1065
|
self._mcp_tools_initialized_on_run = []
|
|
1058
1066
|
|
|
1067
|
+
def _connect_connectable_tools(self) -> None:
|
|
1068
|
+
"""Connect tools that require connection management (e.g., database connections)."""
|
|
1069
|
+
if self.tools:
|
|
1070
|
+
for tool in self.tools:
|
|
1071
|
+
if (
|
|
1072
|
+
hasattr(tool, "requires_connect")
|
|
1073
|
+
and tool.requires_connect
|
|
1074
|
+
and hasattr(tool, "connect")
|
|
1075
|
+
and tool not in self._connectable_tools_initialized_on_run
|
|
1076
|
+
):
|
|
1077
|
+
try:
|
|
1078
|
+
tool.connect() # type: ignore
|
|
1079
|
+
self._connectable_tools_initialized_on_run.append(tool)
|
|
1080
|
+
except Exception as e:
|
|
1081
|
+
log_warning(f"Error connecting tool: {str(e)}")
|
|
1082
|
+
|
|
1083
|
+
def _disconnect_connectable_tools(self) -> None:
|
|
1084
|
+
"""Disconnect tools that require connection management."""
|
|
1085
|
+
for tool in self._connectable_tools_initialized_on_run:
|
|
1086
|
+
if hasattr(tool, "close"):
|
|
1087
|
+
try:
|
|
1088
|
+
tool.close() # type: ignore
|
|
1089
|
+
except Exception as e:
|
|
1090
|
+
log_warning(f"Error disconnecting tool: {str(e)}")
|
|
1091
|
+
self._connectable_tools_initialized_on_run = []
|
|
1092
|
+
|
|
1059
1093
|
def _execute_pre_hooks(
|
|
1060
1094
|
self,
|
|
1061
1095
|
hooks: Optional[List[Callable[..., Any]]],
|
|
@@ -1625,6 +1659,8 @@ class Team:
|
|
|
1625
1659
|
self._cleanup_and_store(run_response=run_response, session=session)
|
|
1626
1660
|
return run_response
|
|
1627
1661
|
finally:
|
|
1662
|
+
# Always disconnect connectable tools
|
|
1663
|
+
self._disconnect_connectable_tools()
|
|
1628
1664
|
cleanup_run(run_response.run_id) # type: ignore
|
|
1629
1665
|
|
|
1630
1666
|
def _run_stream(
|
|
@@ -1911,6 +1947,8 @@ class Team:
|
|
|
1911
1947
|
# Add the RunOutput to Team Session even when cancelled
|
|
1912
1948
|
self._cleanup_and_store(run_response=run_response, session=session)
|
|
1913
1949
|
finally:
|
|
1950
|
+
# Always disconnect connectable tools
|
|
1951
|
+
self._disconnect_connectable_tools()
|
|
1914
1952
|
# Always clean up the run tracking
|
|
1915
1953
|
cleanup_run(run_response.run_id) # type: ignore
|
|
1916
1954
|
|
|
@@ -2060,7 +2098,10 @@ class Team:
|
|
|
2060
2098
|
|
|
2061
2099
|
# Initialize session state
|
|
2062
2100
|
session_state = self._initialize_session_state(
|
|
2063
|
-
session_state=session_state
|
|
2101
|
+
session_state=session_state if session_state is not None else {},
|
|
2102
|
+
user_id=user_id,
|
|
2103
|
+
session_id=session_id,
|
|
2104
|
+
run_id=run_id,
|
|
2064
2105
|
)
|
|
2065
2106
|
# Update session state from DB
|
|
2066
2107
|
session_state = self._load_session_state(session=team_session, session_state=session_state)
|
|
@@ -2293,7 +2334,7 @@ class Team:
|
|
|
2293
2334
|
self._update_metadata(session=team_session)
|
|
2294
2335
|
# Initialize session state
|
|
2295
2336
|
run_context.session_state = self._initialize_session_state(
|
|
2296
|
-
session_state=run_context.session_state
|
|
2337
|
+
session_state=run_context.session_state if run_context.session_state is not None else {},
|
|
2297
2338
|
user_id=user_id,
|
|
2298
2339
|
session_id=session_id,
|
|
2299
2340
|
run_id=run_response.run_id,
|
|
@@ -2472,6 +2513,8 @@ class Team:
|
|
|
2472
2513
|
|
|
2473
2514
|
return run_response
|
|
2474
2515
|
finally:
|
|
2516
|
+
# Always disconnect connectable tools
|
|
2517
|
+
self._disconnect_connectable_tools()
|
|
2475
2518
|
await self._disconnect_mcp_tools()
|
|
2476
2519
|
# Cancel the memory task if it's still running
|
|
2477
2520
|
if memory_task is not None and not memory_task.done():
|
|
@@ -2533,7 +2576,7 @@ class Team:
|
|
|
2533
2576
|
self._update_metadata(session=team_session)
|
|
2534
2577
|
# Initialize session state
|
|
2535
2578
|
run_context.session_state = self._initialize_session_state(
|
|
2536
|
-
session_state=run_context.session_state
|
|
2579
|
+
session_state=run_context.session_state if run_context.session_state is not None else {},
|
|
2537
2580
|
user_id=user_id,
|
|
2538
2581
|
session_id=session_id,
|
|
2539
2582
|
run_id=run_response.run_id,
|
|
@@ -2798,6 +2841,8 @@ class Team:
|
|
|
2798
2841
|
await self._acleanup_and_store(run_response=run_response, session=team_session)
|
|
2799
2842
|
|
|
2800
2843
|
finally:
|
|
2844
|
+
# Always disconnect connectable tools
|
|
2845
|
+
self._disconnect_connectable_tools()
|
|
2801
2846
|
await self._disconnect_mcp_tools()
|
|
2802
2847
|
# Cancel the memory task if it's still running
|
|
2803
2848
|
if memory_task is not None and not memory_task.done():
|
|
@@ -5330,6 +5375,9 @@ class Team:
|
|
|
5330
5375
|
add_session_state_to_context: Optional[bool] = None,
|
|
5331
5376
|
check_mcp_tools: bool = True,
|
|
5332
5377
|
) -> List[Union[Function, dict]]:
|
|
5378
|
+
# Connect tools that require connection management
|
|
5379
|
+
self._connect_connectable_tools()
|
|
5380
|
+
|
|
5333
5381
|
# Prepare tools
|
|
5334
5382
|
_tools: List[Union[Toolkit, Callable, Function, Dict]] = []
|
|
5335
5383
|
|
|
@@ -5379,6 +5427,7 @@ class Team:
|
|
|
5379
5427
|
run_response=run_response,
|
|
5380
5428
|
knowledge_filters=run_context.knowledge_filters,
|
|
5381
5429
|
async_mode=async_mode,
|
|
5430
|
+
run_context=run_context,
|
|
5382
5431
|
)
|
|
5383
5432
|
)
|
|
5384
5433
|
else:
|
|
@@ -5387,6 +5436,7 @@ class Team:
|
|
|
5387
5436
|
run_response=run_response,
|
|
5388
5437
|
knowledge_filters=run_context.knowledge_filters,
|
|
5389
5438
|
async_mode=async_mode,
|
|
5439
|
+
run_context=run_context,
|
|
5390
5440
|
)
|
|
5391
5441
|
)
|
|
5392
5442
|
|
|
@@ -6576,7 +6626,10 @@ class Team:
|
|
|
6576
6626
|
retrieval_timer = Timer()
|
|
6577
6627
|
retrieval_timer.start()
|
|
6578
6628
|
docs_from_knowledge = self.get_relevant_docs_from_knowledge(
|
|
6579
|
-
query=user_msg_content,
|
|
6629
|
+
query=user_msg_content,
|
|
6630
|
+
filters=run_context.knowledge_filters,
|
|
6631
|
+
run_context=run_context,
|
|
6632
|
+
**kwargs,
|
|
6580
6633
|
)
|
|
6581
6634
|
if docs_from_knowledge is not None:
|
|
6582
6635
|
references = MessageReferences(
|
|
@@ -6731,7 +6784,10 @@ class Team:
|
|
|
6731
6784
|
retrieval_timer = Timer()
|
|
6732
6785
|
retrieval_timer.start()
|
|
6733
6786
|
docs_from_knowledge = await self.aget_relevant_docs_from_knowledge(
|
|
6734
|
-
query=user_msg_content,
|
|
6787
|
+
query=user_msg_content,
|
|
6788
|
+
filters=run_context.knowledge_filters,
|
|
6789
|
+
run_context=run_context,
|
|
6790
|
+
**kwargs,
|
|
6735
6791
|
)
|
|
6736
6792
|
if docs_from_knowledge is not None:
|
|
6737
6793
|
references = MessageReferences(
|
|
@@ -6877,7 +6933,7 @@ class Team:
|
|
|
6877
6933
|
return message
|
|
6878
6934
|
# Should already be resolved and passed from run() method
|
|
6879
6935
|
format_variables = ChainMap(
|
|
6880
|
-
session_state
|
|
6936
|
+
session_state if session_state is not None else {},
|
|
6881
6937
|
dependencies or {},
|
|
6882
6938
|
metadata or {},
|
|
6883
6939
|
{"user_id": user_id} if user_id is not None else {},
|
|
@@ -8997,11 +9053,15 @@ class Team:
|
|
|
8997
9053
|
query: str,
|
|
8998
9054
|
num_documents: Optional[int] = None,
|
|
8999
9055
|
filters: Optional[Union[Dict[str, Any], List[FilterExpr]]] = None,
|
|
9056
|
+
run_context: Optional[RunContext] = None,
|
|
9000
9057
|
**kwargs,
|
|
9001
9058
|
) -> Optional[List[Union[Dict[str, Any], str]]]:
|
|
9002
9059
|
"""Return a list of references from the knowledge base"""
|
|
9003
9060
|
from agno.knowledge.document import Document
|
|
9004
9061
|
|
|
9062
|
+
# Extract dependencies from run_context if available
|
|
9063
|
+
dependencies = run_context.dependencies if run_context else None
|
|
9064
|
+
|
|
9005
9065
|
if num_documents is None and self.knowledge is not None:
|
|
9006
9066
|
num_documents = self.knowledge.max_results
|
|
9007
9067
|
|
|
@@ -9033,6 +9093,11 @@ class Team:
|
|
|
9033
9093
|
knowledge_retriever_kwargs = {"team": self}
|
|
9034
9094
|
if "filters" in sig.parameters:
|
|
9035
9095
|
knowledge_retriever_kwargs["filters"] = filters
|
|
9096
|
+
if "run_context" in sig.parameters:
|
|
9097
|
+
knowledge_retriever_kwargs["run_context"] = run_context
|
|
9098
|
+
elif "dependencies" in sig.parameters:
|
|
9099
|
+
# Backward compatibility: support dependencies parameter
|
|
9100
|
+
knowledge_retriever_kwargs["dependencies"] = dependencies
|
|
9036
9101
|
knowledge_retriever_kwargs.update({"query": query, "num_documents": num_documents, **kwargs})
|
|
9037
9102
|
return self.knowledge_retriever(**knowledge_retriever_kwargs)
|
|
9038
9103
|
except Exception as e:
|
|
@@ -9064,11 +9129,15 @@ class Team:
|
|
|
9064
9129
|
query: str,
|
|
9065
9130
|
num_documents: Optional[int] = None,
|
|
9066
9131
|
filters: Optional[Union[Dict[str, Any], List[FilterExpr]]] = None,
|
|
9132
|
+
run_context: Optional[RunContext] = None,
|
|
9067
9133
|
**kwargs,
|
|
9068
9134
|
) -> Optional[List[Union[Dict[str, Any], str]]]:
|
|
9069
9135
|
"""Get relevant documents from knowledge base asynchronously."""
|
|
9070
9136
|
from agno.knowledge.document import Document
|
|
9071
9137
|
|
|
9138
|
+
# Extract dependencies from run_context if available
|
|
9139
|
+
dependencies = run_context.dependencies if run_context else None
|
|
9140
|
+
|
|
9072
9141
|
if num_documents is None and self.knowledge is not None:
|
|
9073
9142
|
num_documents = self.knowledge.max_results
|
|
9074
9143
|
|
|
@@ -9100,6 +9169,11 @@ class Team:
|
|
|
9100
9169
|
knowledge_retriever_kwargs = {"team": self}
|
|
9101
9170
|
if "filters" in sig.parameters:
|
|
9102
9171
|
knowledge_retriever_kwargs["filters"] = filters
|
|
9172
|
+
if "run_context" in sig.parameters:
|
|
9173
|
+
knowledge_retriever_kwargs["run_context"] = run_context
|
|
9174
|
+
elif "dependencies" in sig.parameters:
|
|
9175
|
+
# Backward compatibility: support dependencies parameter
|
|
9176
|
+
knowledge_retriever_kwargs["dependencies"] = dependencies
|
|
9103
9177
|
knowledge_retriever_kwargs.update({"query": query, "num_documents": num_documents, **kwargs})
|
|
9104
9178
|
|
|
9105
9179
|
result = self.knowledge_retriever(**knowledge_retriever_kwargs)
|
|
@@ -9184,6 +9258,7 @@ class Team:
|
|
|
9184
9258
|
run_response: TeamRunOutput,
|
|
9185
9259
|
knowledge_filters: Optional[Union[Dict[str, Any], List[FilterExpr]]] = None,
|
|
9186
9260
|
async_mode: bool = False,
|
|
9261
|
+
run_context: Optional[RunContext] = None,
|
|
9187
9262
|
) -> Function:
|
|
9188
9263
|
"""Factory function to create a search_knowledge_base function with filters."""
|
|
9189
9264
|
|
|
@@ -9199,7 +9274,9 @@ class Team:
|
|
|
9199
9274
|
# Get the relevant documents from the knowledge base, passing filters
|
|
9200
9275
|
retrieval_timer = Timer()
|
|
9201
9276
|
retrieval_timer.start()
|
|
9202
|
-
docs_from_knowledge = self.get_relevant_docs_from_knowledge(
|
|
9277
|
+
docs_from_knowledge = self.get_relevant_docs_from_knowledge(
|
|
9278
|
+
query=query, filters=knowledge_filters, run_context=run_context
|
|
9279
|
+
)
|
|
9203
9280
|
if docs_from_knowledge is not None:
|
|
9204
9281
|
references = MessageReferences(
|
|
9205
9282
|
query=query, references=docs_from_knowledge, time=round(retrieval_timer.elapsed, 4)
|
|
@@ -9226,7 +9303,9 @@ class Team:
|
|
|
9226
9303
|
"""
|
|
9227
9304
|
retrieval_timer = Timer()
|
|
9228
9305
|
retrieval_timer.start()
|
|
9229
|
-
docs_from_knowledge = await self.aget_relevant_docs_from_knowledge(
|
|
9306
|
+
docs_from_knowledge = await self.aget_relevant_docs_from_knowledge(
|
|
9307
|
+
query=query, filters=knowledge_filters, run_context=run_context
|
|
9308
|
+
)
|
|
9230
9309
|
if docs_from_knowledge is not None:
|
|
9231
9310
|
references = MessageReferences(
|
|
9232
9311
|
query=query, references=docs_from_knowledge, time=round(retrieval_timer.elapsed, 4)
|
|
@@ -9253,6 +9332,7 @@ class Team:
|
|
|
9253
9332
|
run_response: TeamRunOutput,
|
|
9254
9333
|
knowledge_filters: Optional[Union[Dict[str, Any], List[FilterExpr]]] = None,
|
|
9255
9334
|
async_mode: bool = False,
|
|
9335
|
+
run_context: Optional[RunContext] = None,
|
|
9256
9336
|
) -> Function:
|
|
9257
9337
|
"""Factory function to create a search_knowledge_base function with filters."""
|
|
9258
9338
|
|
|
@@ -9272,7 +9352,9 @@ class Team:
|
|
|
9272
9352
|
# Get the relevant documents from the knowledge base, passing filters
|
|
9273
9353
|
retrieval_timer = Timer()
|
|
9274
9354
|
retrieval_timer.start()
|
|
9275
|
-
docs_from_knowledge = self.get_relevant_docs_from_knowledge(
|
|
9355
|
+
docs_from_knowledge = self.get_relevant_docs_from_knowledge(
|
|
9356
|
+
query=query, filters=search_filters, run_context=run_context
|
|
9357
|
+
)
|
|
9276
9358
|
if docs_from_knowledge is not None:
|
|
9277
9359
|
references = MessageReferences(
|
|
9278
9360
|
query=query, references=docs_from_knowledge, time=round(retrieval_timer.elapsed, 4)
|
|
@@ -9303,7 +9385,9 @@ class Team:
|
|
|
9303
9385
|
|
|
9304
9386
|
retrieval_timer = Timer()
|
|
9305
9387
|
retrieval_timer.start()
|
|
9306
|
-
docs_from_knowledge = await self.aget_relevant_docs_from_knowledge(
|
|
9388
|
+
docs_from_knowledge = await self.aget_relevant_docs_from_knowledge(
|
|
9389
|
+
query=query, filters=search_filters, run_context=run_context
|
|
9390
|
+
)
|
|
9307
9391
|
if docs_from_knowledge is not None:
|
|
9308
9392
|
references = MessageReferences(
|
|
9309
9393
|
query=query, references=docs_from_knowledge, time=round(retrieval_timer.elapsed, 4)
|
agno/tools/google_drive.py
CHANGED
|
@@ -69,6 +69,7 @@ from pathlib import Path
|
|
|
69
69
|
from typing import Any, List, Optional, Union
|
|
70
70
|
|
|
71
71
|
from agno.tools import Toolkit
|
|
72
|
+
from agno.utils.log import log_error
|
|
72
73
|
|
|
73
74
|
try:
|
|
74
75
|
from google.auth.transport.requests import Request
|
|
@@ -202,7 +203,7 @@ class GoogleDriveTools(Toolkit):
|
|
|
202
203
|
items = results.get("files", [])
|
|
203
204
|
return items
|
|
204
205
|
except Exception as error:
|
|
205
|
-
|
|
206
|
+
log_error(f"Could not list files: {error}")
|
|
206
207
|
return []
|
|
207
208
|
|
|
208
209
|
@authenticate
|
|
@@ -238,7 +239,7 @@ class GoogleDriveTools(Toolkit):
|
|
|
238
239
|
)
|
|
239
240
|
return uploaded_file
|
|
240
241
|
except Exception as error:
|
|
241
|
-
|
|
242
|
+
log_error(f"Could not upload file '{file_path}': {error}")
|
|
242
243
|
return None
|
|
243
244
|
|
|
244
245
|
@authenticate
|
|
@@ -266,5 +267,5 @@ class GoogleDriveTools(Toolkit):
|
|
|
266
267
|
print(f"Download progress: {int(status.progress() * 100)}%.")
|
|
267
268
|
return dest_path
|
|
268
269
|
except Exception as error:
|
|
269
|
-
|
|
270
|
+
log_error(f"Could not download file '{file_id}': {error}")
|
|
270
271
|
return None
|
agno/tools/postgres.py
CHANGED
|
@@ -14,6 +14,21 @@ from agno.utils.log import log_debug, log_error
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class PostgresTools(Toolkit):
|
|
17
|
+
"""
|
|
18
|
+
A toolkit for interacting with PostgreSQL databases.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
connection (Optional[PgConnection[DictRow]]): Existing database connection to reuse.
|
|
22
|
+
db_name (Optional[str]): Database name to connect to.
|
|
23
|
+
user (Optional[str]): Username for authentication.
|
|
24
|
+
password (Optional[str]): Password for authentication.
|
|
25
|
+
host (Optional[str]): PostgreSQL server hostname.
|
|
26
|
+
port (Optional[int]): PostgreSQL server port number.
|
|
27
|
+
table_schema (str): Default schema for table operations. Default is "public".
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
_requires_connect: bool = True
|
|
31
|
+
|
|
17
32
|
def __init__(
|
|
18
33
|
self,
|
|
19
34
|
connection: Optional[PgConnection[DictRow]] = None,
|
|
@@ -44,50 +59,71 @@ class PostgresTools(Toolkit):
|
|
|
44
59
|
|
|
45
60
|
super().__init__(name="postgres_tools", tools=tools, **kwargs)
|
|
46
61
|
|
|
47
|
-
|
|
48
|
-
def connection(self) -> PgConnection[DictRow]:
|
|
49
|
-
"""
|
|
50
|
-
Returns the Postgres psycopg connection.
|
|
51
|
-
:return psycopg.connection.Connection: psycopg connection
|
|
62
|
+
def connect(self) -> PgConnection[DictRow]:
|
|
52
63
|
"""
|
|
53
|
-
|
|
54
|
-
log_debug("Establishing new PostgreSQL connection.")
|
|
55
|
-
connection_kwargs: Dict[str, Any] = {"row_factory": dict_row}
|
|
56
|
-
if self.db_name:
|
|
57
|
-
connection_kwargs["dbname"] = self.db_name
|
|
58
|
-
if self.user:
|
|
59
|
-
connection_kwargs["user"] = self.user
|
|
60
|
-
if self.password:
|
|
61
|
-
connection_kwargs["password"] = self.password
|
|
62
|
-
if self.host:
|
|
63
|
-
connection_kwargs["host"] = self.host
|
|
64
|
-
if self.port:
|
|
65
|
-
connection_kwargs["port"] = self.port
|
|
66
|
-
|
|
67
|
-
connection_kwargs["options"] = f"-c search_path={self.table_schema}"
|
|
68
|
-
|
|
69
|
-
self._connection = psycopg.connect(**connection_kwargs)
|
|
70
|
-
self._connection.read_only = True
|
|
64
|
+
Establish a connection to the PostgreSQL database.
|
|
71
65
|
|
|
66
|
+
Returns:
|
|
67
|
+
The database connection object.
|
|
68
|
+
"""
|
|
69
|
+
if self._connection is not None and not self._connection.closed:
|
|
70
|
+
log_debug("Connection already established, reusing existing connection")
|
|
71
|
+
return self._connection
|
|
72
|
+
|
|
73
|
+
log_debug("Establishing new PostgreSQL connection.")
|
|
74
|
+
connection_kwargs: Dict[str, Any] = {"row_factory": dict_row}
|
|
75
|
+
if self.db_name:
|
|
76
|
+
connection_kwargs["dbname"] = self.db_name
|
|
77
|
+
if self.user:
|
|
78
|
+
connection_kwargs["user"] = self.user
|
|
79
|
+
if self.password:
|
|
80
|
+
connection_kwargs["password"] = self.password
|
|
81
|
+
if self.host:
|
|
82
|
+
connection_kwargs["host"] = self.host
|
|
83
|
+
if self.port:
|
|
84
|
+
connection_kwargs["port"] = self.port
|
|
85
|
+
|
|
86
|
+
connection_kwargs["options"] = f"-c search_path={self.table_schema}"
|
|
87
|
+
|
|
88
|
+
self._connection = psycopg.connect(**connection_kwargs)
|
|
89
|
+
self._connection.read_only = True
|
|
72
90
|
return self._connection
|
|
73
91
|
|
|
74
|
-
def
|
|
75
|
-
return self
|
|
76
|
-
|
|
77
|
-
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
78
|
-
self.close()
|
|
79
|
-
|
|
80
|
-
def close(self):
|
|
92
|
+
def close(self) -> None:
|
|
81
93
|
"""Closes the database connection if it's open."""
|
|
82
94
|
if self._connection and not self._connection.closed:
|
|
83
95
|
log_debug("Closing PostgreSQL connection.")
|
|
84
96
|
self._connection.close()
|
|
85
97
|
self._connection = None
|
|
86
98
|
|
|
99
|
+
@property
|
|
100
|
+
def is_connected(self) -> bool:
|
|
101
|
+
"""Check if a connection is currently established."""
|
|
102
|
+
return self._connection is not None and not self._connection.closed
|
|
103
|
+
|
|
104
|
+
def _ensure_connection(self) -> PgConnection[DictRow]:
|
|
105
|
+
"""
|
|
106
|
+
Ensure a connection exists, creating one if necessary.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
The database connection object.
|
|
110
|
+
"""
|
|
111
|
+
if not self.is_connected:
|
|
112
|
+
return self.connect()
|
|
113
|
+
return self._connection
|
|
114
|
+
|
|
115
|
+
def __enter__(self):
|
|
116
|
+
return self.connect()
|
|
117
|
+
|
|
118
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
119
|
+
if self.is_connected:
|
|
120
|
+
self.close()
|
|
121
|
+
|
|
87
122
|
def _execute_query(self, query: str, params: Optional[tuple] = None) -> str:
|
|
88
123
|
try:
|
|
89
|
-
|
|
90
|
-
|
|
124
|
+
connection = self._ensure_connection()
|
|
125
|
+
with connection.cursor() as cursor:
|
|
126
|
+
log_debug("Running PostgreSQL query")
|
|
91
127
|
cursor.execute(query, params)
|
|
92
128
|
|
|
93
129
|
if cursor.description is None:
|
|
@@ -105,8 +141,8 @@ class PostgresTools(Toolkit):
|
|
|
105
141
|
|
|
106
142
|
except psycopg.Error as e:
|
|
107
143
|
log_error(f"Database error: {e}")
|
|
108
|
-
if self.
|
|
109
|
-
self.
|
|
144
|
+
if self._connection and not self._connection.closed:
|
|
145
|
+
self._connection.rollback()
|
|
110
146
|
return f"Error executing query: {e}"
|
|
111
147
|
except Exception as e:
|
|
112
148
|
log_error(f"An unexpected error occurred: {e}")
|
|
@@ -146,7 +182,8 @@ class PostgresTools(Toolkit):
|
|
|
146
182
|
A string containing a summary of the table.
|
|
147
183
|
"""
|
|
148
184
|
try:
|
|
149
|
-
|
|
185
|
+
connection = self._ensure_connection()
|
|
186
|
+
with connection.cursor() as cursor:
|
|
150
187
|
# First, get column information using a parameterized query
|
|
151
188
|
schema_query = """
|
|
152
189
|
SELECT column_name, data_type
|
|
@@ -230,7 +267,8 @@ class PostgresTools(Toolkit):
|
|
|
230
267
|
stmt = sql.SQL("SELECT * FROM {tbl};").format(tbl=table_identifier)
|
|
231
268
|
|
|
232
269
|
try:
|
|
233
|
-
|
|
270
|
+
connection = self._ensure_connection()
|
|
271
|
+
with connection.cursor() as cursor:
|
|
234
272
|
cursor.execute(stmt)
|
|
235
273
|
|
|
236
274
|
if cursor.description is None:
|
|
@@ -245,6 +283,8 @@ class PostgresTools(Toolkit):
|
|
|
245
283
|
|
|
246
284
|
return f"Successfully exported table '{table}' to '{path}'."
|
|
247
285
|
except (psycopg.Error, IOError) as e:
|
|
286
|
+
if self._connection and not self._connection.closed:
|
|
287
|
+
self._connection.rollback()
|
|
248
288
|
return f"Error exporting table: {e}"
|
|
249
289
|
|
|
250
290
|
def run_query(self, query: str) -> str:
|