econagents 0.0.5__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,45 @@
1
+ import json
2
+ from typing import Any, Dict, Optional
3
+
4
+ from econagents import AgentRole
5
+ from econagents.core.events import Message
6
+ from econagents.core.manager.phase import PhaseManager
7
+ from econagents.core.state.game import GameState
8
+ from econagents.config_parser.base import BaseConfigParser
9
+
10
+
11
+ class BasicConfigParser(BaseConfigParser):
12
+ """
13
+ Basic configuration parser that adds a custom event handler for sending a
14
+ player-is-ready message when it receives a certain message from the server.
15
+ """
16
+
17
+ def create_manager(
18
+ self, game_id: int, state: GameState, agent_role: Optional[AgentRole], auth_kwargs: Dict[str, Any]
19
+ ) -> PhaseManager:
20
+ """
21
+ Create a manager instance with a custom event handler for the assign-name event.
22
+
23
+ Args:
24
+ game_id: The game ID
25
+ state: The game state instance
26
+ agent_role: The agent role instance
27
+ auth_kwargs: Authentication mechanism keyword arguments
28
+
29
+ Returns:
30
+ A PhaseManager instance with custom event handlers
31
+ """
32
+ # Get the base manager
33
+ manager = super().create_manager(game_id=game_id, state=state, agent_role=agent_role, auth_kwargs=auth_kwargs)
34
+
35
+ # Register custom event handler for assign-name event
36
+ async def handle_name_assignment(message: Message) -> None:
37
+ """Handle the name assignment event."""
38
+ # Include the agent ID from auth_kwargs in the ready message
39
+ agent_id = auth_kwargs.get("agent_id")
40
+ ready_msg = {"gameId": game_id, "type": "player-is-ready", "agentId": agent_id}
41
+ await manager.send_message(json.dumps(ready_msg))
42
+
43
+ manager.register_event_handler("assign-name", handle_name_assignment)
44
+
45
+ return manager
@@ -0,0 +1,243 @@
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Any, Dict, Optional, Type, List, Callable, cast
4
+ from pydantic import create_model
5
+
6
+ from econagents import AgentRole
7
+ from econagents.core.events import Message
8
+ from econagents.core.manager.phase import PhaseManager
9
+ from econagents.core.state.game import GameState
10
+ from econagents.core.game_runner import GameRunner
11
+ from econagents.config_parser.base import BaseConfigParser
12
+ from econagents.llm.observability import get_observability_provider
13
+
14
+
15
+ def handle_market_event_impl(self: GameState, event_type: str, data: dict[str, Any]) -> None:
16
+ """Handles market-related events by delegating to the MarketState instance."""
17
+ try:
18
+ getattr(self.public_information, self.meta._market_state_variable_name).process_event( # type: ignore
19
+ event_type=event_type, data=data
20
+ )
21
+ except Exception as e:
22
+ raise ValueError(f"Error processing market event: {e}") from e
23
+
24
+
25
+ def handle_asset_movement_event_impl(self: GameState, event_type: str, data: dict[str, Any]) -> None:
26
+ """Handles asset-movement events by delegating to the MarketState instance."""
27
+ try:
28
+ winning_condition = self.public_information.winning_condition # type: ignore
29
+ self.private_information.wallet[winning_condition]["balance"] = data["balance"] # type: ignore
30
+ self.private_information.wallet[winning_condition]["shares"] = data["shares"] # type: ignore
31
+ except Exception as e:
32
+ raise ValueError(f"Error processing asset-movement event: {e}") from e
33
+
34
+
35
+ def get_custom_handlers_impl(self: GameState) -> Dict[str, Callable[[Message], None]]:
36
+ """Returns custom handlers for market events."""
37
+ market_events = ["add-order", "update-order", "delete-order", "contract-fulfilled"]
38
+ asset_movement_event_handlers: Dict[str, Callable[[Message], None]] = {
39
+ "asset-movement": self._handle_asset_movement_event # type: ignore
40
+ }
41
+ market_event_handlers: Dict[str, Callable[[Message], None]] = {
42
+ event: self._handle_market_event # type: ignore
43
+ for event in market_events
44
+ }
45
+ return {**asset_movement_event_handlers, **market_event_handlers}
46
+
47
+
48
+ class IbexTudelftConfigParser(BaseConfigParser):
49
+ """
50
+ IBEX-TUDelft configuration parser that extends the BasicConfigParser
51
+ and adds role assignment functionality.
52
+ """
53
+
54
+ def __init__(self, config_path: Path):
55
+ """
56
+ Initialize the IBEX-TUDelft config parser.
57
+
58
+ Args:
59
+ config_path: Path to the YAML configuration file
60
+ """
61
+ super().__init__(config_path)
62
+ self._role_classes: Dict[int, Type[AgentRole]] = {}
63
+
64
+ def register_role_class(self, role_id: int, role_class: Type[AgentRole]) -> None:
65
+ """
66
+ Register a role class for a specific role ID.
67
+
68
+ Args:
69
+ role_id: The role ID
70
+ role_class: The agent role class
71
+ """
72
+ self._role_classes[role_id] = role_class
73
+
74
+ def create_manager(
75
+ self, game_id: int, state: GameState, agent_role: Optional[AgentRole], auth_kwargs: Dict[str, Any]
76
+ ) -> PhaseManager:
77
+ """
78
+ Create a manager instance with custom event handlers for assign-name and assign-role events.
79
+ The manager won't have a role initially, but will be assigned one during the game.
80
+
81
+ Args:
82
+ game_id: The game ID
83
+ state: The game state instance
84
+ agent_role: The agent role instance (may be None)
85
+ auth_kwargs: Authentication mechanism keyword arguments
86
+
87
+ Returns:
88
+ A PhaseManager instance with custom event handlers
89
+ """
90
+ # Get the manager with the name assignment handler
91
+ manager = super().create_manager(
92
+ game_id=game_id,
93
+ state=state,
94
+ agent_role=None,
95
+ auth_kwargs=auth_kwargs,
96
+ )
97
+
98
+ # Register custom event handler for assign-name event
99
+ async def handle_name_assignment(message: Message) -> None:
100
+ """Handle the name assignment event."""
101
+ # Include the agent ID from auth_kwargs in the ready message
102
+ agent_id = auth_kwargs.get("agent_id")
103
+ ready_msg = {"gameId": game_id, "type": "player-is-ready", "agentId": agent_id}
104
+ await manager.send_message(json.dumps(ready_msg))
105
+
106
+ # Register custom event handler for assign-role event
107
+ async def handle_role_assignment(message: Message) -> None:
108
+ """Handle the role assignment event."""
109
+ role_id = int(message.data.get("role", 0))
110
+ manager.logger.info(f"Role assigned: {role_id}")
111
+
112
+ # Initialize the agent based on the assigned role
113
+ self._initialize_agent(manager, role_id)
114
+
115
+ manager.register_event_handler("assign-role", handle_role_assignment)
116
+ manager.register_event_handler("assign-name", handle_name_assignment)
117
+
118
+ return manager
119
+
120
+ def _detect_market_state_in_config(self) -> tuple[bool, Optional[tuple[str, str]]]:
121
+ """
122
+ Detects if MarketState is used in the state configuration.
123
+
124
+ Returns:
125
+ A tuple: (has_market_state_field, market_state_details).
126
+ market_state_details is (field_name_on_section, section_attribute_name_on_gamestate)
127
+ or None if not found.
128
+ """
129
+ state_conf = self.config.state
130
+ # Check public_information first, then private, then meta
131
+ for field_conf in state_conf.public_information:
132
+ if field_conf.type == "MarketState":
133
+ return True, (field_conf.name, "public_information")
134
+ for field_conf in state_conf.private_information:
135
+ if field_conf.type == "MarketState":
136
+ return True, (field_conf.name, "private_information")
137
+ for field_conf in state_conf.meta_information:
138
+ if field_conf.type == "MarketState":
139
+ return True, (field_conf.name, "meta")
140
+ return False, None
141
+
142
+ def _create_enhanced_state_class(self, base_class: Type[GameState]) -> Type[GameState]:
143
+ """
144
+ Creates an enhanced GameState class by subclassing base_class and injecting market event handlers.
145
+ """
146
+ enhanced_class_name = f"Enhanced{base_class.__name__}"
147
+ enhanced_class = create_model(
148
+ enhanced_class_name,
149
+ __base__=base_class,
150
+ )
151
+ setattr(enhanced_class, "_handle_market_event", handle_market_event_impl)
152
+ setattr(enhanced_class, "_handle_asset_movement_event", handle_asset_movement_event_impl)
153
+ setattr(enhanced_class, "get_custom_handlers", get_custom_handlers_impl)
154
+ return cast(Type[GameState], enhanced_class)
155
+
156
+ def _check_additional_required_fields(self, base_dynamic_state_class: Type[GameState]) -> None:
157
+ public_fields = base_dynamic_state_class().public_information.model_json_schema()["properties"] # type: ignore
158
+ private_fields = base_dynamic_state_class().private_information.model_json_schema()["properties"] # type: ignore
159
+
160
+ if (
161
+ "winning_condition" not in public_fields.keys() # type: ignore
162
+ or "wallet" not in private_fields.keys() # type: ignore
163
+ ):
164
+ raise ValueError("Winning condition or wallet is not present in the config")
165
+
166
+ async def run_experiment(self, login_payloads: List[Dict[str, Any]], game_id: int) -> None:
167
+ """
168
+ Run the experiment from this configuration, potentially enhancing the GameState
169
+ class with market event handlers if MarketState is specified in the config.
170
+ """
171
+ # Step 1: Get the base state class from the original StateConfig
172
+ base_dynamic_state_class = self.config.state.create_state_class()
173
+
174
+ # Step 2: Detect if MarketState is used and get details
175
+ has_market_state_field, market_state_details = self._detect_market_state_in_config()
176
+
177
+ # Step 3: If MarketState is present, create an enhanced state class
178
+ if has_market_state_field and market_state_details:
179
+ self._check_additional_required_fields(base_dynamic_state_class)
180
+ final_state_class = self._create_enhanced_state_class(base_dynamic_state_class)
181
+ else:
182
+ final_state_class = base_dynamic_state_class
183
+
184
+ if not self.config.agent_roles:
185
+ raise ValueError("Configuration has no 'agent_roles'.")
186
+
187
+ # Create managers for each agent
188
+ agents_for_runner = []
189
+ for payload in login_payloads:
190
+ current_agent_manager = self.create_manager(
191
+ game_id=game_id,
192
+ state=final_state_class(),
193
+ agent_role=None,
194
+ auth_kwargs=payload,
195
+ )
196
+ current_agent_manager.state.meta.game_id = game_id
197
+ if market_state_details:
198
+ setattr(current_agent_manager.state.meta, "_market_state_variable_name", market_state_details[0])
199
+ agents_for_runner.append(current_agent_manager)
200
+
201
+ # Create runner config
202
+ runner_config_instance = self.config.runner.create_runner_config()
203
+ runner_config_instance.state_class = final_state_class
204
+ runner_config_instance.game_id = game_id
205
+
206
+ if any(hasattr(role, "prompts") and role.prompts for role in self.config.agent_roles):
207
+ prompts_dir = self.config._compile_inline_prompts()
208
+ runner_config_instance.prompts_dir = prompts_dir
209
+
210
+ runner = GameRunner(config=runner_config_instance, agents=agents_for_runner)
211
+ await runner.run_game()
212
+
213
+ if self.config._temp_prompts_dir and self.config._temp_prompts_dir.exists():
214
+ import shutil
215
+
216
+ shutil.rmtree(self.config._temp_prompts_dir)
217
+
218
+ def _initialize_agent(self, manager: PhaseManager, role_id: int) -> None:
219
+ """
220
+ Initialize the agent instance based on the assigned role.
221
+
222
+ Args:
223
+ manager: The phase manager instance
224
+ role_id: The role ID
225
+ """
226
+ agent_roles = self.config.agent_roles
227
+ agent_role = next((role for role in agent_roles if role.role_id == role_id), None)
228
+ if agent_role:
229
+ manager.agent_role = agent_role.create_agent_role()
230
+ manager.agent_role.logger = manager.logger # type: ignore
231
+ if self.config.runner.observability_provider:
232
+ manager.agent_role.llm.observability = get_observability_provider(
233
+ self.config.runner.observability_provider
234
+ )
235
+ else:
236
+ manager.logger.error("Invalid role assigned; cannot initialize agent.")
237
+ raise ValueError("Invalid role for agent initialization.")
238
+
239
+
240
+ async def run_experiment_from_yaml(yaml_path: Path, login_payloads: List[Dict[str, Any]], game_id: int) -> None:
241
+ """Run an experiment from a YAML configuration file."""
242
+ parser = IbexTudelftConfigParser(yaml_path)
243
+ await parser.run_experiment(login_payloads, game_id)
@@ -4,7 +4,7 @@ import queue
4
4
  from contextvars import ContextVar
