artificialbrains-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ab_sdk/robot_loop.py ADDED
@@ -0,0 +1,184 @@
1
+ """High level loop for controlling a robot during a run.
2
+
3
+ The :class:`RobotLoop` coordinates sending the robot's observed state to
4
+ the server, decoding the output spikes returned from the brain and
5
+ applying the resulting command to your hardware. It integrates with
6
+ the :class:`~ab_sdk.run_session.RunSession` lifecycle and uses
7
+ callbacks supplied by the user for state acquisition and command
8
+ execution.
9
+
10
+ Example usage::
11
+
12
+ def get_robot_state():
13
+ return { 'q': current_joint_positions(), 'dq': current_joint_vels(), 'grip': {'pos': gripper_pos}, 'dt': dt }
14
+
15
+ def apply_command(cmd):
16
+ set_joint_targets(cmd['dq'])
17
+ set_gripper(cmd['dg'])
18
+
19
+ loop = RobotLoop(session, state_provider=get_robot_state, command_executor=apply_command)
20
+ loop.run_forever()
21
+
22
+ In this example the brain's decoded commands are applied directly to a
23
+ hardware or simulated robot. If you are still letting the server
24
+ generate ``robot:cmd`` events then you can omit the decoder plugin
25
+ and use the command handler registered on the session instead.
26
+ """
27
+
28
+ from __future__ import annotations
29
+
30
+ import logging
31
+ import threading
32
+ import time
33
+ from typing import Any, Callable, Dict, Optional
34
+
35
+ from .run_session import RunSession
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ class RobotLoop:
41
+ """Manage the control loop for a robot.
42
+
43
+ The loop periodically collects the robot's current state and sends
44
+ it to the server via the session's realtime channel. When cycle
45
+ update events arrive the associated decoder plugin (if attached) is
46
+ invoked to produce a command which is then passed to the
47
+ user‑supplied command executor.
48
+
49
+ Parameters
50
+ ----------
51
+ session: RunSession
52
+ The active run session to which robot states and commands should
53
+ be associated.
54
+ state_provider: Callable[[], Dict[str, Any]]
55
+ A callback returning the current robot state. This should
56
+ return a dictionary with keys ``q`` (joint positions), ``dq``
57
+ (joint velocities), ``grip`` (dict with ``pos``) and ``dt``
58
+ (time delta since last call). All fields are optional;
59
+ missing values are simply omitted from the state payload. The
60
+ provider is invoked on a background thread at the configured
61
+ tick frequency.
62
+ command_executor: Callable[[Dict[str, Any]], None]
63
+ A callback invoked with a command dictionary returned by the
64
+ decoder plugin. The dictionary has keys ``dq`` (array of
65
+ joint deltas), ``dg`` (scalar gripper command) and any
66
+ additional keys defined by your decoder. This callback should
67
+ apply the command to the actual robot.
68
+ tick_hz: float, optional
69
+ The frequency in Hz at which to send robot states to the
70
+ server. Defaults to 20Hz. Set to 0 to disable periodic
71
+ sending (state must then be sent manually).
72
+ """
73
+
74
+ def __init__(self, session: RunSession,
75
+ state_provider: Callable[[], Dict[str, Any]],
76
+ command_executor: Callable[[Dict[str, Any]], None],
77
+ tick_hz: float = 20.0) -> None:
78
+ self.session = session
79
+ self.state_provider = state_provider
80
+ self.command_executor = command_executor
81
+ self.tick_hz = tick_hz
82
+ self._running = False
83
+ self._thread: Optional[threading.Thread] = None
84
+ # register to receive cycle updates and decode commands
85
+ self.session.on_cycle_update(self._on_cycle_update)
86
+
87
+ def _on_cycle_update(self, telemetry: Dict[str, Any]) -> None:
88
+ """Handle cycle update events by decoding commands and applying them.
89
+
90
+ If a decoder plugin is attached to the session then this
91
+ callback will build a dictionary of output matrices keyed by
92
+ output ID and invoke the plugin's ``decode`` method. The
93
+ resulting command dictionary is passed to the supplied
94
+ ``command_executor``. Any exceptions raised by the decoder
95
+ are caught and logged; command execution is skipped on error.
96
+ """
97
+ decoder = self.session.decoder_plugin
98
+ if decoder is None:
99
+ # nothing to do; maybe server will send robot:cmd
100
+ return
101
+ outputs = telemetry.get("outputs", [])
102
+ # build a mapping from output id to matrix (gamma x outputN)
103
+ output_matrices: Dict[str, List[List[int]]] = {}
104
+ for entry in outputs:
105
+ try:
106
+ t_step, out_id, bits = entry
107
+ except ValueError:
108
+ continue
109
+ if out_id not in output_matrices:
110
+ # initialize matrix with zeros
111
+ output_matrices[out_id] = [[0] * self.session.output_n for _ in range(self.session.gamma)]
112
+ # assign row
113
+ row = output_matrices[out_id][int(t_step)]
114
+ # bits may be shorter than output_n; pad
115
+ for i in range(min(len(bits), self.session.output_n)):
116
+ row[i] = 1 if bits[i] else 0
117
+ try:
118
+ command = decoder.decode(output_matrices, context={
119
+ "telemetry": telemetry,
120
+ "session": self.session,
121
+ })
122
+ if command is not None:
123
+ logger.debug("Decoded command: %s", command)
124
+ self.command_executor(command)
125
+ except Exception as exc:
126
+ logger.exception("Decoder error: %s", exc)
127
+
128
+ def _send_robot_state(self) -> None:
129
+ """Collect the current robot state and emit it to the server."""
130
+ try:
131
+ state = self.state_provider() or {}
132
+ payload = {"runId": self.session.run_id, "state": state}
133
+ self.session.socket.emit("robot:state", payload, namespace=self.session.namespace)
134
+ logger.debug("Sent robot state: %s", payload)
135
+ except Exception as exc:
136
+ logger.exception("Error collecting or sending robot state: %s", exc)
137
+
138
+ def run_forever(self) -> None:
139
+ """Start the control loop and block until stopped.
140
+
141
+ This method spawns a background thread which periodically
142
+ acquires robot state and sends it to the server. It then
143
+ blocks on the main thread, sleeping indefinitely. To stop the
144
+ loop call :meth:`stop` from another thread or signal handler.
145
+ """
146
+ if self._running:
147
+ logger.warning("RobotLoop.run_forever() called while already running")
148
+ return
149
+ self._running = True
150
+ if self.tick_hz > 0:
151
+ interval = 1.0 / float(self.tick_hz)
152
+ else:
153
+ interval = 0.0
154
+ # define worker function
155
+ def _worker() -> None:
156
+ while self._running:
157
+ if interval > 0:
158
+ start = time.time()
159
+ self._send_robot_state()
160
+ elapsed = time.time() - start
161
+ sleep_time = max(0.0, interval - elapsed)
162
+ time.sleep(sleep_time)
163
+ else:
164
+ time.sleep(0.1)
165
+ # start worker thread
166
+ self._thread = threading.Thread(target=_worker, name="RobotLoopWorker")
167
+ self._thread.daemon = True
168
+ self._thread.start()
169
+ logger.info("Robot loop started")
170
+ try:
171
+ while self._running:
172
+ time.sleep(1)
173
+ finally:
174
+ self.stop()
175
+
176
+ def stop(self) -> None:
177
+ """Stop the control loop."""
178
+ if not self._running:
179
+ return
180
+ self._running = False
181
+ logger.info("Stopping robot loop")
182
+ if self._thread is not None:
183
+ self._thread.join(timeout=5.0)
184
+ self._thread = None
ab_sdk/run_session.py ADDED
@@ -0,0 +1,305 @@
1
+ """Per‑run state container and event router.
2
+
3
+ The :class:`RunSession` encapsulates all of the information about a
4
+ running experiment (identified by a unique run ID) and exposes
5
+ conveniences for emitting input chunks, feedback rasters and reward
6
+ signals. It also manages registration of event handlers for
7
+ telemetry and other realtime notifications.
8
+
9
+ You do not create a `RunSession` directly; instead it is returned
10
+ from :meth:`~ab_sdk.client.ABClient.start`. Once created it holds
11
+ references to the originating :class:`~ab_sdk.client.ABClient`, the
12
+ HTTP API contract describing the IO interface and a live
13
+ Socket.IO client joined to the appropriate room. You can attach
14
+ custom decoders, deviation policies and reward policies via
15
+ :meth:`set_decoder`, :meth:`set_deviation` and :meth:`set_reward`.
16
+
17
+ Instances of this class are not thread safe. If you plan to
18
+ consume realtime events in multiple threads you should implement
19
+ appropriate synchronization in your handlers.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import logging
25
+ import threading
26
+ import time
27
+ from typing import Any, Callable, Dict, Iterable, List, Optional
28
+
29
+ import socketio
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ class RunSession:
35
+ """Represents a single running brain session.
36
+
37
+ Parameters
38
+ ----------
39
+ client: ABClient
40
+ Reference to the client that created this session. Used for
41
+ fallback HTTP operations and error reporting.
42
+ project_id: str
43
+ The project identifier.
44
+ run_id: str
45
+ Unique identifier for this run returned by the server.
46
+ contract: dict
47
+ The run contract returned by the server on start. Contains
48
+ ``constants`` and ``io`` keys describing the IO interface.
49
+ socket: socketio.Client
50
+ A connected Socket.IO client already joined to the run room.
51
+ namespace: str
52
+ The namespace on the server to emit/receive events on (e.g. ``"/ab"``).
53
+ """
54
+
55
+ def __init__(self, client: Any, project_id: str, run_id: str,
56
+ contract: Dict[str, Any], socket: socketio.Client,
57
+ namespace: str) -> None:
58
+ self.client = client
59
+ self.project_id = project_id
60
+ self.run_id = run_id
61
+ self.contract = contract
62
+ self.socket = socket
63
+ self.namespace = namespace
64
+
65
+ # parse constants
66
+ consts = contract.get("constants", {})
67
+ self.gamma: int = int(consts.get("gamma", 64))
68
+ self.output_n: int = int(consts.get("outputWindowN", 32))
69
+ self.feedback_n: int = int(consts.get("feedbackN", 128))
70
+
71
+ # keep track of IO manifest
72
+ self.io_inputs = {item["id"]: item for item in contract.get("io", {}).get("inputs", [])}
73
+ self.io_outputs = {item["id"]: item for item in contract.get("io", {}).get("outputs", [])}
74
+ self.io_feedback = {item["id"]: item for item in contract.get("io", {}).get("feedback", [])}
75
+ self.stdp_layers: List[str] = list(contract.get("io", {}).get("stdp3", {}).get("layers", []))
76
+
77
+ # plugin holders
78
+ self.decoder_plugin: Optional[Any] = None
79
+ self.deviation_plugin: Optional[Any] = None
80
+ self.reward_plugin: Optional[Any] = None
81
+
82
+ # event handlers registry
83
+ self._cycle_handlers: List[Callable[[Dict[str, Any]], None]] = []
84
+ self._io_need_handlers: List[Callable[[Dict[str, Any]], None]] = []
85
+ self._cmd_handlers: List[Callable[[Dict[str, Any]], None]] = []
86
+
87
+ # register default event handlers from contract (if any) when session is created
88
+ self._register_socket_events()
89
+
90
+ # ----------------------------------------------------------------------
91
+ # Event registration API
92
+ #
93
+ # The run session receives events from the server via Socket.IO. You
94
+ # can register additional callbacks for cycle updates, IO needs and
95
+ # command messages. These callbacks will be called sequentially
96
+ # from the Socket.IO event thread, so you should avoid blocking
97
+ # operations inside handlers.
98
+ #
99
+ def on_cycle_update(self, handler: Callable[[Dict[str, Any]], None]) -> None:
100
+ """Register a handler for cycle update events.
101
+
102
+ The handler is called with the full telemetry payload as
103
+ delivered by the server. You can decode outputs, compute
104
+ rewards and send feedback from within this callback. Multiple
105
+ handlers can be registered; they will be invoked in the order
106
+ they were added.
107
+
108
+ Parameters
109
+ ----------
110
+ handler: Callable[[dict], None]
111
+ A callable accepting a telemetry dictionary.
112
+ """
113
+ self._cycle_handlers.append(handler)
114
+
115
+ def on_io_need(self, handler: Callable[[Dict[str, Any]], None]) -> None:
116
+ """Register a handler for IO need events.
117
+
118
+ The handler is called with a payload of the form::
119
+
120
+ {"runId": ..., "cycle": ..., "needs": [...], "deadlineMs": ...}
121
+
122
+ Your handler should respond by calling :meth:`send_input_chunk` or
123
+ :meth:`send_feedback_raster` for each requested input. The SDK
124
+ provides :class:`~ab_sdk.input_streamer.InputStreamer` which
125
+ implements this logic for you.
126
+ """
127
+ self._io_need_handlers.append(handler)
128
+
129
+ def on_robot_cmd(self, handler: Callable[[Dict[str, Any]], None]) -> None:
130
+ """Register a handler for robot command events.
131
+
132
+ This is only necessary if your server still returns
133
+ `robot:cmd` events (legacy behaviour). When mapping and
134
+ decoding move into the SDK the server will stop sending
135
+ commands and instead only emit output spikes via
136
+ ``cycle:update``.
137
+ """
138
+ self._cmd_handlers.append(handler)
139
+
140
+ def set_decoder(self, decoder: Any) -> None:
141
+ """Attach a decoder plugin.
142
+
143
+ The decoder must implement a `decode(outputs, context)` method
144
+ which receives a dictionary mapping output IDs to a matrix
145
+ ``(gamma x outputN)`` and returns a command dictionary
146
+ ``{'dq': [...], 'dg': float}``. See
147
+ :class:`~ab_sdk.plugins.decoder.BaseDecoder` for details.
148
+
149
+ Parameters
150
+ ----------
151
+ decoder: Any
152
+ An object implementing a ``decode`` method.
153
+ """
154
+ self.decoder_plugin = decoder
155
+
156
+ def set_deviation(self, deviation_policy: Any) -> None:
157
+ """Attach a deviation policy plugin.
158
+
159
+ The deviation policy must implement a ``compute(telemetry)``
160
+ method returning a mapping from feedback input IDs to lists of
161
+ floats of length ``gamma`` in the range ``[-1,1]``. See
162
+ :class:`~ab_sdk.plugins.deviation.BaseDeviation`.
163
+ """
164
+ self.deviation_plugin = deviation_policy
165
+
166
+ def set_reward(self, reward_policy: Any) -> None:
167
+ """Attach a reward policy plugin.
168
+
169
+ The reward policy must implement a ``compute(telemetry)``
170
+ method returning a tuple ``(global_reward, by_layer_dict)``
171
+ where ``global_reward`` is a float in ``[0,1]`` and
172
+ ``by_layer_dict`` maps layer names to floats in ``[0,1]``. See
173
+ :class:`~ab_sdk.plugins.reward.BaseReward`.
174
+ """
175
+ self.reward_plugin = reward_policy
176
+
177
+ # ----------------------------------------------------------------------
178
+ # Emission helpers
179
+ #
180
+ def send_input_chunk(self, input_id: str, kind: str, seq: int, t: float,
181
+ fmt: str, meta: Dict[str, Any], data: bytes) -> None:
182
+ """Emit a raw input chunk over the realtime channel.
183
+
184
+ This is a generic helper used by the input streamer. The
185
+ payload shape matches the specification in the README. You
186
+ normally should not call this directly; use
187
+ :class:`~ab_sdk.input_streamer.InputStreamer` instead.
188
+ """
189
+ payload = {
190
+ "runId": self.run_id,
191
+ "inputId": input_id,
192
+ "kind": kind,
193
+ "seq": int(seq),
194
+ "t": float(t),
195
+ "format": fmt,
196
+ "meta": meta,
197
+ "data": data,
198
+ }
199
+ logger.debug("Sending input chunk: %s", {k: payload[k] for k in payload if k != "data"})
200
+ self.socket.emit("io:chunk", payload, namespace=self.namespace)
201
+
202
+ def send_feedback_raster(self, input_id: str, raster: Iterable[float],
203
+ cycle: int) -> None:
204
+ """Send a feedback raster for the given feedback input.
205
+
206
+ The raster should be a flat iterable of length ``gamma * feedbackN``
207
+ (e.g. a list or a numpy array). Values should be floats in
208
+ ``[-1,1]``. The caller is responsible for constructing the
209
+ raster using :func:`~ab_sdk.utils.feedback.build_feedback_raster`.
210
+
211
+ Parameters
212
+ ----------
213
+ input_id: str
214
+ The identifier of the feedback input to send.
215
+ raster: Iterable[float]
216
+ A flat sequence containing ``gamma * feedbackN`` floats.
217
+ cycle: int
218
+ The cycle number associated with this feedback (optional,
219
+ included in ``meta`` for debugging).
220
+ """
221
+ # convert to bytes – we pack as little-endian float32 values
222
+ import array
223
+ arr = array.array('f', raster)
224
+ data_bytes = arr.tobytes()
225
+ meta = {"T": self.gamma, "N": self.feedback_n, "cycle": cycle}
226
+ self.send_input_chunk(input_id=input_id, kind="Feedback",
227
+ seq=int(time.time() * 1000),
228
+ t=time.time(), fmt="raster_f32",
229
+ meta=meta, data=data_bytes)
230
+
231
+ def send_reward(self, global_reward: float, by_layer: Dict[str, float],
232
+ cycle: int) -> None:
233
+ """Send reward information to the server.
234
+
235
+ Only STDP3 layers listed in ``self.stdp_layers`` are included
236
+ in the payload; missing entries are filled with ``global_reward``.
237
+ """
238
+ # sanitize and fill missing
239
+ payload_layers: Dict[str, float] = {}
240
+ for layer in self.stdp_layers:
241
+ val = by_layer.get(layer, global_reward)
242
+ # clamp to [0,1]
243
+ val = max(0.0, min(1.0, float(val)))
244
+ payload_layers[layer] = val
245
+ payload = {
246
+ "runId": self.run_id,
247
+ "cycle": cycle,
248
+ "globalReward": max(0.0, min(1.0, float(global_reward))),
249
+ "byLayer": payload_layers,
250
+ }
251
+ logger.debug("Sending reward: %s", payload)
252
+ self.socket.emit("learn:reward", payload, namespace=self.namespace)
253
+
254
+ def close(self) -> None:
255
+ """Disconnect the Socket.IO client.
256
+
257
+ This method should be called when you are finished with the
258
+ session. It will detach any event handlers and leave the run
259
+ room. Subsequent operations on this session may fail.
260
+ """
261
+ try:
262
+ if self.socket.connected:
263
+ logger.info("Disconnecting session %s", self.run_id)
264
+ self.socket.disconnect(namespace=self.namespace)
265
+ except Exception as exc:
266
+ logger.warning("Error disconnecting session: %s", exc)
267
+
268
+ # ----------------------------------------------------------------------
269
+ # Internal: register socket event handlers
270
+ #
271
+ def _register_socket_events(self) -> None:
272
+ """Setup internal Socket.IO event dispatching.
273
+
274
+ This method attaches handlers to the underlying Socket.IO client
275
+ for the known event types (``cycle:update``, ``io:need``,
276
+ ``robot:cmd``). When events are received the registered
277
+ callbacks are invoked sequentially.
278
+ """
279
+
280
+ @self.socket.on("cycle:update", namespace=self.namespace)
281
+ def _on_cycle_update(payload: Dict[str, Any]) -> None:
282
+ logger.debug("Received cycle update: cycle=%s", payload.get("cycle"))
283
+ for handler in self._cycle_handlers:
284
+ try:
285
+ handler(payload)
286
+ except Exception as exc:
287
+ logger.exception("Error in cycle update handler: %s", exc)
288
+
289
+ @self.socket.on("io:need", namespace=self.namespace)
290
+ def _on_io_need(payload: Dict[str, Any]) -> None:
291
+ logger.debug("Received IO need: %s", payload)
292
+ for handler in self._io_need_handlers:
293
+ try:
294
+ handler(payload)
295
+ except Exception as exc:
296
+ logger.exception("Error in IO need handler: %s", exc)
297
+
298
+ @self.socket.on("robot:cmd", namespace=self.namespace)
299
+ def _on_robot_cmd(payload: Dict[str, Any]) -> None:
300
+ logger.debug("Received robot command: %s", payload)
301
+ for handler in self._cmd_handlers:
302
+ try:
303
+ handler(payload)
304
+ except Exception as exc:
305
+ logger.exception("Error in robot command handler: %s", exc)