unaiverse 0.1.11__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unaiverse might be problematic. Click here for more details.

Files changed (50) hide show
  1. unaiverse/__init__.py +19 -0
  2. unaiverse/agent.py +2090 -0
  3. unaiverse/agent_basics.py +1948 -0
  4. unaiverse/clock.py +221 -0
  5. unaiverse/dataprops.py +1236 -0
  6. unaiverse/hsm.py +1892 -0
  7. unaiverse/modules/__init__.py +18 -0
  8. unaiverse/modules/cnu/__init__.py +17 -0
  9. unaiverse/modules/cnu/cnus.py +536 -0
  10. unaiverse/modules/cnu/layers.py +261 -0
  11. unaiverse/modules/cnu/psi.py +60 -0
  12. unaiverse/modules/hl/__init__.py +15 -0
  13. unaiverse/modules/hl/hl_utils.py +411 -0
  14. unaiverse/modules/networks.py +1509 -0
  15. unaiverse/modules/utils.py +710 -0
  16. unaiverse/networking/__init__.py +16 -0
  17. unaiverse/networking/node/__init__.py +18 -0
  18. unaiverse/networking/node/connpool.py +1308 -0
  19. unaiverse/networking/node/node.py +2499 -0
  20. unaiverse/networking/node/profile.py +446 -0
  21. unaiverse/networking/node/tokens.py +79 -0
  22. unaiverse/networking/p2p/__init__.py +187 -0
  23. unaiverse/networking/p2p/go.mod +127 -0
  24. unaiverse/networking/p2p/go.sum +548 -0
  25. unaiverse/networking/p2p/golibp2p.py +18 -0
  26. unaiverse/networking/p2p/golibp2p.pyi +135 -0
  27. unaiverse/networking/p2p/lib.go +2662 -0
  28. unaiverse/networking/p2p/lib.go.sha256 +1 -0
  29. unaiverse/networking/p2p/lib_types.py +312 -0
  30. unaiverse/networking/p2p/message_pb2.py +50 -0
  31. unaiverse/networking/p2p/messages.py +362 -0
  32. unaiverse/networking/p2p/mylogger.py +77 -0
  33. unaiverse/networking/p2p/p2p.py +871 -0
  34. unaiverse/networking/p2p/proto-go/message.pb.go +846 -0
  35. unaiverse/networking/p2p/unailib.cpython-311-darwin.so +0 -0
  36. unaiverse/stats.py +1481 -0
  37. unaiverse/streamlib/__init__.py +15 -0
  38. unaiverse/streamlib/streamlib.py +210 -0
  39. unaiverse/streams.py +776 -0
  40. unaiverse/utils/__init__.py +16 -0
  41. unaiverse/utils/lone_wolf.json +24 -0
  42. unaiverse/utils/misc.py +310 -0
  43. unaiverse/utils/sandbox.py +293 -0
  44. unaiverse/utils/server.py +435 -0
  45. unaiverse/world.py +335 -0
  46. unaiverse-0.1.11.dist-info/METADATA +367 -0
  47. unaiverse-0.1.11.dist-info/RECORD +50 -0
  48. unaiverse-0.1.11.dist-info/WHEEL +6 -0
  49. unaiverse-0.1.11.dist-info/licenses/LICENSE +43 -0
  50. unaiverse-0.1.11.dist-info/top_level.txt +1 -0
