unaiverse 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. unaiverse/__init__.py +19 -0
  2. unaiverse/agent.py +2226 -0
  3. unaiverse/agent_basics.py +2389 -0
  4. unaiverse/clock.py +234 -0
  5. unaiverse/dataprops.py +1282 -0
  6. unaiverse/hsm.py +2471 -0
  7. unaiverse/modules/__init__.py +18 -0
  8. unaiverse/modules/cnu/__init__.py +17 -0
  9. unaiverse/modules/cnu/cnus.py +536 -0
  10. unaiverse/modules/cnu/layers.py +261 -0
  11. unaiverse/modules/cnu/psi.py +60 -0
  12. unaiverse/modules/hl/__init__.py +15 -0
  13. unaiverse/modules/hl/hl_utils.py +411 -0
  14. unaiverse/modules/networks.py +1509 -0
  15. unaiverse/modules/utils.py +748 -0
  16. unaiverse/networking/__init__.py +16 -0
  17. unaiverse/networking/node/__init__.py +18 -0
  18. unaiverse/networking/node/connpool.py +1332 -0
  19. unaiverse/networking/node/node.py +2752 -0
  20. unaiverse/networking/node/profile.py +446 -0
  21. unaiverse/networking/node/tokens.py +79 -0
  22. unaiverse/networking/p2p/__init__.py +188 -0
  23. unaiverse/networking/p2p/go.mod +127 -0
  24. unaiverse/networking/p2p/go.sum +548 -0
  25. unaiverse/networking/p2p/golibp2p.py +18 -0
  26. unaiverse/networking/p2p/golibp2p.pyi +136 -0
  27. unaiverse/networking/p2p/lib.go +2765 -0
  28. unaiverse/networking/p2p/lib_types.py +311 -0
  29. unaiverse/networking/p2p/message_pb2.py +50 -0
  30. unaiverse/networking/p2p/messages.py +360 -0
  31. unaiverse/networking/p2p/mylogger.py +78 -0
  32. unaiverse/networking/p2p/p2p.py +900 -0
  33. unaiverse/networking/p2p/proto-go/message.pb.go +846 -0
  34. unaiverse/stats.py +1506 -0
  35. unaiverse/streamlib/__init__.py +15 -0
  36. unaiverse/streamlib/streamlib.py +210 -0
  37. unaiverse/streams.py +804 -0
  38. unaiverse/utils/__init__.py +16 -0
  39. unaiverse/utils/lone_wolf.json +28 -0
  40. unaiverse/utils/misc.py +441 -0
  41. unaiverse/utils/sandbox.py +292 -0
  42. unaiverse/world.py +384 -0
  43. unaiverse-0.1.12.dist-info/METADATA +366 -0
  44. unaiverse-0.1.12.dist-info/RECORD +47 -0
  45. unaiverse-0.1.12.dist-info/WHEEL +5 -0
  46. unaiverse-0.1.12.dist-info/licenses/LICENSE +177 -0
  47. unaiverse-0.1.12.dist-info/top_level.txt +1 -0