5
5
  from logging.handlers import QueueHandler, QueueListener
6
6
  from pathlib import Path
7
- from typing import Literal, Optional, Type
7
+ from typing import Literal, Optional, Type, List
8
8
 
9
9
  from pydantic import BaseModel, Field
10
10
 
@@ -63,6 +63,11 @@ class GameRunnerConfig(BaseModel):
63
63
  observability_provider: Optional[Literal["langsmith", "langfuse"]] = None
64
64
  """Name of the observability provider to use. Options: 'langsmith' or 'langfuse'"""
65
65
 
66
+ max_game_duration: int = Field(
67
+ default=600,
68
+ description="Maximum game duration in seconds. Default is 600 (10 minutes). Set to 0 or a negative value to disable the timeout.",
69
+ )
70
+
66
71
  # Agent stop configuration
67
72
  end_game_event: str = "game-over"
68
73
  """Event type that signals the end of the game and should stop the agent."""
@@ -301,6 +306,8 @@ class GameRunner:
301
306
 
302
307
  if not agent_manager.state and self.config.state_class:
303
308
  agent_manager.state = self.config.state_class()
309
+ # Set the game_id in the state
310
+ agent_manager.state.meta.game_id = self.config.game_id
304
311
  agent_manager.logger.debug(f"Injected default state: {agent_manager.state}")
