artificialbrains-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ab_sdk/__init__.py +63 -0
- ab_sdk/client.py +270 -0
- ab_sdk/contract_scaffold.py +315 -0
- ab_sdk/endpoints.py +48 -0
- ab_sdk/input_streamer.py +174 -0
- ab_sdk/plugins/__init__.py +69 -0
- ab_sdk/plugins/decoder.py +331 -0
- ab_sdk/plugins/deviation.py +59 -0
- ab_sdk/plugins/reward.py +55 -0
- ab_sdk/robot_loop.py +184 -0
- ab_sdk/run_session.py +305 -0
- artificialbrains_sdk-0.1.0.dist-info/METADATA +370 -0
- artificialbrains_sdk-0.1.0.dist-info/RECORD +15 -0
- artificialbrains_sdk-0.1.0.dist-info/WHEEL +5 -0
- artificialbrains_sdk-0.1.0.dist-info/top_level.txt +1 -0
ab_sdk/robot_loop.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
"""High level loop for controlling a robot during a run.
|
|
2
|
+
|
|
3
|
+
The :class:`RobotLoop` coordinates sending the robot's observed state to
|
|
4
|
+
the server, decoding the output spikes returned from the brain and
|
|
5
|
+
applying the resulting command to your hardware. It integrates with
|
|
6
|
+
the :class:`~ab_sdk.run_session.RunSession` lifecycle and uses
|
|
7
|
+
callbacks supplied by the user for state acquisition and command
|
|
8
|
+
execution.
|
|
9
|
+
|
|
10
|
+
Example usage::
|
|
11
|
+
|
|
12
|
+
def get_robot_state():
|
|
13
|
+
return { 'q': current_joint_positions(), 'dq': current_joint_vels(), 'grip': {'pos': gripper_pos}, 'dt': dt }
|
|
14
|
+
|
|
15
|
+
def apply_command(cmd):
|
|
16
|
+
set_joint_targets(cmd['dq'])
|
|
17
|
+
set_gripper(cmd['dg'])
|
|
18
|
+
|
|
19
|
+
loop = RobotLoop(session, state_provider=get_robot_state, command_executor=apply_command)
|
|
20
|
+
loop.run_forever()
|
|
21
|
+
|
|
22
|
+
In this example the brain's decoded commands are applied directly to a
|
|
23
|
+
hardware or simulated robot. If you are still letting the server
|
|
24
|
+
generate ``robot:cmd`` events then you can omit the decoder plugin
|
|
25
|
+
and use the command handler registered on the session instead.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
from __future__ import annotations
|
|
29
|
+
|
|
30
|
+
import logging
|
|
31
|
+
import threading
|
|
32
|
+
import time
|
|
33
|
+
from typing import Any, Callable, Dict, Optional
|
|
34
|
+
|
|
35
|
+
from .run_session import RunSession
|
|
36
|
+
|
|
37
|
+
logger = logging.getLogger(__name__)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class RobotLoop:
|
|
41
|
+
"""Manage the control loop for a robot.
|
|
42
|
+
|
|
43
|
+
The loop periodically collects the robot's current state and sends
|
|
44
|
+
it to the server via the session's realtime channel. When cycle
|
|
45
|
+
update events arrive the associated decoder plugin (if attached) is
|
|
46
|
+
invoked to produce a command which is then passed to the
|
|
47
|
+
user‑supplied command executor.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
session: RunSession
|
|
52
|
+
The active run session to which robot states and commands should
|
|
53
|
+
be associated.
|
|
54
|
+
state_provider: Callable[[], Dict[str, Any]]
|
|
55
|
+
A callback returning the current robot state. This should
|
|
56
|
+
return a dictionary with keys ``q`` (joint positions), ``dq``
|
|
57
|
+
(joint velocities), ``grip`` (dict with ``pos``) and ``dt``
|
|
58
|
+
(time delta since last call). All fields are optional;
|
|
59
|
+
missing values are simply omitted from the state payload. The
|
|
60
|
+
provider is invoked on a background thread at the configured
|
|
61
|
+
tick frequency.
|
|
62
|
+
command_executor: Callable[[Dict[str, Any]], None]
|
|
63
|
+
A callback invoked with a command dictionary returned by the
|
|
64
|
+
decoder plugin. The dictionary has keys ``dq`` (array of
|
|
65
|
+
joint deltas), ``dg`` (scalar gripper command) and any
|
|
66
|
+
additional keys defined by your decoder. This callback should
|
|
67
|
+
apply the command to the actual robot.
|
|
68
|
+
tick_hz: float, optional
|
|
69
|
+
The frequency in Hz at which to send robot states to the
|
|
70
|
+
server. Defaults to 20Hz. Set to 0 to disable periodic
|
|
71
|
+
sending (state must then be sent manually).
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(self, session: RunSession,
|
|
75
|
+
state_provider: Callable[[], Dict[str, Any]],
|
|
76
|
+
command_executor: Callable[[Dict[str, Any]], None],
|
|
77
|
+
tick_hz: float = 20.0) -> None:
|
|
78
|
+
self.session = session
|
|
79
|
+
self.state_provider = state_provider
|
|
80
|
+
self.command_executor = command_executor
|
|
81
|
+
self.tick_hz = tick_hz
|
|
82
|
+
self._running = False
|
|
83
|
+
self._thread: Optional[threading.Thread] = None
|
|
84
|
+
# register to receive cycle updates and decode commands
|
|
85
|
+
self.session.on_cycle_update(self._on_cycle_update)
|
|
86
|
+
|
|
87
|
+
def _on_cycle_update(self, telemetry: Dict[str, Any]) -> None:
|
|
88
|
+
"""Handle cycle update events by decoding commands and applying them.
|
|
89
|
+
|
|
90
|
+
If a decoder plugin is attached to the session then this
|
|
91
|
+
callback will build a dictionary of output matrices keyed by
|
|
92
|
+
output ID and invoke the plugin's ``decode`` method. The
|
|
93
|
+
resulting command dictionary is passed to the supplied
|
|
94
|
+
``command_executor``. Any exceptions raised by the decoder
|
|
95
|
+
are caught and logged; command execution is skipped on error.
|
|
96
|
+
"""
|
|
97
|
+
decoder = self.session.decoder_plugin
|
|
98
|
+
if decoder is None:
|
|
99
|
+
# nothing to do; maybe server will send robot:cmd
|
|
100
|
+
return
|
|
101
|
+
outputs = telemetry.get("outputs", [])
|
|
102
|
+
# build a mapping from output id to matrix (gamma x outputN)
|
|
103
|
+
output_matrices: Dict[str, List[List[int]]] = {}
|
|
104
|
+
for entry in outputs:
|
|
105
|
+
try:
|
|
106
|
+
t_step, out_id, bits = entry
|
|
107
|
+
except ValueError:
|
|
108
|
+
continue
|
|
109
|
+
if out_id not in output_matrices:
|
|
110
|
+
# initialize matrix with zeros
|
|
111
|
+
output_matrices[out_id] = [[0] * self.session.output_n for _ in range(self.session.gamma)]
|
|
112
|
+
# assign row
|
|
113
|
+
row = output_matrices[out_id][int(t_step)]
|
|
114
|
+
# bits may be shorter than output_n; pad
|
|
115
|
+
for i in range(min(len(bits), self.session.output_n)):
|
|
116
|
+
row[i] = 1 if bits[i] else 0
|
|
117
|
+
try:
|
|
118
|
+
command = decoder.decode(output_matrices, context={
|
|
119
|
+
"telemetry": telemetry,
|
|
120
|
+
"session": self.session,
|
|
121
|
+
})
|
|
122
|
+
if command is not None:
|
|
123
|
+
logger.debug("Decoded command: %s", command)
|
|
124
|
+
self.command_executor(command)
|
|
125
|
+
except Exception as exc:
|
|
126
|
+
logger.exception("Decoder error: %s", exc)
|
|
127
|
+
|
|
128
|
+
def _send_robot_state(self) -> None:
|
|
129
|
+
"""Collect the current robot state and emit it to the server."""
|
|
130
|
+
try:
|
|
131
|
+
state = self.state_provider() or {}
|
|
132
|
+
payload = {"runId": self.session.run_id, "state": state}
|
|
133
|
+
self.session.socket.emit("robot:state", payload, namespace=self.session.namespace)
|
|
134
|
+
logger.debug("Sent robot state: %s", payload)
|
|
135
|
+
except Exception as exc:
|
|
136
|
+
logger.exception("Error collecting or sending robot state: %s", exc)
|
|
137
|
+
|
|
138
|
+
def run_forever(self) -> None:
|
|
139
|
+
"""Start the control loop and block until stopped.
|
|
140
|
+
|
|
141
|
+
This method spawns a background thread which periodically
|
|
142
|
+
acquires robot state and sends it to the server. It then
|
|
143
|
+
blocks on the main thread, sleeping indefinitely. To stop the
|
|
144
|
+
loop call :meth:`stop` from another thread or signal handler.
|
|
145
|
+
"""
|
|
146
|
+
if self._running:
|
|
147
|
+
logger.warning("RobotLoop.run_forever() called while already running")
|
|
148
|
+
return
|
|
149
|
+
self._running = True
|
|
150
|
+
if self.tick_hz > 0:
|
|
151
|
+
interval = 1.0 / float(self.tick_hz)
|
|
152
|
+
else:
|
|
153
|
+
interval = 0.0
|
|
154
|
+
# define worker function
|
|
155
|
+
def _worker() -> None:
|
|
156
|
+
while self._running:
|
|
157
|
+
if interval > 0:
|
|
158
|
+
start = time.time()
|
|
159
|
+
self._send_robot_state()
|
|
160
|
+
elapsed = time.time() - start
|
|
161
|
+
sleep_time = max(0.0, interval - elapsed)
|
|
162
|
+
time.sleep(sleep_time)
|
|
163
|
+
else:
|
|
164
|
+
time.sleep(0.1)
|
|
165
|
+
# start worker thread
|
|
166
|
+
self._thread = threading.Thread(target=_worker, name="RobotLoopWorker")
|
|
167
|
+
self._thread.daemon = True
|
|
168
|
+
self._thread.start()
|
|
169
|
+
logger.info("Robot loop started")
|
|
170
|
+
try:
|
|
171
|
+
while self._running:
|
|
172
|
+
time.sleep(1)
|
|
173
|
+
finally:
|
|
174
|
+
self.stop()
|
|
175
|
+
|
|
176
|
+
def stop(self) -> None:
|
|
177
|
+
"""Stop the control loop."""
|
|
178
|
+
if not self._running:
|
|
179
|
+
return
|
|
180
|
+
self._running = False
|
|
181
|
+
logger.info("Stopping robot loop")
|
|
182
|
+
if self._thread is not None:
|
|
183
|
+
self._thread.join(timeout=5.0)
|
|
184
|
+
self._thread = None
|
ab_sdk/run_session.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
"""Per‑run state container and event router.
|
|
2
|
+
|
|
3
|
+
The :class:`RunSession` encapsulates all of the information about a
|
|
4
|
+
running experiment (identified by a unique run ID) and exposes
|
|
5
|
+
conveniences for emitting input chunks, feedback rasters and reward
|
|
6
|
+
signals. It also manages registration of event handlers for
|
|
7
|
+
telemetry and other realtime notifications.
|
|
8
|
+
|
|
9
|
+
You do not create a `RunSession` directly; instead it is returned
|
|
10
|
+
from :meth:`~ab_sdk.client.ABClient.start`. Once created it holds
|
|
11
|
+
references to the originating :class:`~ab_sdk.client.ABClient`, the
|
|
12
|
+
HTTP API contract describing the IO interface and a live
|
|
13
|
+
Socket.IO client joined to the appropriate room. You can attach
|
|
14
|
+
custom decoders, deviation policies and reward policies via
|
|
15
|
+
:meth:`set_decoder`, :meth:`set_deviation` and :meth:`set_reward`.
|
|
16
|
+
|
|
17
|
+
Instances of this class are not thread safe. If you plan to
|
|
18
|
+
consume realtime events in multiple threads you should implement
|
|
19
|
+
appropriate synchronization in your handlers.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
import logging
|
|
25
|
+
import threading
|
|
26
|
+
import time
|
|
27
|
+
from typing import Any, Callable, Dict, Iterable, List, Optional
|
|
28
|
+
|
|
29
|
+
import socketio
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class RunSession:
|
|
35
|
+
"""Represents a single running brain session.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
client: ABClient
|
|
40
|
+
Reference to the client that created this session. Used for
|
|
41
|
+
fallback HTTP operations and error reporting.
|
|
42
|
+
project_id: str
|
|
43
|
+
The project identifier.
|
|
44
|
+
run_id: str
|
|
45
|
+
Unique identifier for this run returned by the server.
|
|
46
|
+
contract: dict
|
|
47
|
+
The run contract returned by the server on start. Contains
|
|
48
|
+
``constants`` and ``io`` keys describing the IO interface.
|
|
49
|
+
socket: socketio.Client
|
|
50
|
+
A connected Socket.IO client already joined to the run room.
|
|
51
|
+
namespace: str
|
|
52
|
+
The namespace on the server to emit/receive events on (e.g. ``"/ab"``).
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(self, client: Any, project_id: str, run_id: str,
|
|
56
|
+
contract: Dict[str, Any], socket: socketio.Client,
|
|
57
|
+
namespace: str) -> None:
|
|
58
|
+
self.client = client
|
|
59
|
+
self.project_id = project_id
|
|
60
|
+
self.run_id = run_id
|
|
61
|
+
self.contract = contract
|
|
62
|
+
self.socket = socket
|
|
63
|
+
self.namespace = namespace
|
|
64
|
+
|
|
65
|
+
# parse constants
|
|
66
|
+
consts = contract.get("constants", {})
|
|
67
|
+
self.gamma: int = int(consts.get("gamma", 64))
|
|
68
|
+
self.output_n: int = int(consts.get("outputWindowN", 32))
|
|
69
|
+
self.feedback_n: int = int(consts.get("feedbackN", 128))
|
|
70
|
+
|
|
71
|
+
# keep track of IO manifest
|
|
72
|
+
self.io_inputs = {item["id"]: item for item in contract.get("io", {}).get("inputs", [])}
|
|
73
|
+
self.io_outputs = {item["id"]: item for item in contract.get("io", {}).get("outputs", [])}
|
|
74
|
+
self.io_feedback = {item["id"]: item for item in contract.get("io", {}).get("feedback", [])}
|
|
75
|
+
self.stdp_layers: List[str] = list(contract.get("io", {}).get("stdp3", {}).get("layers", []))
|
|
76
|
+
|
|
77
|
+
# plugin holders
|
|
78
|
+
self.decoder_plugin: Optional[Any] = None
|
|
79
|
+
self.deviation_plugin: Optional[Any] = None
|
|
80
|
+
self.reward_plugin: Optional[Any] = None
|
|
81
|
+
|
|
82
|
+
# event handlers registry
|
|
83
|
+
self._cycle_handlers: List[Callable[[Dict[str, Any]], None]] = []
|
|
84
|
+
self._io_need_handlers: List[Callable[[Dict[str, Any]], None]] = []
|
|
85
|
+
self._cmd_handlers: List[Callable[[Dict[str, Any]], None]] = []
|
|
86
|
+
|
|
87
|
+
# register default event handlers from contract (if any) when session is created
|
|
88
|
+
self._register_socket_events()
|
|
89
|
+
|
|
90
|
+
# ----------------------------------------------------------------------
|
|
91
|
+
# Event registration API
|
|
92
|
+
#
|
|
93
|
+
# The run session receives events from the server via Socket.IO. You
|
|
94
|
+
# can register additional callbacks for cycle updates, IO needs and
|
|
95
|
+
# command messages. These callbacks will be called sequentially
|
|
96
|
+
# from the Socket.IO event thread, so you should avoid blocking
|
|
97
|
+
# operations inside handlers.
|
|
98
|
+
#
|
|
99
|
+
def on_cycle_update(self, handler: Callable[[Dict[str, Any]], None]) -> None:
|
|
100
|
+
"""Register a handler for cycle update events.
|
|
101
|
+
|
|
102
|
+
The handler is called with the full telemetry payload as
|
|
103
|
+
delivered by the server. You can decode outputs, compute
|
|
104
|
+
rewards and send feedback from within this callback. Multiple
|
|
105
|
+
handlers can be registered; they will be invoked in the order
|
|
106
|
+
they were added.
|
|
107
|
+
|
|
108
|
+
Parameters
|
|
109
|
+
----------
|
|
110
|
+
handler: Callable[[dict], None]
|
|
111
|
+
A callable accepting a telemetry dictionary.
|
|
112
|
+
"""
|
|
113
|
+
self._cycle_handlers.append(handler)
|
|
114
|
+
|
|
115
|
+
def on_io_need(self, handler: Callable[[Dict[str, Any]], None]) -> None:
|
|
116
|
+
"""Register a handler for IO need events.
|
|
117
|
+
|
|
118
|
+
The handler is called with a payload of the form::
|
|
119
|
+
|
|
120
|
+
{"runId": ..., "cycle": ..., "needs": [...], "deadlineMs": ...}
|
|
121
|
+
|
|
122
|
+
Your handler should respond by calling :meth:`send_input_chunk` or
|
|
123
|
+
:meth:`send_feedback_raster` for each requested input. The SDK
|
|
124
|
+
provides :class:`~ab_sdk.input_streamer.InputStreamer` which
|
|
125
|
+
implements this logic for you.
|
|
126
|
+
"""
|
|
127
|
+
self._io_need_handlers.append(handler)
|
|
128
|
+
|
|
129
|
+
def on_robot_cmd(self, handler: Callable[[Dict[str, Any]], None]) -> None:
|
|
130
|
+
"""Register a handler for robot command events.
|
|
131
|
+
|
|
132
|
+
This is only necessary if your server still returns
|
|
133
|
+
`robot:cmd` events (legacy behaviour). When mapping and
|
|
134
|
+
decoding move into the SDK the server will stop sending
|
|
135
|
+
commands and instead only emit output spikes via
|
|
136
|
+
``cycle:update``.
|
|
137
|
+
"""
|
|
138
|
+
self._cmd_handlers.append(handler)
|
|
139
|
+
|
|
140
|
+
def set_decoder(self, decoder: Any) -> None:
|
|
141
|
+
"""Attach a decoder plugin.
|
|
142
|
+
|
|
143
|
+
The decoder must implement a `decode(outputs, context)` method
|
|
144
|
+
which receives a dictionary mapping output IDs to a matrix
|
|
145
|
+
``(gamma x outputN)`` and returns a command dictionary
|
|
146
|
+
``{'dq': [...], 'dg': float}``. See
|
|
147
|
+
:class:`~ab_sdk.plugins.decoder.BaseDecoder` for details.
|
|
148
|
+
|
|
149
|
+
Parameters
|
|
150
|
+
----------
|
|
151
|
+
decoder: Any
|
|
152
|
+
An object implementing a ``decode`` method.
|
|
153
|
+
"""
|
|
154
|
+
self.decoder_plugin = decoder
|
|
155
|
+
|
|
156
|
+
def set_deviation(self, deviation_policy: Any) -> None:
|
|
157
|
+
"""Attach a deviation policy plugin.
|
|
158
|
+
|
|
159
|
+
The deviation policy must implement a ``compute(telemetry)``
|
|
160
|
+
method returning a mapping from feedback input IDs to lists of
|
|
161
|
+
floats of length ``gamma`` in the range ``[-1,1]``. See
|
|
162
|
+
:class:`~ab_sdk.plugins.deviation.BaseDeviation`.
|
|
163
|
+
"""
|
|
164
|
+
self.deviation_plugin = deviation_policy
|
|
165
|
+
|
|
166
|
+
def set_reward(self, reward_policy: Any) -> None:
|
|
167
|
+
"""Attach a reward policy plugin.
|
|
168
|
+
|
|
169
|
+
The reward policy must implement a ``compute(telemetry)``
|
|
170
|
+
method returning a tuple ``(global_reward, by_layer_dict)``
|
|
171
|
+
where ``global_reward`` is a float in ``[0,1]`` and
|
|
172
|
+
``by_layer_dict`` maps layer names to floats in ``[0,1]``. See
|
|
173
|
+
:class:`~ab_sdk.plugins.reward.BaseReward`.
|
|
174
|
+
"""
|
|
175
|
+
self.reward_plugin = reward_policy
|
|
176
|
+
|
|
177
|
+
# ----------------------------------------------------------------------
|
|
178
|
+
# Emission helpers
|
|
179
|
+
#
|
|
180
|
+
def send_input_chunk(self, input_id: str, kind: str, seq: int, t: float,
|
|
181
|
+
fmt: str, meta: Dict[str, Any], data: bytes) -> None:
|
|
182
|
+
"""Emit a raw input chunk over the realtime channel.
|
|
183
|
+
|
|
184
|
+
This is a generic helper used by the input streamer. The
|
|
185
|
+
payload shape matches the specification in the README. You
|
|
186
|
+
normally should not call this directly; use
|
|
187
|
+
:class:`~ab_sdk.input_streamer.InputStreamer` instead.
|
|
188
|
+
"""
|
|
189
|
+
payload = {
|
|
190
|
+
"runId": self.run_id,
|
|
191
|
+
"inputId": input_id,
|
|
192
|
+
"kind": kind,
|
|
193
|
+
"seq": int(seq),
|
|
194
|
+
"t": float(t),
|
|
195
|
+
"format": fmt,
|
|
196
|
+
"meta": meta,
|
|
197
|
+
"data": data,
|
|
198
|
+
}
|
|
199
|
+
logger.debug("Sending input chunk: %s", {k: payload[k] for k in payload if k != "data"})
|
|
200
|
+
self.socket.emit("io:chunk", payload, namespace=self.namespace)
|
|
201
|
+
|
|
202
|
+
def send_feedback_raster(self, input_id: str, raster: Iterable[float],
|
|
203
|
+
cycle: int) -> None:
|
|
204
|
+
"""Send a feedback raster for the given feedback input.
|
|
205
|
+
|
|
206
|
+
The raster should be a flat iterable of length ``gamma * feedbackN``
|
|
207
|
+
(e.g. a list or a numpy array). Values should be floats in
|
|
208
|
+
``[-1,1]``. The caller is responsible for constructing the
|
|
209
|
+
raster using :func:`~ab_sdk.utils.feedback.build_feedback_raster`.
|
|
210
|
+
|
|
211
|
+
Parameters
|
|
212
|
+
----------
|
|
213
|
+
input_id: str
|
|
214
|
+
The identifier of the feedback input to send.
|
|
215
|
+
raster: Iterable[float]
|
|
216
|
+
A flat sequence containing ``gamma * feedbackN`` floats.
|
|
217
|
+
cycle: int
|
|
218
|
+
The cycle number associated with this feedback (optional,
|
|
219
|
+
included in ``meta`` for debugging).
|
|
220
|
+
"""
|
|
221
|
+
# convert to bytes – we pack as little-endian float32 values
|
|
222
|
+
import array
|
|
223
|
+
arr = array.array('f', raster)
|
|
224
|
+
data_bytes = arr.tobytes()
|
|
225
|
+
meta = {"T": self.gamma, "N": self.feedback_n, "cycle": cycle}
|
|
226
|
+
self.send_input_chunk(input_id=input_id, kind="Feedback",
|
|
227
|
+
seq=int(time.time() * 1000),
|
|
228
|
+
t=time.time(), fmt="raster_f32",
|
|
229
|
+
meta=meta, data=data_bytes)
|
|
230
|
+
|
|
231
|
+
def send_reward(self, global_reward: float, by_layer: Dict[str, float],
|
|
232
|
+
cycle: int) -> None:
|
|
233
|
+
"""Send reward information to the server.
|
|
234
|
+
|
|
235
|
+
Only STDP3 layers listed in ``self.stdp_layers`` are included
|
|
236
|
+
in the payload; missing entries are filled with ``global_reward``.
|
|
237
|
+
"""
|
|
238
|
+
# sanitize and fill missing
|
|
239
|
+
payload_layers: Dict[str, float] = {}
|
|
240
|
+
for layer in self.stdp_layers:
|
|
241
|
+
val = by_layer.get(layer, global_reward)
|
|
242
|
+
# clamp to [0,1]
|
|
243
|
+
val = max(0.0, min(1.0, float(val)))
|
|
244
|
+
payload_layers[layer] = val
|
|
245
|
+
payload = {
|
|
246
|
+
"runId": self.run_id,
|
|
247
|
+
"cycle": cycle,
|
|
248
|
+
"globalReward": max(0.0, min(1.0, float(global_reward))),
|
|
249
|
+
"byLayer": payload_layers,
|
|
250
|
+
}
|
|
251
|
+
logger.debug("Sending reward: %s", payload)
|
|
252
|
+
self.socket.emit("learn:reward", payload, namespace=self.namespace)
|
|
253
|
+
|
|
254
|
+
def close(self) -> None:
|
|
255
|
+
"""Disconnect the Socket.IO client.
|
|
256
|
+
|
|
257
|
+
This method should be called when you are finished with the
|
|
258
|
+
session. It will detach any event handlers and leave the run
|
|
259
|
+
room. Subsequent operations on this session may fail.
|
|
260
|
+
"""
|
|
261
|
+
try:
|
|
262
|
+
if self.socket.connected:
|
|
263
|
+
logger.info("Disconnecting session %s", self.run_id)
|
|
264
|
+
self.socket.disconnect(namespace=self.namespace)
|
|
265
|
+
except Exception as exc:
|
|
266
|
+
logger.warning("Error disconnecting session: %s", exc)
|
|
267
|
+
|
|
268
|
+
# ----------------------------------------------------------------------
|
|
269
|
+
# Internal: register socket event handlers
|
|
270
|
+
#
|
|
271
|
+
def _register_socket_events(self) -> None:
|
|
272
|
+
"""Setup internal Socket.IO event dispatching.
|
|
273
|
+
|
|
274
|
+
This method attaches handlers to the underlying Socket.IO client
|
|
275
|
+
for the known event types (``cycle:update``, ``io:need``,
|
|
276
|
+
``robot:cmd``). When events are received the registered
|
|
277
|
+
callbacks are invoked sequentially.
|
|
278
|
+
"""
|
|
279
|
+
|
|
280
|
+
@self.socket.on("cycle:update", namespace=self.namespace)
|
|
281
|
+
def _on_cycle_update(payload: Dict[str, Any]) -> None:
|
|
282
|
+
logger.debug("Received cycle update: cycle=%s", payload.get("cycle"))
|
|
283
|
+
for handler in self._cycle_handlers:
|
|
284
|
+
try:
|
|
285
|
+
handler(payload)
|
|
286
|
+
except Exception as exc:
|
|
287
|
+
logger.exception("Error in cycle update handler: %s", exc)
|
|
288
|
+
|
|
289
|
+
@self.socket.on("io:need", namespace=self.namespace)
|
|
290
|
+
def _on_io_need(payload: Dict[str, Any]) -> None:
|
|
291
|
+
logger.debug("Received IO need: %s", payload)
|
|
292
|
+
for handler in self._io_need_handlers:
|
|
293
|
+
try:
|
|
294
|
+
handler(payload)
|
|
295
|
+
except Exception as exc:
|
|
296
|
+
logger.exception("Error in IO need handler: %s", exc)
|
|
297
|
+
|
|
298
|
+
@self.socket.on("robot:cmd", namespace=self.namespace)
|
|
299
|
+
def _on_robot_cmd(payload: Dict[str, Any]) -> None:
|
|
300
|
+
logger.debug("Received robot command: %s", payload)
|
|
301
|
+
for handler in self._cmd_handlers:
|
|
302
|
+
try:
|
|
303
|
+
handler(payload)
|
|
304
|
+
except Exception as exc:
|
|
305
|
+
logger.exception("Error in robot command handler: %s", exc)
|