unaiverse/agent.py ADDED
@@ -0,0 +1,2226 @@
1
+ """
2
+ █████ █████ ██████ █████ █████ █████ █████ ██████████ ███████████ █████████ ██████████
3
+ ░░███ ░░███ ░░██████ ░░███ ░░███ ░░███ ░░███ ░░███░░░░░█░░███░░░░░███ ███░░░░░███░░███░░░░░█
4
+ ░███ ░███ ░███░███ ░███ ██████ ░███ ░███ ░███ ░███ █ ░ ░███ ░███ ░███ ░░░ ░███ █ ░
5
+ ░███ ░███ ░███░░███░███ ░░░░░███ ░███ ░███ ░███ ░██████ ░██████████ ░░█████████ ░██████
6
+ ░███ ░███ ░███ ░░██████ ███████ ░███ ░░███ ███ ░███░░█ ░███░░░░░███ ░░░░░░░░███ ░███░░█
7
+ ░███ ░███ ░███ ░░█████ ███░░███ ░███ ░░░█████░ ░███ ░ █ ░███ ░███ ███ ░███ ░███ ░ █
8
+ ░░████████ █████ ░░█████░░████████ █████ ░░███ ██████████ █████ █████░░█████████ ██████████
9
+ ░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░ ░░░░░ ░░░ ░░░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░░ ░░░░░░░░░░
10
+ A Collectionless AI Project (https://collectionless.ai)
11
+ Registration/Login: https://unaiverse.io
12
+ Code Repositories: https://github.com/collectionlessai/
13
+ Main Developers: Stefano Melacci (Project Leader), Christian Di Maio, Tommaso Guidi
14
+ """
15
+ import copy
16
+ import json
17
+ import torch
18
+ from unaiverse.stats import Stats
19
+ from unaiverse.dataprops import DataProps
20
+ from unaiverse.agent_basics import AgentBasics
21
+ from unaiverse.streams import BufferedDataStream
22
+ from unaiverse.networking.p2p.messages import Msg
23
+
24
+
25
+ class Agent(AgentBasics):
26
+ """This class contains those basic actions that can be performed by every agent."""
27
+
28
+ def __init__(self, *args, **kwargs):
29
+ super().__init__(*args, **kwargs)
30
+
31
+ # Status variables (assumed to start with "_"): Agent exchanges
32
+ self._available = True # It will be automatically set/changed during the agent's life
33
+ self._found_agents = set() # Peer IDs discovered
34
+ self._valid_cmp_agents = set() # Agents for which the last evaluation was positive
35
+ self._engaged_agents = set()
36
+ self._agents_who_completed_what_they_were_asked = set()
37
+ self._agents_who_were_asked = set()
38
+ self._agents_who_received_set_next_action = set()
39
+ self._eval_results = {}
40
+
41
+ # Status variables (assumed to start with "_"): Recordings
42
+ self._last_recorded_stream_num = 1
43
+ self._last_recorded_stream_dict = None
44
+ self._last_recording_stream_dict = None
45
+
46
+ # Status variables (assumed to start with "_"): Playlist
47
+ self._preferred_streams = [] # List of preferred streams
48
+ self._cur_preferred_stream = 0 # ID of the current preferred stream from the list
49
+ self._repeat = 1 # Number of repetitions of the playlist
50
+
51
+ # Stats
52
+ self.stats = Stats(is_world=False)
53
+ self.overwrite_stats = False # Whether to overwrite stats when receiving the next STATS_RESPONSE from the world
54
+
55
+ def remove_peer_from_agent_status_attrs(self, peer_id):
56
+ super().remove_peer_from_agent_status_attrs(peer_id)
57
+ self._available = len(self._engaged_agents) == 0
58
+
59
+ def reset_agent_status_attrs(self):
60
+ super().reset_agent_status_attrs() # this sets status vars to [], {}, 0, 0., False, in function of their type
61
+ self._available = True
62
+ self._repeat = 1
63
+ self._last_recorded_stream_num = 1
64
+
65
+ async def set_next_action(self, agent: str | None, action: str, args: dict | None = None,
66
+ from_state: str | None = None, to_state: str | None = None, ref_uuid: str | None = None):
67
+ """Try to tell another agent what is the next action it should run (async).
68
+
69
+ Args:
70
+ agent: The ID of the agent to send the action to or a valid wildcard like "<valid_cmp>" for a set of agents
71
+ (if None the agents in self._engaged_agents will be considered).
72
+ action: The name of the action to be executed by the agent.
73
+ args: A dictionary of arguments for the action. Defaults to None.
74
+ from_state: The optional starting state from which the action should be executed.
75
+ to_state: The optional destination state where the action should lead if correctly executed.
76
+ ref_uuid: An optional UUID for referencing the action. Defaults to None.
77
+
78
+ Returns:
79
+ True if the action was successfully sent to the target agent or to at least one of the
80
+ involved agents (wildcard case).
81
+ """
82
+
83
+ # - if "agent" is a peer ID, the involved agents will be a list with one element.
84
+ # - if "agent" is a known wildcard, as "<valid_cmp>", then involved agents will be self._valid_cmp_agents
85
+ # - if "agent" is None, then the current agent in self._engaged_agents will be returned
86
+ involved_agents = self.__involved_agents(agent)
87
+ if len(involved_agents) == 0:
88
+ return False
89
+
90
+ at_least_one_completed = False
91
+ _, private_peer_id = self.get_peer_ids()
92
+ self._agents_who_received_set_next_action = set()
93
+ for _peer_id in involved_agents:
94
+ ret = await self._node_conn.send(_peer_id, channel_trail=None,
95
+ content={"action_name": action, "args": args,
96
+ "from_state": from_state, "to_state": to_state, "uuid": ref_uuid},
97
+ content_type=Msg.ACTION_REQUEST)
98
+ at_least_one_completed = at_least_one_completed or ret
99
+ self.deb(f"[set_next_action] {self._node_name} sent action: {action}, with args: {args}, "
100
+ f"and result of sending is {ret}")
101
+
102
+ if ret:
103
+ self._agents_who_received_set_next_action.add(_peer_id)
104
+ return at_least_one_completed
105
+
106
+ async def set_engaged_partner(self, agent: str | list[str] | set[str] | None, clear_found: bool = True):
107
+ """Virtually forces the engagement with a single agent (or a group of agents), clearing all existing
108
+ engagements and results of previous find operations (async).
109
+
110
+ Args:
111
+ agent: The agent or the list of agents to engage (it None, we only clear the list of engaged partners).
112
+ clear_found: Clears results of previously run find operations (True by default).
113
+
114
+ Returns:
115
+ True all the times.
116
+ """
117
+ if clear_found:
118
+ self._found_agents.clear()
119
+ self._engaged_agents.clear()
120
+ if agent is not None:
121
+ if isinstance(agent, str):
122
+ agent = [agent]
123
+ for a in agent:
124
+ self._engaged_agents.add(a)
125
+ self._available = len(self._engaged_agents) == 0
126
+ return True
127
+
128
+ async def send_engagement(self):
129
+ """Offer engagement to the agents whose identifiers are in self._found_agents (async).
130
+
131
+ Returns:
132
+ True if engagement requests were successfully sent to at least one found agent, False otherwise.
133
+ """
134
+ at_least_one_sent = False
135
+
136
+ if len(self._found_agents) > 0:
137
+ self.out(f"Sending engagement request to {', '.join([x for x in self._found_agents])}")
138
+ my_role_str = self._node_profile.get_dynamic_profile()['connections']['role']
139
+ for found_agent in self._found_agents: # The list of found agents will be cleared after this function
140
+ if await self.set_next_action(found_agent, action="get_engagement",
141
+ args={"sender_role": my_role_str}):
142
+ at_least_one_sent = True
143
+ else:
144
+ self.err(f"Unable to send engagement to {found_agent}")
145
+ return at_least_one_sent
146
+
147
+ async def get_engagement(self, acceptable_role: str | None = None, sender_role: str | None = None,
148
+ _requester: str | None = None):
149
+ """Receive engagement from another agent whose authority is in the specified range (async).
150
+
151
+ Args:
152
+ acceptable_role: The role that the sender must have for engagement to be accepted. Defaults to None.
153
+ sender_role: The role of the agent sending the engagement request. Defaults to None.
154
+ _requester: The ID of the agent requesting engagement (automatically set by the action calling routine)
155
+
156
+ Returns:
157
+ True if the engagement was successfully received and confirmed, False otherwise.
158
+ """
159
+ self.out(f"Getting engagement from {_requester}, whose role is {sender_role} (looking for {acceptable_role})")
160
+ if _requester not in self.world_agents and _requester not in self.world_masters:
161
+ self.err(f"Unknown agent: {_requester}")
162
+ return False
163
+
164
+ if sender_role is None:
165
+ self.err(f"Unknown role of {_requester}")
166
+ return False
167
+
168
+ # Confirming
169
+ if self._available:
170
+ acceptable_role_int = self.ROLE_STR_TO_BITS[acceptable_role]
171
+ if "~" not in acceptable_role:
172
+ sender_role_int = (self.ROLE_STR_TO_BITS[sender_role] >> 2) << 2
173
+ else:
174
+ sender_role_int = self.ROLE_STR_TO_BITS[sender_role]
175
+
176
+ if acceptable_role_int == sender_role_int:
177
+ if await self.set_next_action(_requester, "got_engagement"):
178
+ self._engaged_agents.add(_requester)
179
+
180
+ # Marking this agent as not available since it engaged with another one
181
+ self._available = False
182
+ return True
183
+ else:
184
+ self.err(f"Unable to confirm engagement to {_requester}")
185
+ return False
186
+ else:
187
+ self.err(f"Cannot engage to {_requester}")
188
+ return False
189
+ else:
190
+ self.err(f"Cannot engage to {_requester}")
191
+ return False
192
+
193
+ async def got_engagement(self, _requester: str | None = None):
194
+ """Confirm an engagement (async).
195
+
196
+ Args:
197
+ _requester: The ID of the agent confirming the engagement (automatically set by the action calling routine).
198
+
199
+ Returns:
200
+ True if the engagement was successfully confirmed, False otherwise.
201
+ """
202
+ self.out(f"Confirming engagement with {_requester}")
203
+ if _requester in self._found_agents:
204
+ self._engaged_agents.add(_requester)
205
+
206
+ # Marking this agent as not available since it engaged with another one
207
+ self._available = False
208
+
209
+ # Removing the agent from the list of asked agents
210
+ self._found_agents.discard(_requester)
211
+ return True
212
+ else:
213
+ self.err(f"Unable to confirm engagement with {_requester}")
214
+ return False
215
+
216
+ async def send_disengagement(self, send_disconnection_too: bool = False):
217
+ """Ask for disengagement (async).
218
+
219
+ Args:
220
+ send_disconnection_too: Whether to send a disconnect-suggestion together with the disengagement.
221
+
222
+ Returns:
223
+ True if disengagement requests were successfully sent to at least one engaged agent, False otherwise.
224
+ """
225
+ at_least_one_sent = False
226
+
227
+ if len(self._engaged_agents) > 0:
228
+ self.out(f"Sending disengagement request to {', '.join([x for x in self._engaged_agents])}")
229
+ for agent in self._engaged_agents:
230
+ if await self.set_next_action(agent, action="get_disengagement",
231
+ args={"disconnect_too": send_disconnection_too}):
232
+ at_least_one_sent = True
233
+ else:
234
+ self.err(f"Unable to send disengagement to {agent}")
235
+
236
+ if at_least_one_sent:
237
+ self._engaged_agents.clear() # There is no "got_disengagement"
238
+ return at_least_one_sent
239
+
240
+ async def get_disengagement(self, disconnect_too: bool = False, _requester: str | None = None):
241
+ """Get a disengagement request from an agent (async).
242
+
243
+ Args:
244
+ disconnect_too: Whether to disconnect the agent who sent the disengagement.
245
+ _requester: The ID of the agent requesting disengagement. Defaults to None.
246
+
247
+ Returns:
248
+ True if the disengagement request was successfully processed, False otherwise.
249
+ """
250
+ self.out(f"Getting a disengagement request from {_requester}")
251
+ if _requester not in self.world_agents and _requester not in self.world_masters:
252
+ self.err(f"Unknown agent: {_requester}")
253
+ return False
254
+
255
+ if _requester not in self._engaged_agents:
256
+ self.err(f"Not previously engaged to {_requester}")
257
+ return False
258
+
259
+ self._engaged_agents.discard(_requester) # Remove if present
260
+
261
+ if disconnect_too:
262
+ await self._node_purge_fcn(_requester)
263
+
264
+ # Marking this agent as available if not engaged to any agent
265
+ self._available = len(self._engaged_agents) == 0
266
+ return True
267
+
268
+ async def disengage_all(self):
269
+ """Disengage all the previously engaged agents (async).
270
+
271
+ Returns:
272
+ True if the disengagement procedure was successfully executed, False otherwise.
273
+ """
274
+ self.out(f"Disengaging all agents")
275
+ self._engaged_agents = set()
276
+
277
+ # Marking this agent as available
278
+ self._available = True
279
+ return True
280
+
281
+ async def disconnect(self, agent: str):
282
+ """Disconnects an agent (async).
283
+
284
+ Args:
285
+ agent: The peer ID of the agent to disconnect.
286
+
287
+ Returns:
288
+ Always True.
289
+ """
290
+ self.out(f"Disconnecting agent: {agent}")
291
+ await self._node_purge_fcn(agent) # This will also call remove_agent, that will call remove_streams
292
+ return True
293
+
294
+ async def disconnect_by_role(self, role: str | list[str], disengage_too: bool = False):
295
+ """Disconnects from all agents that match a specified role (async).
296
+ It finds the agents and calls the node's purge function on each.
297
+
298
+ Args:
299
+ role: A string or list of strings representing the role(s) of agents to disconnect from.
300
+ disengage_too: Also forces the sending of a disengagement (before disconnecting - default False).
301
+
302
+ Returns:
303
+ Always True.
304
+ """
305
+ self.out(f"Disconnecting agents with role: {role}")
306
+ if disengage_too:
307
+ await self.send_disengagement(send_disconnection_too=True)
308
+ if await self.find_agents(role):
309
+ found_agents = copy.deepcopy(self._found_agents)
310
+ for agent in found_agents:
311
+ await self._node_purge_fcn(agent) # This will also call remove_agent, that will call remove_streams
312
+ return True
313
+
314
+ async def disconnected(self, agent: str | None = None, handshake_completed: bool = False, delay: float = -1.):
315
+ """Checks if a specific set of agents (by ID or wildcard) are no longer connected to the agent.
316
+ It returns False if any of the specified agents are still connected (async).
317
+
318
+ Args:
319
+ agent: The ID of the agent or a wildcard to check.
320
+ handshake_completed: If True, only consider agents that have completed the handshake.
321
+ delay: The time (seconds) to be spent in the current state before actually considering this action.
322
+
323
+ Returns:
324
+ True if all involved agents are disconnected, False otherwise.
325
+
326
+ """
327
+ assert delay is not None, "Missing basic action information"
328
+
329
+ # - if "agent" is a peer ID, the involved agents will be a list with one element.
330
+ # - if "agent" is a known wildcard, as "<valid_cmp>", then involved agents will be self._valid_cmp_agents
331
+ # - if "agent" is None, then the current agent in self._engaged_agents will be returned
332
+ involved_agents = self.__involved_agents(agent)
333
+ if len(involved_agents) == 0:
334
+ return False
335
+
336
+ self.out(f"Checking if all these agents are not connected to me anymore: {involved_agents}")
337
+ all_disconnected = True
338
+ for agent in involved_agents:
339
+ if handshake_completed:
340
+ if agent in self.all_agents:
341
+ all_disconnected = False
342
+ break
343
+ else:
344
+ if agent in self.all_agents or self._node_conn.is_connected(agent):
345
+ all_disconnected = False
346
+ break
347
+ return all_disconnected
348
+
349
+ async def received_some_asked_data(self, processing_fcn: str | None = None):
350
+ """Checks if any of the agents that were previously asked for data (e.g., via `ask_gen`) have sent a stream
351
+ sample back. Optionally, it can process the received data with a specified function (async).
352
+
353
+ Args:
354
+ processing_fcn: The name of a function to process the received data.
355
+
356
+ Returns:
357
+ True if at least one data sample was received, False otherwise.
358
+ """
359
+ _processing_fcn = None
360
+ if processing_fcn is not None:
361
+ if hasattr(self, processing_fcn):
362
+ _processing_fcn = getattr(self, processing_fcn)
363
+ if not callable(_processing_fcn):
364
+ _processing_fcn = None
365
+ if _processing_fcn is None:
366
+ self.err(f"Processing function not found: {processing_fcn}")
367
+
368
+ got_something = False
369
+ for agent in self._agents_who_were_asked:
370
+ net_hash_to_stream_dict = self.find_streams(agent, "processor")
371
+ for stream_dict in net_hash_to_stream_dict.values():
372
+ for stream_obj in stream_dict.values():
373
+ if not stream_obj.props.is_public():
374
+ data = stream_obj.get("received_some_asked_data")
375
+ data_tag = stream_obj.get_tag()
376
+
377
+ if data is not None:
378
+ if _processing_fcn is None:
379
+ return True
380
+ else:
381
+ got_something = True
382
+ _processing_fcn(agent, stream_obj.props, data, data_tag)
383
+ return got_something
384
+
385
+ async def nop(self, message: str | None = None, delay: float = -1.):
386
+ """Do nothing (async).
387
+
388
+ Args:
389
+ message: An optional message to print. Defaults to None.
390
+ delay: The time (seconds) to be spent in the current state before actually considering this action.
391
+
392
+ Returns:
393
+ Always True.
394
+ """
395
+ assert delay is not None, "Missing basic action information"
396
+ if message is not None:
397
+ self.out(message)
398
+ return True
399
+
400
+ async def wait_for_actions(self, agent: str, from_state: str, to_state: str, wait: bool):
401
+ """Lock or unlock every action between a pair of states in the state machine of a target agent (async).
402
+
403
+ Args:
404
+ agent: The ID of the agent to send the action locking request to, or a valid wildcard like "<valid_cmp>"
405
+ for a set of agents (if None the agents in self._engaged_agents will be considered).
406
+ from_state: The starting state of the actions to be locked/unlocked.
407
+ to_state: The ending state of the actions to be locked/unlocked.
408
+ wait: A boolean indicating whether to wait for the actions to complete (wait == !ready).
409
+
410
+ Returns:
411
+ True if the request was successfully sent to at least one involved agent, False otherwise.
412
+ """
413
+
414
+ # - if "agent" is a peer ID, the involved agents will be a list with one element.
415
+ # - if "agent" is a known wildcard, as "<valid_cmp>", then involved agents will be self._valid_cmp_agents
416
+ # - if "agent" is None, then the current agent in self._engaged_agents will be returned
417
+ involved_agents = self.__involved_agents(agent)
418
+ if len(involved_agents) == 0:
419
+ return False
420
+
421
+ at_least_one_completed = False
422
+ for _agent in involved_agents:
423
+ self.out(f"Telling {_agent} to alter his HSM {from_state} -> {to_state} (wait: {wait}) "
424
+ f"by calling method 'wait_for_actions' on it")
425
+ ret = await self._node_conn.send(_agent, channel_trail=None,
426
+ content={'method': 'wait_for_actions',
427
+ 'args': (from_state, to_state, wait)},
428
+ content_type=Msg.HSM)
429
+ at_least_one_completed = at_least_one_completed or ret
430
+ return at_least_one_completed
431
+
432
+ async def ask_gen(self, agent: str | None = None, u_hashes: list[str] | None = None,
433
+ samples: int = 100, from_state: str | None = None, to_state: str | None = None,
434
+ time: float = -1., timeout: float = -1., ask_uuid: str | None = None,
435
+ ignore_uuid: bool = False):
436
+ """Asking for generation.
437
+
438
+ Args:
439
+ agent: The ID of the agent to ask for generation, or a valid wildcard like "<valid_cmp>"
440
+ for a set of agents (if None the agents in self._engaged_agents will be considered).
441
+ u_hashes: A list of input stream hashes for generation. Defaults to None.
442
+ samples: The number of samples to generate. Defaults to 100.
443
+ from_state: The optional starting state from which the generation should be executed.
444
+ to_state: The optional destination state where the generation should lead if correctly executed.
445
+ time: The time duration for generation. Defaults to -1.
446
+ timeout: The timeout for the generation request. Defaults to -1.
447
+ ask_uuid: Specify the UUID of the action (if None - default -, it is randomly generated).
448
+ ignore_uuid: Force a None UUID instead of generating a random one.
449
+
450
+ Returns:
451
+ True if the generation request was successfully sent to at least one involved agent, False otherwise.
452
+ """
453
+ assert samples is not None and time is not None and timeout is not None, "Missing basic action information"
454
+
455
+ # - if "agent" is a peer ID, the involved agents will be a list with one element.
456
+ # - if "agent" is a known wildcard, as "<valid_cmp>", then involved agents will be self._valid_cmp_agents
457
+ # - if "agent" is None, then the current agent in self._engaged_agents will be returned
458
+ involved_agents = self.__involved_agents(agent)
459
+ self.deb(f"[ask_gen] Involved_agents: {involved_agents}")
460
+
461
+ if len(involved_agents) == 0:
462
+ self.deb(f"[ask_gen] No involved agents, action ask_gen returns False")
463
+ return False
464
+
465
+ # Create a copy of the input hashes, normalizing them in the appropriate way
466
+ u_hashes_copy = self.__normalize_user_hash(u_hashes)
467
+
468
+ # Generate a new UUID for this request
469
+ ref_uuid = AgentBasics.generate_uuid() if ask_uuid is None else ask_uuid
470
+ if ignore_uuid:
471
+ ref_uuid = None
472
+
473
+ # If the input streams are all owned by this agent, discard UUID
474
+ all_owned = True
475
+ for i in range(len(u_hashes_copy)):
476
+ if u_hashes_copy[i] not in self.owned_streams:
477
+ all_owned = False
478
+ break
479
+ if not all_owned:
480
+ ref_uuid = None
481
+
482
+ for i in range(len(u_hashes_copy)):
483
+
484
+ # If there are our own streams involved, and they are buffered, let's plan to restart them when we will
485
+ # start sending them through the net: moreover, let's set the local stream UUID appropriately to
486
+ # the generated UUID
487
+ if u_hashes_copy[i] in self.owned_streams:
488
+ stream_dict = self.known_streams[u_hashes_copy[i]]
489
+ for stream_name, stream_obj in stream_dict.items():
490
+
491
+ # Plan to restart buffered streams
492
+ if isinstance(stream_obj, BufferedDataStream):
493
+ stream_obj.plan_restart_before_next_get(requested_by="send_stream_samples")
494
+
495
+ # Activate the stream (if it was off)
496
+ stream_obj.enable()
497
+
498
+ # Set UUID to the generated one
499
+ stream_obj.set_uuid(ref_uuid=ref_uuid, expected=False)
500
+ stream_obj.set_uuid(ref_uuid=None, expected=True)
501
+
502
+ self.deb(f"[ask_gen] Input streams u_hashes: {u_hashes_copy}")
503
+
504
+ self.out(f"Asking {', '.join(involved_agents)} to generate signal given {u_hashes_copy} (ref_uuid: {ref_uuid})")
505
+ self._agents_who_completed_what_they_were_asked = set()
506
+ self._agents_who_were_asked = set()
507
+ correctly_asked = []
508
+ for peer_id in involved_agents:
509
+ ret = await self.__ask_gen_or_learn(for_what="gen", agent=peer_id,
510
+ u_hashes=u_hashes_copy,
511
+ yhat_hashes=None,
512
+ from_state=from_state, to_state=to_state,
513
+ samples=samples, time=time, timeout=timeout, ref_uuid=ref_uuid)
514
+ self.deb(f"[ask_gen] Asking {peer_id} returned {ret}")
515
+ if ret:
516
+ correctly_asked.append(peer_id)
517
+
518
+ # Preparing the buffered stream where to store data, if needed
519
+ if len(correctly_asked) > 0:
520
+
521
+ # Saving
522
+ self.last_ref_uuid = ref_uuid
523
+
524
+ # For each agent that we involve in this request....
525
+ for peer_id in correctly_asked:
526
+
527
+ # Finding the streams generated by the processor of the agent we asked to generate
528
+ processor_streams = self.find_streams(peer_id, name_or_group="processor")
529
+
530
+ # For each stream generated by the processor of the agent we asked to generate...
531
+ for net_hash, stream_dict in processor_streams.items():
532
+
533
+ # Set the appropriate UUID to the one we created in this method
534
+ for stream in stream_dict.values():
535
+ stream.set_uuid(None, expected=False)
536
+ stream.set_uuid(ref_uuid, expected=True) # Setting the "expected" one
537
+
538
+ # There will be no callbacks in the case of 1 sample, so mark the streams to clear UUID when
539
+ # getting such single sample
540
+ if samples == 1:
541
+ stream.mark_uuid_as_clearable()
542
+
543
+ self.deb(f"[ask_gen] Overall, the action ask_gen will return {len(correctly_asked) > 0}")
544
+ return len(correctly_asked) > 0
545
+
546
+ async def do_gen(self, u_hashes: list[str] | None = None, extra_hashes: list[str] | None = None,
547
+ samples: int = 100, time: float = -1., timeout: float = -1.,
548
+ _requester: str | list | None = None, _request_time: float = -1., _request_uuid: str | None = None,
549
+ _completed: bool = False) -> bool:
550
+ """Generate a signal (async).
551
+
552
+ Args:
553
+ u_hashes: A list of input stream hashes for generation. Defaults to None.
554
+ extra_hashes: A list of streams that might be used in a custom manner when overloading this function
555
+ (warning: they are not passed to the processor).
556
+ samples: The number of samples to generate. Defaults to 100.
557
+ time: The max time duration for whole generation process. Defaults to -1.
558
+ timeout: The timeout for generation attempts: if calling the generate action fails for more than "timeout"
559
+ seconds, it is declared as complete. Defaults to -1.
560
+ _requester: The ID of the agent who requested generation (automatically set by the action calling routine).
561
+ _request_time: The time the generation was requested (automatically set by the action calling routine).
562
+ _request_uuid: The UUID of the generation request (automatically set by the action calling routine).
563
+ _completed: A boolean indicating if the generation is already completed (automatically set by the action
564
+ calling routine). This will tell that it is time to run a final procedure.
565
+
566
+ Returns:
567
+ True if the signal generation was successful, False otherwise.
568
+ """
569
+ assert samples is not None and time is not None and timeout is not None, "Missing basic action information"
570
+
571
+ self.deb(f"[do_gen] u_hashes: {u_hashes}, extra_hashes: {extra_hashes}, "
572
+ f"samples: {samples}, time: {time}, timeout: {timeout}, "
573
+ f"requester: {_requester}, request_time: {_request_time}, request_uuid: {_request_uuid}, "
574
+ f"completed: {_completed}")
575
+
576
+ if _requester is not None:
577
+ if isinstance(_requester, list):
578
+ for _r in _requester:
579
+ if self.behaving_in_world():
580
+ if _r not in self.world_agents and _requester not in self.world_masters:
581
+ self.err(f"Unknown agent: {_r} in list {_requester} (fully skipping generation)")
582
+ return False
583
+ else:
584
+ if _r not in self.public_agents:
585
+ self.err(f"Unknown agent: {_r} in list {_requester} (fully skipping generation)")
586
+ return False
587
+ else:
588
+ if self.behaving_in_world():
589
+ if _requester not in self.world_agents and _requester not in self.world_masters:
590
+ self.err(f"Unknown agent: {_requester} (fully skipping generation)")
591
+ return False
592
+ else:
593
+ if _requester not in self.public_agents:
594
+ self.err(f"Unknown agent: {_requester} (fully skipping generation)")
595
+ return False
596
+
597
+ # Create a copy of the input hashes, normalizing them in the appropriate way
598
+ u_hashes_copy = self.__normalize_user_hash(u_hashes)
599
+
600
+ # Create a copy of the input hashes, normalizing them in the appropriate way
601
+ extra_hashes_copy = self.__normalize_user_hash(extra_hashes)
602
+
603
+ # Check what is the step ID of the multistep action
604
+ k = self.get_action_step()
605
+
606
+ # In the first step of this action, we change the UUID of the local stream associated to the input data we will
607
+ # use to handle this action, setting expectations to avoid handling tags of old data
608
+ if k == 0:
609
+ for net_hash in u_hashes_copy:
610
+ if net_hash in self.known_streams:
611
+ for stream_name, stream_obj in self.known_streams[net_hash].items():
612
+
613
+ # If the data arrived before this action, then the UUID is already set, and here there is
614
+ # no need to do anything; if the data has not yet arrived (common case) ...
615
+ if stream_obj.get_uuid(expected=False) != _request_uuid:
616
+ stream_obj.set_uuid(None, expected=False) # Clearing UUID
617
+ stream_obj.set_uuid(_request_uuid, expected=True) # Setting expectations
618
+ else:
619
+ self.out(f"Unknown stream mentioned in u_hashes: {net_hash}")
620
+ return False
621
+
622
+ for net_hash in extra_hashes_copy:
623
+ if net_hash in self.known_streams:
624
+ for stream_name, stream_obj in self.known_streams[net_hash].items():
625
+
626
+ # If the data arrived before this action, then the UUID is already set, and here there is
627
+ # no need to do anything; if the data has not yet arrived (common case) ...
628
+ if stream_obj.get_uuid(expected=False) != _request_uuid:
629
+ stream_obj.set_uuid(None, expected=False) # Clearing UUID
630
+ stream_obj.set_uuid(_request_uuid, expected=True) # Setting expectations
631
+ else:
632
+ self.out(f"Unknown stream mentioned in extra_hashes: {net_hash}")
633
+ return False
634
+
635
+ if not _completed:
636
+ self.out(f"Generating signal")
637
+ ret = self.__process_streams(u_hashes=u_hashes_copy, yhat_hashes=None, learn=False,
638
+ recipient_info=(_requester, samples), ref_uuid=_request_uuid)
639
+ if not ret:
640
+ self.out(f"Generating signal failed")
641
+ else:
642
+ if not self.is_multi_steps_action():
643
+ self.out(f"Completing signal generation (degenerate single-step case of a multi-step action)")
644
+ all_hashes = u_hashes_copy + extra_hashes_copy
645
+ ret = await self.__complete_do(do_what="gen", peer_id_who_asked=_requester,
646
+ all_hashes=all_hashes,
647
+ send_back_confirmation=False, ref_uuid=_request_uuid)
648
+ if not ret:
649
+ self.out(f"Completing signal generation failed")
650
+ return ret
651
+ else:
652
+ self.out(f"Completing signal generation")
653
+ all_hashes = u_hashes_copy + extra_hashes_copy
654
+ ret = await self.__complete_do(do_what="gen", peer_id_who_asked=_requester, all_hashes=all_hashes,
655
+ ref_uuid=_request_uuid)
656
+ if not ret:
657
+ self.out(f"Completing signal generation failed")
658
+ return ret
659
+
660
+ async def done_gen(self, _requester: str | None = None):
661
+ """This is a way to get back the confirmation of a completed generation (async).
662
+
663
+ Args:
664
+ _requester: The ID of the agent who completed the generation. Defaults to None.
665
+
666
+ Returns:
667
+ True if the generation confirmation was successfully handled by this agent, False is something went wrong.
668
+ """
669
+ self.out(f"Agent {_requester} finished generation")
670
+
671
+ # Searching for the processor-streams of the agent who generated data
672
+ processor_streams = self.find_streams(_requester, name_or_group="processor")
673
+ if processor_streams is None or len(processor_streams) == 0:
674
+ self.err("Unexpected confirmation of finished generation")
675
+ return False
676
+
677
+ # Remembering that the agent that invoked this action is the one who generated the data, and what he generated
678
+ # could be used in future action (for example, in evaluation processes)
679
+ self._agents_who_completed_what_they_were_asked.add(_requester)
680
+
681
+ # Clearing the UUID of the local streams associated to the agent who generated
682
+ for net_hash, stream_dict in processor_streams.items():
683
+ self.remove_recipient(net_hash, _requester)
684
+ for stream_obj in stream_dict.values():
685
+ stream_obj.set_uuid(None, expected=False)
686
+ stream_obj.set_uuid(None, expected=True)
687
+
688
+ # If one or more of my streams where used as arguments of the generation request I did (ask_gen), then their
689
+ # UUID must be cleared...we clear them all
690
+ for net_hash, stream_dict in self.owned_streams.items():
691
+ self.remove_recipient(net_hash, _requester)
692
+ for stream_obj in stream_dict.values():
693
+ if stream_obj.props.is_public() != self.behaving_in_world():
694
+ stream_obj.set_uuid(None, expected=False)
695
+ stream_obj.set_uuid(None, expected=True)
696
+
697
+ return True
698
+
699
+ async def ask_learn(self, agent: str | None = None,
700
+ u_hashes: list[str] | None = None, yhat_hashes: list[str] | None = None,
701
+ samples: int = 100, from_state: str | None = None, to_state: str | None = None,
702
+ time: float = -1., timeout: float = -1., ask_uuid: str | None = None,
703
+ ignore_uuid: str | None = None):
704
+ """Asking for learning to generate (async).
705
+
706
+ Args:
707
+ agent: The ID of the agent to ask for generation, or a valid wildcard like "<valid_cmp>"
708
+ for a set of agents (if None the agents in self._engaged_agents will be considered).
709
+ u_hashes: A list of input stream hashes for inference. Defaults to None.
710
+ yhat_hashes: A list of target stream hashes to be used for loss computation. Defaults to None.
711
+ samples: The number of samples to learn from. Defaults to 100.
712
+ from_state: The optional starting state from which learning should be executed.
713
+ to_state: The optional destination state where learning should lead if correctly executed.
714
+ time: The time duration for generation. Defaults to -1.
715
+ timeout: The timeout for the generation request. Defaults to -1.
716
+ ask_uuid: Specify the action UUID (default = None, i.e., it is automatically generated).
717
+ ignore_uuid: If True, the UUID is fully ignored (i.e, forced to None).
718
+
719
+ Returns:
720
+ True if the learning request was successfully sent to at least one involved agent, False otherwise.
721
+ """
722
+ assert samples is not None and time is not None and timeout is not None, "Missing basic action information"
723
+
724
+ # - if "agent" is a peer ID, the involved agents will be a list with one element.
725
+ # - if "agent" is a known wildcard, as "<valid_cmp>", then involved agents will be self._valid_cmp_agents
726
+ # - if "agent" is None, then the current agent in self._engaged_agents will be returned
727
+ involved_agents = self.__involved_agents(agent)
728
+ self.deb(f"[ask_learn] Involved agents: {involved_agents}")
729
+
730
+ if len(involved_agents) == 0:
731
+ self.deb(f"[ask_learn] No involved agents, action will return False")
732
+ return False
733
+
734
+ # Create a copy of the input hashes, normalizing them in the appropriate way
735
+ u_hashes_copy = self.__normalize_user_hash(u_hashes)
736
+
737
+ # Create a copy of the target hashes, normalizing them in the appropriate way
738
+ yhat_hashes_copy = self.__normalize_user_hash(yhat_hashes)
739
+
740
+ # Generate a new UUID for this request
741
+ ref_uuid = AgentBasics.generate_uuid() if ask_uuid is None else ask_uuid
742
+ if ignore_uuid:
743
+ ref_uuid = None
744
+
745
+ # If the input streams are all owned by this agent, discard UUID
746
+ all_owned = True
747
+ for i in range(len(u_hashes_copy)):
748
+ if u_hashes_copy[i] not in self.owned_streams:
749
+ all_owned = False
750
+ break
751
+ if all_owned:
752
+ for i in range(len(yhat_hashes_copy)):
753
+ if yhat_hashes_copy[i] not in self.owned_streams:
754
+ all_owned = False
755
+ break
756
+ if not all_owned:
757
+ ref_uuid = None
758
+
759
+ for i in range(len(u_hashes_copy)):
760
+
761
+ # If there are our own streams involved, and they are buffered, let's plan to restart them when we will
762
+ # start sending them through the net: moreover, let's set the local stream UUID appropriately to
763
+ # the generated UUID
764
+ if u_hashes_copy[i] in self.owned_streams:
765
+ stream_dict = self.known_streams[u_hashes_copy[i]]
766
+ for stream_name, stream_obj in stream_dict.items():
767
+
768
+ # Plan to restart buffered streams
769
+ if isinstance(stream_obj, BufferedDataStream):
770
+ stream_obj.plan_restart_before_next_get(requested_by="send_stream_samples")
771
+
772
+ # Activate the stream (if it was off)
773
+ stream_obj.enable()
774
+
775
+ # Set UUID to the generated one
776
+ stream_obj.set_uuid(ref_uuid=ref_uuid, expected=False)
777
+ stream_obj.set_uuid(ref_uuid=None, expected=True)
778
+
779
+ for i in range(len(yhat_hashes_copy)):
780
+
781
+ # If there are our own streams involved, and they are buffered, let's plan to restart them when we will
782
+ # start sending them through the net: moreover, let's set the local stream UUID appropriately to
783
+ # the generated UUID
784
+ if yhat_hashes_copy[i] in self.owned_streams:
785
+ stream_dict = self.known_streams[yhat_hashes_copy[i]]
786
+ for stream_name, stream_obj in stream_dict.items():
787
+
788
+ # Plan to restart buffered streams
789
+ if isinstance(stream_obj, BufferedDataStream):
790
+ stream_obj.plan_restart_before_next_get(requested_by="send_stream_samples")
791
+
792
+ # Activate the stream (if it was off)
793
+ stream_obj.enable()
794
+
795
+ # Set UUID to the generated one
796
+ stream_obj.set_uuid(ref_uuid=ref_uuid, expected=False)
797
+ stream_obj.set_uuid(ref_uuid=None, expected=True)
798
+
799
+ self.out(f"Asking {', '.join(involved_agents)} to learn to generate signal {yhat_hashes_copy}, "
800
+ f"given {u_hashes_copy} (ref_uuid: {ref_uuid})")
801
+ self._agents_who_completed_what_they_were_asked = set()
802
+ self._agents_who_were_asked = set()
803
+ correctly_asked = []
804
+ for peer_id in involved_agents:
805
+ ret = await self.__ask_gen_or_learn(for_what="learn", agent=peer_id,
806
+ u_hashes=u_hashes_copy,
807
+ yhat_hashes=yhat_hashes_copy,
808
+ samples=samples,
809
+ from_state=from_state, to_state=to_state,
810
+ time=time, timeout=timeout, ref_uuid=ref_uuid)
811
+ self.deb(f"[ask_learn] Asking {peer_id} returned {ret}")
812
+ if ret:
813
+ correctly_asked.append(peer_id)
814
+
815
+ # Preparing the buffered stream where to store data, if needed
816
+ if len(correctly_asked) > 0:
817
+
818
+ # Saving
819
+ self.last_ref_uuid = ref_uuid
820
+
821
+ # For each agent that we involve in this request....
822
+ for peer_id in correctly_asked:
823
+
824
+ # Finding the streams generated by the processor of the agent we asked to generate
825
+ processor_streams = self.find_streams(peer_id, name_or_group="processor")
826
+
827
+ # For each stream generated by the processor of the agent we asked to generate...
828
+ for net_hash, stream_dict in processor_streams.items():
829
+
830
+ # Set the appropriate UUID to the one we created in this method
831
+ for stream in stream_dict.values():
832
+ stream.set_uuid(None, expected=False)
833
+ stream.set_uuid(ref_uuid, expected=True) # Setting the "expected" one
834
+
835
+ # There will be no callbacks in the case of 1 sample, so mark the streams to clear UUID when
836
+ # getting such single sample
837
+ if samples == 1:
838
+ stream.mark_uuid_as_clearable()
839
+
840
+ self.deb(f"[ask_learn] Overall the action ask_learn will return {len(correctly_asked) > 0}")
841
+ return len(correctly_asked) > 0
842
+
843
+ async def do_learn(self, yhat_hashes: list[str] | None = None, u_hashes: list[str] | None = None,
844
+ extra_hashes: list[str] | None = None,
845
+ samples: int = 100, time: float = -1., timeout: float = -1.,
846
+ _requester: str | None = None, _request_time: float = -1., _request_uuid: str | None = None,
847
+ _completed: bool = False) -> bool:
848
+ """Learn to generate a signal (async).
849
+
850
+ Args:
851
+ yhat_hashes: A list of target stream hashes to be used for loss computation. Defaults to None.
852
+ u_hashes: A list of input stream hashes for inference. Defaults to None.
853
+ extra_hashes: A list of streams that might be used in a custom manner when overloading this function
854
+ (warning: they are not passed to the processor).
855
+ samples: The number of samples to learn from. Defaults to 100.
856
+ time: The max time duration of the learning procedure. Defaults to -1.
857
+ timeout: The timeout for learning attempts: if calling the learning action fails for more than "timeout"
858
+ seconds, it is declared as complete. Defaults to -1.
859
+ _requester: The ID of the agent who requested learning (automatically set by the action calling routine).
860
+ _request_time: The time learning was requested (automatically set by the action calling routine).
861
+ _request_uuid: The UUID of the learning request (automatically set by the action calling routine).
862
+ _completed: A boolean indicating if the learning is already completed (automatically set by the action
863
+ calling routine). This will tell that it is time to run a final procedure.
864
+
865
+ Returns:
866
+ True if the signal generation was successful, False otherwise.
867
+ """
868
+ assert samples is not None and time is not None and timeout is not None, "Missing basic action information"
869
+
870
+ self.deb(f"[do_learn] samples: {samples}, time: {time}, timeout: {timeout}, "
871
+ f"requester: {_requester}, request_time: {_request_time}, request_uuid: {_request_uuid} "
872
+ f"completed: {_completed}")
873
+
874
+ if _requester not in self.world_agents and _requester not in self.world_masters:
875
+ self.err(f"Unknown agent: {_requester}")
876
+ return False
877
+
878
+ # Check what is the step ID of the multistep action
879
+ k = self.get_action_step()
880
+
881
+ # Create a copy of the input hashes, normalizing them in the appropriate way
882
+ u_hashes_copy = self.__normalize_user_hash(u_hashes)
883
+
884
+ # Create a copy of the input hashes, normalizing them in the appropriate way
885
+ yhat_hashes_copy = self.__normalize_user_hash(yhat_hashes)
886
+
887
+ # Create a copy of the input hashes, normalizing them in the appropriate way
888
+ extra_hashes_copy = self.__normalize_user_hash(extra_hashes)
889
+
890
+ # In the first step of this action, we change the UUID of the local stream associated to the input data we will
891
+ # use to handle this action, setting expectations to avoid handling tags of old data
892
+ if k == 0:
893
+ if u_hashes_copy is not None:
894
+ for net_hash in u_hashes_copy:
895
+ if net_hash in self.known_streams:
896
+ for stream_obj in self.known_streams[net_hash].values():
897
+
898
+ # If the data arrived before this action, then the UUID is already set, and here there is
899
+ # no need to do anything; if the data has not yet arrived (common case) ...
900
+ if stream_obj.get_uuid(expected=False) != _request_uuid:
901
+ stream_obj.set_uuid(None, expected=False) # Clearing UUID
902
+ stream_obj.set_uuid(_request_uuid, expected=True) # Setting expectations
903
+
904
+ if yhat_hashes_copy is not None:
905
+ for net_hash in yhat_hashes_copy:
906
+ if net_hash in self.known_streams:
907
+ for stream_obj in self.known_streams[net_hash].values():
908
+ if stream_obj.get_uuid(expected=False) != _request_uuid:
909
+ stream_obj.set_uuid(None, expected=False) # Clearing UUID
910
+ stream_obj.set_uuid(_request_uuid, expected=True) # Setting expectations
911
+
912
+ if not _completed:
913
+ self.out(f"Learning to generate signal {yhat_hashes_copy}")
914
+ ret = self.__process_streams(u_hashes=u_hashes_copy, yhat_hashes=yhat_hashes_copy, learn=True,
915
+ recipient_info=(_requester, samples), ref_uuid=_request_uuid)
916
+ if not ret:
917
+ self.out(f"Learning to generate signal {yhat_hashes_copy} failed")
918
+ self.proc_updated_since_last_save = True # Set it after complete?
919
+ return ret
920
+ else:
921
+ self.out(f"Completing learning to generate signal {yhat_hashes_copy}")
922
+ all_hashes = u_hashes_copy + yhat_hashes_copy + extra_hashes_copy
923
+ ret = await self.__complete_do(do_what="learn", peer_id_who_asked=_requester, all_hashes=all_hashes,
924
+ ref_uuid=_request_uuid)
925
+ if not ret:
926
+ self.out(f"Completing learning to generate signal {yhat_hashes} failed")
927
+ return ret
928
+
929
+ async def done_learn(self, _requester: str | None = None):
930
+ """This is a way to get back the confirmation of a completed learning procedure (async).
931
+
932
+ Args:
933
+ _requester: The ID of the agent who completed the learning procedure. Defaults to None.
934
+
935
+ Returns:
936
+ True if the learning-complete confirmation was successfully handled by this agent, False otherwise.
937
+ """
938
+ self.out(f"Agent {_requester} finished learning")
939
+ self._agents_who_completed_what_they_were_asked.add(_requester)
940
+
941
+ # Searching for the processor-streams of the agent who generated the (inference) data
942
+ processor_streams = self.find_streams(_requester, name_or_group="processor")
943
+ if processor_streams is None or len(processor_streams) == 0:
944
+ self.err("Unexpected confirmation of finished learning")
945
+ return False
946
+
947
+ # Warning: differently from the case of done_gen, we are not considering the streams generated by the
948
+ # learning agents as something we could use for evaluation (this might be changed in the future)
949
+
950
+ # Clearing the UUID of the local streams associated to the agent who learned
951
+ for net_hash, stream_dict in processor_streams.items():
952
+ self.remove_recipient(net_hash, _requester)
953
+ for stream_obj in stream_dict.values():
954
+ stream_obj.set_uuid(None, expected=False)
955
+ stream_obj.set_uuid(None, expected=True)
956
+
957
+ # If one or more of my streams where used as arguments of the learning request I did (ask_learn), then their
958
+ # UUID must be cleared...we clear them all
959
+ for net_hash, stream_dict in self.owned_streams.items():
960
+ self.remove_recipient(net_hash, _requester)
961
+ for stream_obj in stream_dict.values():
962
+ if stream_obj.props.is_public() != self.behaving_in_world():
963
+ stream_obj.set_uuid(None, expected=False)
964
+ stream_obj.set_uuid(None, expected=True)
965
+ return True
966
+
967
+ async def connected(self, agent: str | list[str] | None = None, handshake_completed: bool = False):
968
+ """Checks if an agent is connected to us or not.
969
+
970
+ Args:
971
+ agent: The agent to check (or None if there are other already set engaged agents).
972
+ handshake_completed: If True, only consider agents that have completed the handshake.
973
+
974
+ Returns:
975
+ True if all the requested agents are indeed connected.
976
+ """
977
+
978
+ # - if "agent" is a peer ID, the involved agents will be a list with one element.
979
+ # - if "agent" is a known wildcard, as "<valid_cmp>", then involved agents will be self._valid_cmp_agents
980
+ # - if "agent" is None, then the current agent in self._engaged_agents will be returned
981
+ involved_agents = self.__involved_agents(agent)
982
+ if len(involved_agents) == 0:
983
+ return False
984
+
985
+ self.out(f"Checking if all these agents are connected to me now: {involved_agents}")
986
+
987
+ for agent in involved_agents:
988
+ if handshake_completed:
989
+ if agent not in self.all_agents:
990
+ return False
991
+ else:
992
+ if not self._node_conn.is_connected(agent):
993
+ return False
994
+ return True
995
+
996
+ async def all_asked_finished(self):
997
+ """Checks if all agents that were previously asked to perform a task (e.g., generate or learn) have sent a
998
+ completion confirmation. It compares the set of agents asked with the set of agents that have completed
999
+ the task (async).
1000
+
1001
+ Returns:
1002
+ True if all agents are done, False otherwise.
1003
+ """
1004
+ return self._agents_who_were_asked == self._agents_who_completed_what_they_were_asked
1005
+
1006
+ async def all_engagements_completed(self):
1007
+ """Checks if all engagement requests that were sent have been confirmed. It returns True if there are no agents
1008
+ remaining in the `_found_agents` list, implying all have been engaged with or discarded (async).
1009
+
1010
+ Returns:
1011
+ True if all engagements are complete, False otherwise.
1012
+
1013
+ """
1014
+ return len(self._found_agents) == 0
1015
+
1016
+ async def agents_are_waiting(self, timeout: float = -1.):
1017
+ """Checks if there are any agents who have connected but have not yet been fully processed or added to the
1018
+ agent's known lists. This indicates that new agents are waiting to be managed (async).
1019
+
1020
+ Returns:
1021
+ True if there are waiting agents, False otherwise.
1022
+ """
1023
+ assert timeout is not None, "Missing basic action information"
1024
+
1025
+ self.out(f"Current set of {len(self._node_agents_waiting)} connected peer IDs non managed yet: "
1026
+ f"{list(self._node_agents_waiting.keys())}")
1027
+ for found_agent in self._found_agents:
1028
+ if found_agent in self._node_agents_waiting:
1029
+ return True
1030
+ return False
1031
+
1032
+ async def ask_subscribe(self, agent: str | None = None,
1033
+ stream_hashes: list[str] | None = None, unsubscribe: bool = False):
1034
+ """Requests a remote agent or a group of agents to subscribe to or unsubscribe from a list of specified PubSub
1035
+ streams. It normalizes the stream hashes and sends an action request containing the stream properties (async).
1036
+
1037
+ Args:
1038
+ agent: The target agent's ID or a wildcard.
1039
+ stream_hashes: A list of streams to subscribe to or unsubscribe from.
1040
+ unsubscribe: A boolean to indicate if it's an unsubscription request.
1041
+
1042
+ Returns:
1043
+ True if the request was sent to at least one agent, False otherwise.
1044
+ """
1045
+
1046
+ # - if "agent" is a peer ID, the involved agents will be a list with one element.
1047
+ # - if "agent" is a known wildcard, as "<valid_cmp>", then involved agents will be self._valid_cmp_agents
1048
+ # - if "agent" is None, then the current agent in self._engaged_agents will be returned
1049
+ involved_agents = self.__involved_agents(agent)
1050
+ self.deb(f"[ask_subscribe] Involved_agents: {involved_agents}")
1051
+
1052
+ if len(involved_agents) == 0:
1053
+ self.deb(f"[ask_subscribe] No involved agents, action ask_gen returns False")
1054
+ return False
1055
+
1056
+ # Create a copy of the stream hashes, normalizing them in the appropriate way
1057
+ stream_hashes_copy = self.__normalize_user_hash(stream_hashes)
1058
+
1059
+ # Getting properties
1060
+ stream_owners = []
1061
+ stream_props = []
1062
+ for i in range(len(stream_hashes_copy)):
1063
+ stream_dict = self.known_streams[stream_hashes_copy[i]]
1064
+ peer_id = DataProps.peer_id_from_net_hash(stream_hashes_copy[i])
1065
+ for name, stream_obj in stream_dict.items():
1066
+ stream_owners.append(peer_id)
1067
+ stream_props.append(json.dumps(stream_obj.props.to_dict()))
1068
+
1069
+ what = "subscribe to" if not unsubscribe else "unsubscribe from "
1070
+ self.out(f"Asking {', '.join(involved_agents)} to {what} {stream_hashes}")
1071
+ self._agents_who_completed_what_they_were_asked = set()
1072
+ self._agents_who_were_asked = set()
1073
+ correctly_asked = []
1074
+ for agent in involved_agents:
1075
+ if await self.set_next_action(agent, action="do_subscribe", args={"stream_owners": stream_owners,
1076
+ "stream_props": stream_props,
1077
+ "unsubscribe": unsubscribe}):
1078
+ self._agents_who_were_asked.add(agent)
1079
+ ret = True
1080
+ else:
1081
+ what = "subscribe" if not unsubscribe else "unsubscribe"
1082
+ self.err(f"Unable to ask {agent} to {what}")
1083
+ ret = False
1084
+ self.deb(f"[ask_subscribe] Asking {agent} returned {ret}")
1085
+ if ret:
1086
+ correctly_asked.append(agent)
1087
+
1088
+ self.deb(f"[ask_subscribe] Overall, the action ask_subscribe (unsubscribe: {unsubscribe})"
1089
+ f" will return {len(correctly_asked) > 0}")
1090
+ return len(correctly_asked) > 0
1091
+
1092
+ async def do_subscribe(self, stream_owners: list[str] | None = None, stream_props: list[str] | None = None,
1093
+ unsubscribe: bool = False,
1094
+ _requester: str | list | None = None, _request_time: float = -1.):
1095
+ """Executes a subscription or unsubscription request received from another agent. It processes the stream
1096
+ properties, adds or removes the streams from the agent's known streams, and handles the underlying PubSub topic
1097
+ subscriptions (async).
1098
+
1099
+ Args:
1100
+ stream_owners: A list of peer IDs who own the streams.
1101
+ stream_props: A list of JSON-serialized stream properties.
1102
+ unsubscribe: A boolean to indicate unsubscription.
1103
+ _requester: The ID of the requesting agent.
1104
+ _request_time: The time the request was made.
1105
+
1106
+ Returns:
1107
+ True if the action is successful, False otherwise.
1108
+ """
1109
+ self.deb(f"[do_subscribe] unsubscribe: {unsubscribe}, "
1110
+ f"stream_owners: {stream_owners}, stream_props: ... ({len(stream_props)} props)")
1111
+
1112
+ if _requester is not None:
1113
+ if isinstance(_requester, list):
1114
+ for _r in _requester:
1115
+ if self.behaving_in_world():
1116
+ if _r not in self.world_agents and _requester not in self.world_masters:
1117
+ self.err(f"Unknown agent: {_r} in list {_requester} (fully skipping do_subscribe)")
1118
+ return False
1119
+ else:
1120
+ if _r not in self.public_agents:
1121
+ self.err(f"Unknown agent: {_r} in list {_requester} (fully skipping do_subscribe)")
1122
+ return False
1123
+ else:
1124
+ if self.behaving_in_world():
1125
+ if _requester not in self.world_agents and _requester not in self.world_masters:
1126
+ self.err(f"Unknown agent: {_requester} (fully skipping do_subscribe)")
1127
+ return False
1128
+ else:
1129
+ if _requester not in self.public_agents:
1130
+ self.err(f"Unknown agent: {_requester} (fully skipping do_subscribe)")
1131
+ return False
1132
+ else:
1133
+ self.err("Unknown requester (None)")
1134
+ return False
1135
+
1136
+ # Building properties
1137
+ props_dicts = []
1138
+ props_objs = []
1139
+ for i in range(len(stream_props)):
1140
+ p_dict = json.loads(stream_props[i])
1141
+ props = DataProps.from_dict(p_dict)
1142
+ if props.is_pubsub():
1143
+ props_dicts.append(p_dict)
1144
+ props_objs.append(props)
1145
+ else:
1146
+ self.err(f"Expecting a pubsub stream, got a stream named {props.get_name()} "
1147
+ f"(group is {props.get_group()}), which is not pubsub")
1148
+ return False
1149
+
1150
+ # Adding new streams and subscribing (if compatible with our processor)
1151
+ for stream_owner, prop_dict, prop_obj in zip(stream_owners, props_dicts, props_objs):
1152
+ if not unsubscribe:
1153
+ if not (await self.add_compatible_streams(peer_id=stream_owner, streams_in_profile=[prop_dict],
1154
+ buffered=False, public=False)):
1155
+ self.out(f"Unable to add a pubsub stream ({prop_obj.get_name()}) from agent {stream_owner}: "
1156
+ f"no compatible streams were found")
1157
+ else:
1158
+ if not (await self.remove_streams(peer_id=stream_owner, name=prop_obj.get_name())):
1159
+ self.out(f"Unable to unsubscribe from pubsub stream ({prop_obj.get_name()}) "
1160
+ f"of agent {stream_owner}")
1161
+ return True
1162
+
1163
+ async def done_subscribe(self, unsubscribe: bool = False, _requester: str | None = None):
1164
+ """Handles the confirmation that a subscription or unsubscription request has been completed by another agent.
1165
+ It adds the requester to the set of agents that have completed their asked tasks (async).
1166
+
1167
+ Args:
1168
+ unsubscribe: A boolean indicating if it was an unsubscription.
1169
+ _requester: The ID of the agent who completed the task.
1170
+
1171
+ Returns:
1172
+ Always True.
1173
+ """
1174
+ what = "subscribing" if unsubscribe else "unsubscribing"
1175
+ self.out(f"Agent {_requester} finished {what}")
1176
+
1177
+ # Remembering that the agent that invoked this action is the one who actually subscribed
1178
+ self._agents_who_completed_what_they_were_asked.add(_requester)
1179
+ return True
1180
+
1181
+ async def record(self, net_hash: str, samples: int = 100, time: float = -1., timeout: float = -1.):
1182
+ """Records data from a specified stream into a new, owned `BufferedDataStream`. This is a multistep action
1183
+ that captures a sequence of samples over time and then adds the new recorded stream to the agent's profile
1184
+ (async).
1185
+
1186
+ Args:
1187
+ net_hash: The hash of the stream to record.
1188
+ samples: The number of samples to record.
1189
+ time: The time duration for recording.
1190
+ timeout: The timeout for each recording attempt.
1191
+
1192
+ Returns:
1193
+ True if a sample was successfully recorded, False otherwise.
1194
+ """
1195
+ assert samples is not None and time is not None and timeout is not None, "Missing basic action information"
1196
+
1197
+ k = self.get_action_step()
1198
+
1199
+ self.out(f"Recording stream {net_hash}")
1200
+
1201
+ if k == 0:
1202
+
1203
+ # Getting stream(s)
1204
+ _net_hash = self.user_stream_hash_to_net_hash(net_hash) # In case of ambiguity, it yields the first one
1205
+ if _net_hash is None:
1206
+ self.err(f"Unknown stream {net_hash}")
1207
+ return False
1208
+ else:
1209
+ net_hash = _net_hash
1210
+
1211
+ stream_src_dict = self.known_streams[net_hash]
1212
+
1213
+ # Creating the new recorded stream (same props of the recorded one, just owned now)
1214
+ stream_dest_dict = {}
1215
+ for name, stream_obj in stream_src_dict.items():
1216
+ props = stream_obj.props.clone()
1217
+ props.set_group("recorded" + str(self._last_recorded_stream_num))
1218
+ stream_dest_dict[name] = BufferedDataStream(props=props, clock=self._node_clock)
1219
+ self._last_recorded_stream_dict = stream_dest_dict
1220
+ self._last_recording_stream_dict = stream_src_dict
1221
+
1222
+ else:
1223
+
1224
+ # Retrieving the stream(s)
1225
+ stream_dest_dict = self._last_recorded_stream_dict
1226
+ stream_src_dict = self._last_recording_stream_dict
1227
+
1228
+ # Recording
1229
+ for name, stream_obj in stream_src_dict.items():
1230
+ x = stream_obj.get(requested_by="record")
1231
+ if x is None:
1232
+ self.deb("[record] data sample missing, returning False")
1233
+ return False
1234
+ else:
1235
+ self.deb(f"[record] data_tag: {stream_obj.get_tag()}, data_uuid: {stream_obj.get_uuid()}")
1236
+ stream_dest_dict[name].set(x, k) # Saving specific data tags 0, 1, 2, ... #record_steps - 1
1237
+
1238
+ # Updating profile
1239
+ if self.is_last_action_step():
1240
+ self.deb("[record] last action step detected, finishing")
1241
+
1242
+ # Dummy get to ensure that the next get will return None (i.e., we only PubSub if somebody restarts this)
1243
+ for stream_obj in stream_dest_dict.values():
1244
+ stream_obj.get(requested_by="send_stream_samples")
1245
+
1246
+ self.add_streams(list(stream_dest_dict.values()), owned=True)
1247
+ self.update_streams_in_profile()
1248
+ await self.subscribe_to_pubsub_owned_streams()
1249
+ await self.send_profile_to_all()
1250
+
1251
+ # New recorded stream
1252
+ self._last_recorded_stream_num += 1
1253
+
1254
+ return True
1255
+
1256
+ async def connect_by_role(self, role: str | list[str], filter_fcn: str | None = None,
1257
+ time: float = -1., timeout: float = -1.):
1258
+ """Finds and attempts to connect with agents whose profiles match a specific role. It can be optionally
1259
+ filtered by a custom function. It returns True if at least one valid agent is found (async).
1260
+
1261
+ Args:
1262
+ role: The role or list of roles to search for.
1263
+ filter_fcn: The name of an optional filter function.
1264
+ time: The time duration for the action.
1265
+ timeout: The action timeout.
1266
+
1267
+ Returns:
1268
+ True if at least one agent is found and a connection request is made, False otherwise.
1269
+ """
1270
+ self.out(f"Asking to get in touch with all agents whose role is {role}")
1271
+ assert time is not None and timeout is not None, "Missing basic action information"
1272
+
1273
+ if self.get_action_step() == 0:
1274
+ role_list = role if isinstance(role, list) else [role]
1275
+ self._found_agents.clear()
1276
+ at_least_one_is_valid = False
1277
+
1278
+ for role in role_list:
1279
+ role = self.ROLE_STR_TO_BITS[role]
1280
+
1281
+ found_addresses1, found_peer_ids1 = self._node_conn.find_addrs_by_role(Agent.ROLE_WORLD_MASTER | role,
1282
+ return_peer_ids_too=True)
1283
+ found_addresses2, found_peer_ids2 = self._node_conn.find_addrs_by_role(Agent.ROLE_WORLD_AGENT | role,
1284
+ return_peer_ids_too=True)
1285
+ found_addresses = found_addresses1 + found_addresses2
1286
+ found_peer_ids = found_peer_ids1 + found_peer_ids2
1287
+
1288
+ if filter_fcn is not None:
1289
+ if hasattr(self, filter_fcn):
1290
+ filter_fcn = getattr(self, filter_fcn)
1291
+ if callable(filter_fcn):
1292
+ found_addresses, found_peer_ids = filter_fcn(found_addresses, found_peer_ids)
1293
+ else:
1294
+ self.err(f"Filter function not found: {filter_fcn}")
1295
+
1296
+ self.out(f"Found addresses ({len(found_addresses)}) with role: {role}")
1297
+ for f_addr, f_peer_id in zip(found_addresses, found_peer_ids):
1298
+ if not self._node_conn.is_connected(f_peer_id):
1299
+ self.out(f"Asking to get in touch with {f_addr}...")
1300
+ peer_id = await self._node_ask_to_get_in_touch_fcn(addresses=f_addr, public=False)
1301
+ else:
1302
+ self.out(f"Not-asking to get in touch with {f_addr}, "
1303
+ f"since I am already connected to the corresponding peer...")
1304
+ peer_id = f_peer_id
1305
+ if peer_id is not None:
1306
+ at_least_one_is_valid = True
1307
+ self._found_agents.add(peer_id)
1308
+ self.out(f"...returned {peer_id}")
1309
+ return at_least_one_is_valid
1310
+ else:
1311
+ return True
1312
+
1313
+ async def find_agents(self, role: str | list[str], engage: bool = False, handshake_completed: bool = False):
1314
+ """Locally searches through the agent's known peers (world and public agents) to find agents with a specific
1315
+ role. It populates the `_found_agents` set with the peer IDs of matching agents (async).
1316
+
1317
+ Args:
1318
+ role: The role or list of roles to search for.
1319
+ handshake_completed: If True, only consider agents that have completed the handshake.
1320
+ engage: If you want to force the found agents to be the ones that you are engaged with.
1321
+
1322
+ Returns:
1323
+ True if at least one agent is found, False otherwise.
1324
+ """
1325
+ self.out(f"Finding an available agent whose role is {role}")
1326
+ role_list = role if isinstance(role, list) else [role]
1327
+ self._found_agents = set()
1328
+
1329
+ for role_str in role_list:
1330
+ role_int = self.ROLE_STR_TO_BITS[role_str]
1331
+
1332
+ _, found_peer_ids1 = self._node_conn.find_addrs_by_role(Agent.ROLE_WORLD_MASTER | role_int,
1333
+ return_peer_ids_too=True)
1334
+ _, found_peer_ids2 = self._node_conn.find_addrs_by_role(Agent.ROLE_WORLD_AGENT | role_int,
1335
+ return_peer_ids_too=True)
1336
+ found_peer_ids = found_peer_ids1 + found_peer_ids2
1337
+
1338
+ for peer_id in found_peer_ids:
1339
+ if not handshake_completed or peer_id in self.all_agents:
1340
+ self._found_agents.add(peer_id) # Peer IDs here
1341
+
1342
+ self.deb(f"[find_agents] Found these agents: {self._found_agents}")
1343
+ if engage:
1344
+ self._engaged_agents = copy.deepcopy(self._found_agents)
1345
+ return len(self._found_agents) > 0
1346
+
1347
+ async def next_pref_stream(self):
1348
+ """Moves the internal pointer to the next stream in the list of preferred streams, which is often used for
1349
+ playlist-like operations. It wraps around to the beginning if it reaches the end (async).
1350
+
1351
+ Returns:
1352
+ True if the move is successful, False if the list is empty.
1353
+ """
1354
+ if len(self._preferred_streams) == 0:
1355
+ self.err(f"Cannot move to the next stream because the list of preferred streams is empty")
1356
+ return False
1357
+
1358
+ self._cur_preferred_stream = (self._cur_preferred_stream + 1) % len(self._preferred_streams)
1359
+ suffix = ", warning: restarted" if self._cur_preferred_stream == 0 else ""
1360
+ self.out(f"Moving to the next preferred stream ({self._preferred_streams[self._cur_preferred_stream]}){suffix}")
1361
+ return True
1362
+
1363
+ async def first_pref_stream(self):
1364
+ """Resets the internal pointer to the first stream in the list of preferred streams. This is useful for
1365
+ restarting a playback or processing loop (async).
1366
+
1367
+ Returns:
1368
+ True if the move is successful, False if the list is empty.
1369
+ """
1370
+ if len(self._preferred_streams) == 0:
1371
+ self.err(f"Cannot move to the first stream because the list of preferred streams is empty")
1372
+ return False
1373
+
1374
+ self._cur_preferred_stream = 0
1375
+ self.out(f"Moving to the first preferred stream ({self._preferred_streams[self._cur_preferred_stream]})")
1376
+ return True
1377
+
1378
+ async def check_pref_stream(self, what: str = "last"):
1379
+ """Checks the position of the current preferred stream within the list. It can check if it's the first, last,
1380
+ or if it has completed a full round, among other checks (async).
1381
+
1382
+ Args:
1383
+ what: A string specifying the type of check to perform (e.g., 'first', 'last', 'last_round').
1384
+
1385
+ Returns:
1386
+ True if the condition is met, False otherwise.
1387
+ """
1388
+ valid = ['first', 'last', 'not_first', 'not_last', 'last_round', 'not_last_round', 'last_song', 'not_last_song']
1389
+ assert what in valid, f"The what argument can only be one of {valid}"
1390
+
1391
+ self.out(f"Checking if the current preferred playlist item "
1392
+ f"(id: {self._cur_preferred_stream}) is the '{what}' one")
1393
+ if what == "first":
1394
+ return self._cur_preferred_stream == 0
1395
+ elif what == "last":
1396
+ return self._cur_preferred_stream == len(self._preferred_streams) - 1
1397
+ elif what == "not_first":
1398
+ return self._cur_preferred_stream != 0
1399
+ elif what == "not_last":
1400
+ return self._cur_preferred_stream != len(self._preferred_streams) - 1
1401
+ elif what == "last_round":
1402
+ return (self._cur_preferred_stream + len(self._preferred_streams) // self._repeat >=
1403
+ len(self._preferred_streams))
1404
+ elif what == "not_last_round":
1405
+ return (self._cur_preferred_stream + len(self._preferred_streams) // self._repeat <
1406
+ len(self._preferred_streams))
1407
+ elif what == "last_song":
1408
+ num_streams_in_playlist = len(self._preferred_streams) // self._repeat
1409
+ return (self._cur_preferred_stream + 1) % num_streams_in_playlist == 0
1410
+ elif what == "not_last_song":
1411
+ num_streams_in_playlist = len(self._preferred_streams) // self._repeat
1412
+ return (self._cur_preferred_stream + 1) % num_streams_in_playlist != 0
1413
+
1414
+ async def set_pref_streams(self, net_hashes: list[str], repeat: int = 1):
1415
+ """Fills the agent's list of preferred streams (a playlist). It can repeat the playlist a specified number of
1416
+ times and resolves user-provided stream hashes to their full network hashes (async).
1417
+
1418
+ Args:
1419
+ net_hashes: A list of stream hashes to add to the playlist.
1420
+ repeat: The number of times to repeat the playlist.
1421
+
1422
+ Returns:
1423
+ Always True.
1424
+ """
1425
+ self.out(f"Setting up a list of {len(net_hashes)} preferred streams")
1426
+ self._cur_preferred_stream = 0
1427
+ self._preferred_streams = []
1428
+ self._repeat = repeat
1429
+ for i in range(0, self._repeat):
1430
+ for net_hash in net_hashes:
1431
+
1432
+ # We are tolerating both peer_id:name_or_group and also peer_id::ps:name_or_group
1433
+ components = net_hash.split(":")
1434
+ peer_id = components[0]
1435
+ name_or_group = components[-1]
1436
+ net_hash_to_streams = self.find_streams(peer_id=peer_id, name_or_group=name_or_group)
1437
+ for _net_hash in net_hash_to_streams.keys():
1438
+ self._preferred_streams.append(_net_hash)
1439
+
1440
+ return True
1441
+
1442
+ async def evaluate(self, stream_hash: str, how: str, steps: int = 100, re_offset: bool = False):
1443
+ """Evaluates the performance of agents that have completed a generation task. It compares the generated data
1444
+ from each agent with a local stream (which can be a ground truth or reference stream) using a specified
1445
+ comparison method (async).
1446
+
1447
+ Args:
1448
+ stream_hash: The hash of the local stream to use for comparison.
1449
+ how: The name of the comparison method to use.
1450
+ steps: The number of steps to perform the evaluation.
1451
+ re_offset: A boolean to indicate whether to re-offset the streams.
1452
+
1453
+ Returns:
1454
+ True if the evaluation is successful, False otherwise.
1455
+ """
1456
+ if not self.buffer_generated_by_others:
1457
+ self.err("Cannot evaluate if not buffering data generated by others")
1458
+ return False
1459
+
1460
+ if stream_hash == "<playlist>":
1461
+ net_hash = self._preferred_streams[self._cur_preferred_stream]
1462
+ else:
1463
+ net_hash = self.user_stream_hash_to_net_hash(stream_hash)
1464
+
1465
+ self._eval_results = {}
1466
+ self.deb(f"[eval] Agents returning streams: {self._agents_who_completed_what_they_were_asked}")
1467
+ for peer_id in self._agents_who_completed_what_they_were_asked:
1468
+ if peer_id not in self.last_buffered_peer_id_to_info:
1469
+ self.err(f"Missing buffered stream for {peer_id}, cannot evaluate!")
1470
+ continue
1471
+ received_net_hash = self.last_buffered_peer_id_to_info[peer_id]["net_hash"]
1472
+ self.out(f"Comparing {net_hash} with {received_net_hash}")
1473
+ eval_result, ret = self.__compare_streams(net_hash_a=net_hash,
1474
+ net_hash_b=received_net_hash,
1475
+ how=how, steps=steps, re_offset=re_offset)
1476
+ self.out(f"Result of the comparison: {eval_result}")
1477
+ if not ret:
1478
+ return False
1479
+ else:
1480
+ peer_id = DataProps.peer_id_from_net_hash(received_net_hash)
1481
+ self._eval_results[peer_id] = eval_result
1482
+
1483
+ return True
1484
+
1485
+ async def compare_eval(self, cmp: str, thres: float, good_if_true: bool = True):
1486
+ """Compares the results of a previous evaluation to a given threshold or finds the best result among all
1487
+ agents. It can check for minimum, maximum, or simple threshold-based comparisons, and it populates a list of
1488
+ 'valid' agents that passed the comparison (async).
1489
+
1490
+ Args:
1491
+ cmp: The comparison operator (e.g., '<', '>', 'min').
1492
+ thres: The threshold value for comparison.
1493
+ good_if_true: A boolean to invert the pass/fail logic.
1494
+
1495
+ Returns:
1496
+ True if at least one agent passed the comparison, False otherwise.
1497
+ """
1498
+ assert cmp in ["<", ">", ">=", "<=", "min", "max"], f"Invalid comparison operator: {cmp}"
1499
+ assert thres >= 0. or cmp in ["min", "max"], f"Invalid evaluation threshold: {thres} (it must be in >= 0.)"
1500
+
1501
+ self._valid_cmp_agents = set()
1502
+ msgs = []
1503
+ best_so_far = -1
1504
+
1505
+ min_or_max = None
1506
+ leq_or_geq = None
1507
+ if cmp in ["min", "max"]:
1508
+ min_or_max = "minimum" if cmp == "min" else "maximum"
1509
+ leq_or_geq = "<=" if cmp == "min" else ">="
1510
+
1511
+ for agent_peer_id, eval_result in self._eval_results.items():
1512
+ if cmp not in ["min", "max"]:
1513
+ self.out(f"Checking if result {eval_result} {cmp} {thres}, for agent {agent_peer_id}")
1514
+ else:
1515
+ if thres >= 0:
1516
+ self.out(f"Checking if result {eval_result} is the {min_or_max} so far, "
1517
+ f"only if {leq_or_geq} {thres}, for agent {agent_peer_id}")
1518
+ else:
1519
+ self.out(f"Checking if result {eval_result} is the {min_or_max} so far, for agent {agent_peer_id}")
1520
+
1521
+ if eval_result < 0.:
1522
+ self.print(f"Invalid evaluation result: {eval_result}")
1523
+ return False
1524
+
1525
+ owner_account = self.all_agents[agent_peer_id].get_static_profile()['email']
1526
+ agent_name = self.all_agents[agent_peer_id].get_static_profile()['node_name']
1527
+
1528
+ if cmp != "min" and cmp != "max":
1529
+ outcome = False
1530
+ if cmp == "<" and eval_result < thres:
1531
+ outcome = True
1532
+ elif cmp == "<=" and eval_result <= thres:
1533
+ outcome = True
1534
+ elif cmp == ">" and eval_result > thres:
1535
+ outcome = True
1536
+ elif cmp == ">=" and eval_result >= thres:
1537
+ outcome = True
1538
+
1539
+ if cmp[0] == "<" or cmp[0] == "<=":
1540
+ alias = 'error level' if good_if_true else 'mark'
1541
+ else:
1542
+ alias = 'mark' if good_if_true else 'error level'
1543
+
1544
+ if good_if_true:
1545
+ if outcome:
1546
+ msgs.append(f"Agent {owner_account}/{agent_name} passed with {alias} {eval_result}/{thres}")
1547
+ self._valid_cmp_agents.add(agent_peer_id)
1548
+ else:
1549
+ msgs.append(f"Agent {owner_account}/{agent_name} did not pass, {alias} {eval_result}/{thres}")
1550
+ else:
1551
+ if outcome:
1552
+ msgs.append(f"Agent {owner_account}/{agent_name} did not pass, {alias} {eval_result}/{thres}")
1553
+ else:
1554
+ msgs.append(f"Agent {owner_account}/{agent_name} passed with {alias} {eval_result}/{thres}")
1555
+ self._valid_cmp_agents.add(agent_peer_id)
1556
+ else:
1557
+ if ((cmp == "min" and (thres < 0 or eval_result <= thres) and
1558
+ (eval_result < best_so_far or best_so_far < 0)) or
1559
+ (cmp == "max" and (thres < 0 or eval_result >= thres) and
1560
+ (eval_result > best_so_far or best_so_far < 0))):
1561
+ best_so_far = eval_result
1562
+ self._valid_cmp_agents = {agent_peer_id}
1563
+ msgs = [f"The best agent is {owner_account}/{agent_name}"]
1564
+ else:
1565
+ msgs = [f"No best agent found for the considered threshold ({thres})"]
1566
+
1567
+ for msg in msgs:
1568
+ self.print(msg)
1569
+
1570
+ if len(self._valid_cmp_agents) == 0:
1571
+
1572
+ # # cheating (hack):
1573
+ # self._valid_cmp_agents.append(agent_peer_id)
1574
+ # self.out(", ".join(msgs))
1575
+ # return True
1576
+ return False
1577
+ else:
1578
+ return True
1579
+
1580
+ def collect_and_store_own_stats(self):
1581
+ """Collects this agent's own stats and pushes them to the stats recorder."""
1582
+ if self.stats is None:
1583
+ return
1584
+
1585
+ _, own_private_pid = self.get_peer_ids()
1586
+ t = self._node_clock.get_time_ms()
1587
+ try:
1588
+ info = self._node_conn['p2p_world'].get_connected_peers_info()
1589
+ peers_list = [i['id'] for i in info]
1590
+ self.stats.store_stat('connected_peers', peers_list, own_private_pid, t)
1591
+ except Exception as e:
1592
+ self.stats.store_stat('connected_peers', [], own_private_pid, t)
1593
+ self.err(f"[Stats] Error collecting and storing own stats, clearing: {e}")
1594
+
1595
+ try:
1596
+ behav = self.behav
1597
+ self.stats.store_stat('state', behav.get_state_name(), own_private_pid, t)
1598
+ self.stats.store_stat('action', behav.get_action_name(), own_private_pid, t)
1599
+ self.stats.store_stat('last_action', behav.get_last_completed_action_name(), own_private_pid, t)
1600
+ except Exception as e:
1601
+ self.err(f"[Stats] Error storing HSM stats: {e}")
1602
+
1603
+ async def send_stats_to_world(self):
1604
+ """Sends the agent's currently buffered stats to the world and clears them (async)."""
1605
+ if not self.in_world():
1606
+ self.deb("[send_stats_to_world] Not in a world, skipping stats send.")
1607
+ return
1608
+
1609
+ world_peer_id = self._node_conn.get_world_peer_id()
1610
+ if world_peer_id is None:
1611
+ self.err("[send_stats_to_world] In world, but world_peer_id is None.")
1612
+ return
1613
+
1614
+ self.collect_and_store_own_stats() # update own stats
1615
+ payload = self.stats.get_payload_for_world()
1616
+ if not payload:
1617
+ self.deb("[send_stats_to_world] No stats to send.")
1618
+ return
1619
+
1620
+ # Send all stats
1621
+ self.out(f"[AGENT] Sending stats update to world {world_peer_id}...")
1622
+ if not (await self._node_conn.send(world_peer_id,
1623
+ channel_trail=None,
1624
+ content=payload,
1625
+ content_type=Msg.STATS_UPDATE)):
1626
+ self.err("Failed to send stats update to world.")
1627
+
1628
+ # Ask the updates to the world (no overwrite required)
1629
+ self.out(f"[AGENT] Requesting stats update from world {world_peer_id}...")
1630
+ if not (await self._node_conn.send(world_peer_id,
1631
+ channel_trail=None,
1632
+ content={'time_range': self.stats.max_seen_timestamp},
1633
+ content_type=Msg.STATS_REQUEST)):
1634
+ self.err("Failed to request stats to world.")
1635
+
1636
+ def update_stats_view(self, received_view, overwrite: bool = False):
1637
+ """
1638
+ Updates the _world_view attribute of the Stats object.
1639
+ """
1640
+ self.stats.update_view(received_view, overwrite)
1641
+
1642
+ async def suggest_role_to_world(self, agent: str | None, role: str):
1643
+ """Suggests a role change for one or more agents to the world master. It iterates through the involved agents,
1644
+ checks if their current role differs from the suggested one, and sends a role suggestion message to the
1645
+ world master (async).
1646
+
1647
+ Args:
1648
+ agent: The ID of the agent or a wildcard to suggest the role for.
1649
+ role: The new role to suggest (as a string).
1650
+
1651
+ Returns:
1652
+ True if the suggestion was sent successfully, False otherwise.
1653
+ """
1654
+ self.out("Suggesting role to world")
1655
+
1656
+ agents = self.__involved_agents(agent)
1657
+ role_bits = (self.ROLE_STR_TO_BITS[role] >> 2) << 2
1658
+
1659
+ content = []
1660
+
1661
+ for _agent in agents:
1662
+ cur_role_bits = self.ROLE_STR_TO_BITS[self.all_agents[_agent].get_dynamic_profile()['connections']['role']]
1663
+ cur_role_bits = (cur_role_bits >> 2) << 2
1664
+ if cur_role_bits == role_bits:
1665
+ self.out(f"Not suggesting to change the role of {_agent} "
1666
+ f"since it has already such a role")
1667
+ else:
1668
+ self.out(f"Suggesting to change the role of {_agent} to {self.ROLE_BITS_TO_STR[role_bits]}")
1669
+ content.append({'peer_id': _agent, 'role': role_bits})
1670
+
1671
+ if len(content) > 0:
1672
+ world_peer_id = self._node_conn.get_world_peer_id()
1673
+ if not (await self._node_conn.send(world_peer_id, channel_trail=None,
1674
+ content=content,
1675
+ content_type=Msg.ROLE_SUGGESTION)):
1676
+ self.err("Failed to send role suggestion to the world")
1677
+ return False
1678
+ return True
1679
+
1680
+ async def suggest_badges_to_world(self, agent: str | None = None,
1681
+ score: float = -1.0, badge_type: str = "completed",
1682
+ badge_description: str | None = None):
1683
+ """Suggests one or more badges to the world master for specific agents. This is typically used to reward agents
1684
+ for completing tasks, such as for a competition. It sends a message with the badge details, including the score
1685
+ and type, to the world master (async).
1686
+
1687
+ Args:
1688
+ agent: The ID of the agent or a wildcard for which to suggest the badge.
1689
+ score: The score associated with the badge.
1690
+ badge_type: The type of badge (e.g., 'completed').
1691
+ badge_description: An optional description for the badge.
1692
+
1693
+ Returns:
1694
+ True if the badge suggestion was sent successfully, False otherwise.
1695
+ """
1696
+ self.out("Suggesting one or more badges to world")
1697
+
1698
+ if score < 0.:
1699
+ self.err("Invalid score (did you specify the 'score' argument? it must be positive)")
1700
+ return False
1701
+
1702
+ agents = self.__involved_agents(agent)
1703
+ world_peer_id = self._node_conn.get_world_peer_id()
1704
+
1705
+ if badge_type not in Agent.BADGE_TYPES:
1706
+ self.err(f"Unknown badge type: {badge_type}")
1707
+ return False
1708
+
1709
+ list_of_badge_dictionaries = []
1710
+ for peer_id in agents:
1711
+ list_of_badge_dictionaries.append({'peer_id': peer_id,
1712
+ 'score': score,
1713
+ 'badge_type': badge_type,
1714
+ 'badge_description': badge_description,
1715
+ 'agent_token': self._node_conn.get_last_token(peer_id)})
1716
+
1717
+ if not (await self._node_conn.send(world_peer_id, channel_trail=None,
1718
+ content=list_of_badge_dictionaries,
1719
+ content_type=Msg.BADGE_SUGGESTIONS)):
1720
+ self.err("Failed to send badge suggestions to the world")
1721
+ return False
1722
+ else:
1723
+ return True
1724
+
1725
+ async def __ask_gen_or_learn(self, for_what: str, agent: str,
1726
+ u_hashes: list[str] | None,
1727
+ yhat_hashes: list[str] | None,
1728
+ samples: int = 100,
1729
+ from_state: str | None = None,
1730
+ to_state: str | None = None,
1731
+ time: float = -1., timeout: float = -1.,
1732
+ ref_uuid: str | None = None):
1733
+ """A private helper method that encapsulates the logic for sending a 'do_gen' or 'do_learn' action request to
1734
+ another agent. It handles the normalization of stream hashes, sets up recipients for direct messages, and adds
1735
+ the target agent to the list of agents asked (async).
1736
+
1737
+ Args:
1738
+ for_what: A string indicating whether to ask for 'gen' or 'learn'.
1739
+ agent: The ID of the agent to send the request to.
1740
+ u_hashes: A list of input stream hashes.
1741
+ yhat_hashes: A list of target stream hashes (for learning).
1742
+ samples: The number of samples.
1743
+ from_state: The optional starting state from which the 'do_gen'/'do_learn' should be executed.
1744
+ to_state: The optional destination state where the 'do_gen'/'do_learn' should lead if correctly executed.
1745
+ time: The time duration.
1746
+ timeout: The request timeout.
1747
+ ref_uuid: The UUID for the request.
1748
+
1749
+ Returns:
1750
+ True if the request was sent successfully, False otherwise.
1751
+ """
1752
+ if agent not in self.all_agents:
1753
+ self.err(f"Unknown agent: {agent}")
1754
+ return False
1755
+
1756
+ assert for_what in ["gen", "learn"]
1757
+
1758
+ if for_what == "learn":
1759
+ for yhat_hash in yhat_hashes:
1760
+ yhat_stream_dict = self.known_streams[yhat_hash]
1761
+ for yhat_stream in yhat_stream_dict.values():
1762
+ if isinstance(yhat_stream, BufferedDataStream):
1763
+ y_text = yhat_stream.to_text_snippet(length=200)
1764
+ if y_text is not None and len(y_text) > 0:
1765
+ self.out("Asking to learn: \"" + y_text + "\"")
1766
+
1767
+ # Triggering
1768
+ if for_what == "gen":
1769
+ if await self.set_next_action(agent, action="do_gen",
1770
+ args=({"u_hashes": u_hashes} if len(u_hashes) > 0 else {}) | {
1771
+ "samples": samples, "time": time, "timeout": timeout},
1772
+ from_state=from_state, to_state=to_state,
1773
+ ref_uuid=ref_uuid):
1774
+ if samples > 1:
1775
+ self._agents_who_were_asked.add(agent)
1776
+
1777
+ # Setting recipient in the case of direct messages
1778
+ # (differently, in case of pubsub, the agent is already sending messages to all)
1779
+ if u_hashes is not None:
1780
+ for u_hash in u_hashes:
1781
+ if not DataProps.is_pubsub_from_net_hash(u_hash):
1782
+ self.add_recipient(u_hash, agent, samples)
1783
+ return True
1784
+ else:
1785
+ self.err(f"Unable to ask {agent} to generate")
1786
+ return False
1787
+ elif for_what == "learn":
1788
+ if await self.set_next_action(
1789
+ agent, action="do_learn", args=({"u_hashes": u_hashes} if len(u_hashes) > 0 else {}) | (
1790
+ {"yhat_hashes": yhat_hashes} if len(yhat_hashes) > 0 else {}) | {"samples": samples, "time": time,
1791
+ "timeout": timeout},
1792
+ from_state=from_state, to_state=to_state,
1793
+ ref_uuid=ref_uuid):
1794
+ if samples > 1:
1795
+ self._agents_who_were_asked.add(agent)
1796
+
1797
+ # Setting recipient in the case of direct messages
1798
+ # (differently, in case of pubsub, the agent is already sending messages to all)
1799
+ if u_hashes is not None:
1800
+ for u_hash in u_hashes:
1801
+ if not DataProps.is_pubsub_from_net_hash(u_hash):
1802
+ self.add_recipient(u_hash, agent, samples)
1803
+ if yhat_hashes is not None:
1804
+ for yhat_hash in yhat_hashes:
1805
+ if not DataProps.is_pubsub_from_net_hash(yhat_hash):
1806
+ self.add_recipient(yhat_hash, agent, samples)
1807
+ return True
1808
+ else:
1809
+ self.err(f"Unable to ask {agent} to learn to generate")
1810
+ return False
1811
+
1812
+ def __process_streams(self,
1813
+ u_hashes: list[str] | None,
1814
+ yhat_hashes: list[str] | None,
1815
+ learn: bool = False,
1816
+ recipient_info: tuple[str, int] | None = None,
1817
+ ref_uuid: str | None = None):
1818
+ """A private helper method that contains the core logic for processing data streams, either for generation or
1819
+ learning. It reads input streams, passes them to the agent's processor, and handles the output streams.
1820
+ It's designed to be called repeatedly by multistep actions like `do_gen` and `do_learn`.
1821
+
1822
+ Args:
1823
+ u_hashes: A list of input stream hashes.
1824
+ yhat_hashes: A list of target stream hashes (for learning).
1825
+ learn: A boolean to indicate if the task is a learning task.
1826
+ recipient_info: The tuple (ID, samples), being 'ID' the peer ID of the agent to send data back to and
1827
+ being 'samples' the number of total samples in the father request.
1828
+ ref_uuid: The UUID for the request.
1829
+
1830
+ Returns:
1831
+ True if the stream processing is successful, False otherwise.
1832
+ """
1833
+
1834
+ # Getting current step index
1835
+ k = self.get_action_step()
1836
+
1837
+ # Checking data and creating new buffered streams
1838
+ if k == 0:
1839
+ self.deb("[__process_streams] First action step")
1840
+
1841
+ # Checking data
1842
+ if u_hashes is not None:
1843
+ for u_hash in u_hashes:
1844
+ if u_hash is not None and u_hash not in self.known_streams:
1845
+ self.err(f"Unknown stream (u_hash): {u_hash}")
1846
+ return False
1847
+ if yhat_hashes is not None:
1848
+ for yhat_hash in yhat_hashes:
1849
+ if yhat_hash is not None and yhat_hash not in self.known_streams:
1850
+ self.err(f"Unknown stream (yhat_hash): {yhat_hash}")
1851
+ return False
1852
+
1853
+ if self.is_last_action_step():
1854
+ self.deb("[__process_streams] Last action step detected")
1855
+
1856
+ self.deb(f"[__process_streams] Generating data, step {k}")
1857
+
1858
+ # Generate output
1859
+ outputs, data_tag_from_inputs = (
1860
+ self.generate(input_net_hashes=u_hashes, first=(k == 0), last=self.is_last_action_step(),
1861
+ ref_uuid=ref_uuid))
1862
+ if outputs is None:
1863
+ return False
1864
+ self.deb(f"[__process_streams] data_tag_from_inputs: {data_tag_from_inputs}")
1865
+ if data_tag_from_inputs is None:
1866
+ data_tag_from_inputs = -1
1867
+ self.deb(f"[__process_streams] data_tag_from_inputs (forced): {data_tag_from_inputs}")
1868
+
1869
+ # Learn
1870
+ if learn:
1871
+ self.deb(f"[__process_streams] learning, step {k}")
1872
+ loss_values, data_tags_from_targets = self.learn_generate(outputs=outputs, targets_net_hashes=yhat_hashes)
1873
+ self.deb(f"[__process_streams] data_tags_from_targets: {data_tags_from_targets}")
1874
+
1875
+ # Fusing data tags
1876
+ data_tags = [data_tag_from_inputs if _data_tag == -1 else _data_tag for _data_tag in data_tags_from_targets]
1877
+
1878
+ if loss_values is None:
1879
+ return False
1880
+ else:
1881
+ self.print(f"Losses: {loss_values}, Step: {k}, Tags: {data_tags}")
1882
+ else:
1883
+ data_tags = [data_tag_from_inputs] * len(outputs)
1884
+ self.deb(f"[__process_streams] data_tags (final): {data_tags}")
1885
+
1886
+ # Set each data sample in "outputs" to the right stream
1887
+ i = 0
1888
+ for net_hash, stream_dict in self.proc_streams.items():
1889
+
1890
+ # Setting the data sample
1891
+ for name, stream_obj in stream_dict.items():
1892
+
1893
+ # Public output streams are only considered if the agent IS NOT acting in a world
1894
+ # private output streams are only considered if the agent IS acting in a world
1895
+ if self.behaving_in_world() != stream_obj.props.is_public():
1896
+
1897
+ # Guessing recipient of the communication
1898
+ if i == 0 and not DataProps.is_pubsub_from_net_hash(net_hash):
1899
+ _recipient, _samples = recipient_info
1900
+ self.add_recipient(net_hash, _recipient, _samples)
1901
+
1902
+ self.deb(f"[__process_streams] Setting the {i}-th network output to stream with "
1903
+ f"net_hash: {net_hash}, name: {name}")
1904
+
1905
+ # Here we exploit the fact that streams were inserted in order
1906
+ try:
1907
+ stream_obj.set(stream_obj.props.check_and_postprocess(outputs[i]), data_tags[i])
1908
+ except Exception as e:
1909
+ self.err(f"Error while post-processing the processor output\nException: {e}")
1910
+ return False
1911
+
1912
+ if k == 0:
1913
+ stream_obj.set_uuid(ref_uuid, expected=False)
1914
+ stream_obj.set_uuid(None, expected=True)
1915
+ i += 1
1916
+
1917
+ return True
1918
+
1919
+ async def __complete_do(self, do_what: str, peer_id_who_asked: str, all_hashes: list[str] | None,
1920
+ send_back_confirmation: bool = True, ref_uuid: str | None = None):
1921
+ """A private helper method to be called at the end of a `do_gen` or `do_learn` action. It performs cleanup
1922
+ tasks, such as clearing UUIDs on streams, and sends a confirmation message back to the requesting agent (async).
1923
+
1924
+ Args:
1925
+ do_what: A string ('gen' or 'learn') indicating which task was completed.
1926
+ peer_id_who_asked: The ID of the agent who requested the task.
1927
+ all_hashes: A list of all stream hashes involved in the task.
1928
+ send_back_confirmation: A boolean to indicate if a confirmation message should be sent.
1929
+
1930
+ Returns:
1931
+ True if the completion process is successful, False otherwise.
1932
+ """
1933
+ assert do_what in ["gen", "learn"]
1934
+
1935
+ if do_what == "gen":
1936
+ for net_hash, stream_dict in self.proc_streams.items():
1937
+ for stream in stream_dict.values():
1938
+ if isinstance(stream, BufferedDataStream):
1939
+ y_text = stream.to_text_snippet(length=200)
1940
+ if y_text is not None:
1941
+ self.out("Generated: \"" + y_text + "\"")
1942
+
1943
+ for net_hash, stream_dict in self.proc_streams.items():
1944
+ for stream_obj in stream_dict.values():
1945
+ if stream_obj.props.is_public() != self.behaving_in_world():
1946
+ stream_obj.mark_uuid_as_clearable()
1947
+ self.mark_recipient_as_removable(net_hash, peer_id_who_asked)
1948
+
1949
+ if all_hashes is not None:
1950
+ for net_hash in all_hashes:
1951
+ self.remove_recipient(net_hash, peer_id_who_asked)
1952
+ for stream_obj in self.known_streams[net_hash].values():
1953
+ stream_obj.set_uuid(None, expected=False)
1954
+ stream_obj.set_uuid(None, expected=True)
1955
+
1956
+ # Confirming
1957
+ if send_back_confirmation:
1958
+ if await self.set_next_action(peer_id_who_asked, action="done_" + do_what, args={}, ref_uuid=ref_uuid):
1959
+ return True
1960
+ else:
1961
+ self.err(f"Unable to confirm '{do_what}' to {peer_id_who_asked}")
1962
+ return False
1963
+ else:
1964
+ return True
1965
+
1966
+ def __compare_streams(self, net_hash_a: str, net_hash_b: str,
1967
+ how: str = "mse", steps: int = 100, re_offset: bool = False):
1968
+ """A private helper method that compares two buffered data streams based on a specified metric (e.g., MSE,
1969
+ max accuracy). It handles stream compatibility checks, data retrieval, and the actual comparison, returning a
1970
+ dissimilarity score.
1971
+
1972
+ Args:
1973
+ net_hash_a: The network hash of the first stream.
1974
+ net_hash_b: The network hash of the second stream.
1975
+ how: The comparison metric ('mse', 'max', 'geqX').
1976
+ steps: The number of samples to compare.
1977
+ re_offset: A boolean to re-align stream tags before comparison.
1978
+
1979
+ Returns:
1980
+ A tuple containing the dissimilarity score and a success flag (e.g., `(0.5, True)`).
1981
+ """
1982
+ if net_hash_a not in self.known_streams:
1983
+ self.err(f"Unknown stream (net_hash_a): {net_hash_a}")
1984
+ return -1., False
1985
+
1986
+ if net_hash_b not in self.known_streams:
1987
+ self.err(f"Unknown stream (net_hash_b): {net_hash_b}")
1988
+ return -1., False
1989
+
1990
+ if steps <= 0:
1991
+ self.err(f"Invalid number of steps: {steps}")
1992
+ return -1., False
1993
+
1994
+ if how not in ["mse", "max"] and not how.startswith("geq"):
1995
+ self.err(f"Data can be compared by MSE, or by comparing the argmax ('max'), or comparing the number "
1996
+ f"of corresponding bits (obtained by 'geqX', where 'X' is a number). Unknown: {how})")
1997
+ return -1., False
1998
+
1999
+ stream_dict_a = self.known_streams[net_hash_a]
2000
+ stream_dict_b = self.known_streams[net_hash_b]
2001
+
2002
+ if len(stream_dict_a) == 1 and len(stream_dict_b) == 1:
2003
+
2004
+ # If there is only 1 stream is each group, things are easy
2005
+ stream_a = next(iter(stream_dict_a.values()))
2006
+ stream_b = next(iter(stream_dict_b.values()))
2007
+ elif len(stream_dict_a) == 1 and len(stream_dict_b) > 1:
2008
+
2009
+ # If there is only 1 stream is one of the groups, we look for a compatible stream in the other group,
2010
+ # giving priority to streams with labels
2011
+ stream_a = next(iter(stream_dict_a.values()))
2012
+ stream_b = None
2013
+ for stream_obj in stream_dict_b.values():
2014
+ if (stream_a.get_props().has_tensor_labels() and stream_obj.get_props().has_tensor_labels() and
2015
+ stream_obj.get_props().is_compatible(stream_a.get_props())):
2016
+ stream_b = stream_obj
2017
+ break
2018
+ if stream_b is None:
2019
+ for stream_obj in stream_dict_b.values():
2020
+ if stream_obj.get_props().is_compatible(stream_a.get_props()):
2021
+ stream_b = stream_obj
2022
+ break
2023
+ elif len(stream_dict_a) > 1 and len(stream_dict_b) == 1:
2024
+
2025
+ # If there is only 1 stream is one of the groups, we look for a compatible stream in the other group,
2026
+ # giving priority to streams with labels
2027
+ stream_a = None
2028
+ stream_b = next(iter(stream_dict_b.values()))
2029
+ for stream_obj in stream_dict_a.values():
2030
+ if (stream_b.get_props().has_tensor_labels() and stream_obj.get_props().has_tensor_labels() and
2031
+ stream_obj.get_props().is_compatible(stream_b.get_props())):
2032
+ stream_a = stream_obj
2033
+ break
2034
+ if stream_a is None:
2035
+ for stream_obj in stream_dict_a.values():
2036
+ if stream_obj.get_props().is_compatible(stream_b.get_props()):
2037
+ stream_a = stream_obj
2038
+ break
2039
+ else:
2040
+
2041
+ # If both groups have more than a stream, let's give priority to streams with labels to find a match
2042
+ stream_a = None
2043
+ stream_b = None
2044
+ for stream_obj_a in stream_dict_a.values():
2045
+ if not stream_obj_a.get_props().has_tensor_labels():
2046
+ continue
2047
+ if stream_a is not None and stream_b is not None:
2048
+ break
2049
+ for stream_obj_b in stream_dict_b.values():
2050
+ if (stream_obj_b.get_props().has_tensor_labels() and
2051
+ stream_obj_a.get_props().is_compatible(stream_obj_b.get_props())):
2052
+ stream_a = stream_obj_a
2053
+ stream_b = stream_obj_b
2054
+ break
2055
+ if stream_a is None and stream_b is None:
2056
+ for stream_obj_a in stream_dict_a.values():
2057
+ if stream_a is not None and stream_b is not None:
2058
+ break
2059
+ for stream_obj_b in stream_dict_b.values():
2060
+ if stream_obj_a.get_props().is_compatible(stream_obj_b.get_props()):
2061
+ stream_a = stream_obj_a
2062
+ stream_b = stream_obj_b
2063
+ break
2064
+
2065
+ if stream_a is None:
2066
+ self.err(f"Cannot find the data stream to consider in the comparison, {net_hash_a}")
2067
+ return -1., False
2068
+ if stream_b is None:
2069
+ self.err(f"Cannot find the data stream to consider in the comparison, {net_hash_b}")
2070
+ return -1., False
2071
+
2072
+ if not isinstance(stream_a, BufferedDataStream):
2073
+ self.err(f"Can only compare buffered streams and {net_hash_a} is not buffered")
2074
+ return -1., False
2075
+
2076
+ if not isinstance(stream_b, BufferedDataStream):
2077
+ self.err(f"Can only compare buffered streams and {net_hash_b} is not buffered")
2078
+ return -1., False
2079
+
2080
+ if steps > len(stream_a) and steps > len(stream_b):
2081
+ self.err(f"Cannot compare streams for {steps} steps, since both of them are shorter "
2082
+ f"(length of the first stream is {len(stream_a)}, of the second stream is {len(stream_b)})")
2083
+
2084
+ if not stream_a.get_props().is_compatible(stream_b.get_props()):
2085
+ self.err(f"Cannot compare incompatible streams")
2086
+
2087
+ stream_a.restart()
2088
+ stream_b.restart()
2089
+
2090
+ def compare(_a: torch.Tensor | str, _b: torch.Tensor | str, _how: str = "mse") -> float:
2091
+ """Compare two samples of signals or descriptors, returning a dissimilarity score >= 0."""
2092
+
2093
+ assert how in ['mse', 'max', 'same'] or how.startswith("geq"), f"Invalid comparison in terms of {how}"
2094
+
2095
+ if isinstance(_a, torch.Tensor) and isinstance(_b, torch.Tensor):
2096
+ if _a.shape != _b.shape:
2097
+ return 1. # Mismatching
2098
+ if _a.dtype == torch.long and _b.dtype == torch.long: # Token IDS
2099
+ return 1. - float((_a == _b).sum().item()) / _a.numel() # Accuracy
2100
+ elif how == "mse":
2101
+ ret = torch.nn.functional.mse_loss(_a, _b, reduction='mean')
2102
+ elif how == "max":
2103
+ ret = 1. - float((torch.argmax(_a) == torch.argmax(_b)).sum().item())
2104
+ elif how == "same":
2105
+ ret = 1. - float(torch.eq(_a, _b).sum()) / _a.numel()
2106
+ else:
2107
+ thres = float(how[3:])
2108
+ ret = 1. - float(torch.sum((_a > thres) == (_b > thres)).item()) / _a.numel()
2109
+ else:
2110
+ ret = 1. - float(_a == _b) # Strings (always handled as 'same')
2111
+ return ret
2112
+
2113
+ # Comparing data (averaging)
2114
+ o = 0.
2115
+ k_b = 0
2116
+ a_tag_offset = 0
2117
+ b_tag_offset = 0
2118
+ a_tag = None
2119
+ a_tag_prev = None
2120
+ for k_a in range(0, steps):
2121
+
2122
+ restart_detected = False
2123
+ if a_tag is not None:
2124
+ a_tag_prev = a_tag
2125
+
2126
+ # Signals or descriptors
2127
+ a, a_tag = stream_a[k_a]
2128
+ b, b_tag = stream_b[k_b]
2129
+
2130
+ # If the streams do not share the same first tag equal to zero, and we asked to re-offset them,
2131
+ # then we force the initial offsets to be zero on both
2132
+ # if not, then re-offset the tags
2133
+ if k_a == 0 and k_b == 0 and re_offset:
2134
+ a_tag_offset = a_tag
2135
+ b_tag_offset = b_tag
2136
+
2137
+ # Offset-based tags
2138
+ a_tag_w_offset = a_tag - a_tag_offset
2139
+ b_tag_w_offset = b_tag - b_tag_offset
2140
+
2141
+ # Checking
2142
+ if a is None:
2143
+ self.err("Cannot compare stream samples if the reference stream yields None")
2144
+ return -1., False
2145
+
2146
+ # Some streams might have been pre-buffered in advance, and have increasing data tags belonging to finite,
2147
+ # fixed set (such as 0, 1, 2, ..., N). when continuously streaming them, we will go from tag N to tag 0 at
2148
+ # a certain point, which is a "restart".
2149
+ # We have to remember that this happened, and we do it for stream "a", our "reference" stream.
2150
+ # Then, below, we will fix tags on stream "b" if needed, considering that such a restart happened.
2151
+ if a_tag_prev is not None and a_tag < a_tag_prev:
2152
+ restart_detected = True
2153
+
2154
+ # Some streams might have been pre-buffered in advance, and have a fixed data tag (usually -1).
2155
+ # Being it negative, it will happen that the data tag will be replaced by a clock cycle, but this function
2156
+ # does not change clock cycles at all, so all samples will have the exact same data tag.
2157
+ # The following code automatically advances the tag by 1 for stream "a", that is expected to be the
2158
+ # reference stream (i.e., the one for which the agent has all samples, with no missing data in between)
2159
+ if a_tag_prev is not None and a_tag <= a_tag_prev:
2160
+ a_tag = a_tag_prev + 1 # Fixed tag detected (patching)
2161
+ a_tag_w_offset = a_tag - a_tag_offset
2162
+
2163
+ # Fixing
2164
+ if b is None:
2165
+ o = o + (1. if how != "mse" else (o / steps) * 1.1)
2166
+ self.print(f"Comparing: the second stream yields None")
2167
+ else:
2168
+ if b_tag_w_offset == a_tag_w_offset:
2169
+ o += compare(a, b, how)
2170
+ k_b += 1
2171
+ self.print(f"Comparing tags: {a_tag} vs {b_tag}, samples: {a} vs {b}")
2172
+ elif b_tag_w_offset > a_tag_w_offset:
2173
+ if not restart_detected:
2174
+ o = o + (1. if how != "mse" else (o / steps) * 1.1) # Don't change k_b, some samples missing
2175
+ self.print(f"Comparing tags: {a_tag} vs {b_tag} -> "
2176
+ f"expected one was missing, samples: {a} vs {b}")
2177
+ else:
2178
+ o = o + (1. if how != "mse" else (o / steps) * 1.1)
2179
+ self.print(f"Comparing tags: {a_tag} vs {b_tag} -> "
2180
+ f"expected one was missing, samples: {a} vs {b}")
2181
+ k_b += 1 # A restart was detected, it means that "stream_b" is behind, let's move it ahead
2182
+ elif b_tag_w_offset < a_tag_w_offset:
2183
+ self.print(f"Comparing tags: {a_tag} vs {b_tag} -> too early w.r.t. expected, samples: {a} vs {b}")
2184
+ return -1., False
2185
+
2186
+ self.print(f"Comparing error: {o / steps}")
2187
+
2188
+ # Input("*** press enter to continue ***")
2189
+ return o / steps, True
2190
+
2191
+ def __involved_agents(self, agent: str | None | list[str]):
2192
+ """A private helper method that resolves an agent ID or a wildcard into a list of specific peer IDs.
2193
+ It can resolve a single agent, a group of agents that passed a previous comparison (`<valid_cmp>`), or all
2194
+ currently engaged agents.
2195
+
2196
+ Args:
2197
+ agent: The agent ID or wildcard string.
2198
+
2199
+ Returns:
2200
+ A list of peer IDs corresponding to the involved agents.
2201
+ """
2202
+ if isinstance(agent, list):
2203
+ return agent
2204
+ peer_id = agent
2205
+ engaged_or_found = (
2206
+ self._engaged_agents) if len(self._engaged_agents) > 0 else self._found_agents
2207
+ involved_agents = [peer_id] if peer_id is not None and peer_id != "<valid_cmp>" else (
2208
+ self._valid_cmp_agents) if peer_id is not None and peer_id == "<valid_cmp>" else engaged_or_found
2209
+ if len(involved_agents) == 0:
2210
+ self.err("Not engaged to any agents, no previously searched agent, or no agent specified")
2211
+ return involved_agents
2212
+
2213
+ def __normalize_user_hash(self, net_hashes: list[str] | None) -> list[str]:
2214
+ if net_hashes is None:
2215
+ return []
2216
+ net_hashes_copy = []
2217
+ for i in range(len(net_hashes)):
2218
+ if net_hashes[i] == "<playlist>":
2219
+
2220
+ # From <playlist> to the current element of the playlist
2221
+ net_hashes_copy.append(self._preferred_streams[self._cur_preferred_stream])
2222
+ else:
2223
+
2224
+ # From user specified hash to a net hash (e.g., peer_id:name_or_group to peer_id::ps:name_or_group)
2225
+ net_hashes_copy.append(self.user_stream_hash_to_net_hash(net_hashes[i]))
2226
+ return net_hashes_copy