305
312
 
306
313
  if not agent_manager.auth_mechanism:
@@ -338,8 +345,6 @@ class GameRunner:
338
345
  agent_id (int): Agent identifier
339
346
  """
340
347
  agent_logger = self.get_agent_logger(agent_id, self.config.game_id)
341
- ctx_agent_id.set(str(agent_id)) # Convert int to str for context variable
342
-
343
348
  agent_manager.logger = agent_logger
344
349
 
345
350
  async def spawn_agent(self, agent_manager: PhaseManager, agent_id: int) -> None:
@@ -350,32 +355,126 @@ class GameRunner:
350
355
  agent_manager (PhaseManager): Agent manager to spawn
351
356
  agent_id (int): Agent identifier
352
357
  """
358
+ token = ctx_agent_id.set(str(agent_id))
353
359
  try:
354
360
  self._inject_agent_logger(agent_manager, agent_id)
355
361
  self._inject_default_config(agent_manager)
356
362
 
357
363
  agent_manager.logger.info(f"Connecting to WebSocket URL: {agent_manager.url}")
358
364
  await agent_manager.start()
365
+ except asyncio.CancelledError:
366
+ agent_manager.logger.info(f"Agent {agent_id} spawn_agent task was cancelled.")
367
+ raise
359
368
  except Exception:
