smarta2a 0.3.1__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smarta2a/agent/a2a_agent.py +25 -15
- smarta2a/agent/a2a_human.py +56 -0
- smarta2a/archive/smart_mcp_client.py +47 -0
- smarta2a/archive/subscription_service.py +85 -0
- smarta2a/{server → archive}/task_service.py +17 -8
- smarta2a/client/a2a_client.py +33 -6
- smarta2a/history_update_strategies/rolling_window_strategy.py +16 -0
- smarta2a/model_providers/__init__.py +1 -1
- smarta2a/model_providers/base_llm_provider.py +3 -3
- smarta2a/model_providers/openai_provider.py +126 -89
- smarta2a/server/json_rpc_request_processor.py +130 -0
- smarta2a/server/nats_client.py +49 -0
- smarta2a/server/request_handler.py +667 -0
- smarta2a/server/send_task_handler.py +174 -0
- smarta2a/server/server.py +124 -726
- smarta2a/server/state_manager.py +171 -20
- smarta2a/server/webhook_request_processor.py +112 -0
- smarta2a/state_stores/base_state_store.py +3 -3
- smarta2a/state_stores/inmemory_state_store.py +21 -7
- smarta2a/utils/agent_discovery_manager.py +121 -0
- smarta2a/utils/prompt_helpers.py +1 -1
- smarta2a/{client → utils}/tools_manager.py +39 -8
- smarta2a/utils/types.py +17 -3
- {smarta2a-0.3.1.dist-info → smarta2a-0.4.1.dist-info}/METADATA +7 -4
- smarta2a-0.4.1.dist-info/RECORD +40 -0
- smarta2a-0.4.1.dist-info/licenses/LICENSE +35 -0
- smarta2a/examples/__init__.py +0 -0
- smarta2a/examples/echo_server/__init__.py +0 -0
- smarta2a/examples/echo_server/curl.txt +0 -1
- smarta2a/examples/echo_server/main.py +0 -39
- smarta2a/examples/openai_airbnb_agent/__init__.py +0 -0
- smarta2a/examples/openai_airbnb_agent/main.py +0 -33
- smarta2a/examples/openai_delegator_agent/__init__.py +0 -0
- smarta2a/examples/openai_delegator_agent/main.py +0 -51
- smarta2a/examples/openai_weather_agent/__init__.py +0 -0
- smarta2a/examples/openai_weather_agent/main.py +0 -32
- smarta2a/server/subscription_service.py +0 -109
- smarta2a-0.3.1.dist-info/RECORD +0 -42
- smarta2a-0.3.1.dist-info/licenses/LICENSE +0 -21
- {smarta2a-0.3.1.dist-info → smarta2a-0.4.1.dist-info}/WHEEL +0 -0
@@ -1,13 +1,16 @@
|
|
1
1
|
# Library imports
|
2
2
|
import json
|
3
|
+
import httpx
|
3
4
|
from typing import AsyncGenerator, List, Dict, Optional, Union, Any
|
4
5
|
from openai import AsyncOpenAI
|
6
|
+
from pydantic import HttpUrl, ValidationError
|
5
7
|
|
6
8
|
# Local imports
|
7
|
-
from smarta2a.utils.types import Message, TextPart, FilePart, DataPart, Part, AgentCard
|
9
|
+
from smarta2a.utils.types import Message, TextPart, FilePart, DataPart, Part, AgentCard, StateData
|
8
10
|
from smarta2a.model_providers.base_llm_provider import BaseLLMProvider
|
9
|
-
from smarta2a.
|
11
|
+
from smarta2a.utils.tools_manager import ToolsManager
|
10
12
|
from smarta2a.utils.prompt_helpers import build_system_prompt
|
13
|
+
from smarta2a.utils.agent_discovery_manager import AgentDiscoveryManager
|
11
14
|
|
12
15
|
class OpenAIProvider(BaseLLMProvider):
|
13
16
|
def __init__(
|
@@ -17,22 +20,39 @@ class OpenAIProvider(BaseLLMProvider):
|
|
17
20
|
base_system_prompt: Optional[str] = None,
|
18
21
|
mcp_server_urls_or_paths: Optional[List[str]] = None,
|
19
22
|
agent_cards: Optional[List[AgentCard]] = None,
|
20
|
-
|
23
|
+
agent_base_urls: Optional[List[HttpUrl]] = None,
|
24
|
+
discovery_endpoint: Optional[HttpUrl] = None,
|
25
|
+
timeout: float = 5.0,
|
26
|
+
retries: int = 2
|
21
27
|
):
|
22
28
|
self.client = AsyncOpenAI(api_key=api_key)
|
23
29
|
self.model = model
|
24
30
|
self.mcp_server_urls_or_paths = mcp_server_urls_or_paths
|
25
|
-
self.agent_cards = agent_cards
|
26
31
|
# Store the base system prompt; will be enriched by tool descriptions
|
27
32
|
self.base_system_prompt = base_system_prompt
|
28
33
|
self.supported_media_types = [
|
29
34
|
"image/png", "image/jpeg", "image/gif", "image/webp"
|
30
35
|
]
|
36
|
+
|
37
|
+
# Initialize discovery manager
|
38
|
+
self.agent_discovery = AgentDiscoveryManager(
|
39
|
+
agent_cards=agent_cards,
|
40
|
+
agent_base_urls=agent_base_urls,
|
41
|
+
discovery_endpoint=discovery_endpoint,
|
42
|
+
timeout=timeout,
|
43
|
+
retries=retries
|
44
|
+
)
|
45
|
+
|
46
|
+
self.agent_cards: List[AgentCard] = []
|
31
47
|
# Initialize ToolsManager
|
32
48
|
self.tools_manager = ToolsManager()
|
33
49
|
|
34
50
|
|
35
51
|
async def load(self):
|
52
|
+
"""Async initialization of resources"""
|
53
|
+
# Discover agents first
|
54
|
+
self.agent_cards = await self.agent_discovery.discover_agents()
|
55
|
+
|
36
56
|
if self.mcp_server_urls_or_paths:
|
37
57
|
await self.tools_manager.load_mcp_tools(self.mcp_server_urls_or_paths)
|
38
58
|
|
@@ -116,102 +136,108 @@ class OpenAIProvider(BaseLLMProvider):
|
|
116
136
|
return openai_messages
|
117
137
|
|
118
138
|
|
119
|
-
def
|
139
|
+
def _format_openai_functions(self) -> List[dict]:
|
120
140
|
"""
|
121
141
|
Convert internal tools metadata to OpenAI's function-call schema.
|
122
142
|
"""
|
123
|
-
|
143
|
+
functions = []
|
124
144
|
for tool in self.tools_manager.get_tools():
|
125
|
-
|
126
|
-
"
|
127
|
-
"
|
128
|
-
|
129
|
-
"description": tool.description,
|
130
|
-
"parameters": tool.inputSchema
|
131
|
-
}
|
145
|
+
functions.append({
|
146
|
+
"name": tool.key,
|
147
|
+
"description": tool.description,
|
148
|
+
"parameters": tool.inputSchema,
|
132
149
|
})
|
133
|
-
return
|
134
|
-
|
150
|
+
return functions
|
151
|
+
|
135
152
|
|
136
|
-
async def generate(self,
|
153
|
+
async def generate(self, state: StateData, **kwargs) -> str:
|
137
154
|
"""
|
138
155
|
Generate a complete response, invoking tools as needed.
|
139
156
|
"""
|
140
|
-
|
141
|
-
|
142
|
-
|
157
|
+
|
158
|
+
# Prepare history
|
159
|
+
messages = [msg if isinstance(msg, Message) else Message(**msg) for msg in state.context_history]
|
160
|
+
openai_messages = self._convert_messages(messages)
|
143
161
|
max_iterations = 30
|
144
162
|
|
145
|
-
for
|
163
|
+
for _ in range(max_iterations):
|
164
|
+
# Call ChatCompletion with functions (not 'tools')
|
146
165
|
response = await self.client.chat.completions.create(
|
147
166
|
model=self.model,
|
148
|
-
messages=
|
149
|
-
|
167
|
+
messages=openai_messages,
|
168
|
+
functions=self._format_openai_functions(),
|
150
169
|
**kwargs
|
151
170
|
)
|
152
|
-
|
153
|
-
|
154
|
-
#
|
155
|
-
if
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
print(message)
|
165
|
-
# Append the assistant's intent
|
166
|
-
converted_messages.append({
|
171
|
+
msg = response.choices[0].message
|
172
|
+
|
173
|
+
# If no function call, return content
|
174
|
+
if not msg.function_call:
|
175
|
+
return msg.content
|
176
|
+
|
177
|
+
# Extract function call details
|
178
|
+
fn_name = msg.function_call.name
|
179
|
+
fn_args = json.loads(msg.function_call.arguments or '{}')
|
180
|
+
|
181
|
+
# Append assistant function_call message
|
182
|
+
openai_messages.append({
|
167
183
|
"role": "assistant",
|
168
184
|
"content": None,
|
169
|
-
"function_call": {
|
185
|
+
"function_call": {
|
186
|
+
"name": fn_name,
|
187
|
+
"arguments": msg.function_call.arguments,
|
188
|
+
}
|
170
189
|
})
|
171
190
|
|
172
|
-
#
|
191
|
+
# Call the actual tool
|
173
192
|
try:
|
174
|
-
|
175
|
-
|
176
|
-
|
193
|
+
override_args = {
|
194
|
+
'id': state.task_id,
|
195
|
+
'sessionId': state.task.sessionId
|
196
|
+
}
|
177
197
|
|
178
|
-
|
179
|
-
try:
|
180
|
-
tool_result = await self.tools_manager.call_tool(name, args)
|
198
|
+
tool_result = await self.tools_manager.call_tool(fn_name, fn_args, override_args)
|
181
199
|
except Exception as e:
|
182
|
-
tool_result = {"content": f"Error calling {
|
200
|
+
tool_result = {"content": f"Error calling {fn_name}: {e}"}
|
201
|
+
|
202
|
+
# Case 1: Handle list of TextContent objects
|
203
|
+
if isinstance(tool_result, list) and len(tool_result) > 0:
|
204
|
+
# Extract text from the first TextContent item in the list
|
205
|
+
result = tool_result[0].text # Access the `text` attribute
|
183
206
|
|
184
|
-
#
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
207
|
+
# Case 2: Handle error case (dict with 'content' key)
|
208
|
+
elif isinstance(tool_result, dict):
|
209
|
+
result = tool_result.get('content', str(tool_result))
|
210
|
+
|
211
|
+
# Fallback for unexpected types
|
189
212
|
else:
|
190
|
-
|
213
|
+
result = str(tool_result)
|
191
214
|
|
192
|
-
# Append
|
193
|
-
|
215
|
+
# Append function response
|
216
|
+
openai_messages.append({
|
194
217
|
"role": "function",
|
195
|
-
"name":
|
196
|
-
"content":
|
218
|
+
"name": fn_name,
|
219
|
+
"content": result,
|
197
220
|
})
|
198
221
|
|
199
222
|
raise RuntimeError("Max tool iteration depth reached in generate().")
|
200
223
|
|
201
224
|
|
202
|
-
|
203
|
-
async def generate_stream(
|
204
|
-
self, messages: List[Message], **kwargs
|
205
|
-
) -> AsyncGenerator[str, None]:
|
225
|
+
async def generate_stream(self, state: StateData, **kwargs) -> AsyncGenerator[str, None]:
|
206
226
|
"""
|
207
|
-
Stream response chunks,
|
227
|
+
Stream response chunks, invoking tools as needed.
|
208
228
|
"""
|
209
|
-
|
210
|
-
|
211
|
-
|
229
|
+
context_history = state.context_history
|
230
|
+
# Normalize incoming messages to your Message model
|
231
|
+
msgs = [
|
232
|
+
msg if isinstance(msg, Message) else Message(**msg)
|
233
|
+
for msg in context_history
|
234
|
+
]
|
235
|
+
# Convert to OpenAI schema, including any prior tool results
|
236
|
+
converted_messages = self._convert_messages(msgs)
|
237
|
+
max_iterations = 30
|
212
238
|
|
213
239
|
for _ in range(max_iterations):
|
214
|
-
#
|
240
|
+
# Kick off the streaming completion
|
215
241
|
stream = await self.client.chat.completions.create(
|
216
242
|
model=self.model,
|
217
243
|
messages=converted_messages,
|
@@ -221,25 +247,28 @@ class OpenAIProvider(BaseLLMProvider):
|
|
221
247
|
**kwargs
|
222
248
|
)
|
223
249
|
|
224
|
-
full_content =
|
250
|
+
full_content = ""
|
225
251
|
tool_calls: List[Dict[str, Any]] = []
|
226
252
|
|
227
|
-
#
|
253
|
+
# As chunks arrive, yield them and collect any tool_call deltas
|
228
254
|
async for chunk in stream:
|
229
255
|
delta = chunk.choices[0].delta
|
230
|
-
|
231
|
-
|
232
|
-
|
256
|
+
|
257
|
+
# 1) Stream content immediately
|
258
|
+
if hasattr(delta, "content") and delta.content:
|
233
259
|
yield delta.content
|
260
|
+
full_content += delta.content
|
234
261
|
|
235
|
-
#
|
236
|
-
if hasattr(delta,
|
262
|
+
# 2) Buffer up any function/tool calls for after the stream
|
263
|
+
if hasattr(delta, "tool_calls") and delta.tool_calls:
|
237
264
|
for d in delta.tool_calls:
|
238
265
|
idx = d.index
|
239
|
-
# Ensure
|
266
|
+
# Ensure list is long enough
|
240
267
|
while len(tool_calls) <= idx:
|
241
|
-
tool_calls.append({
|
242
|
-
|
268
|
+
tool_calls.append({
|
269
|
+
"id": "",
|
270
|
+
"function": {"name": "", "arguments": ""}
|
271
|
+
})
|
243
272
|
if d.id:
|
244
273
|
tool_calls[idx]["id"] = d.id
|
245
274
|
if d.function.name:
|
@@ -247,35 +276,39 @@ class OpenAIProvider(BaseLLMProvider):
|
|
247
276
|
if d.function.arguments:
|
248
277
|
tool_calls[idx]["function"]["arguments"] += d.function.arguments
|
249
278
|
|
250
|
-
# If
|
279
|
+
# If the assistant didn't invoke any tools, we're done
|
251
280
|
if not tool_calls:
|
252
281
|
return
|
253
282
|
|
254
|
-
#
|
283
|
+
# Otherwise, append the assistant's outgoing call and loop for tool execution
|
255
284
|
converted_messages.append({
|
256
285
|
"role": "assistant",
|
257
|
-
"content":
|
286
|
+
"content": full_content,
|
258
287
|
"tool_calls": [
|
259
|
-
{
|
260
|
-
|
261
|
-
|
262
|
-
|
288
|
+
{
|
289
|
+
"id": tc["id"],
|
290
|
+
"type": "function",
|
291
|
+
"function": {
|
292
|
+
"name": tc["function"]["name"],
|
293
|
+
"arguments": tc["function"]["arguments"]
|
294
|
+
}
|
263
295
|
}
|
264
296
|
for tc in tool_calls
|
265
297
|
]
|
266
298
|
})
|
267
299
|
|
268
|
-
# Execute each tool
|
300
|
+
# Execute each tool in turn and append its result
|
269
301
|
for tc in tool_calls:
|
270
302
|
name = tc["function"]["name"]
|
271
303
|
try:
|
272
|
-
args = json.loads(tc["function"]["arguments"])
|
304
|
+
args = json.loads(tc["function"]["arguments"] or "{}")
|
273
305
|
except json.JSONDecodeError:
|
274
306
|
args = {}
|
275
|
-
|
276
307
|
try:
|
277
|
-
|
278
|
-
result_content =
|
308
|
+
tool_res = await self.tools_manager.call_tool(name, args)
|
309
|
+
result_content = getattr(tool_res, "content", None) or (
|
310
|
+
tool_res.get("content") if isinstance(tool_res, dict) else str(tool_res)
|
311
|
+
)
|
279
312
|
except Exception as e:
|
280
313
|
result_content = f"Error executing {name}: {e}"
|
281
314
|
|
@@ -284,5 +317,9 @@ class OpenAIProvider(BaseLLMProvider):
|
|
284
317
|
"content": result_content,
|
285
318
|
"tool_call_id": tc["id"]
|
286
319
|
})
|
287
|
-
|
288
|
-
raise RuntimeError("Max tool iteration depth reached in generate_stream().")
|
320
|
+
|
321
|
+
raise RuntimeError("Max tool iteration depth reached in generate_stream().")
|
322
|
+
|
323
|
+
|
324
|
+
|
325
|
+
|
@@ -0,0 +1,130 @@
|
|
1
|
+
# Library imports
|
2
|
+
from typing import Optional, Any
|
3
|
+
from pydantic import BaseModel, ValidationError
|
4
|
+
from fastapi import HTTPException
|
5
|
+
|
6
|
+
# Local imports
|
7
|
+
from smarta2a.utils.types import (
|
8
|
+
JSONRPCRequest,
|
9
|
+
JSONRPCResponse,
|
10
|
+
SendTaskRequest,
|
11
|
+
SendTaskStreamingRequest,
|
12
|
+
GetTaskRequest,
|
13
|
+
CancelTaskRequest,
|
14
|
+
SetTaskPushNotificationRequest,
|
15
|
+
GetTaskPushNotificationRequest,
|
16
|
+
GetTaskResponse,
|
17
|
+
CancelTaskResponse,
|
18
|
+
SetTaskPushNotificationResponse,
|
19
|
+
GetTaskPushNotificationResponse,
|
20
|
+
TaskStatus,
|
21
|
+
TaskState
|
22
|
+
)
|
23
|
+
from smarta2a.server.state_manager import StateManager
|
24
|
+
from smarta2a.server.request_handler import RequestHandler
|
25
|
+
from smarta2a.server.handler_registry import HandlerRegistry
|
26
|
+
from smarta2a.utils.types import (
|
27
|
+
TaskNotFoundError,
|
28
|
+
MethodNotFoundError,
|
29
|
+
InvalidParamsError,
|
30
|
+
UnsupportedOperationError,
|
31
|
+
InternalError,
|
32
|
+
InvalidRequestError
|
33
|
+
)
|
34
|
+
|
35
|
+
class JSONRPCRequestProcessor:
|
36
|
+
def __init__(self, registry: HandlerRegistry, state_manager: Optional[StateManager] = None):
|
37
|
+
self.request_handler = RequestHandler(registry, state_manager)
|
38
|
+
self.state_manager = state_manager
|
39
|
+
|
40
|
+
async def process_request(self, request: JSONRPCRequest) -> JSONRPCResponse:
|
41
|
+
|
42
|
+
try:
|
43
|
+
method = request.method
|
44
|
+
params = request.params
|
45
|
+
|
46
|
+
|
47
|
+
match method:
|
48
|
+
case "tasks/send":
|
49
|
+
send_task_request = self._validate_request(request, SendTaskRequest)
|
50
|
+
|
51
|
+
if self.state_manager:
|
52
|
+
state_data = await self.state_manager.get_or_create_and_update_state(send_task_request.params.id, send_task_request.params.sessionId, send_task_request.params.message, send_task_request.params.metadata, send_task_request.params.pushNotification)
|
53
|
+
return await self.request_handler.handle_send_task(send_task_request, state_data)
|
54
|
+
else:
|
55
|
+
return await self.request_handler.handle_send_task(send_task_request)
|
56
|
+
|
57
|
+
case "tasks/sendSubscribe":
|
58
|
+
send_subscribe_request = self._validate_request(request, SendTaskStreamingRequest)
|
59
|
+
if self.state_manager:
|
60
|
+
state_data = await self.state_manager.get_or_create_and_update_state(send_subscribe_request.params.id, send_subscribe_request.params.sessionId, send_subscribe_request.params.message, send_subscribe_request.params.metadata, send_subscribe_request.params.pushNotification)
|
61
|
+
return await self.request_handler.handle_subscribe_task(send_subscribe_request, state_data)
|
62
|
+
else:
|
63
|
+
return await self.request_handler.handle_subscribe_task(send_subscribe_request)
|
64
|
+
|
65
|
+
case "tasks/get":
|
66
|
+
get_task_request = self._validate_request(request, GetTaskRequest)
|
67
|
+
if self.state_manager:
|
68
|
+
state_data = self.state_manager.get_state(get_task_request.id)
|
69
|
+
if state_data:
|
70
|
+
return GetTaskResponse(
|
71
|
+
id=get_task_request.id,
|
72
|
+
result=state_data.task
|
73
|
+
)
|
74
|
+
else:
|
75
|
+
return JSONRPCResponse(id=request.id, error=TaskNotFoundError())
|
76
|
+
else:
|
77
|
+
return self.request_handler.handle_get_task(get_task_request)
|
78
|
+
|
79
|
+
case "tasks/cancel":
|
80
|
+
cancel_task_request = self._validate_request(request, CancelTaskRequest)
|
81
|
+
if self.state_manager:
|
82
|
+
state_data = self.state_manager.get_state(cancel_task_request.id)
|
83
|
+
if state_data:
|
84
|
+
state_data.task.status = TaskStatus(state=TaskState.CANCELLED)
|
85
|
+
self.state_manager.update_state(cancel_task_request.id, state_data)
|
86
|
+
return CancelTaskResponse(id=cancel_task_request.id)
|
87
|
+
else:
|
88
|
+
return JSONRPCResponse(id=request.id, error=TaskNotFoundError())
|
89
|
+
else:
|
90
|
+
return self.request_handler.handle_cancel_task(cancel_task_request)
|
91
|
+
|
92
|
+
case "tasks/pushNotification/set":
|
93
|
+
set_push_notification_request = self._validate_request(request, SetTaskPushNotificationRequest)
|
94
|
+
if self.state_manager:
|
95
|
+
state_data = self.state_manager.get_state(set_push_notification_request.id)
|
96
|
+
if state_data:
|
97
|
+
state_data.push_notification_config = set_push_notification_request.pushNotificationConfig
|
98
|
+
self.state_manager.update_state(set_push_notification_request.id, state_data)
|
99
|
+
return SetTaskPushNotificationResponse(id=set_push_notification_request.id, result=state_data.push_notification_config)
|
100
|
+
else:
|
101
|
+
return JSONRPCResponse(id=request.id, error=TaskNotFoundError())
|
102
|
+
else:
|
103
|
+
return self.request_handler.handle_set_notification(set_push_notification_request)
|
104
|
+
|
105
|
+
case "tasks/pushNotification/get":
|
106
|
+
get_push_notification_request = self._validate_request(request, GetTaskPushNotificationRequest)
|
107
|
+
if self.state_manager:
|
108
|
+
state_data = self.state_manager.get_state(get_push_notification_request.id)
|
109
|
+
if state_data:
|
110
|
+
return GetTaskPushNotificationResponse(id=get_push_notification_request.id, result=state_data.push_notification_config)
|
111
|
+
else:
|
112
|
+
return JSONRPCResponse(id=request.id, error=TaskNotFoundError())
|
113
|
+
else:
|
114
|
+
return self.request_handler.handle_get_notification(get_push_notification_request)
|
115
|
+
|
116
|
+
case _:
|
117
|
+
return JSONRPCResponse(id=request.id, error=MethodNotFoundError()).model_dump()
|
118
|
+
|
119
|
+
except ValidationError as e:
|
120
|
+
return JSONRPCResponse(id=request.id, error=InvalidParamsError(data=e.errors())).model_dump()
|
121
|
+
except HTTPException as e:
|
122
|
+
err = UnsupportedOperationError() if e.status_code == 405 else InternalError(data=str(e))
|
123
|
+
return JSONRPCResponse(id=request.id, error=err).model_dump()
|
124
|
+
|
125
|
+
|
126
|
+
def _validate_request(self, request: JSONRPCRequest, validation_schema: BaseModel) -> Any:
|
127
|
+
try:
|
128
|
+
return validation_schema.model_validate(request.model_dump())
|
129
|
+
except ValidationError as e:
|
130
|
+
return JSONRPCResponse(id=request.id, error=InvalidRequestError(data=e.errors())).model_dump()
|
@@ -0,0 +1,49 @@
|
|
1
|
+
# Library imports
|
2
|
+
import asyncio
|
3
|
+
import json
|
4
|
+
from typing import Any, Dict
|
5
|
+
from nats.aio.client import Client as NATS
|
6
|
+
|
7
|
+
# Local imports
|
8
|
+
|
9
|
+
class NATSClient:
|
10
|
+
def __init__(self, server_url: str = "nats://localhost:4222"):
|
11
|
+
self.server_url = server_url
|
12
|
+
self.nats = NATS()
|
13
|
+
self._connected = False
|
14
|
+
|
15
|
+
async def connect(self) -> None:
|
16
|
+
"""Establishes an asynchronous connection to the NATS server."""
|
17
|
+
if not self._connected:
|
18
|
+
try:
|
19
|
+
# Use the current running loop by default
|
20
|
+
await self.nats.connect(self.server_url)
|
21
|
+
self._connected = True
|
22
|
+
print(f"Connected to NATS at {self.server_url}")
|
23
|
+
except Exception as e:
|
24
|
+
print(f"Failed to connect to NATS: {e}")
|
25
|
+
raise
|
26
|
+
|
27
|
+
async def publish(self, subject: str, payload: Dict[str, Any]) -> None:
|
28
|
+
"""Publishes a JSON-encoded message to a NATS subject, auto-connecting if needed."""
|
29
|
+
if not self._connected:
|
30
|
+
# Ensure connection before publishing
|
31
|
+
await self.connect()
|
32
|
+
|
33
|
+
try:
|
34
|
+
data = json.dumps(payload).encode()
|
35
|
+
await self.nats.publish(subject, data)
|
36
|
+
except Exception as e:
|
37
|
+
print(f"Failed to publish message: {e}")
|
38
|
+
raise
|
39
|
+
|
40
|
+
async def close(self) -> None:
|
41
|
+
"""Close NATS connection gracefully"""
|
42
|
+
if self._connected:
|
43
|
+
await self.nats.close()
|
44
|
+
self._connected = False
|
45
|
+
print("NATS connection closed")
|
46
|
+
|
47
|
+
@property
|
48
|
+
def is_connected(self) -> bool:
|
49
|
+
return self._connected
|