unaiverse/agent.py ADDED
@@ -0,0 +1,2090 @@
1
+ """
2
+ █████ █████ ██████ █████ █████ █████ █████ ██████████ ███████████ █████████ ██████████
3
+ ░░███ ░░███ ░░██████ ░░███ ░░███ ░░███ ░░███ ░░███░░░░░█░░███░░░░░███ ███░░░░░███░░███░░░░░█
4
+ ░███ ░███ ░███░███ ░███ ██████ ░███ ░███ ░███ ░███ █ ░ ░███ ░███ ░███ ░░░ ░███ █ ░
5
+ ░███ ░███ ░███░░███░███ ░░░░░███ ░███ ░███ ░███ ░██████ ░██████████ ░░█████████ ░██████
6
+ ░███ ░███ ░███ ░░██████ ███████ ░███ ░░███ ███ ░███░░█ ░███░░░░░███ ░░░░░░░░███ ░███░░█
7
+ ░███ ░███ ░███ ░░█████ ███░░███ ░███ ░░░█████░ ░███ ░ █ ░███ ░███ ███ ░███ ░███ ░ █
8
+ ░░████████ █████ ░░█████░░████████ █████ ░░███ ██████████ █████ █████░░█████████ ██████████
9
+ ░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░ ░░░░░ ░░░ ░░░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░░ ░░░░░░░░░░
10
+ A Collectionless AI Project (https://collectionless.ai)
11
+ Registration/Login: https://unaiverse.io
12
+ Code Repositories: https://github.com/collectionlessai/
13
+ Main Developers: Stefano Melacci (Project Leader), Christian Di Maio, Tommaso Guidi
14
+ """
15
+ import copy
16
+ import json
17
+ import uuid
18
+ import torch
19
+ from unaiverse.stats import Stats
20
+ from unaiverse.dataprops import DataProps
21
+ from unaiverse.agent_basics import AgentBasics
22
+ from unaiverse.streams import BufferedDataStream
23
+ from unaiverse.networking.p2p.messages import Msg
24
+
25
+
26
+ class Agent(AgentBasics):
27
+ """This class contains those basic actions that can be performed by every agent."""
28
+
29
+ def __init__(self, *args, **kwargs):
30
+ super().__init__(*args, **kwargs)
31
+
32
+ # Status variables (assumed to start with "_"): Agent exchanges
33
+ self._available = True # It will be automatically set/changed during the agent's life
34
+ self._found_agents = set() # Peer IDs discovered
35
+ self._valid_cmp_agents = set() # Agents for which the last evaluation was positive
36
+ self._engaged_agents = set()
37
+ self._agents_who_completed_what_they_were_asked = set()
38
+ self._agents_who_were_asked = set()
39
+ self._eval_results = {}
40
+
41
+ # Status variables (assumed to start with "_"): Recordings
42
+ self._last_recorded_stream_num = 1
43
+ self._last_recorded_stream_dict = None
44
+ self._last_recording_stream_dict = None
45
+
46
+ # Status variables (assumed to start with "_"): Playlist
47
+ self._preferred_streams = [] # List of preferred streams
48
+ self._cur_preferred_stream = 0 # ID of the current preferred stream from the list
49
+ self._repeat = 1 # Number of repetitions of the playlist
50
+
51
+ # Stats
52
+ self.stats = Stats(is_world=False)
53
+ self.overwrite_stats = False # Whether to overwrite stats when receiving the next STATS_RESPONSE from the world
54
+
55
+ def remove_peer_from_agent_status_attrs(self, peer_id):
56
+ super().remove_peer_from_agent_status_attrs(peer_id)
57
+ self._available = len(self._engaged_agents) == 0
58
+
59
+ def reset_agent_status_attrs(self):
60
+ super().reset_agent_status_attrs() # this sets status vars to [], {}, 0, 0., False, in function of their type
61
+ self._available = True
62
+ self._repeat = 1
63
+ self._last_recorded_stream_num = 1
64
+
65
+ async def set_next_action(self, agent: str | None, action: str, args: dict | None = None,
66
+ ref_uuid: str | None = None):
67
+ """Try to tell another agent what is the next action it should run (async).
68
+
69
+ Args:
70
+ agent: The ID of the agent to send the action to or a valid wildcard like "<valid_cmp>" for a set of agents
71
+ (if None the agents in self._engaged_agents will be considered).
72
+ action: The name of the action to be executed by the agent.
73
+ args: A dictionary of arguments for the action. Defaults to None.
74
+ ref_uuid: An optional UUID for referencing the action. Defaults to None.
75
+
76
+ Returns:
77
+ True if the action was successfully sent to the target agent or to at least one of the
78
+ involved agents (wildcard case).
79
+ """
80
+
81
+ # - if "agent" is a peer ID, the involved agents will be a list with one element.
82
+ # - if "agent" is a known wildcard, as "<valid_cmp>", then involved agents will be self._valid_cmp_agents
83
+ # - if "agent" is None, then the current agent in self._engaged_agents will be returned
84
+ involved_agents = self.__involved_agents(agent)
85
+ if len(involved_agents) == 0:
86
+ return False
87
+
88
+ at_least_one_completed = False
89
+ _, private_peer_id = self.get_peer_ids()
90
+ for _peer_id in involved_agents:
91
+ ret = await self._node_conn.send(_peer_id, channel_trail=None,
92
+ content={"action_name": action, "args": args, "uuid": ref_uuid},
93
+ content_type=Msg.ACTION_REQUEST)
94
+ at_least_one_completed = at_least_one_completed or ret
95
+ self.deb(f"[set_next_action] {self._node_name} sent action: {action}, with args: {args}, "
96
+ f"and result of sending is {ret}")
97
+ return at_least_one_completed
98
+
99
+ async def set_engaged_partner(self, agent: str):
100
+ """Virtually forces the engagement with a single agent, clearing all existing engagements (async).
101
+
102
+ Returns:
103
+ True all the times.
104
+ """
105
+ self._engaged_agents.clear()
106
+ self._engaged_agents.add(agent)
107
+ return True
108
+
109
+ async def send_engagement(self):
110
+ """Offer engagement to the agents whose identifiers are in self._found_agents (async).
111
+
112
+ Returns:
113
+ True if engagement requests were successfully sent to at least one found agent, False otherwise.
114
+ """
115
+ at_least_one_sent = False
116
+
117
+ if len(self._found_agents) > 0:
118
+ self.out(f"Sending engagement request to {', '.join([x for x in self._found_agents])}")
119
+ my_role_str = self._node_profile.get_dynamic_profile()['connections']['role']
120
+ for found_agent in self._found_agents:
121
+ if await self.set_next_action(found_agent, action="get_engagement",
122
+ args={"sender_role": my_role_str}):
123
+ at_least_one_sent = True
124
+ else:
125
+ self.err(f"Unable to send engagement to {found_agent}")
126
+
127
+ return at_least_one_sent
128
+
129
+ async def get_engagement(self, acceptable_role: str | None = None, sender_role: str | None = None,
130
+ _requester: str | None = None):
131
+ """Receive engagement from another agent whose authority is in the specified range (async).
132
+
133
+ Args:
134
+ acceptable_role: The role that the sender must have for engagement to be accepted. Defaults to None.
135
+ sender_role: The role of the agent sending the engagement request. Defaults to None.
136
+ _requester: The ID of the agent requesting engagement (automatically set by the action calling routine)
137
+
138
+ Returns:
139
+ True if the engagement was successfully received and confirmed, False otherwise.
140
+ """
141
+ self.out(f"Getting engagement from {_requester}, whose role is {sender_role} (looking for {acceptable_role})")
142
+ if _requester not in self.world_agents and _requester not in self.world_masters:
143
+ self.err(f"Unknown agent: {_requester}")
144
+ return False
145
+
146
+ if sender_role is None:
147
+ self.err(f"Unknown role of {_requester}")
148
+ return False
149
+
150
+ # Confirming
151
+ if self._available:
152
+ acceptable_role_int = self.ROLE_STR_TO_BITS[acceptable_role]
153
+ if "~" not in acceptable_role:
154
+ sender_role_int = (self.ROLE_STR_TO_BITS[sender_role] >> 2) << 2
155
+ else:
156
+ sender_role_int = self.ROLE_STR_TO_BITS[sender_role]
157
+
158
+ if acceptable_role_int == sender_role_int:
159
+ if await self.set_next_action(_requester, "got_engagement"):
160
+ self._engaged_agents.add(_requester)
161
+
162
+ # Marking this agent as not available since it engaged with another one
163
+ self._available = False
164
+ return True
165
+ else:
166
+ self.err(f"Unable to confirm engagement to {_requester}")
167
+ return False
168
+ else:
169
+ self.err(f"Cannot engage to {_requester}")
170
+ return False
171
+ else:
172
+ self.err(f"Cannot engage to {_requester}")
173
+ return False
174
+
175
+ async def got_engagement(self, _requester: str | None = None):
176
+ """Confirm an engagement (async).
177
+
178
+ Args:
179
+ _requester: The ID of the agent confirming the engagement (automatically set by the action calling routine).
180
+
181
+ Returns:
182
+ True if the engagement was successfully confirmed, False otherwise.
183
+ """
184
+ self.out(f"Confirming engagement with {_requester}")
185
+ if _requester in self._found_agents:
186
+ self._engaged_agents.add(_requester)
187
+
188
+ # Marking this agent as not available since it engaged with another one
189
+ self._available = False
190
+
191
+ # Removing the engaged agent from the list of found agents, to avoid sending him another engagement request
192
+ self._found_agents.discard(_requester)
193
+ return True
194
+ else:
195
+ self.err(f"Unable to confirm engagement with {_requester}")
196
+ return False
197
+
198
+ async def send_disengagement(self, send_disconnection_too: bool = False):
199
+ """Ask for disengagement (async).
200
+
201
+ Args:
202
+ send_disconnection_too: Whether to send a disconnect-suggestion together with the disengagement.
203
+
204
+ Returns:
205
+ True if disengagement requests were successfully sent to at least one engaged agent, False otherwise.
206
+ """
207
+ at_least_one_sent = False
208
+
209
+ if len(self._engaged_agents) > 0:
210
+ self.out(f"Sending disengagement request to {', '.join([x for x in self._engaged_agents])}")
211
+ for agent in self._engaged_agents:
212
+ if await self.set_next_action(agent, action="get_disengagement",
213
+ args={"disconnect_too": send_disconnection_too}):
214
+ at_least_one_sent = True
215
+ else:
216
+ self.err(f"Unable to send disengagement to {agent}")
217
+
218
+ return at_least_one_sent
219
+
220
+ async def get_disengagement(self, disconnect_too: bool = False, _requester: str | None = None):
221
+ """Get a disengagement request from an agent (async).
222
+
223
+ Args:
224
+ disconnect_too: Whether to disconnect the agent who sent the disengagement.
225
+ _requester: The ID of the agent requesting disengagement. Defaults to None.
226
+
227
+ Returns:
228
+ True if the disengagement request was successfully processed, False otherwise.
229
+ """
230
+ self.out(f"Getting a disengagement request from {_requester}")
231
+ if _requester not in self.world_agents and _requester not in self.world_masters:
232
+ self.err(f"Unknown agent: {_requester}")
233
+ return False
234
+
235
+ if _requester not in self._engaged_agents:
236
+ self.err(f"Not previously engaged to {_requester}")
237
+ return False
238
+
239
+ if disconnect_too:
240
+ await self._node_purge_fcn(_requester)
241
+
242
+ self._engaged_agents.discard(_requester) # Remove if present
243
+
244
+ # Marking this agent as available if not engaged to any agent
245
+ self._available = len(self._engaged_agents) == 0
246
+ return True
247
+
248
+ async def disengage_all(self):
249
+ """Disengage all the previously engaged agents (async).
250
+
251
+ Returns:
252
+ True if the disengagement procedure was successfully executed, False otherwise.
253
+ """
254
+ self.out(f"Disengaging all agents")
255
+ self._engaged_agents = set()
256
+
257
+ # Marking this agent as available
258
+ self._available = True
259
+ return True
260
+
261
+ async def disconnect_by_role(self, role: str | list[str]):
262
+ """Disconnects from all agents that match a specified role (async).
263
+ It finds the agents and calls the node's purge function on each.
264
+
265
+ Args:
266
+ role: A string or list of strings representing the role(s) of agents to disconnect from.
267
+
268
+ Returns:
269
+ Always True.
270
+ """
271
+ self.out(f"Disconnecting agents with role: {role}")
272
+ if await self.find_agents(role):
273
+ found_agents = copy.deepcopy(self._found_agents)
274
+ for agent in found_agents:
275
+ await self._node_purge_fcn(agent) # This will also call remove_agent, that will call remove_streams
276
+ return True
277
+
278
+ async def disconnected(self, agent: str | None = None, delay: float = -1.):
279
+ """Checks if a specific set of agents (by ID or wildcard) are no longer connected to the agent.
280
+ It returns False if any of the specified agents are still connected (async).
281
+
282
+ Args:
283
+ agent: The ID of the agent or a wildcard to check.
284
+ delay: The time (seconds) to be spent in the current state before actually considering this action.
285
+
286
+ Returns:
287
+ True if all involved agents are disconnected, False otherwise.
288
+
289
+ """
290
+ assert delay is not None, "Missing basic action information"
291
+
292
+ # - if "agent" is a peer ID, the involved agents will be a list with one element.
293
+ # - if "agent" is a known wildcard, as "<valid_cmp>", then involved agents will be self._valid_cmp_agents
294
+ # - if "agent" is None, then the current agent in self._engaged_agents will be returned
295
+ involved_agents = self.__involved_agents(agent)
296
+ if len(involved_agents) == 0:
297
+ return False
298
+
299
+ self.out(f"Checking if all these agents are not connected to me anymore: {involved_agents}")
300
+ all_disconnected = True
301
+ for agent in involved_agents:
302
+ if agent in self.world_agents or agent in self.public_agents or agent in self._node_agents_waiting\
303
+ or self._node_conn.is_connected(agent):
304
+ all_disconnected = False
305
+ break
306
+ return all_disconnected
307
+
308
+ async def received_some_asked_data(self, processing_fcn: str | None = None):
309
+ """Checks if any of the agents that were previously asked for data (e.g., via `ask_gen`) have sent a stream
310
+ sample back. Optionally, it can process the received data with a specified function (async).
311
+
312
+ Args:
313
+ processing_fcn: The name of a function to process the received data.
314
+
315
+ Returns:
316
+ True if at least one data sample was received, False otherwise.
317
+ """
318
+ _processing_fcn = None
319
+ if processing_fcn is not None:
320
+ if hasattr(self, processing_fcn):
321
+ _processing_fcn = getattr(self, processing_fcn)
322
+ if not callable(_processing_fcn):
323
+ _processing_fcn = None
324
+ if _processing_fcn is None:
325
+ self.err(f"Processing function not found: {processing_fcn}")
326
+
327
+ got_something = False
328
+ for agent in self._agents_who_were_asked:
329
+ net_hash_to_stream_dict = self.find_streams(agent, "processor")
330
+ for stream_dict in net_hash_to_stream_dict.values():
331
+ for stream_obj in stream_dict.values():
332
+ if not stream_obj.props.is_public():
333
+ data = stream_obj.get("received_some_asked_data")
334
+ data_tag = stream_obj.get_tag()
335
+
336
+ if data is not None:
337
+ if _processing_fcn is None:
338
+ return True
339
+ else:
340
+ got_something = True
341
+ _processing_fcn(agent, stream_obj.props, data, data_tag)
342
+ return got_something
343
+
344
+ async def nop(self, message: str | None = None, delay: float = -1.):
345
+ """Do nothing (async).
346
+
347
+ Args:
348
+ message: An optional message to print. Defaults to None.
349
+ delay: The time (seconds) to be spent in the current state before actually considering this action.
350
+
351
+ Returns:
352
+ Always True.
353
+ """
354
+ assert delay is not None, "Missing basic action information"
355
+ if message is not None:
356
+ self.out(message)
357
+ return True
358
+
359
+ async def wait_for_actions(self, agent: str, from_state: str, to_state: str, wait: bool):
360
+ """Lock or unlock every action between a pair of states in the state machine of a target agent (async).
361
+
362
+ Args:
363
+ agent: The ID of the agent to send the action locking request to, or a valid wildcard like "<valid_cmp>"
364
+ for a set of agents (if None the agents in self._engaged_agents will be considered).
365
+ from_state: The starting state of the actions to be locked/unlocked.
366
+ to_state: The ending state of the actions to be locked/unlocked.
367
+ wait: A boolean indicating whether to wait for the actions to complete (wait == !ready).
368
+
369
+ Returns:
370
+ True if the request was successfully sent to at least one involved agent, False otherwise.
371
+ """
372
+
373
+ # - if "agent" is a peer ID, the involved agents will be a list with one element.
374
+ # - if "agent" is a known wildcard, as "<valid_cmp>", then involved agents will be self._valid_cmp_agents
375
+ # - if "agent" is None, then the current agent in self._engaged_agents will be returned
376
+ involved_agents = self.__involved_agents(agent)
377
+ if len(involved_agents) == 0:
378
+ return False
379
+
380
+ at_least_one_completed = False
381
+ for _agent in involved_agents:
382
+ self.out(f"Telling {_agent} to alter his HSM {from_state} -> {to_state} (wait: {wait}) "
383
+ f"by calling method 'wait_for_actions' on it")
384
+ ret = await self._node_conn.send(_agent, channel_trail=None,
385
+ content={'method': 'wait_for_actions',
386
+ 'args': (from_state, to_state, wait)},
387
+ content_type=Msg.HSM)
388
+ at_least_one_completed = at_least_one_completed or ret
389
+ return at_least_one_completed
390
+
391
+ async def ask_gen(self, agent: str | None = None, u_hashes: list[str] | None = None,
392
+ samples: int = 100, time: float = -1., timeout: float = -1., ask_uuid: str | None = None,
393
+ ignore_uuid: bool = False):
394
+ """Asking for generation.
395
+
396
+ Args:
397
+ agent: The ID of the agent to ask for generation, or a valid wildcard like "<valid_cmp>"
398
+ for a set of agents (if None the agents in self._engaged_agents will be considered).
399
+ u_hashes: A list of input stream hashes for generation. Defaults to None.
400
+ samples: The number of samples to generate. Defaults to 100.
401
+ time: The time duration for generation. Defaults to -1.
402
+ timeout: The timeout for the generation request. Defaults to -1.
403
+ ask_uuid: Specify the UUID of the action (if None - default -, it is randomly generated).
404
+ ignore_uuid: Force a None UUID instead of generating a random one.
405
+
406
+ Returns:
407
+ True if the generation request was successfully sent to at least one involved agent, False otherwise.
408
+ """
409
+ assert samples is not None and time is not None and timeout is not None, "Missing basic action information"
410
+
411
+ # - if "agent" is a peer ID, the involved agents will be a list with one element.
412
+ # - if "agent" is a known wildcard, as "<valid_cmp>", then involved agents will be self._valid_cmp_agents
413
+ # - if "agent" is None, then the current agent in self._engaged_agents will be returned
414
+ involved_agents = self.__involved_agents(agent)
415
+ self.deb(f"[ask_gen] Involved_agents: {involved_agents}")
416
+
417
+ if len(involved_agents) == 0:
418
+ self.deb(f"[ask_gen] No involved agents, action ask_gen returns False")
419
+ return False
420
+
421
+ # Create a copy of the input hashes, normalizing them in the appropriate way
422
+ u_hashes_copy: list[str | None] = [None] * len(u_hashes)
423
+ for i in range(len(u_hashes_copy)):
424
+ if u_hashes_copy[i] == "<playlist>":
425
+
426
+ # From <playlist> to the current element of the playlist
427
+ u_hashes_copy[i] = self._preferred_streams[self._cur_preferred_stream]
428
+ else:
429
+
430
+ # From a user specified hash to a net hash (e.g., peer_id:name_or_group to peer_id::ps:name_or_group)
431
+ u_hashes_copy[i] = self.user_stream_hash_to_net_hash(u_hashes[i])
432
+
433
+ # Generate a new UUID for this request
434
+ ref_uuid = uuid.uuid4().hex[0:8] if ask_uuid is None else ask_uuid
435
+ if ignore_uuid:
436
+ ref_uuid = None
437
+
438
+ # If the input streams are all owned by this agent, discard UUID
439
+ all_owned = True
440
+ for i in range(len(u_hashes_copy)):
441
+ if u_hashes_copy[i] not in self.owned_streams:
442
+ all_owned = False
443
+ break
444
+ if not all_owned:
445
+ ref_uuid = None
446
+
447
+ for i in range(len(u_hashes_copy)):
448
+
449
+ # If there are our own streams involved, and they are buffered, let's plan to restart them when we will
450
+ # start sending them through the net: moreover, let's set the local stream UUID appropriately to
451
+ # the generated UUID
452
+ if u_hashes_copy[i] in self.owned_streams:
453
+ stream_dict = self.known_streams[u_hashes_copy[i]]
454
+ for stream_name, stream_obj in stream_dict.items():
455
+
456
+ # Plan to restart buffered streams
457
+ if isinstance(stream_obj, BufferedDataStream):
458
+ stream_obj.plan_restart_before_next_get(requested_by="send_stream_samples")
459
+
460
+ # Activate the stream (if it was off)
461
+ stream_obj.enable()
462
+
463
+ # Set UUID to the generated one
464
+ stream_obj.set_uuid(ref_uuid=ref_uuid, expected=False)
465
+ stream_obj.set_uuid(ref_uuid=None, expected=True)
466
+
467
+ self.deb(f"[ask_gen] Input streams u_hashes: {u_hashes_copy}")
468
+
469
+ self.out(f"Asking {', '.join(involved_agents)} to generate signal given {u_hashes_copy} (ref_uuid: {ref_uuid})")
470
+ self._agents_who_completed_what_they_were_asked = set()
471
+ self._agents_who_were_asked = set()
472
+ correctly_asked = []
473
+ for peer_id in involved_agents:
474
+ ret = await self.__ask_gen_or_learn(for_what="gen", agent=peer_id,
475
+ u_hashes=u_hashes_copy,
476
+ yhat_hashes=None,
477
+ samples=samples, time=time, timeout=timeout, ref_uuid=ref_uuid)
478
+ self.deb(f"[ask_gen] Asking {peer_id} returned {ret}")
479
+ if ret:
480
+ correctly_asked.append(peer_id)
481
+
482
+ # Preparing the buffered stream where to store data, if needed
483
+ if len(correctly_asked) > 0:
484
+
485
+ # Saving
486
+ self.last_ref_uuid = ref_uuid
487
+
488
+ # For each agent that we involve in this request....
489
+ for peer_id in correctly_asked:
490
+
491
+ # Finding the streams generated by the processor of the agent we asked to generate
492
+ processor_streams = self.find_streams(peer_id, name_or_group="processor")
493
+
494
+ # For each stream generated by the processor of the agent we asked to generate...
495
+ for net_hash, stream_dict in processor_streams.items():
496
+
497
+ # Set the appropriate UUID to the one we created in this method
498
+ for stream in stream_dict.values():
499
+ stream.set_uuid(None, expected=False)
500
+ stream.set_uuid(ref_uuid, expected=True) # Setting the "expected" one
501
+
502
+ self.deb(f"[ask_gen] Overall, the action ask_gen will return {len(correctly_asked) > 0}")
503
+ return len(correctly_asked) > 0
504
+
505
+ async def do_gen(self, u_hashes: list[str] | None = None,
506
+ samples: int = 100, time: float = -1., timeout: float = -1.,
507
+ _requester: str | list | None = None, _request_time: float = -1., _request_uuid: str | None = None,
508
+ _completed: bool = False) -> bool:
509
+ """Generate a signal (async).
510
+
511
+ Args:
512
+ u_hashes: A list of input stream hashes for generation. Defaults to None.
513
+ samples: The number of samples to generate. Defaults to 100.
514
+ time: The max time duration for whole generation process. Defaults to -1.
515
+ timeout: The timeout for generation attempts: if calling the generate action fails for more than "timeout"
516
+ seconds, it is declared as complete. Defaults to -1.
517
+ _requester: The ID of the agent who requested generation (automatically set by the action calling routine).
518
+ _request_time: The time the generation was requested (automatically set by the action calling routine).
519
+ _request_uuid: The UUID of the generation request (automatically set by the action calling routine).
520
+ _completed: A boolean indicating if the generation is already completed (automatically set by the action
521
+ calling routine). This will tell that it is time to run a final procedure.
522
+
523
+ Returns:
524
+ True if the signal generation was successful, False otherwise.
525
+ """
526
+ assert samples is not None and time is not None and timeout is not None, "Missing basic action information"
527
+
528
+ self.deb(f"[do_gen] Samples: {samples}, time: {time}, timeout: {timeout}, "
529
+ f"requester: {_requester}, request_time: {_request_time}, request_uuid: {_request_uuid}, "
530
+ f"completed: {_completed}")
531
+
532
+ if _requester is not None:
533
+ if isinstance(_requester, list):
534
+ for _r in _requester:
535
+ if self.behaving_in_world():
536
+ if _r not in self.world_agents and _requester not in self.world_masters:
537
+ self.err(f"Unknown agent: {_r} in list {_requester} (fully skipping generation)")
538
+ return False
539
+ else:
540
+ if _r not in self.public_agents:
541
+ self.err(f"Unknown agent: {_r} in list {_requester} (fully skipping generation)")
542
+ return False
543
+ else:
544
+ if self.behaving_in_world():
545
+ if _requester not in self.world_agents and _requester not in self.world_masters:
546
+ self.err(f"Unknown agent: {_requester} (fully skipping generation)")
547
+ return False
548
+ else:
549
+ if _requester not in self.public_agents:
550
+ self.err(f"Unknown agent: {_requester} (fully skipping generation)")
551
+ return False
552
+
553
+ # Check what is the step ID of the multistep action
554
+ k = self.get_action_step()
555
+
556
+ # In the first step of this action, we change the UUID of the local stream associated to the input data we will
557
+ # use to handle this action, setting expectations to avoid handling tags of old data
558
+ if k == 0:
559
+
560
+ # Warning: we are not normalizing the hashes, we should do it if this action is called directly
561
+ if u_hashes is not None:
562
+ for net_hash in u_hashes:
563
+ if net_hash in self.known_streams:
564
+ for stream_name, stream_obj in self.known_streams[net_hash].items():
565
+
566
+ # If the data arrived before this action, then the UUID is already set, and here there is
567
+ # no need to do anything; if the data has not yet arrived (common case) ...
568
+ if stream_obj.get_uuid(expected=False) != _request_uuid:
569
+ stream_obj.set_uuid(None, expected=False) # Clearing UUID
570
+ stream_obj.set_uuid(_request_uuid, expected=True) # Setting expectations
571
+ else:
572
+ self.out(f"Unknown stream mentioned in u_hashes: {net_hash}")
573
+ return False
574
+
575
+ if not _completed:
576
+ self.out(f"Generating signal")
577
+ ret = self.__process_streams(u_hashes=u_hashes, yhat_hashes=None, learn=False,
578
+ recipient=_requester, ref_uuid=_request_uuid)
579
+ if not ret:
580
+ self.out(f"Generating signal failed")
581
+ else:
582
+ if not self.is_multi_steps_action():
583
+ self.out(f"Completing signal generation (degenerate single-step case of a multi-step action")
584
+ ret = await self.__complete_do(do_what="gen", peer_id_who_asked=_requester, all_hashes=u_hashes,
585
+ send_back_confirmation=False)
586
+ if not ret:
587
+ self.out(f"Completing signal generation failed")
588
+ return ret
589
+ else:
590
+ self.out(f"Completing signal generation")
591
+ ret = await self.__complete_do(do_what="gen", peer_id_who_asked=_requester, all_hashes=u_hashes)
592
+ if not ret:
593
+ self.out(f"Completing signal generation failed")
594
+ return ret
595
+
596
+ async def done_gen(self, _requester: str | None = None):
597
+ """This is a way to get back the confirmation of a completed generation (async).
598
+
599
+ Args:
600
+ _requester: The ID of the agent who completed the generation. Defaults to None.
601
+
602
+ Returns:
603
+ True if the generation confirmation was successfully handled by this agent, False is something went wrong.
604
+ """
605
+ self.out(f"Agent {_requester} finished generation")
606
+
607
+ # Searching for the processor-streams of the agent who generated data
608
+ processor_streams = self.find_streams(_requester, name_or_group="processor")
609
+ if processor_streams is None or len(processor_streams) == 0:
610
+ self.err("Unexpected confirmation of finished generation")
611
+ return False
612
+
613
+ # Remembering that the agent that invoked this action is the one who generated the data, and what he generated
614
+ # could be used in future action (for example, in evaluation processes)
615
+ self._agents_who_completed_what_they_were_asked.add(_requester)
616
+
617
+ # Clearing the UUID of the local streams associated to the agent who generated
618
+ for net_hash, stream_dict in processor_streams.items():
619
+ for stream_obj in stream_dict.values():
620
+ stream_obj.set_uuid(None, expected=False)
621
+ stream_obj.set_uuid(None, expected=True)
622
+
623
+ # If one or more of my streams where used as arguments of the generation request I did (ask_gen), then their
624
+ # UUID must be cleared...we clear them all
625
+ for net_hash, stream_dict in self.owned_streams.items():
626
+ for stream_obj in stream_dict.values():
627
+ if stream_obj.props.is_public() != self.behaving_in_world():
628
+ stream_obj.set_uuid(None, expected=False)
629
+ stream_obj.set_uuid(None, expected=True)
630
+ return True
631
+
632
+ async def ask_learn(self, agent: str | None = None,
633
+ u_hashes: list[str] | None = None, yhat_hashes: list[str] | None = None,
634
+ samples: int = 100, time: float = -1., timeout: float = -1., ask_uuid: str | None = None,
635
+ ignore_uuid: str | None = None):
636
+ """Asking for learning to generate (async).
637
+
638
+ Args:
639
+ agent: The ID of the agent to ask for generation, or a valid wildcard like "<valid_cmp>"
640
+ for a set of agents (if None the agents in self._engaged_agents will be considered).
641
+ u_hashes: A list of input stream hashes for inference. Defaults to None.
642
+ yhat_hashes: A list of target stream hashes to be used for loss computation. Defaults to None.
643
+ samples: The number of samples to learn from. Defaults to 100.
644
+ time: The time duration for generation. Defaults to -1.
645
+ timeout: The timeout for the generation request. Defaults to -1.
646
+ ask_uuid: Specify the action UUID (default = None, i.e., it is automatically generated).
647
+ ignore_uuid: If Trie, the UUID is fully ignored (i.e, forced to None).
648
+
649
+ Returns:
650
+ True if the learning request was successfully sent to at least one involved agent, False otherwise.
651
+ """
652
+ assert samples is not None and time is not None and timeout is not None, "Missing basic action information"
653
+
654
+ # - if "agent" is a peer ID, the involved agents will be a list with one element.
655
+ # - if "agent" is a known wildcard, as "<valid_cmp>", then involved agents will be self._valid_cmp_agents
656
+ # - if "agent" is None, then the current agent in self._engaged_agents will be returned
657
+ involved_agents = self.__involved_agents(agent)
658
+ self.deb(f"[ask_learn] Involved agents: {involved_agents}")
659
+
660
+ if len(involved_agents) == 0:
661
+ self.deb(f"[ask_learn] No involved agents, action will return False")
662
+ return False
663
+
664
+ # Create a copy of the input hashes, normalizing them in the appropriate way
665
+ u_hashes_copy = [x for x in u_hashes]
666
+ for i in range(len(u_hashes_copy)):
667
+ if u_hashes_copy[i] == "<playlist>":
668
+
669
+ # From <playlist> to the current element of the playlist
670
+ u_hashes_copy[i] = self._preferred_streams[self._cur_preferred_stream]
671
+ else:
672
+
673
+ # From a user specified hash to a net hash (e.g., peer_id:name_or_group to peer_id::ps:name_or_group)
674
+ u_hashes_copy[i] = self.user_stream_hash_to_net_hash(u_hashes_copy[i])
675
+
676
+ # Create a copy of the target hashes, normalizing them in the appropriate way
677
+ yhat_hashes_copy = [x for x in yhat_hashes]
678
+ for i in range(len(yhat_hashes_copy)):
679
+ if yhat_hashes_copy[i] == "<playlist>":
680
+
681
+ # From <playlist> to the current element of the playlist
682
+ yhat_hashes_copy[i] = self._preferred_streams[self._cur_preferred_stream]
683
+ else:
684
+
685
+ # From a user specified hash to a net hash (e.g., peer_id:name_or_group to peer_id::ps:name_or_group)
686
+ yhat_hashes_copy[i] = self.user_stream_hash_to_net_hash(yhat_hashes_copy[i])
687
+
688
+ # Generate a new UUID for this request
689
+ ref_uuid = uuid.uuid4().hex[0:8] if ask_uuid is None else ask_uuid
690
+ if ignore_uuid:
691
+ ref_uuid = None
692
+
693
+ # If the input streams are all owned by this agent, discard UUID
694
+ all_owned = True
695
+ for i in range(len(u_hashes_copy)):
696
+ if u_hashes_copy[i] not in self.owned_streams:
697
+ all_owned = False
698
+ break
699
+ if all_owned:
700
+ for i in range(len(yhat_hashes_copy)):
701
+ if yhat_hashes_copy[i] not in self.owned_streams:
702
+ all_owned = False
703
+ break
704
+ if not all_owned:
705
+ ref_uuid = None
706
+
707
+ for i in range(len(u_hashes_copy)):
708
+
709
+ # If there are our own streams involved, and they are buffered, let's plan to restart them when we will
710
+ # start sending them through the net: moreover, let's set the local stream UUID appropriately to
711
+ # the generated UUID
712
+ if u_hashes_copy[i] in self.owned_streams:
713
+ stream_dict = self.known_streams[u_hashes_copy[i]]
714
+ for stream_name, stream_obj in stream_dict.items():
715
+
716
+ # Plan to restart buffered streams
717
+ if isinstance(stream_obj, BufferedDataStream):
718
+ stream_obj.plan_restart_before_next_get(requested_by="send_stream_samples")
719
+
720
+ # Activate the stream (if it was off)
721
+ stream_obj.enable()
722
+
723
+ # Set UUID to the generated one
724
+ stream_obj.set_uuid(ref_uuid=ref_uuid, expected=False)
725
+ stream_obj.set_uuid(ref_uuid=None, expected=True)
726
+
727
+ for i in range(len(yhat_hashes_copy)):
728
+
729
+ # If there are our own streams involved, and they are buffered, let's plan to restart them when we will
730
+ # start sending them through the net: moreover, let's set the local stream UUID appropriately to
731
+ # the generated UUID
732
+ if yhat_hashes_copy[i] in self.owned_streams:
733
+ stream_dict = self.known_streams[yhat_hashes_copy[i]]
734
+ for stream_name, stream_obj in stream_dict.items():
735
+
736
+ # Plan to restart buffered streams
737
+ if isinstance(stream_obj, BufferedDataStream):
738
+ stream_obj.plan_restart_before_next_get(requested_by="send_stream_samples")
739
+
740
+ # Activate the stream (if it was off)
741
+ stream_obj.enable()
742
+
743
+ # Set UUID to the generated one
744
+ stream_obj.set_uuid(ref_uuid=ref_uuid, expected=False)
745
+ stream_obj.set_uuid(ref_uuid=None, expected=True)
746
+
747
+ self.out(f"Asking {', '.join(involved_agents)} to learn to generate signal {yhat_hashes_copy}, "
748
+ f"given {u_hashes_copy} (ref_uuid: {ref_uuid})")
749
+ self._agents_who_completed_what_they_were_asked = set()
750
+ self._agents_who_were_asked = set()
751
+ correctly_asked = []
752
+ for peer_id in involved_agents:
753
+ ret = await self.__ask_gen_or_learn(for_what="learn", agent=peer_id,
754
+ u_hashes=u_hashes_copy,
755
+ yhat_hashes=yhat_hashes_copy,
756
+ samples=samples, time=time, timeout=timeout, ref_uuid=ref_uuid)
757
+ self.deb(f"[ask_learn] Asking {peer_id} returned {ret}")
758
+ if ret:
759
+ correctly_asked.append(peer_id)
760
+
761
+ # Preparing the buffered stream where to store data, if needed
762
+ if len(correctly_asked) > 0:
763
+
764
+ # Saving
765
+ self.last_ref_uuid = ref_uuid
766
+
767
+ # For each agent that we involve in this request....
768
+ for peer_id in correctly_asked:
769
+
770
+ # Finding the streams generated by the processor of the agent we asked to generate
771
+ processor_streams = self.find_streams(peer_id, name_or_group="processor")
772
+
773
+ # For each stream generated by the processor of the agent we asked to generate...
774
+ for net_hash, stream_dict in processor_streams.items():
775
+
776
+ # Set the appropriate UUID to the one we created in this method
777
+ for stream in stream_dict.values():
778
+ stream.set_uuid(None, expected=False)
779
+ stream.set_uuid(ref_uuid, expected=True) # Setting the "expected" one
780
+
781
+ self.deb(f"[ask_learn] Overall the action ask_learn will return {len(correctly_asked) > 0}")
782
+ return len(correctly_asked) > 0
783
+
784
+ async def do_learn(self, yhat_hashes: list[str] | None = None, u_hashes: list[str] | None = None,
785
+ samples: int = 100, time: float = -1., timeout: float = -1.,
786
+ _requester: str | None = None, _request_time: float = -1., _request_uuid: str | None = None,
787
+ _completed: bool = False) -> bool:
788
+ """Learn to generate a signal (async).
789
+
790
+ Args:
791
+ yhat_hashes: A list of target stream hashes to be used for loss computation. Defaults to None.
792
+ u_hashes: A list of input stream hashes for inference. Defaults to None.
793
+ samples: The number of samples to learn from. Defaults to 100.
794
+ time: The max time duration of the learning procedure. Defaults to -1.
795
+ timeout: The timeout for learning attempts: if calling the learning action fails for more than "timeout"
796
+ seconds, it is declared as complete. Defaults to -1.
797
+ _requester: The ID of the agent who requested learning (automatically set by the action calling routine).
798
+ _request_time: The time learning was requested (automatically set by the action calling routine).
799
+ _request_uuid: The UUID of the learning request (automatically set by the action calling routine).
800
+ _completed: A boolean indicating if the learning is already completed (automatically set by the action
801
+ calling routine). This will tell that it is time to run a final procedure.
802
+
803
+ Returns:
804
+ True if the signal generation was successful, False otherwise.
805
+ """
806
+ assert samples is not None and time is not None and timeout is not None, "Missing basic action information"
807
+
808
+ self.deb(f"[do_learn] samples: {samples}, time: {time}, timeout: {timeout}, "
809
+ f"requester: {_requester}, request_time: {_request_time}, request_uuid: {_request_uuid} "
810
+ f"completed: {_completed}")
811
+
812
+ if _requester not in self.world_agents and _requester not in self.world_masters:
813
+ self.err(f"Unknown agent: {_requester}")
814
+ return False
815
+
816
+ # Check what is the step ID of the multistep action
817
+ k = self.get_action_step()
818
+
819
+ # In the first step of this action, we change the UUID of the local stream associated to the input data we will
820
+ # use to handle this action, setting expectations to avoid handling tags of old data
821
+ if k == 0:
822
+
823
+ # Warning: we are not normalizing the hashes, we should do it if this action is called directly
824
+ if u_hashes is not None:
825
+ for net_hash in u_hashes:
826
+ if net_hash in self.known_streams:
827
+ for stream_obj in self.known_streams[net_hash].values():
828
+
829
+ # If the data arrived before this action, then the UUID is already set, and here there is
830
+ # no need to do anything; if the data has not yet arrived (common case) ...
831
+ if stream_obj.get_uuid(expected=False) != _request_uuid:
832
+ stream_obj.set_uuid(None, expected=False) # Clearing UUID
833
+ stream_obj.set_uuid(_request_uuid, expected=True) # Setting expectations
834
+
835
+ # Warning: we are not normalizing the hashes, we should do it if this action is called directly
836
+ if yhat_hashes is not None:
837
+ for net_hash in yhat_hashes:
838
+ if net_hash in self.known_streams:
839
+ for stream_obj in self.known_streams[net_hash].values():
840
+ if stream_obj.get_uuid(expected=False) != _request_uuid:
841
+ stream_obj.set_uuid(None, expected=False) # Clearing UUID
842
+ stream_obj.set_uuid(_request_uuid, expected=True) # Setting expectations
843
+
844
+ if not _completed:
845
+ self.out(f"Learning to generate signal {yhat_hashes}")
846
+ ret = self.__process_streams(u_hashes=u_hashes, yhat_hashes=yhat_hashes, learn=True,
847
+ recipient=_requester, ref_uuid=_request_uuid)
848
+ if not ret:
849
+ self.out(f"Learning to generate signal {yhat_hashes} failed")
850
+ return ret
851
+ else:
852
+ self.out(f"Completing learning to generate signal {yhat_hashes}")
853
+ all_hashes = (u_hashes if u_hashes is not None else []) + (yhat_hashes if yhat_hashes is not None else [])
854
+ ret = await self.__complete_do(do_what="learn", peer_id_who_asked=_requester, all_hashes=all_hashes)
855
+ if not ret:
856
+ self.out(f"Completing learning to generate signal {yhat_hashes} failed")
857
+ return ret
858
+
859
+ async def done_learn(self, _requester: str | None = None):
860
+ """This is a way to get back the confirmation of a completed learning procedure (async).
861
+
862
+ Args:
863
+ _requester: The ID of the agent who completed the learning procedure. Defaults to None.
864
+
865
+ Returns:
866
+ True if the learning-complete confirmation was successfully handled by this agent, False otherwise.
867
+ """
868
+ self.out(f"Agent {_requester} finished learning")
869
+ self._agents_who_completed_what_they_were_asked.add(_requester)
870
+
871
+ # Searching for the processor-streams of the agent who generated the (inference) data
872
+ processor_streams = self.find_streams(_requester, name_or_group="processor")
873
+ if processor_streams is None or len(processor_streams) == 0:
874
+ self.err("Unexpected confirmation of finished learning")
875
+ return False
876
+
877
+ # Warning: differently from the case of done_gen, we are not considering the streams generated by the
878
+ # learning agents as something we could use for evaluation (this might be changed in the future)
879
+
880
+ # Clearing the UUID of the local streams associated to the agent who learned
881
+ for net_hash, stream_dict in processor_streams.items():
882
+ for stream_obj in stream_dict.values():
883
+ stream_obj.set_uuid(None, expected=False)
884
+ stream_obj.set_uuid(None, expected=True)
885
+
886
+ # If one or more of my streams where used as arguments of the learning request I did (ask_learn), then their
887
+ # UUID must be cleared...we clear them all
888
+ for net_hash, stream_dict in self.owned_streams.items():
889
+ for stream_obj in stream_dict.values():
890
+ if stream_obj.props.is_public() != self.behaving_in_world():
891
+ stream_obj.set_uuid(None, expected=False)
892
+ stream_obj.set_uuid(None, expected=True)
893
+ return True
894
+
895
+ async def all_asked_finished(self):
896
+ """Checks if all agents that were previously asked to perform a task (e.g., generate or learn) have sent a
897
+ completion confirmation. It compares the set of agents asked with the set of agents that have completed
898
+ the task (async).
899
+
900
+ Returns:
901
+ True if all agents are done, False otherwise.
902
+ """
903
+ return self._agents_who_were_asked == self._agents_who_completed_what_they_were_asked
904
+
905
+ async def all_engagements_completed(self):
906
+ """Checks if all engagement requests that were sent have been confirmed. It returns True if there are no agents
907
+ remaining in the `_found_agents` list, implying all have been engaged with or discarded (async).
908
+
909
+ Returns:
910
+ True if all engagements are complete, False otherwise.
911
+
912
+ """
913
+ return len(self._found_agents) == 0
914
+
915
+ async def agents_are_waiting(self):
916
+ """Checks if there are any agents who have connected but have not yet been fully processed or added to the
917
+ agent's known lists. This indicates that new agents are waiting to be managed (async).
918
+
919
+ Returns:
920
+ True if there are waiting agents, False otherwise.
921
+ """
922
+ self.out(f"Current set of {len(self._node_agents_waiting)} connected peer IDs non managed yet: "
923
+ f"{self._node_agents_waiting}")
924
+ for found_agent in self._found_agents:
925
+ if found_agent in self._node_agents_waiting:
926
+ return True
927
+ return False
928
+
929
+ async def ask_subscribe(self, agent: str | None = None,
930
+ stream_hashes: list[str] | None = None, unsubscribe: bool = False):
931
+ """Requests a remote agent or a group of agents to subscribe to or unsubscribe from a list of specified PubSub
932
+ streams. It normalizes the stream hashes and sends an action request containing the stream properties (async).
933
+
934
+ Args:
935
+ agent: The target agent's ID or a wildcard.
936
+ stream_hashes: A list of streams to subscribe to or unsubscribe from.
937
+ unsubscribe: A boolean to indicate if it's an unsubscription request.
938
+
939
+ Returns:
940
+ True if the request was sent to at least one agent, False otherwise.
941
+ """
942
+
943
+ # - if "agent" is a peer ID, the involved agents will be a list with one element.
944
+ # - if "agent" is a known wildcard, as "<valid_cmp>", then involved agents will be self._valid_cmp_agents
945
+ # - if "agent" is None, then the current agent in self._engaged_agents will be returned
946
+ involved_agents = self.__involved_agents(agent)
947
+ self.deb(f"[ask_subscribe] Involved_agents: {involved_agents}")
948
+
949
+ if len(involved_agents) == 0:
950
+ self.deb(f"[ask_subscribe] No involved agents, action ask_gen returns False")
951
+ return False
952
+
953
+ # Create a copy of the stream hashes, normalizing them in the appropriate way
954
+ stream_hashes_copy: list[str | None] = [None] * len(stream_hashes)
955
+ for i in range(len(stream_hashes_copy)):
956
+ if stream_hashes_copy[i] == "<playlist>":
957
+
958
+ # From <playlist> to the current element of the playlist
959
+ stream_hashes_copy[i] = self._preferred_streams[self._cur_preferred_stream]
960
+ else:
961
+
962
+ # From a user specified hash to a net hash (e.g., peer_id:name_or_group to peer_id::ps:name_or_group)
963
+ stream_hashes_copy[i] = self.user_stream_hash_to_net_hash(stream_hashes[i])
964
+
965
+ # Getting properties
966
+ stream_owners = []
967
+ stream_props = []
968
+ for i in range(len(stream_hashes_copy)):
969
+ stream_dict = self.known_streams[stream_hashes_copy[i]]
970
+ peer_id = DataProps.peer_id_from_net_hash(stream_hashes_copy[i])
971
+ for name, stream_obj in stream_dict.items():
972
+ stream_owners.append(peer_id)
973
+ stream_props.append(json.dumps(stream_obj.props.to_dict()))
974
+
975
+ what = "subscribe to" if not unsubscribe else "unsubscribe from "
976
+ self.out(f"Asking {', '.join(involved_agents)} to {what} {stream_hashes}")
977
+ self._agents_who_completed_what_they_were_asked = set()
978
+ self._agents_who_were_asked = set()
979
+ correctly_asked = []
980
+ for agent in involved_agents:
981
+ if await self.set_next_action(agent, action="do_subscribe", args={"stream_owners": stream_owners,
982
+ "stream_props": stream_props,
983
+ "unsubscribe": unsubscribe}):
984
+ self._agents_who_were_asked.add(agent)
985
+ ret = True
986
+ else:
987
+ what = "subscribe" if not unsubscribe else "unsubscribe"
988
+ self.err(f"Unable to ask {agent} to {what}")
989
+ ret = False
990
+ self.deb(f"[ask_subscribe] Asking {agent} returned {ret}")
991
+ if ret:
992
+ correctly_asked.append(agent)
993
+
994
+ self.deb(f"[ask_subscribe] Overall, the action ask_subscribe (unsubscribe: {unsubscribe})"
995
+ f" will return {len(correctly_asked) > 0}")
996
+ return len(correctly_asked) > 0
997
+
998
+ async def do_subscribe(self, stream_owners: list[str] | None = None, stream_props: list[str] | None = None,
999
+ unsubscribe: bool = False,
1000
+ _requester: str | list | None = None, _request_time: float = -1.):
1001
+ """Executes a subscription or unsubscription request received from another agent. It processes the stream
1002
+ properties, adds or removes the streams from the agent's known streams, and handles the underlying PubSub topic
1003
+ subscriptions (async).
1004
+
1005
+ Args:
1006
+ stream_owners: A list of peer IDs who own the streams.
1007
+ stream_props: A list of JSON-serialized stream properties.
1008
+ unsubscribe: A boolean to indicate unsubscription.
1009
+ _requester: The ID of the requesting agent.
1010
+ _request_time: The time the request was made.
1011
+
1012
+ Returns:
1013
+ True if the action is successful, False otherwise.
1014
+ """
1015
+ self.deb(f"[do_subscribe] unsubscribe: {unsubscribe}, "
1016
+ f"stream_owners: {stream_owners}, stream_props: ... ({len(stream_props)} props)")
1017
+
1018
+ if _requester is not None:
1019
+ if isinstance(_requester, list):
1020
+ for _r in _requester:
1021
+ if self.behaving_in_world():
1022
+ if _r not in self.world_agents and _requester not in self.world_masters:
1023
+ self.err(f"Unknown agent: {_r} in list {_requester} (fully skipping do_subscribe)")
1024
+ return False
1025
+ else:
1026
+ if _r not in self.public_agents:
1027
+ self.err(f"Unknown agent: {_r} in list {_requester} (fully skipping do_subscribe)")
1028
+ return False
1029
+ else:
1030
+ if self.behaving_in_world():
1031
+ if _requester not in self.world_agents and _requester not in self.world_masters:
1032
+ self.err(f"Unknown agent: {_requester} (fully skipping do_subscribe)")
1033
+ return False
1034
+ else:
1035
+ if _requester not in self.public_agents:
1036
+ self.err(f"Unknown agent: {_requester} (fully skipping do_subscribe)")
1037
+ return False
1038
+ else:
1039
+ self.err("Unknown requester (None)")
1040
+ return False
1041
+
1042
+ # Building properties
1043
+ props_dicts = []
1044
+ props_objs = []
1045
+ for i in range(len(stream_props)):
1046
+ p_dict = json.loads(stream_props[i])
1047
+ props = DataProps.from_dict(p_dict)
1048
+ if props.is_pubsub():
1049
+ props_dicts.append(p_dict)
1050
+ props_objs.append(props)
1051
+ else:
1052
+ self.err(f"Expecting a pubsub stream, got a stream named {props.get_name()} "
1053
+ f"(group is {props.get_group()}), which is not pubsub")
1054
+ return False
1055
+
1056
+ # Adding new streams and subscribing (if compatible with our processor)
1057
+ for stream_owner, prop_dict, prop_obj in zip(stream_owners, props_dicts, props_objs):
1058
+ if not unsubscribe:
1059
+ if not (await self.add_compatible_streams(peer_id=stream_owner, streams_in_profile=[prop_dict],
1060
+ buffered=False, public=False)):
1061
+ self.out(f"Unable to add a pubsub stream ({prop_obj.get_name()}) from agent {stream_owner}: "
1062
+ f"no compatible streams were found")
1063
+ else:
1064
+ if not (await self.remove_streams(peer_id=stream_owner, name=prop_obj.get_name())):
1065
+ self.out(f"Unable to unsubscribe from pubsub stream ({prop_obj.get_name()}) "
1066
+ f"of agent {stream_owner}")
1067
+ return True
1068
+
1069
+ async def done_subscribe(self, unsubscribe: bool = False, _requester: str | None = None):
1070
+ """Handles the confirmation that a subscription or unsubscription request has been completed by another agent.
1071
+ It adds the requester to the set of agents that have completed their asked tasks (async).
1072
+
1073
+ Args:
1074
+ unsubscribe: A boolean indicating if it was an unsubscription.
1075
+ _requester: The ID of the agent who completed the task.
1076
+
1077
+ Returns:
1078
+ Always True.
1079
+ """
1080
+ what = "subscribing" if unsubscribe else "unsubscribing"
1081
+ self.out(f"Agent {_requester} finished {what}")
1082
+
1083
+ # Remembering that the agent that invoked this action is the one who actually subscribed
1084
+ self._agents_who_completed_what_they_were_asked.add(_requester)
1085
+ return True
1086
+
1087
+ async def record(self, net_hash: str, samples: int = 100, time: float = -1., timeout: float = -1.):
1088
+ """Records data from a specified stream into a new, owned `BufferedDataStream`. This is a multistep action
1089
+ that captures a sequence of samples over time and then adds the new recorded stream to the agent's profile
1090
+ (async).
1091
+
1092
+ Args:
1093
+ net_hash: The hash of the stream to record.
1094
+ samples: The number of samples to record.
1095
+ time: The time duration for recording.
1096
+ timeout: The timeout for each recording attempt.
1097
+
1098
+ Returns:
1099
+ True if a sample was successfully recorded, False otherwise.
1100
+ """
1101
+ assert samples is not None and time is not None and timeout is not None, "Missing basic action information"
1102
+
1103
+ k = self.get_action_step()
1104
+
1105
+ self.out(f"Recording stream {net_hash}")
1106
+
1107
+ if k == 0:
1108
+
1109
+ # Getting stream(s)
1110
+ _net_hash = self.user_stream_hash_to_net_hash(net_hash) # In case of ambiguity, it yields the first one
1111
+ if _net_hash is None:
1112
+ self.err(f"Unknown stream {net_hash}")
1113
+ return False
1114
+ else:
1115
+ net_hash = _net_hash
1116
+
1117
+ stream_src_dict = self.known_streams[net_hash]
1118
+
1119
+ # Creating the new recorded stream (same props of the recorded one, just owned now)
1120
+ stream_dest_dict = {}
1121
+ for name, stream_obj in stream_src_dict.items():
1122
+ props = stream_obj.props.clone()
1123
+ props.set_group("recorded" + str(self._last_recorded_stream_num))
1124
+ stream_dest_dict[name] = BufferedDataStream(props=props, clock=self._node_clock)
1125
+ self._last_recorded_stream_dict = stream_dest_dict
1126
+ self._last_recording_stream_dict = stream_src_dict
1127
+
1128
+ else:
1129
+
1130
+ # Retrieving the stream(s)
1131
+ stream_dest_dict = self._last_recorded_stream_dict
1132
+ stream_src_dict = self._last_recording_stream_dict
1133
+
1134
+ # Recording
1135
+ for name, stream_obj in stream_src_dict.items():
1136
+ x = stream_obj.get(requested_by="record")
1137
+ if x is None:
1138
+ self.deb("[record] data sample missing, returning False")
1139
+ return False
1140
+ else:
1141
+ self.deb(f"[record] data_tag: {stream_obj.get_tag()}, data_uuid: {stream_obj.get_uuid()}")
1142
+ stream_dest_dict[name].set(x, k) # Saving specific data tags 0, 1, 2, ... #record_steps - 1
1143
+
1144
+ # Updating profile
1145
+ if self.is_last_action_step():
1146
+ self.deb("[record] last action step detected, finishing")
1147
+
1148
+ # Dummy get to ensure that the next get will return None (i.e., we only PubSub if somebody restarts this)
1149
+ for stream_obj in stream_dest_dict.values():
1150
+ stream_obj.get(requested_by="send_stream_samples")
1151
+
1152
+ self.add_streams(list(stream_dest_dict.values()), owned=True)
1153
+ self.update_streams_in_profile()
1154
+ await self.subscribe_to_pubsub_owned_streams()
1155
+ await self.send_profile_to_all()
1156
+
1157
+ # New recorded stream
1158
+ self._last_recorded_stream_num += 1
1159
+
1160
+ return True
1161
+
1162
+ async def connect_by_role(self, role: str | list[str], filter_fcn: str | None = None,
1163
+ time: float = -1., timeout: float = -1.):
1164
+ """Finds and attempts to connect with agents whose profiles match a specific role. It can be optionally
1165
+ filtered by a custom function. It returns True if at least one valid agent is found (async).
1166
+
1167
+ Args:
1168
+ role: The role or list of roles to search for.
1169
+ filter_fcn: The name of an optional filter function.
1170
+ time: The time duration for the action.
1171
+ timeout: The action timeout.
1172
+
1173
+ Returns:
1174
+ True if at least one agent is found and a connection request is made, False otherwise.
1175
+ """
1176
+ self.out(f"Asking to get in touch with all agents whose role is {role}")
1177
+ assert time is not None and timeout is not None, "Missing basic action information"
1178
+
1179
+ if self.get_action_step() == 0:
1180
+ role_list = role if isinstance(role, list) else [role]
1181
+ self._found_agents = set()
1182
+ at_least_one_is_valid = False
1183
+
1184
+ for role in role_list:
1185
+ role = self.ROLE_STR_TO_BITS[role]
1186
+
1187
+ found_addresses1, found_peer_ids1 = self._node_conn.find_addrs_by_role(Agent.ROLE_WORLD_MASTER | role,
1188
+ return_peer_ids_too=True)
1189
+ found_addresses2, found_peer_ids2 = self._node_conn.find_addrs_by_role(Agent.ROLE_WORLD_AGENT | role,
1190
+ return_peer_ids_too=True)
1191
+ found_addresses = found_addresses1 + found_addresses2
1192
+ found_peer_ids = found_peer_ids1 + found_peer_ids2
1193
+
1194
+ if filter_fcn is not None:
1195
+ if hasattr(self, filter_fcn):
1196
+ filter_fcn = getattr(self, filter_fcn)
1197
+ if callable(filter_fcn):
1198
+ found_addresses, found_peer_ids = filter_fcn(found_addresses, found_peer_ids)
1199
+ else:
1200
+ self.err(f"Filter function not found: {filter_fcn}")
1201
+
1202
+ self.out(f"Found addresses ({len(found_addresses)}) with role: {role}")
1203
+ for f_addr, f_peer_id in zip(found_addresses, found_peer_ids):
1204
+ if not self._node_conn.is_connected(f_peer_id):
1205
+ self.out(f"Asking to get in touch with {f_addr}...")
1206
+ peer_id = await self._node_ask_to_get_in_touch_fcn(addresses=f_addr, public=False)
1207
+ else:
1208
+ self.out(f"Not-asking to get in touch with {f_addr}, "
1209
+ f"since I am already connected to the corresponding peer...")
1210
+ peer_id = f_peer_id
1211
+ if peer_id is not None:
1212
+ at_least_one_is_valid = True
1213
+ self._found_agents.add(peer_id)
1214
+ self.out(f"...returned {peer_id}")
1215
+ return at_least_one_is_valid
1216
+ else:
1217
+ return True
1218
+
1219
+ async def find_agents(self, role: str | list[str], engage: bool = False):
1220
+ """Locally searches through the agent's known peers (world and public agents) to find agents with a specific
1221
+ role. It populates the `_found_agents` set with the peer IDs of matching agents (async).
1222
+
1223
+ Args:
1224
+ role: The role or list of roles to search for.
1225
+ engage: If you want to force the found agents to be the ones that you are engaged with.
1226
+
1227
+ Returns:
1228
+ True if at least one agent is found, False otherwise.
1229
+ """
1230
+ self.out(f"Finding an available agent whose role is {role}")
1231
+ role_list = role if isinstance(role, list) else [role]
1232
+ self._found_agents = set()
1233
+
1234
+ for role_str in role_list:
1235
+ agents = self.all_agents
1236
+ role_int = self.ROLE_STR_TO_BITS[role_str]
1237
+ role_clean = (role_int >> 2) << 2
1238
+ for peer_id, profile in agents.items():
1239
+ _role_int = self.ROLE_STR_TO_BITS[profile.get_dynamic_profile()['connections']['role']]
1240
+ _role_clean = (_role_int >> 2) << 2
1241
+ if _role_clean == role_clean:
1242
+ self._found_agents.add(peer_id) # Peer IDs here
1243
+
1244
+ self.deb(f"[find_agents] Found these agents: {self._found_agents}")
1245
+ if engage:
1246
+ self._engaged_agents = copy.deepcopy(self._found_agents)
1247
+ return len(self._found_agents) > 0
1248
+
1249
+ async def next_pref_stream(self):
1250
+ """Moves the internal pointer to the next stream in the list of preferred streams, which is often used for
1251
+ playlist-like operations. It wraps around to the beginning if it reaches the end (async).
1252
+
1253
+ Returns:
1254
+ True if the move is successful, False if the list is empty.
1255
+ """
1256
+ if len(self._preferred_streams) == 0:
1257
+ self.err(f"Cannot move to the next stream because the list of preferred streams is empty")
1258
+ return False
1259
+
1260
+ self._cur_preferred_stream = (self._cur_preferred_stream + 1) % len(self._preferred_streams)
1261
+ suffix = ", warning: restarted" if self._cur_preferred_stream == 0 else ""
1262
+ self.out(f"Moving to the next preferred stream ({self._preferred_streams[self._cur_preferred_stream]}){suffix}")
1263
+ return True
1264
+
1265
+ async def first_pref_stream(self):
1266
+ """Resets the internal pointer to the first stream in the list of preferred streams. This is useful for
1267
+ restarting a playback or processing loop (async).
1268
+
1269
+ Returns:
1270
+ True if the move is successful, False if the list is empty.
1271
+ """
1272
+ if len(self._preferred_streams) == 0:
1273
+ self.err(f"Cannot move to the first stream because the list of preferred streams is empty")
1274
+ return False
1275
+
1276
+ self._cur_preferred_stream = 0
1277
+ self.out(f"Moving to the first preferred stream ({self._preferred_streams[self._cur_preferred_stream]})")
1278
+ return True
1279
+
1280
+ async def check_pref_stream(self, what: str = "last"):
1281
+ """Checks the position of the current preferred stream within the list. It can check if it's the first, last,
1282
+ or if it has completed a full round, among other checks (async).
1283
+
1284
+ Args:
1285
+ what: A string specifying the type of check to perform (e.g., 'first', 'last', 'last_round').
1286
+
1287
+ Returns:
1288
+ True if the condition is met, False otherwise.
1289
+ """
1290
+ valid = ['first', 'last', 'not_first', 'not_last', 'last_round', 'not_last_round', 'last_song', 'not_last_song']
1291
+ assert what in valid, f"The what argument can only be one of {valid}"
1292
+
1293
+ self.out(f"Checking if the current preferred playlist item "
1294
+ f"(id: {self._cur_preferred_stream}) is the '{what}' one")
1295
+ if what == "first":
1296
+ return self._cur_preferred_stream == 0
1297
+ elif what == "last":
1298
+ return self._cur_preferred_stream == len(self._preferred_streams) - 1
1299
+ elif what == "not_first":
1300
+ return self._cur_preferred_stream != 0
1301
+ elif what == "not_last":
1302
+ return self._cur_preferred_stream != len(self._preferred_streams) - 1
1303
+ elif what == "last_round":
1304
+ return (self._cur_preferred_stream + len(self._preferred_streams) // self._repeat >=
1305
+ len(self._preferred_streams))
1306
+ elif what == "not_last_round":
1307
+ return (self._cur_preferred_stream + len(self._preferred_streams) // self._repeat <
1308
+ len(self._preferred_streams))
1309
+ elif what == "last_song":
1310
+ num_streams_in_playlist = len(self._preferred_streams) // self._repeat
1311
+ return (self._cur_preferred_stream + 1) % num_streams_in_playlist == 0
1312
+ elif what == "not_last_song":
1313
+ num_streams_in_playlist = len(self._preferred_streams) // self._repeat
1314
+ return (self._cur_preferred_stream + 1) % num_streams_in_playlist != 0
1315
+
1316
+ async def set_pref_streams(self, net_hashes: list[str], repeat: int = 1):
1317
+ """Fills the agent's list of preferred streams (a playlist). It can repeat the playlist a specified number of
1318
+ times and resolves user-provided stream hashes to their full network hashes (async).
1319
+
1320
+ Args:
1321
+ net_hashes: A list of stream hashes to add to the playlist.
1322
+ repeat: The number of times to repeat the playlist.
1323
+
1324
+ Returns:
1325
+ Always True.
1326
+ """
1327
+ self.out(f"Setting up a list of {len(net_hashes)} preferred streams")
1328
+ self._cur_preferred_stream = 0
1329
+ self._preferred_streams = []
1330
+ self._repeat = repeat
1331
+ for i in range(0, self._repeat):
1332
+ for net_hash in net_hashes:
1333
+
1334
+ # We are tolerating both peer_id:name_or_group and also peer_id::ps:name_or_group
1335
+ components = net_hash.split(":")
1336
+ peer_id = components[0]
1337
+ name_or_group = components[-1]
1338
+ net_hash_to_streams = self.find_streams(peer_id=peer_id, name_or_group=name_or_group)
1339
+ for _net_hash in net_hash_to_streams.keys():
1340
+ self._preferred_streams.append(_net_hash)
1341
+
1342
+ return True
1343
+
1344
+ async def evaluate(self, stream_hash: str, how: str, steps: int = 100, re_offset: bool = False):
1345
+ """Evaluates the performance of agents that have completed a generation task. It compares the generated data
1346
+ from each agent with a local stream (which can be a ground truth or reference stream) using a specified
1347
+ comparison method (async).
1348
+
1349
+ Args:
1350
+ stream_hash: The hash of the local stream to use for comparison.
1351
+ how: The name of the comparison method to use.
1352
+ steps: The number of steps to perform the evaluation.
1353
+ re_offset: A boolean to indicate whether to re-offset the streams.
1354
+
1355
+ Returns:
1356
+ True if the evaluation is successful, False otherwise.
1357
+ """
1358
+ if not self.buffer_generated_by_others:
1359
+ self.err("Cannot evaluate if not buffering data generated by others")
1360
+ return False
1361
+
1362
+ if stream_hash == "<playlist>":
1363
+ net_hash = self._preferred_streams[self._cur_preferred_stream]
1364
+ else:
1365
+ net_hash = self.user_stream_hash_to_net_hash(stream_hash)
1366
+
1367
+ self._eval_results = {}
1368
+ self.deb(f"[eval] Agents returning streams: {self._agents_who_completed_what_they_were_asked}")
1369
+ for peer_id in self._agents_who_completed_what_they_were_asked:
1370
+ received_net_hash = self.last_buffered_peer_id_to_info[peer_id]["net_hash"]
1371
+ self.out(f"Comparing {net_hash} with {received_net_hash}")
1372
+ eval_result, ret = self.__compare_streams(net_hash_a=net_hash,
1373
+ net_hash_b=received_net_hash,
1374
+ how=how, steps=steps, re_offset=re_offset)
1375
+ self.out(f"Result of the comparison: {eval_result}")
1376
+ if not ret:
1377
+ return False
1378
+ else:
1379
+ peer_id = DataProps.peer_id_from_net_hash(received_net_hash)
1380
+ self._eval_results[peer_id] = eval_result
1381
+
1382
+ return True
1383
+
1384
+ async def compare_eval(self, cmp: str, thres: float, good_if_true: bool = True):
1385
+ """Compares the results of a previous evaluation to a given threshold or finds the best result among all
1386
+ agents. It can check for minimum, maximum, or simple threshold-based comparisons, and it populates a list of
1387
+ 'valid' agents that passed the comparison (async).
1388
+
1389
+ Args:
1390
+ cmp: The comparison operator (e.g., '<', '>', 'min').
1391
+ thres: The threshold value for comparison.
1392
+ good_if_true: A boolean to invert the pass/fail logic.
1393
+
1394
+ Returns:
1395
+ True if at least one agent passed the comparison, False otherwise.
1396
+ """
1397
+ assert cmp in ["<", ">", ">=", "<=", "min", "max"], f"Invalid comparison operator: {cmp}"
1398
+ assert thres >= 0. or cmp in ["min", "max"], f"Invalid evaluation threshold: {thres} (it must be in >= 0.)"
1399
+
1400
+ self._valid_cmp_agents = set()
1401
+ msgs = []
1402
+ best_so_far = -1
1403
+
1404
+ min_or_max = None
1405
+ leq_or_geq = None
1406
+ if cmp in ["min", "max"]:
1407
+ min_or_max = "minimum" if cmp == "min" else "maximum"
1408
+ leq_or_geq = "<=" if cmp == "min" else ">="
1409
+
1410
+ for agent_peer_id, eval_result in self._eval_results.items():
1411
+ if cmp not in ["min", "max"]:
1412
+ self.out(f"Checking if result {eval_result} {cmp} {thres}, for agent {agent_peer_id}")
1413
+ else:
1414
+ if thres >= 0:
1415
+ self.out(f"Checking if result {eval_result} is the {min_or_max} so far, "
1416
+ f"only if {leq_or_geq} {thres}, for agent {agent_peer_id}")
1417
+ else:
1418
+ self.out(f"Checking if result {eval_result} is the {min_or_max} so far, for agent {agent_peer_id}")
1419
+
1420
+ if eval_result < 0.:
1421
+ self.err(f"Invalid evaluation result: {eval_result}")
1422
+ return False
1423
+
1424
+ if cmp != "min" and cmp != "max":
1425
+ outcome = False
1426
+ if cmp == "<" and eval_result < thres:
1427
+ outcome = True
1428
+ elif cmp == "<=" and eval_result <= thres:
1429
+ outcome = True
1430
+ elif cmp == ">" and eval_result > thres:
1431
+ outcome = True
1432
+ elif cmp == ">=" and eval_result >= thres:
1433
+ outcome = True
1434
+
1435
+ if cmp[0] == "<" or cmp[0] == "<=":
1436
+ alias = 'error level' if good_if_true else 'mark'
1437
+ else:
1438
+ alias = 'mark' if good_if_true else 'error level'
1439
+
1440
+ if good_if_true:
1441
+ if outcome:
1442
+ msgs.append(f"Agent {agent_peer_id} passed with {alias} {eval_result}/{thres}")
1443
+ self._valid_cmp_agents.add(agent_peer_id)
1444
+ else:
1445
+ msgs.append(f"Agent {agent_peer_id} did not pass")
1446
+ else:
1447
+ if outcome:
1448
+ msgs.append(f"Agent {agent_peer_id} did not pass")
1449
+ else:
1450
+ msgs.append(f"Agent {agent_peer_id} passed with {alias} {eval_result}/{thres}")
1451
+ self._valid_cmp_agents.add(agent_peer_id)
1452
+
1453
+ if len(msgs) > 1:
1454
+ msgs[-1] = str(msgs[-1].lower())[0] + msgs[-1][1:]
1455
+ else:
1456
+ if ((cmp == "min" and (thres < 0 or eval_result <= thres) and
1457
+ (eval_result < best_so_far or best_so_far < 0)) or
1458
+ (cmp == "max" and (thres < 0 or eval_result >= thres) and
1459
+ (eval_result > best_so_far or best_so_far < 0))):
1460
+ best_so_far = eval_result
1461
+ self._valid_cmp_agents = {agent_peer_id}
1462
+ msgs = [f"The best agent is {agent_peer_id}"]
1463
+ else:
1464
+ msgs = [f"No best agent found for the considered threshold ({thres})"]
1465
+
1466
+ if len(self._valid_cmp_agents) == 0:
1467
+
1468
+ # # cheating (hack):
1469
+ # self._valid_cmp_agents.append(agent_peer_id)
1470
+ # self.out(", ".join(msgs))
1471
+ # return True
1472
+ self.err(f"The evaluation was not passed by any agents")
1473
+ return False
1474
+ else:
1475
+ self.out(", ".join(msgs))
1476
+ return True
1477
+
1478
+ def collect_and_store_own_stats(self):
1479
+ """Collects this agent's own stats and pushes them to the stats recorder."""
1480
+ if self.stats is None:
1481
+ return
1482
+
1483
+ _, own_private_pid = self.get_peer_ids()
1484
+ t = self._node_clock.get_time_ms()
1485
+ try:
1486
+ info = self._node_conn['p2p_world'].get_connected_peers_info()
1487
+ peers_list = [i['id'] for i in info]
1488
+ self.stats.store_stat('connected_peers', peers_list, own_private_pid, t)
1489
+ except Exception:
1490
+ self.stats.store_stat('connected_peers', [], own_private_pid, t)
1491
+
1492
+ try:
1493
+ behav = self.behav
1494
+ self.stats.store_stat('state', behav.get_state_name(), own_private_pid, t)
1495
+ self.stats.store_stat('action', behav.get_action_name(), own_private_pid, t)
1496
+ self.stats.store_stat('last_action', behav.get_last_completed_action_name(), own_private_pid, t)
1497
+ except Exception as e:
1498
+ self.err(f"[Stats] Error storing HSM stats: {e}")
1499
+
1500
+ async def send_stats_to_world(self):
1501
+ """Sends the agent's currently buffered stats to the world and clears them (async)."""
1502
+ if not self.in_world():
1503
+ self.deb("[send_stats_to_world] Not in a world, skipping stats send.")
1504
+ return
1505
+
1506
+ world_peer_id = self._node_conn.get_world_peer_id()
1507
+ if world_peer_id is None:
1508
+ self.err("[send_stats_to_world] In world, but world_peer_id is None.")
1509
+ return
1510
+
1511
+ self.collect_and_store_own_stats() # update own stats
1512
+ payload = self.stats.get_payload_for_world()
1513
+ if not payload:
1514
+ self.deb("[send_stats_to_world] No stats to send.")
1515
+ return
1516
+
1517
+ # Send all stats
1518
+ self.out(f"[AGENT] Sending stats update to world {world_peer_id}...")
1519
+ if not (await self._node_conn.send(world_peer_id,
1520
+ channel_trail=None,
1521
+ content=payload,
1522
+ content_type=Msg.STATS_UPDATE)):
1523
+ self.err("Failed to send stats update to world.")
1524
+
1525
+ # Ask the updates to the world (no overwrite required)
1526
+ self.out(f"[AGENT] Requesting stats update from world {world_peer_id}...")
1527
+ if not (await self._node_conn.send(world_peer_id,
1528
+ channel_trail=None,
1529
+ content={'time_range': self.stats._max_seen_timestamp},
1530
+ content_type=Msg.STATS_REQUEST)):
1531
+ self.err("Failed to request stats to world.")
1532
+
1533
+ def update_stats_view(self, received_view, overwrite: bool = False):
1534
+ """
1535
+ Updates the _world_view attribute of the Stats object.
1536
+ """
1537
+ self.stats.update_view(received_view, overwrite)
1538
+
1539
+ async def suggest_role_to_world(self, agent: str | None, role: str):
1540
+ """Suggests a role change for one or more agents to the world master. It iterates through the involved agents,
1541
+ checks if their current role differs from the suggested one, and sends a role suggestion message to the
1542
+ world master (async).
1543
+
1544
+ Args:
1545
+ agent: The ID of the agent or a wildcard to suggest the role for.
1546
+ role: The new role to suggest (as a string).
1547
+
1548
+ Returns:
1549
+ True if the suggestion was sent successfully, False otherwise.
1550
+ """
1551
+ self.out("Suggesting role to world")
1552
+
1553
+ agents = self.__involved_agents(agent)
1554
+ role_bits = (self.ROLE_STR_TO_BITS[role] >> 2) << 2
1555
+
1556
+ content = []
1557
+
1558
+ for _agent in agents:
1559
+ cur_role_bits = self.ROLE_STR_TO_BITS[self.all_agents[_agent].get_dynamic_profile()['connections']['role']]
1560
+ cur_role_bits = (cur_role_bits >> 2) << 2
1561
+ if cur_role_bits == role_bits:
1562
+ self.out(f"Not suggesting to change the role of {_agent} "
1563
+ f"since it has already such a role")
1564
+ else:
1565
+ self.out(f"Suggesting to change the role of {_agent} to {self.ROLE_BITS_TO_STR[role_bits]}")
1566
+ content.append({'peer_id': _agent, 'role': role_bits})
1567
+
1568
+ if len(content) > 0:
1569
+ world_peer_id = self._node_conn.get_world_peer_id()
1570
+ if not (await self._node_conn.send(world_peer_id, channel_trail=None,
1571
+ content=content,
1572
+ content_type=Msg.ROLE_SUGGESTION)):
1573
+ self.err("Failed to send role suggestion to the world")
1574
+ return False
1575
+ return True
1576
+
1577
+ async def suggest_badges_to_world(self, agent: str | None = None,
1578
+ score: float = -1.0, badge_type: str = "completed",
1579
+ badge_description: str | None = None):
1580
+ """Suggests one or more badges to the world master for specific agents. This is typically used to reward agents
1581
+ for completing tasks, such as for a competition. It sends a message with the badge details, including the score
1582
+ and type, to the world master (async).
1583
+
1584
+ Args:
1585
+ agent: The ID of the agent or a wildcard for which to suggest the badge.
1586
+ score: The score associated with the badge.
1587
+ badge_type: The type of badge (e.g., 'completed').
1588
+ badge_description: An optional description for the badge.
1589
+
1590
+ Returns:
1591
+ True if the badge suggestion was sent successfully, False otherwise.
1592
+ """
1593
+ self.out("Suggesting one or more badges to world")
1594
+
1595
+ if score < 0.:
1596
+ self.err("Invalid score (did you specify the 'score' argument? it must be positive)")
1597
+ return False
1598
+
1599
+ agents = self.__involved_agents(agent)
1600
+ world_peer_id = self._node_conn.get_world_peer_id()
1601
+
1602
+ if badge_type not in Agent.BADGE_TYPES:
1603
+ self.err(f"Unknown badge type: {badge_type}")
1604
+ return False
1605
+
1606
+ list_of_badge_dictionaries = []
1607
+ for peer_id in agents:
1608
+ list_of_badge_dictionaries.append({'peer_id': peer_id,
1609
+ 'score': score,
1610
+ 'badge_type': badge_type,
1611
+ 'badge_description': badge_description,
1612
+ 'agent_token': self._node_conn.get_last_token(peer_id)})
1613
+
1614
+ if not (await self._node_conn.send(world_peer_id, channel_trail=None,
1615
+ content=list_of_badge_dictionaries,
1616
+ content_type=Msg.BADGE_SUGGESTIONS)):
1617
+ self.err("Failed to send badge suggestions to the world")
1618
+ return False
1619
+ else:
1620
+ return True
1621
+
1622
+ async def __ask_gen_or_learn(self, for_what: str, agent: str,
1623
+ u_hashes: list[str] | None,
1624
+ yhat_hashes: list[str] | None,
1625
+ samples: int = 100, time: float = -1., timeout: float = -1.,
1626
+ ref_uuid: str | None = None):
1627
+ """A private helper method that encapsulates the logic for sending a 'do_gen' or 'do_learn' action request to
1628
+ another agent. It handles the normalization of stream hashes, sets up recipients for direct messages, and adds
1629
+ the target agent to the list of agents asked (async).
1630
+
1631
+ Args:
1632
+ for_what: A string indicating whether to ask for 'gen' or 'learn'.
1633
+ agent: The ID of the agent to send the request to.
1634
+ u_hashes: A list of input stream hashes.
1635
+ yhat_hashes: A list of target stream hashes (for learning).
1636
+ samples: The number of samples.
1637
+ time: The time duration.
1638
+ timeout: The request timeout.
1639
+ ref_uuid: The UUID for the request.
1640
+
1641
+ Returns:
1642
+ True if the request was sent successfully, False otherwise.
1643
+ """
1644
+ if agent not in self.all_agents:
1645
+ self.err(f"Unknown agent: {agent}")
1646
+ return False
1647
+
1648
+ assert for_what in ["gen", "learn"]
1649
+
1650
+ if for_what == "learn":
1651
+ for yhat_hash in yhat_hashes:
1652
+ yhat_stream_dict = self.known_streams[yhat_hash]
1653
+ for yhat_stream in yhat_stream_dict.values():
1654
+ if isinstance(yhat_stream, BufferedDataStream):
1655
+ y_text = yhat_stream.to_text_snippet(length=200)
1656
+ if y_text is not None and len(y_text) > 0:
1657
+ self.out("Asking to learn: \"" + y_text + "\"")
1658
+
1659
+ # Setting recipient in the case of direct messages
1660
+ # (differently, in case of pubsub, the agent is already sending messages to all)
1661
+ if u_hashes is not None:
1662
+ for u_hash in u_hashes:
1663
+ if not DataProps.is_pubsub_from_net_hash(u_hash):
1664
+ self.recipients[u_hash] = agent
1665
+ if yhat_hashes is not None:
1666
+ for yhat_hash in yhat_hashes:
1667
+ if not DataProps.is_pubsub_from_net_hash(yhat_hash):
1668
+ self.recipients[yhat_hash] = agent
1669
+
1670
+ # Triggering
1671
+ if for_what == "gen":
1672
+ if await self.set_next_action(agent, action="do_gen", args={"u_hashes": u_hashes,
1673
+ "samples": samples, "time": time,
1674
+ "timeout": timeout},
1675
+ ref_uuid=ref_uuid):
1676
+ self._agents_who_were_asked.add(agent)
1677
+ return True
1678
+ else:
1679
+ self.err(f"Unable to ask {agent} to generate")
1680
+ return False
1681
+ elif for_what == "learn":
1682
+ if await self.set_next_action(agent, action="do_learn", args={"u_hashes": u_hashes,
1683
+ "yhat_hashes": yhat_hashes,
1684
+ "samples": samples, "time": time,
1685
+ "timeout": timeout},
1686
+ ref_uuid=ref_uuid):
1687
+ self._agents_who_were_asked.add(agent)
1688
+ return True
1689
+ else:
1690
+ self.err(f"Unable to ask {agent} to learn to generate")
1691
+ return False
1692
+
1693
+ def __process_streams(self,
1694
+ u_hashes: list[str] | None,
1695
+ yhat_hashes: list[str] | None,
1696
+ learn: bool = False,
1697
+ recipient: str | None = None,
1698
+ ref_uuid: str | None = None):
1699
+ """A private helper method that contains the core logic for processing data streams, either for generation or
1700
+ learning. It reads input streams, passes them to the agent's processor, and handles the output streams.
1701
+ It's designed to be called repeatedly by multistep actions like `do_gen` and `do_learn`.
1702
+
1703
+ Args:
1704
+ u_hashes: A list of input stream hashes.
1705
+ yhat_hashes: A list of target stream hashes (for learning).
1706
+ learn: A boolean to indicate if the task is a learning task.
1707
+ recipient: The ID of the agent to send data back to.
1708
+ ref_uuid: The UUID for the request.
1709
+
1710
+ Returns:
1711
+ True if the stream processing is successful, False otherwise.
1712
+ """
1713
+
1714
+ # Getting current step index
1715
+ k = self.get_action_step()
1716
+
1717
+ # Checking data and creating new buffered streams
1718
+ if k == 0:
1719
+ self.deb("[__process_streams] First action step")
1720
+
1721
+ # Checking data
1722
+ if u_hashes is not None:
1723
+ for u_hash in u_hashes:
1724
+ if u_hash is not None and u_hash not in self.known_streams:
1725
+ self.err(f"Unknown stream (u_hash): {u_hash}")
1726
+ return False
1727
+ if yhat_hashes is not None:
1728
+ for yhat_hash in yhat_hashes:
1729
+ if yhat_hash is not None and yhat_hash not in self.known_streams:
1730
+ self.err(f"Unknown stream (yhat_hash): {yhat_hash}")
1731
+ return False
1732
+
1733
+ if self.is_last_action_step():
1734
+ self.deb("[__process_streams] Last action step detected")
1735
+
1736
+ self.deb(f"[__process_streams] Generating data, step {k}")
1737
+
1738
+ # Generate output
1739
+ outputs, data_tag_from_inputs = (
1740
+ self.generate(input_net_hashes=u_hashes, first=(k == 0), last=self.is_last_action_step(),
1741
+ ref_uuid=ref_uuid))
1742
+ if outputs is None:
1743
+ return False
1744
+ self.deb(f"[__process_streams] data_tag_from_inputs: {data_tag_from_inputs}")
1745
+ if data_tag_from_inputs is None:
1746
+ data_tag_from_inputs = -1
1747
+ self.deb(f"[__process_streams] data_tag_from_inputs (forced): {data_tag_from_inputs}")
1748
+
1749
+ # Learn
1750
+ if learn:
1751
+ self.deb(f"[__process_streams] learning, step {k}")
1752
+ loss_values, data_tags_from_targets = self.learn_generate(outputs=outputs, targets_net_hashes=yhat_hashes)
1753
+ self.deb(f"[__process_streams] data_tags_from_targets: {data_tags_from_targets}")
1754
+
1755
+ if loss_values is None:
1756
+ return False
1757
+ else:
1758
+ self.out(f"Losses: {loss_values}")
1759
+
1760
+ # Fusing data tags
1761
+ data_tags = [data_tag_from_inputs if _data_tag == -1 else _data_tag for _data_tag in data_tags_from_targets]
1762
+ else:
1763
+ data_tags = [data_tag_from_inputs] * len(outputs)
1764
+ self.deb(f"[__process_streams] data_tags (final): {data_tags}")
1765
+
1766
+ # Set each data sample in "outputs" to the right stream
1767
+ i = 0
1768
+ for net_hash, stream_dict in self.proc_streams.items():
1769
+
1770
+ # Setting the data sample
1771
+ for name, stream_obj in stream_dict.items():
1772
+
1773
+ # Public output streams are only considered if the agent IS NOT acting in a world
1774
+ # private output streams are only considered if the agent IS acting in a world
1775
+ if self.behaving_in_world() != stream_obj.props.is_public():
1776
+
1777
+ # Guessing recipient of the communication
1778
+ if i == 0:
1779
+ self.recipients[net_hash] = recipient \
1780
+ if not DataProps.is_pubsub_from_net_hash(net_hash) else None
1781
+
1782
+ self.deb(f"[__process_streams] Setting the {i}-th network output to stream with "
1783
+ f"net_hash: {net_hash}, name: {name}")
1784
+
1785
+ # Here we exploit the fact that streams were inserted in order
1786
+ try:
1787
+ stream_obj.set(stream_obj.props.check_and_postprocess(outputs[i]), data_tags[i])
1788
+ except Exception as e:
1789
+ self.err(f"Error while post-processing the processor output\nException: {e}")
1790
+ return False
1791
+
1792
+ if k == 0:
1793
+ stream_obj.set_uuid(ref_uuid, expected=False)
1794
+ stream_obj.set_uuid(None, expected=True)
1795
+ i += 1
1796
+
1797
+ return True
1798
+
1799
+ async def __complete_do(self, do_what: str, peer_id_who_asked: str, all_hashes: list[str] | None,
1800
+ send_back_confirmation: bool = True):
1801
+ """A private helper method to be called at the end of a `do_gen` or `do_learn` action. It performs cleanup
1802
+ tasks, such as clearing UUIDs on streams, and sends a confirmation message back to the requesting agent (async).
1803
+
1804
+ Args:
1805
+ do_what: A string ('gen' or 'learn') indicating which task was completed.
1806
+ peer_id_who_asked: The ID of the agent who requested the task.
1807
+ all_hashes: A list of all stream hashes involved in the task.
1808
+ send_back_confirmation: A boolean to indicate if a confirmation message should be sent.
1809
+
1810
+ Returns:
1811
+ True if the completion process is successful, False otherwise.
1812
+ """
1813
+ assert do_what in ["gen", "learn"]
1814
+
1815
+ if do_what == "gen":
1816
+ for net_hash, stream_dict in self.proc_streams.items():
1817
+ for stream in stream_dict.values():
1818
+ if isinstance(stream, BufferedDataStream):
1819
+ y_text = stream.to_text_snippet(length=200)
1820
+ if y_text is not None:
1821
+ self.out("Generated: \"" + y_text + "\"")
1822
+
1823
+ for stream_dict in self.proc_streams.values():
1824
+ for stream_obj in stream_dict.values():
1825
+ if stream_obj.props.is_public() != self.behaving_in_world():
1826
+ stream_obj.mark_uuid_as_clearable()
1827
+
1828
+ if all_hashes is not None:
1829
+ for net_hash in all_hashes:
1830
+ for stream_obj in self.known_streams[net_hash].values():
1831
+ stream_obj.set_uuid(None, expected=False)
1832
+ stream_obj.set_uuid(None, expected=True)
1833
+
1834
+ # Confirming
1835
+ if send_back_confirmation:
1836
+ if await self.set_next_action(peer_id_who_asked, action="done_" + do_what, args={}):
1837
+ return True
1838
+ else:
1839
+ self.err(f"Unable to confirm '{do_what}' to {peer_id_who_asked}")
1840
+ return False
1841
+ else:
1842
+ return True
1843
+
1844
+ def __compare_streams(self, net_hash_a: str, net_hash_b: str,
1845
+ how: str = "mse", steps: int = 100, re_offset: bool = False):
1846
+ """A private helper method that compares two buffered data streams based on a specified metric (e.g., MSE,
1847
+ max accuracy). It handles stream compatibility checks, data retrieval, and the actual comparison, returning a
1848
+ dissimilarity score.
1849
+
1850
+ Args:
1851
+ net_hash_a: The network hash of the first stream.
1852
+ net_hash_b: The network hash of the second stream.
1853
+ how: The comparison metric ('mse', 'max', 'geqX').
1854
+ steps: The number of samples to compare.
1855
+ re_offset: A boolean to re-align stream tags before comparison.
1856
+
1857
+ Returns:
1858
+ A tuple containing the dissimilarity score and a success flag (e.g., `(0.5, True)`).
1859
+ """
1860
+ if net_hash_a not in self.known_streams:
1861
+ self.err(f"Unknown stream (net_hash_a): {net_hash_a}")
1862
+ return -1., False
1863
+
1864
+ if net_hash_b not in self.known_streams:
1865
+ self.err(f"Unknown stream (net_hash_b): {net_hash_b}")
1866
+ return -1., False
1867
+
1868
+ if steps <= 0:
1869
+ self.err(f"Invalid number of steps: {steps}")
1870
+ return -1., False
1871
+
1872
+ if how not in ["mse", "max"] and not how.startswith("geq"):
1873
+ self.err(f"Data can be compared by MSE, or by comparing the argmax ('max'), or comparing the number "
1874
+ f"of corresponding bits (obtained by 'geqX', where 'X' is a number). Unknown: {how})")
1875
+ return -1., False
1876
+
1877
+ stream_dict_a = self.known_streams[net_hash_a]
1878
+ stream_dict_b = self.known_streams[net_hash_b]
1879
+
1880
+ if len(stream_dict_a) == 1 and len(stream_dict_b) == 1:
1881
+
1882
+ # If there is only 1 stream is each group, things are easy
1883
+ stream_a = next(iter(stream_dict_a.values()))
1884
+ stream_b = next(iter(stream_dict_b.values()))
1885
+ elif len(stream_dict_a) == 1 and len(stream_dict_b) > 1:
1886
+
1887
+ # If there is only 1 stream is one of the groups, we look for a compatible stream in the other group,
1888
+ # giving priority to streams with labels
1889
+ stream_a = next(iter(stream_dict_a.values()))
1890
+ stream_b = None
1891
+ for stream_obj in stream_dict_b.values():
1892
+ if (stream_a.get_props().has_tensor_labels() and stream_obj.get_props().has_tensor_labels() and
1893
+ stream_obj.get_props().is_compatible(stream_a.get_props())):
1894
+ stream_b = stream_obj
1895
+ break
1896
+ if stream_b is None:
1897
+ for stream_obj in stream_dict_b.values():
1898
+ if stream_obj.get_props().is_compatible(stream_a.get_props()):
1899
+ stream_b = stream_obj
1900
+ break
1901
+ elif len(stream_dict_a) > 1 and len(stream_dict_b) == 1:
1902
+
1903
+ # If there is only 1 stream is one of the groups, we look for a compatible stream in the other group,
1904
+ # giving priority to streams with labels
1905
+ stream_a = None
1906
+ stream_b = next(iter(stream_dict_b.values()))
1907
+ for stream_obj in stream_dict_a.values():
1908
+ if (stream_b.get_props().has_tensor_labels() and stream_obj.get_props().has_tensor_labels() and
1909
+ stream_obj.get_props().is_compatible(stream_b.get_props())):
1910
+ stream_a = stream_obj
1911
+ break
1912
+ if stream_a is None:
1913
+ for stream_obj in stream_dict_a.values():
1914
+ if stream_obj.get_props().is_compatible(stream_b.get_props()):
1915
+ stream_a = stream_obj
1916
+ break
1917
+ else:
1918
+
1919
+ # If both groups have more than a stream, let's give priority to streams with labels to find a match
1920
+ stream_a = None
1921
+ stream_b = None
1922
+ for stream_obj_a in stream_dict_a.values():
1923
+ if not stream_obj_a.get_props().has_tensor_labels():
1924
+ continue
1925
+ if stream_a is not None and stream_b is not None:
1926
+ break
1927
+ for stream_obj_b in stream_dict_b.values():
1928
+ if (stream_obj_b.get_props().has_tensor_labels() and
1929
+ stream_obj_a.get_props().is_compatible(stream_obj_b.get_props())):
1930
+ stream_a = stream_obj_a
1931
+ stream_b = stream_obj_b
1932
+ break
1933
+ if stream_a is None and stream_b is None:
1934
+ for stream_obj_a in stream_dict_a.values():
1935
+ if stream_a is not None and stream_b is not None:
1936
+ break
1937
+ for stream_obj_b in stream_dict_b.values():
1938
+ if stream_obj_a.get_props().is_compatible(stream_obj_b.get_props()):
1939
+ stream_a = stream_obj_a
1940
+ stream_b = stream_obj_b
1941
+ break
1942
+
1943
+ if stream_a is None:
1944
+ self.err(f"Cannot find the data stream to consider in the comparison, {net_hash_a}")
1945
+ return -1., False
1946
+ if stream_b is None:
1947
+ self.err(f"Cannot find the data stream to consider in the comparison, {net_hash_b}")
1948
+ return -1., False
1949
+
1950
+ if not isinstance(stream_a, BufferedDataStream):
1951
+ self.err(f"Can only compare buffered streams and {net_hash_a} is not buffered")
1952
+ return -1., False
1953
+
1954
+ if not isinstance(stream_b, BufferedDataStream):
1955
+ self.err(f"Can only compare buffered streams and {net_hash_b} is not buffered")
1956
+ return -1., False
1957
+
1958
+ if steps > len(stream_a) and steps > len(stream_b):
1959
+ self.err(f"Cannot compare streams for {steps} steps, since both of them are shorter "
1960
+ f"(length of the first stream is {len(stream_a)}, of the second stream is {len(stream_b)})")
1961
+
1962
+ if not stream_a.get_props().is_compatible(stream_b.get_props()):
1963
+ self.err(f"Cannot compare incompatible streams")
1964
+
1965
+ stream_a.restart()
1966
+ stream_b.restart()
1967
+
1968
+ def compare(_a: torch.Tensor | str, _b: torch.Tensor | str, _how: str = "mse") -> float:
1969
+ """Compare two samples of signals or descriptors, returning a dissimilarity score >= 0."""
1970
+
1971
+ assert how in ['mse', 'max', 'same'] or how.startswith("geq"), f"Invalid comparison in terms of {how}"
1972
+
1973
+ if isinstance(_a, torch.Tensor) and isinstance(_b, torch.Tensor):
1974
+ if _a.dtype == torch.long and _b.dtype == torch.long: # Token IDS
1975
+ return 1. - float((_a == _b).sum().item()) / a.numel() # Accuracy
1976
+ elif how == "mse":
1977
+ ret = torch.nn.functional.mse_loss(_a, _b, reduction='mean')
1978
+ elif how == "max":
1979
+ ret = 1. - float((torch.argmax(_a) == torch.argmax(_b)).sum().item()) / a.numel()
1980
+ elif how == "same":
1981
+ ret = 1. - float(torch.eq(_a, _b).sum()) / a.numel()
1982
+ else:
1983
+ thres = float(how[3:])
1984
+ ret = 1. - float(torch.sum((_a > thres) == (_b > thres)).item()) / a.numel()
1985
+ else:
1986
+ ret = 1. - float(_a == _b) # Strings (always handled as 'same')
1987
+ return ret
1988
+
1989
+ # Comparing data (averaging)
1990
+ o = 0.
1991
+ k_b = 0
1992
+ a_tag_offset = 0
1993
+ b_tag_offset = 0
1994
+ a_tag = None
1995
+ a_tag_prev = None
1996
+ for k_a in range(0, steps):
1997
+
1998
+ restart_detected = False
1999
+ if a_tag is not None:
2000
+ a_tag_prev = a_tag
2001
+
2002
+ # Signals or descriptors
2003
+ a, a_tag = stream_a[k_a]
2004
+ b, b_tag = stream_b[k_b]
2005
+
2006
+ # If the streams do not share the same first tag equal to zero, and we asked to re-offset them,
2007
+ # then we force the initial offsets to be zero on both
2008
+ # if not, then re-offset the tags
2009
+ if k_a == 0 and k_b == 0 and re_offset:
2010
+ a_tag_offset = a_tag
2011
+ b_tag_offset = b_tag
2012
+
2013
+ # Offset-based tags
2014
+ a_tag_w_offset = a_tag - a_tag_offset
2015
+ b_tag_w_offset = b_tag - b_tag_offset
2016
+
2017
+ # Checking
2018
+ if a is None:
2019
+ self.err("Cannot compare stream samples if the reference stream yields None")
2020
+ return -1., False
2021
+
2022
+ # Some streams might have been pre-buffered in advance, and have increasing data tags belonging to finite,
2023
+ # fixed set (such as 0, 1, 2, ..., N). when continuously streaming them, we will go from tag N to tag 0 at
2024
+ # a certain point, which is a "restart".
2025
+ # We have to remember that this happened, and we do it for stream "a", our "reference" stream.
2026
+ # Then, below, we will fix tags on stream "b" if needed, considering that such a restart happened.
2027
+ if a_tag_prev is not None and a_tag < a_tag_prev:
2028
+ restart_detected = True
2029
+
2030
+ # Some streams might have been pre-buffered in advance, and have a fixed data tag (usually -1).
2031
+ # Being it negative, it will happen that the data tag will be replaced by a clock cycle, but this function
2032
+ # does not change clock cycles at all, so all samples will have the exact same data tag.
2033
+ # The following code automatically advances the tag by 1 for stream "a", that is expected to be the
2034
+ # reference stream (i.e., the one for which the agent has all samples, with no missing data in between)
2035
+ if a_tag_prev is not None and a_tag <= a_tag_prev:
2036
+ a_tag = a_tag_prev + 1 # Fixed tag detected (patching)
2037
+ a_tag_w_offset = a_tag - a_tag_offset
2038
+
2039
+ # Fixing
2040
+ if b is None:
2041
+ o = o + (1. if how != "mse" else (o / steps) * 1.1)
2042
+ self.deb(f"[__compare_streams] The second stream yields None")
2043
+ else:
2044
+ if b_tag_w_offset == a_tag_w_offset:
2045
+ o += compare(a, b, how)
2046
+ k_b += 1
2047
+ self.deb(f"[__compare_streams] Comparing tags: {a_tag} vs {b_tag} "
2048
+ f"(with offsets: {a_tag_w_offset} vs {b_tag_w_offset}), samples: {a} vs {b}")
2049
+ elif b_tag_w_offset > a_tag_w_offset:
2050
+ if not restart_detected:
2051
+ o = o + (1. if how != "mse" else (o / steps) * 1.1) # Don't change k_b, some samples missing
2052
+ self.deb(f"[__compare_streams] (b) Comparing tags: {a_tag} vs {b_tag} -> "
2053
+ f"expected one was missing "
2054
+ f"(with offsets: {a_tag_w_offset} vs {b_tag_w_offset}) "
2055
+ f"samples: {a} vs {b}")
2056
+ else:
2057
+ o = o + (1. if how != "mse" else (o / steps) * 1.1)
2058
+ self.deb(f"[__compare_streams] (c) Comparing tags: {a_tag} vs {b_tag} -> "
2059
+ f"expected one was missing "
2060
+ f"(with offsets: {a_tag_w_offset} vs {b_tag_w_offset}) "
2061
+ f"samples: {a} vs {b}")
2062
+ k_b += 1 # A restart was detected, it means that "stream_b" is behind, let's move it ahead
2063
+ elif b_tag_w_offset < a_tag_w_offset:
2064
+ self.deb(f"[__compare_streams] (d) Comparing tags: {a_tag} vs {b_tag} -> too early w.r.t. expected "
2065
+ f"(with offsets: {a_tag_w_offset} vs {b_tag_w_offset}) "
2066
+ f"samples: {a} vs {b}")
2067
+ return -1., False
2068
+
2069
+ self.deb(f"[__compare_streams] Error: {o / steps}")
2070
+
2071
+ # Input("*** press enter to continue ***")
2072
+ return o / steps, True
2073
+
2074
+ def __involved_agents(self, agent: str | None):
2075
+ """A private helper method that resolves an agent ID or a wildcard into a list of specific peer IDs.
2076
+ It can resolve a single agent, a group of agents that passed a previous comparison (`<valid_cmp>`), or all
2077
+ currently engaged agents.
2078
+
2079
+ Args:
2080
+ agent: The agent ID or wildcard string.
2081
+
2082
+ Returns:
2083
+ A list of peer IDs corresponding to the involved agents.
2084
+ """
2085
+ peer_id = agent
2086
+ involved_agents = [peer_id] if peer_id is not None and peer_id != "<valid_cmp>" else (
2087
+ self._valid_cmp_agents) if peer_id is not None and peer_id == "<valid_cmp>" else self._engaged_agents
2088
+ if len(involved_agents) == 0:
2089
+ self.err("Not engaged to any agents or no agent specified")
2090
+ return involved_agents