360
369
  agent_manager.logger.exception(f"Error in game for Agent {agent_id}")
361
370
  raise
371
+ finally:
372
+ agent_manager.logger.info(
373
+ f"Agent {agent_id} spawn_agent task finished. Agent running state: {agent_manager.running}"
374
+ )
375
+ ctx_agent_id.reset(token)
376
+
377
+ async def _timeout_watchdog(self, game_logger: logging.Logger, agent_tasks: List[asyncio.Task]) -> None:
378
+ """
379
+ Timeout watchdog that monitors game duration and initiates shutdown when exceeded.
380
+
381
+ Args:
382
+ game_logger: Logger instance for the game
383
+ agent_tasks: List of agent tasks to cancel if timeout occurs
384
+ """
385
+ try:
386
+ await asyncio.sleep(self.config.max_game_duration)
387
+ game_logger.warning(
388
+ f"Game {self.config.game_id} reached maximum duration of {self.config.max_game_duration}s. Initiating shutdown."
389
+ )
390
+
391
+ stop_agent_manager_tasks = []
392
+ for idx, ag_mgr in enumerate(self.agents):
393
+ if ag_mgr.running:
394
+ game_logger.info(f"Timeout: Stopping agent {idx + 1} for game {self.config.game_id}.")
395
+ stop_agent_manager_tasks.append(
396
+ asyncio.create_task(ag_mgr.stop(), name=f"TimeoutStopAgent-{self.config.game_id}-{idx + 1}")
397
+ )
398
+ if stop_agent_manager_tasks:
399
+ await asyncio.gather(*stop_agent_manager_tasks, return_exceptions=True)
400
+
401
+ game_logger.info(f"Timeout: Checking for lingering agent tasks for game {self.config.game_id}.")
402
+ for agent_task_instance in agent_tasks:
403
+ if not agent_task_instance.done():
404
+ game_logger.warning(
405
+ f"Timeout: Agent task {agent_task_instance.get_name()} still running. Cancelling."
406
+ )
407
+ agent_task_instance.cancel()
408
+ # Await the cancellation to ensure it's processed
409
+ try:
410
+ await agent_task_instance
411
+ except asyncio.CancelledError:
412
+ game_logger.info(f"Agent task {agent_task_instance.get_name()} cancelled by timeout watchdog.")
413
+ except Exception as e_cancel:
414
+ game_logger.error(
415
+ f"Error awaiting cancelled agent task {agent_task_instance.get_name()}: {e_cancel}"
416
+ )
417
+
418
+ except asyncio.CancelledError:
419
+ game_logger.info(f"Timeout watchdog for game {self.config.game_id} was cancelled.")
420
+ raise
421
+ except Exception as e_watchdog:
422
+ game_logger.exception(f"Error in timeout watchdog for game {self.config.game_id}: {e_watchdog}")
362
423
 
363
424
  async def run_game(self) -> None:
364
425
  """Run a game using provided game data."""
365
426
 
366
427
  game_logger = self.get_game_logger(self.config.game_id)
367
- game_logger.info(f"Running game with ID: {self.config.game_id}")
428
+ game_logger.info(f"Running game with ID: {self.config.game_id}. Max duration: {self.config.max_game_duration}s")
429
+
430
+ agent_tasks: List[asyncio.Task] = []
431
+ timeout_monitor_task: Optional[asyncio.Task] = None
368
432
 
369
433
  try:
370
- tasks = []
371
- game_logger.info("Starting game")
372
-
373
- for i, agent_manager in enumerate(self.agents, start=1):
374
- tasks.append(self.spawn_agent(agent_manager, i))
375
- await asyncio.gather(*tasks)
376
- except Exception as e:
377
- game_logger.exception(f"Failed to run game: {e}")
434
+ game_logger.info("Configuring and starting agents...")
435
+ for i, agent_manager_instance in enumerate(self.agents, start=1):
436
+ task = asyncio.create_task(
437
+ self.spawn_agent(agent_manager_instance, i), name=f"AgentTask-{self.config.game_id}-{i}"
438
+ )
439
+ agent_tasks.append(task)
440
+
441
+ if self.config.max_game_duration is not None and self.config.max_game_duration > 0:
442
+ timeout_monitor_task = asyncio.create_task(
443
+ self._timeout_watchdog(game_logger, agent_tasks), name=f"TimeoutWatchdog-{self.config.game_id}"
444
+ )
445
+
446
+ if agent_tasks:
447
+ results = await asyncio.gather(*agent_tasks, return_exceptions=True)
448
+ for i, result in enumerate(results):
449
+ if isinstance(result, Exception):
450
+ game_logger.error(f"Agent task {agent_tasks[i].get_name()} failed with: {result}")
451
+ else:
452
+ game_logger.debug("No agent tasks to run.")
453
+
454
+ except Exception as e_run_game:
455
+ game_logger.exception(f"GameRunner.run_game failed: {e_run_game}")
378
456
  raise
379
457
  finally:
380
- game_logger.info("Game over")
458
+ if timeout_monitor_task and not timeout_monitor_task.done():
459
+ game_logger.debug(
460
+ f"Game {self.config.game_id} finished or errored before timeout. Cancelling timeout watchdog."
461
+ )
462
+ timeout_monitor_task.cancel()
463
+
464
+ game_logger.info(f"Game {self.config.game_id}: Final cleanup - ensuring all agents are stopped.")
465
+ final_stop_tasks = []
466
+ for idx, agent_manager_instance in enumerate(self.agents):
467
+ if agent_manager_instance.running:
468
+ game_logger.info(f"Final cleanup: Stopping agent {idx + 1} for game {self.config.game_id}.")
469
+ final_stop_tasks.append(
470
+ asyncio.create_task(
471
+ agent_manager_instance.stop(), name=f"FinalStopAgent-{self.config.game_id}-{idx + 1}"
472
+ )
473
+ )
474
+
475
+ if final_stop_tasks:
476
+ await asyncio.gather(*final_stop_tasks, return_exceptions=True)
477
+ game_logger.info(f"Game {self.config.game_id}: Final agent stop tasks completed.")
478
+
381
479
  self.cleanup_logging()
480
+ game_logger.info(f"Game {self.config.game_id} finished and cleaned up.")
@@ -142,7 +142,7 @@ class AgentManager(LoggerMixin):
142
142
  auth_mechanism_kwargs=self._auth_mechanism_kwargs,
143
143
  )
144
144
 
145
- def _raw_message_received(self, raw_message: str):
145
+ async def _raw_message_received(self, raw_message: str):
146
146
  """Process raw message from the transport layer"""
147
147
  msg = self._extract_message_data(raw_message)
148
148
  if msg:
@@ -218,19 +218,15 @@ class AgentManager(LoggerMixin):
218
218
  raise ValueError("URL must be set before starting the agent manager")
219
219
 
220
220
  self.running = True
221
- connected = await self.transport.connect()
222
- if connected:
223
- self.logger.info("Connected to WebSocket server. Receiving messages...")
224
- await self.transport.start_listening()
225
- else:
226
- self.logger.error("Failed to connect to WebSocket server")
221
+ self.logger.info("Starting agent manager. Receiving messages...")
222
+ await self.transport.start_listening()
227
223
 
228
224
  async def stop(self):
229
225
  """Stop the agent manager and close the connection."""
230
226
  self.running = False
231
227
  if self.transport:
232
228
  await self.transport.stop()
233
- self.logger.info("Agent manager stopped and connection closed.")
229
+ self.logger.info("Agent manager stopped and connection closed.")
234
230
 
235
231
  async def on_event(self, message: Message):
236
232
  """
@@ -69,14 +69,11 @@ class WebSocketTransport(LoggerMixin):
69
69
  self.on_message_callback = on_message_callback
70
70
  self.ws: Optional[ClientConnection] = None
71
71
  self._running = False
72
+ self._authenticated = False
72
73
 
73
- async def connect(self) -> bool:
74
- """Establish the WebSocket connection and authenticate."""
74
+ async def _authenticate(self) -> bool:
75
+ """Authenticate the connection."""
75
76
  try:
76
- self.ws = await websockets.connect(self.url, ping_interval=30, ping_timeout=10)
77
- self.logger.info("WebSocketTransport: connection opened.")
78
-
79
- # Perform authentication using the callback
80
77
  if self.auth_mechanism:
81
78
  if not self.auth_mechanism_kwargs:
82
79
  self.auth_mechanism_kwargs = {}
@@ -84,9 +81,9 @@ class WebSocketTransport(LoggerMixin):
84
81
  if not auth_success:
85
82
  self.logger.error("Authentication failed")
86
83
  await self.stop()
87
- self.ws = None # Ensure ws is set to None after stopping
84
+ self.ws = None
88
85
  return False
89
-
86
+ self._authenticated = True
90
87
  except Exception as e:
91
88
  self.logger.exception(f"Transport connection error: {e}")
92
89
  return False
@@ -95,24 +92,55 @@ class WebSocketTransport(LoggerMixin):
95
92
 
96
93
  async def start_listening(self):
97
94
  """Begin receiving messages in a loop."""
95
+ self.logger.info("WebSocketTransport: starting to listen.")
98
96
  self._running = True
99
- while self._running and self.ws:
100
- try:
101
- message_str = await self.ws.recv()
102
- if self.on_message_callback:
103
- # Call the callback, supporting both sync and async functions
104
- self.logger.debug(f"<-- Transport received: {message_str}")
105
- result = self.on_message_callback(message_str)
106
- # If the callback is a coroutine function, await it
107
- if asyncio.iscoroutine(result):
108
- asyncio.create_task(result)
109
- except ConnectionClosed:
110
- self.logger.info("WebSocket connection closed by remote.")
111
- break
112
- except Exception:
113
- self.logger.exception("Error in receive loop.")
114
- break
115
- self._running = False
97
+
98
+ try:
99
+ async for websocket in websockets.connect(self.url):
100
+ if not self._running:
101
+ self.logger.info("WebSocketTransport: stopping as requested.")
102
+ break
103
+
104
+ try:
105
+ self.ws = websocket
106
+
107
+ if not self._authenticated:
108
+ await self._authenticate()
109
+ if not self._authenticated:
110
+ self.logger.error("Authentication failed. Stopping transport.")
111
+ break
112
+
113
+ async for message in self.ws:
114
+ if not self._running:
115
+ break
116
+ if self.on_message_callback:
117
+ self.logger.info(f"<-- Transport received: {message}")
118
+ await self.on_message_callback(message)
119
+
120
+ except ConnectionClosed as e:
121
+ self.logger.info(f"WebSocketTransport: connection closed: ({e.code}) {e.reason}")
122
+ if not self._running:
123
+ self.logger.info("WebSocketTransport: connection closed by client. Stopping transport.")
124
+ break
125
+ self.logger.info("WebSocketTransport: reconnecting...")
126
+ continue
127
+ except Exception as e:
128
+ self.logger.exception(f"Error in receive loop: {e}")
129
+ break
130
+ finally:
131
+ if self.ws:
132
+ try:
133
+ await self.ws.close()
134
+ self.logger.info("WebSocketTransport: connection closed.")
135
+ except Exception as e:
136
+ self.logger.debug(f"Error closing websocket: {e}")
137
+ finally:
138
+ self.ws = None
139
+ except Exception as e:
140
+ self.logger.exception(f"Error in start_listening: {e}")
141
+ finally:
142
+ self._running = False
143
+ self.logger.info("WebSocketTransport: stopped listening.")
116
144
 
117
145
  async def send(self, message: str):
118
146
  """Send a raw string message to the WebSocket."""
@@ -125,8 +153,14 @@ class WebSocketTransport(LoggerMixin):
125
153
 
126
154
  async def stop(self):
127
155
  """Gracefully close the WebSocket connection."""
156
+ self.logger.info("WebSocketTransport: stopping...")
128
157
  self._running = False
129
158
  if self.ws:
130
- await self.ws.close()
131
- self.logger.info("WebSocketTransport: connection closed.")
132
- self.ws = None # Set ws to None after closing
159
+ try:
160
+ await self.ws.close()
161
+ self.logger.info("WebSocketTransport: connection closed.")
162
+ except Exception as e:
163
+ self.logger.debug(f"Error during stop: {e}")
164
+ finally:
165
+ self.ws = None
166
+ self._authenticated = False