unaiverse 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. unaiverse/__init__.py +19 -0
  2. unaiverse/agent.py +2226 -0
  3. unaiverse/agent_basics.py +2389 -0
  4. unaiverse/clock.py +234 -0
  5. unaiverse/dataprops.py +1282 -0
  6. unaiverse/hsm.py +2471 -0
  7. unaiverse/modules/__init__.py +18 -0
  8. unaiverse/modules/cnu/__init__.py +17 -0
  9. unaiverse/modules/cnu/cnus.py +536 -0
  10. unaiverse/modules/cnu/layers.py +261 -0
  11. unaiverse/modules/cnu/psi.py +60 -0
  12. unaiverse/modules/hl/__init__.py +15 -0
  13. unaiverse/modules/hl/hl_utils.py +411 -0
  14. unaiverse/modules/networks.py +1509 -0
  15. unaiverse/modules/utils.py +748 -0
  16. unaiverse/networking/__init__.py +16 -0
  17. unaiverse/networking/node/__init__.py +18 -0
  18. unaiverse/networking/node/connpool.py +1332 -0
  19. unaiverse/networking/node/node.py +2752 -0
  20. unaiverse/networking/node/profile.py +446 -0
  21. unaiverse/networking/node/tokens.py +79 -0
  22. unaiverse/networking/p2p/__init__.py +188 -0
  23. unaiverse/networking/p2p/go.mod +127 -0
  24. unaiverse/networking/p2p/go.sum +548 -0
  25. unaiverse/networking/p2p/golibp2p.py +18 -0
  26. unaiverse/networking/p2p/golibp2p.pyi +136 -0
  27. unaiverse/networking/p2p/lib.go +2765 -0
  28. unaiverse/networking/p2p/lib_types.py +311 -0
  29. unaiverse/networking/p2p/message_pb2.py +50 -0
  30. unaiverse/networking/p2p/messages.py +360 -0
  31. unaiverse/networking/p2p/mylogger.py +78 -0
  32. unaiverse/networking/p2p/p2p.py +900 -0
  33. unaiverse/networking/p2p/proto-go/message.pb.go +846 -0
  34. unaiverse/stats.py +1506 -0
  35. unaiverse/streamlib/__init__.py +15 -0
  36. unaiverse/streamlib/streamlib.py +210 -0
  37. unaiverse/streams.py +804 -0
  38. unaiverse/utils/__init__.py +16 -0
  39. unaiverse/utils/lone_wolf.json +28 -0
  40. unaiverse/utils/misc.py +441 -0
  41. unaiverse/utils/sandbox.py +292 -0
  42. unaiverse/world.py +384 -0
  43. unaiverse-0.1.12.dist-info/METADATA +366 -0
  44. unaiverse-0.1.12.dist-info/RECORD +47 -0
  45. unaiverse-0.1.12.dist-info/WHEEL +5 -0
  46. unaiverse-0.1.12.dist-info/licenses/LICENSE +177 -0
  47. unaiverse-0.1.12.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2389 @@
1
+ """
2
+ █████ █████ ██████ █████ █████ █████ █████ ██████████ ███████████ █████████ ██████████
3
+ ░░███ ░░███ ░░██████ ░░███ ░░███ ░░███ ░░███ ░░███░░░░░█░░███░░░░░███ ███░░░░░███░░███░░░░░█
4
+ ░███ ░███ ░███░███ ░███ ██████ ░███ ░███ ░███ ░███ █ ░ ░███ ░███ ░███ ░░░ ░███ █ ░
5
+ ░███ ░███ ░███░░███░███ ░░░░░███ ░███ ░███ ░███ ░██████ ░██████████ ░░█████████ ░██████
6
+ ░███ ░███ ░███ ░░██████ ███████ ░███ ░░███ ███ ░███░░█ ░███░░░░░███ ░░░░░░░░███ ░███░░█
7
+ ░███ ░███ ░███ ░░█████ ███░░███ ░███ ░░░█████░ ░███ ░ █ ░███ ░███ ███ ░███ ░███ ░ █
8
+ ░░████████ █████ ░░█████░░████████ █████ ░░███ ██████████ █████ █████░░█████████ ██████████
9
+ ░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░ ░░░░░ ░░░ ░░░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░░ ░░░░░░░░░░
10
+ A Collectionless AI Project (https://collectionless.ai)
11
+ Registration/Login: https://unaiverse.io
12
+ Code Repositories: https://github.com/collectionlessai/
13
+ Main Developers: Stefano Melacci (Project Leader), Christian Di Maio, Tommaso Guidi
14
+ """
15
+ import os
16
+ import torch
17
+ import types
18
+ import pickle
19
+ import uuid as _uuid
20
+ import importlib.resources
21
+ from PIL.Image import Image
22
+ from unaiverse.stats import Stats
23
+ from unaiverse.clock import Clock
24
+ from collections.abc import Callable
25
+ from unaiverse.networking.p2p.messages import Msg
26
+ from unaiverse.dataprops import DataProps, Data4Proc
27
+ from unaiverse.networking.node.profile import NodeProfile
28
+ from unaiverse.utils.misc import GenException, FileTracker
29
+ from unaiverse.streams import BufferedDataStream, DataStream
30
+ from unaiverse.networking.node.connpool import ConnectionPools
31
+ from unaiverse.hsm import HybridStateMachine, Action, ActionRequest
32
+ from unaiverse.modules.utils import AgentProcessorChecker, ModuleWrapper
33
+
34
+
35
+ class AgentBasics:
36
+ """This class contains those methods and properties that are about building the agent, known agents,
37
+ known streams, etc., and no actions at all (see the class "Agent" for actions)."""
38
+
39
+ DEBUG = True # Turns on/off extra logging
40
+
41
+ # Role bits (a.k.a. role int): default roles, shared by every possible agent
42
+ ROLE_PUBLIC = 0 << 0 # 00000000 = 0 means "public"
43
+ ROLE_WORLD_MASTER = (1 << 0) | (1 << 1) # 00000011 = 3 means "world master" (the first bit means "about world")
44
+ ROLE_WORLD_AGENT = (1 << 0) | (0 << 1) # 00000001 = 2 means "world agent" (the first bit means "about world")
45
+ CUSTOM_ROLES = []
46
+
47
+ # From role bits (int) to string
48
+ ROLE_BITS_TO_STR = {
49
+ ROLE_PUBLIC: "public_agent",
50
+ ROLE_WORLD_MASTER: "world_master",
51
+ ROLE_WORLD_AGENT: "world_agent",
52
+ }
53
+
54
+ # From role string to bits (int)
55
+ ROLE_STR_TO_BITS = {
56
+ "public_agent": ROLE_PUBLIC,
57
+ "world_master": ROLE_WORLD_MASTER,
58
+ "world_agent": ROLE_WORLD_AGENT,
59
+ }
60
+
61
+ # Types of badges
62
+ BADGE_TYPES = {'completed', 'attended', 'intermediate', 'pro'}
63
+
64
+ # The type associated to a human: it is not exploited at a node-level, only at agent level
65
+ HUMAN = "human" # Human agent
66
+
67
+ def __init__(self,
68
+ proc: ModuleWrapper | torch.nn.Module | None,
69
+ proc_inputs: list[Data4Proc] | None = None,
70
+ proc_outputs: list[Data4Proc] | None = None,
71
+ proc_opts: dict | None = None,
72
+ behav: HybridStateMachine | None = None,
73
+ behav_lone_wolf: HybridStateMachine | str = "serve",
74
+ merge_flat_stream_labels: bool = False,
75
+ buffer_generated: bool = False,
76
+ buffer_generated_by_others: str = "none",
77
+ world_folder: str | None = None,
78
+ policy_filter: callable = None,
79
+ policy_filter_lone_wolf: callable = None):
80
+ """Create a new agent.
81
+
82
+ Args:
83
+ proc: The processing module (e.g., a neural network) for the agent. Can be None or "default".
84
+ proc_inputs: list of DataProps defining the expected inputs for processor (if None it will be guessed).
85
+ proc_outputs: list of DataProps defining the expected outputs from processor (if None it will be guessed).
86
+ proc_opts: A dictionary of options for the processor.
87
+ behav: The HybridStateMachine that describes the agent's behavior when joining a world.
88
+ behav_lone_wolf: The HybridStateMachine that describes the agent's behavior when in the public net
89
+ (it can also be a string "serve" or "ask", that will load pre-designed HSMs).
90
+ merge_flat_stream_labels: If True, merges flat stream labels across all owned streams.
91
+ buffer_generated: If True, generated streams will be buffered.
92
+ buffer_generated_by_others: If set to "one" or "aòò", streams generated by other agents will be buffered
93
+ ("one" per peer or "all"). If set to "none", no buffering will happen (default).
94
+ world_folder: World only. Folder where the world data is (role files, represented by *.json behavior files).
95
+ policy_filter: The name of a method of the Agent class or a function that implements a
96
+ policy filtering function, overriding what the action-selection-policy decided (when in a world).
97
+ policy_filter_lone_wolf: Same as policy_filter, but for the public network.
98
+ """
99
+
100
+ # Agent-related features
101
+ self.behav = behav # HSM that describes the agent behavior in the private/world net
102
+ self.behav_lone_wolf = behav_lone_wolf # HSM that describes the agent behavior in the public net
103
+ self.behav_wildcards = {}
104
+ self.proc = proc
105
+ self.proc_updated_since_last_save = False
106
+ self.proc_inputs = proc_inputs
107
+ self.proc_outputs = proc_outputs
108
+ self.proc_opts = proc_opts
109
+ self.proc_last_inputs = None
110
+ self.proc_last_outputs = None
111
+ self.proc_optional_inputs = None
112
+ self.proc_net_hash = {'public': None, 'private': None}
113
+ self.proc_in_net_hash = {'public': None, 'private': None}
114
+ self.merge_flat_stream_labels = merge_flat_stream_labels
115
+ self.buffer_generated = buffer_generated
116
+ self.buffer_generated_by_others = buffer_generated_by_others
117
+ self.world_folder = world_folder
118
+ self.policy_filter = policy_filter
119
+ self.policy_filter_opts = {}
120
+ self.policy_filter_lone_wolf = policy_filter_lone_wolf
121
+ self.policy_filter_lone_wolf_opts = {}
122
+
123
+ if self.buffer_generated_by_others not in {"one", "all", "none"}:
124
+ raise GenException("Param buffer_generated_by_others can be set to 'one', 'all', or 'none' only.")
125
+
126
+ # Streams
127
+ self.known_streams = {} # All streams that are known to this agent
128
+ self.owned_streams = {} # The streams that are generated/offered by this agent
129
+ self.env_streams = {} # The owned streams that come from environmental sources (e.g., a camera)
130
+ self.proc_streams = {} # The owned streams that are generated by the agent's processor
131
+ self.compat_in_streams = set() # Streams compatible with the processor input (dynamically set)
132
+ self.compat_out_streams = set() # Streams compatible with the processor output (dynamically set)
133
+
134
+ # Agents, world masters, expected world masters
135
+ self.all_agents = {} # ID -> profile (all types of agent)
136
+ self.public_agents = {} # ID -> profile of lone wolves talking to this world in a public manner (profile)
137
+ self.world_agents = {} # ID -> profile of all agents living in this world (profile)
138
+ self.world_masters = {} # ID -> profile of all master-agents living in this world (profile)
139
+ self.human_agents = {} # ID -> profile (human agents)
140
+ self.artificial_agents = {} # ID -> profile (artificial agent)
141
+ self.world_profile = None
142
+ self.is_world = False # If this instance is about a world: it will be discovered at creation time
143
+
144
+ # World specific attributes (they are only used if this agent is actually a world)
145
+ self.agent_actions = None
146
+ self.role_to_behav = {}
147
+ self.agent_badges: dict[str, list[dict]] = {} # Peer_id -> collected badges for other agents
148
+ self.role_changed_by_world: bool = False
149
+ self.received_address_update: bool = False
150
+
151
+ # Internal properties about the way streams are used
152
+ self.last_buffered_peer_id_to_info = {} # If buffering was turned on
153
+ self.last_ref_uuid = None
154
+ self.recipients = {} # The peer IDs of the recipients of the next batch of direct messages
155
+ self.overridden_action_step = None
156
+ self.locked_set_proc_input = False
157
+
158
+ # Stats
159
+ self.stats: Stats | None = None
160
+ self.agent_stats_code = None
161
+
162
+ # Information inherited from the node that hosts this agent
163
+ self._node_name = "unk"
164
+ self._node_clock = None
165
+ self._node_conn = None
166
+ self._node_profile = None
167
+ self._node_out_fcn = print
168
+ self._node_ask_to_get_in_touch_fcn = None
169
+ self._node_purge_fcn = None
170
+ self._node_agents_waiting = None
171
+ self._node_identity_dir = ''
172
+ self._debug_flag = False
173
+ self._basic_print_on = True
174
+
175
+ # Checking
176
+ if not (self.proc is None or
177
+ (isinstance(self.proc, torch.nn.Module) or (isinstance(self.proc, str) and self.proc == "default"))):
178
+ raise GenException("Invalid data processor: it must be either the string 'default' or a torch.nn.module")
179
+ if not (self.behav is None or isinstance(self.behav, HybridStateMachine)):
180
+ raise GenException("Invalid behavior: it must be either None or a HybridStateMachine")
181
+
182
+ # Filling (guessing) missing processor-related info (proc_inputs and proc_outputs)
183
+ # and allocating a dummy processor if it was not specified (if None)
184
+ AgentProcessorChecker(self)
185
+
186
+ # The stream_hash of compatible streams for each data_props are stored in a set
187
+ self.compat_in_streams = [set() for _ in range(len(self.proc_inputs))] \
188
+ if self.proc_inputs is not None else None
189
+ self.compat_out_streams = [set() for _ in range(len(self.proc_outputs))] \
190
+ if self.proc_outputs is not None else None
191
+
192
+ # Loading default public HSM
193
+ if hasattr(self, "do_gen"): # Trick to distinguish if this is an Agent or a World (both sons of this class)
194
+ self.is_world = False
195
+
196
+ # Setting an empty HSM as default is not provided (private/world)
197
+ if self.behav is None:
198
+ self.behav = HybridStateMachine(self, policy=self.policy_default)
199
+ self.behav.add_state("empty")
200
+
201
+ if self.behav_lone_wolf is not None and isinstance(self.behav_lone_wolf, str):
202
+ template_string = self.behav_lone_wolf
203
+ if template_string == "serve":
204
+ json_to_load = "lone_wolf.json"
205
+ elif template_string == "ask":
206
+ json_to_load = "lone_wolf.json"
207
+ else:
208
+ raise ValueError("Invalid behav_lone_wolf: it must be an HybridStateMachine or a string "
209
+ "in ('serve', 'ask')")
210
+
211
+ # Safe way to load a file packed in a pip package
212
+ self.behav_lone_wolf = HybridStateMachine(self, policy=self.policy_default)
213
+ utils_path = importlib.resources.files("unaiverse.utils")
214
+ json_file = utils_path.joinpath(json_to_load)
215
+ file = json_file.open()
216
+ self.behav_lone_wolf.load(file)
217
+ file.close()
218
+ self.set_policy_filter(self.policy_filter_lone_wolf, public=True)
219
+ else:
220
+ self.is_world = True
221
+ if self.world_folder is None:
222
+ raise GenException("No world folder was indicated (world_folder argument)")
223
+
224
+ def set_node_info(self, clock: Clock, conn: ConnectionPools, profile: NodeProfile,
225
+ out_fcn, ask_to_get_in_touch_fcn, purge_fcn, node_identity_dir: str,
226
+ agents_waiting, print_level):
227
+ """Set the required information from the node that hosts this agent.
228
+
229
+ Args:
230
+ clock: The global clock instance from the node.
231
+ conn: The connection pool manager from the node.
232
+ profile: The profile of the hosting node.
233
+ out_fcn: The function to use for general output messages.
234
+ ask_to_get_in_touch_fcn: The function to call to request getting in touch with another peer.
235
+ purge_fcn: The function to call to purge (kill/disconnect) a connection.
236
+ node_identity_dir: The folder where the node identity files are stored.
237
+ agents_waiting: Set of agents that connected to this node but have not been evaluated yet to be added.
238
+ print_level: The level of output printing verbosity (0, 1, 2).
239
+ """
240
+
241
+ # Getting basic references
242
+ self._node_clock = clock
243
+ self._node_conn = conn
244
+ self._node_profile = profile
245
+ self._node_name = profile.get_static_profile()['node_name']
246
+ self._node_out_fcn = out_fcn
247
+ self._node_ask_to_get_in_touch_fcn = ask_to_get_in_touch_fcn
248
+ self._node_purge_fcn = purge_fcn
249
+ self._node_identity_dir = node_identity_dir
250
+ self._node_agents_waiting = agents_waiting
251
+ self._debug_flag = print_level > 1
252
+
253
+ # Adding peer_id information into the already existing stream data (if any)
254
+ # (initially marked with generic wildcards like <public_peer_id>, ...)
255
+ net_hashes = list(self.known_streams.keys())
256
+ for net_hash in net_hashes:
257
+ if net_hash.startswith("<public_peer_id>") or net_hash.startswith("<private_peer_id>"):
258
+ stream_dict = self.known_streams[net_hash]
259
+ for stream_obj in stream_dict.values():
260
+ self.add_stream(stream_obj, owned=True) # This will also re-add streams using the node clock
261
+
262
+ # Removing place-holder streams
263
+ for peer_id in ["<public_peer_id>", "<private_peer_id>"]:
264
+ to_remove = []
265
+ for net_hash in self.known_streams.keys():
266
+ if DataProps.peer_id_from_net_hash(net_hash) == peer_id:
267
+ for _name, _stream in self.known_streams[net_hash].items():
268
+ to_remove.append((net_hash, _name))
269
+
270
+ # Removing
271
+ for (net_hash, name) in to_remove:
272
+ del self.known_streams[net_hash][name]
273
+ if len(self.known_streams[net_hash]) == 0:
274
+ del self.known_streams[net_hash]
275
+
276
+ # Removing all the owned streams (environment and processor streams are of course "owned")
277
+ if net_hash in self.owned_streams:
278
+ if name in self.owned_streams[net_hash]:
279
+ del self.owned_streams[net_hash][name]
280
+ if len(self.owned_streams[net_hash]) == 0:
281
+ del self.owned_streams[net_hash]
282
+ if net_hash in self.env_streams:
283
+ if name in self.env_streams[net_hash]:
284
+ del self.env_streams[net_hash][name]
285
+ if len(self.env_streams[net_hash]) == 0:
286
+ del self.env_streams[net_hash]
287
+ if net_hash in self.proc_streams:
288
+ if name in self.proc_streams[net_hash]:
289
+ del self.proc_streams[net_hash][name]
290
+ if len(self.proc_streams[net_hash]) == 0:
291
+ del self.proc_streams[net_hash]
292
+ self.out(f"Successfully removed known stream with network hash {net_hash}, stream name: {name}")
293
+
294
+ # World only: loading action files and refactoring (or building) JSON files of the different roles.
295
+ # This where the world guesses roles.
296
+ if self.is_world:
297
+
298
+ # Check role-JSON files in the world folder
299
+ role_json_tracker = FileTracker(self.world_folder, ext=".json")
300
+
301
+ # This usually does nothing, but if you like to dynamically create JSON files, overload this method
302
+ self.create_behav_files()
303
+
304
+ # Loading and refactoring roles and behaviors
305
+ self.load_and_refactor_action_file_and_behav_files(force_save=role_json_tracker.something_changed())
306
+
307
+ # Building combination of default roles (considering public, world_agent, world_master default roles), and
308
+ # agent/world specific roles
309
+ self.augment_roles()
310
+
311
+ # Loading the custom Stats code
312
+ if self.world_folder is not None:
313
+ stats_file = os.path.join(self.world_folder, 'stats.py')
314
+ if os.path.exists(stats_file):
315
+ self.out(f"Found custom stats.py at {stats_file}")
316
+ try:
317
+ with open(stats_file, 'r', encoding='utf-8') as file:
318
+ self.agent_stats_code = file.read()
319
+ except Exception as e:
320
+ raise GenException(f'Error while reading/loading the stats.py file: {stats_file} [{e}]')
321
+
322
+ # Creating streams associated to the processor input (right now we assume there is no need to buffer them)
323
+ self.create_proc_input_streams(buffered=False)
324
+
325
+ # Creating streams associated to the processor output
326
+ self.create_proc_output_streams(buffered=self.buffer_generated)
327
+
328
+ # Updating node profile by indicating the processor-related streams
329
+ self.update_streams_in_profile()
330
+
331
+ # Print level
332
+ AgentBasics.debug_printing(self._debug_flag)
333
+ return True
334
+
335
+ @staticmethod
336
+ def debug_printing(on: bool = False):
337
+ Stats.DEBUG = on
338
+ AgentBasics.DEBUG = on
339
+ ConnectionPools.DEBUG = on
340
+ HybridStateMachine.DEBUG = on
341
+
342
+ @staticmethod
343
+ def get_hsm_debug_state():
344
+ return HybridStateMachine.DEBUG
345
+
346
+ @staticmethod
347
+ def set_hsm_debug_state(on: bool):
348
+ HybridStateMachine.DEBUG = on
349
+
350
+ def get_proc_output_net_hash(self, public: bool = True):
351
+ return self.proc_net_hash['public'] if public else self.proc_net_hash['private']
352
+
353
+ def get_proc_input_net_hash(self, public: bool = True):
354
+ return self.proc_in_net_hash['public'] if public else self.proc_in_net_hash['private']
355
+
356
+ @staticmethod
357
+ def generate_uuid():
358
+ return _uuid.uuid4().hex[0:8]
359
+
360
+ def augment_roles(self):
361
+ """Augment the custom roles (role1, role2, etc.) with the default ones (public, world_master, etc.), generating
362
+ all the mixed roles (world_master~role1, world_master~role2, ...)"""
363
+
364
+ # Both Agent and World: Fusing basic roles and custom roles
365
+ if len(self.CUSTOM_ROLES) > 0:
366
+ if len(self.CUSTOM_ROLES) > 30: # Safe value, could be increased
367
+ raise GenException("Maximum number of custom role overcame (max is 30)")
368
+ for i, role_str in enumerate(self.CUSTOM_ROLES):
369
+ role_int = 1 << (i + 2) # 000000100, then 00001000, etc. (recall that the first two bits are reserved)
370
+ self.ROLE_BITS_TO_STR[role_int] = role_str
371
+ self.ROLE_STR_TO_BITS[role_str] = role_int
372
+
373
+ # Both Agent and World: Augmenting roles
374
+ roles_not_to_be_augmented = {self.ROLE_PUBLIC, self.ROLE_WORLD_AGENT, self.ROLE_WORLD_MASTER}
375
+ role_bits_to_str_original = {k: v for k, v in self.ROLE_BITS_TO_STR.items()}
376
+ for role_int, role_str in role_bits_to_str_original.items():
377
+ if role_int not in roles_not_to_be_augmented and "~" not in role_str:
378
+ for role_base_int in {self.ROLE_WORLD_AGENT, self.ROLE_WORLD_MASTER}:
379
+ augmented_role_int = role_base_int | role_int
380
+ augmented_role_str = self.ROLE_BITS_TO_STR[role_base_int] + "~" + role_str
381
+ if augmented_role_str not in self.ROLE_STR_TO_BITS:
382
+ self.ROLE_STR_TO_BITS[augmented_role_str] = augmented_role_int
383
+ self.ROLE_BITS_TO_STR[augmented_role_int] = augmented_role_str
384
+
385
+ async def clear_world_related_data(self):
386
+ """Destroy all the cached information that is about a world (useful when leaving a world) (async)."""
387
+
388
+ # Clearing status variables
389
+ self.reset_agent_status_attrs()
390
+
391
+ # Clear/reset
392
+ await self.__remove_all_world_private_streams()
393
+ await self.__remove_all_world_related_agents()
394
+ self._node_conn.reset_rendezvous_tag()
395
+
396
+ def load_and_refactor_action_file_and_behav_files(self, force_save: bool = False):
397
+ """This method is called when building a world object. It loads the behavior files and refactors them.
398
+ It loads the action file agent.py. It checks consistency between the agent action files agent.py and the roles
399
+ in the behavior files.
400
+
401
+ Args:
402
+ force_save: Boolean to force the saving of the JSON and of a "pdf" folder with the PDFs of the state
403
+ machines.
404
+ """
405
+
406
+ # World only: the world discovers CUSTOM_ROLES from the JSON files in the world folder
407
+ if self.world_folder is not None and self.is_world:
408
+
409
+ # Guessing roles from the list of json files
410
+ self.CUSTOM_ROLES = [os.path.splitext(f)[0] for f in os.listdir(self.world_folder)
411
+ if os.path.isfile(os.path.join(self.world_folder, f))
412
+ and f.lower().endswith(".json")]
413
+ if len(self.CUSTOM_ROLES) == 0:
414
+ raise GenException(f"No world-role files (*.json) were found in the world folder {self.world_folder}")
415
+
416
+ # Default behaviours (getting roles, that are the names of the files with extension "json")
417
+ default_behav_files = [os.path.join(self.world_folder, f) for f in os.listdir(self.world_folder)
418
+ if os.path.isfile(os.path.join(self.world_folder, f)) and
419
+ f.lower().endswith(".json")]
420
+
421
+ # Loading action file
422
+ action_file = os.path.join(self.world_folder, 'agent.py')
423
+ try:
424
+ with open(action_file, 'r', encoding='utf-8') as file:
425
+ self.agent_actions = file.read()
426
+ except Exception as e:
427
+ raise GenException(f'Error while reading the agent.py file: {action_file} [{e}]')
428
+
429
+ # Creating a dummy agent which supports the actions of the following state machines
430
+ mod = types.ModuleType("dynamic_module")
431
+ try:
432
+ exec(self.agent_actions, mod.__dict__)
433
+ dummy_agent = mod.WAgent(proc=None)
434
+ dummy_agent.CUSTOM_ROLES = self.CUSTOM_ROLES
435
+ except Exception as e:
436
+ raise GenException(f'Unable to create a valid agent object from the agent action file '
437
+ f'{action_file} [{e}]')
438
+
439
+ # Checking if the roles you wrote in agent.py are coherent with the JSON files in this folder
440
+ if dummy_agent.CUSTOM_ROLES != self.CUSTOM_ROLES:
441
+ raise GenException(f"Mismatching roles. "
442
+ f"Roles in JSON files: {self.CUSTOM_ROLES}. "
443
+ f"Roles specified in the agent.py file: {dummy_agent.CUSTOM_ROLES}")
444
+
445
+ # Loading and refactoring behaviors
446
+ for role, default_behav_file in zip(self.CUSTOM_ROLES, default_behav_files):
447
+ try:
448
+ behav = HybridStateMachine(dummy_agent)
449
+ behav.load(default_behav_file)
450
+ self.role_to_behav[role] = str(behav)
451
+
452
+ # Adding roles and machines to profile
453
+ self._node_profile.get_dynamic_profile()['world_roles_fsm'] = self.role_to_behav
454
+ except Exception as e:
455
+ raise GenException(f'Error while loading or handling '
456
+ f'behav file {default_behav_file} for role {role} [{e}]')
457
+
458
+ # Refactoring and saving PDF
459
+ try:
460
+ if (force_save or
461
+ behav.save(os.path.join(self.world_folder, f'{role}.json'), only_if_changed=dummy_agent)):
462
+ os.makedirs(os.path.join(self.world_folder, 'pdf'), exist_ok=True)
463
+ behav.save_pdf(os.path.join(self.world_folder, 'pdf', f'{role}.pdf'))
464
+ except Exception as e:
465
+ raise GenException(f'Error while saving the behav file {default_behav_file} for role {role} [{e}]')
466
+
467
+ def create_behav_files(self):
468
+ """This method is called when building a world object. In your custom world-class, you can overload this method
469
+ and create the JSON files with the role-related behaviors, if you like. Recall that acting like this is not
470
+ mandatory at all: you can just manually create the JSON files, and this method will simply do nothing."""
471
+ pass
472
+
473
+ def out(self, msg: str):
474
+ """Print a message to the console, if enabled at node level (it reuses the node-out-function).
475
+
476
+ Args:
477
+ msg: The message string to print.
478
+ """
479
+ self._node_out_fcn(msg)
480
+
481
+ def err(self, msg: str):
482
+ """Print an error message to the console, if enabled at node level (it reuses the node-err-function).
483
+
484
+ Args:
485
+ msg: The error message string to print.
486
+ """
487
+ self.out("<ERROR> " + msg)
488
+
489
+ def print(self, msg: str):
490
+ """Print a message to the console, no matter what.
491
+
492
+ Args:
493
+ msg: The message string to print.
494
+ """
495
+ if self._basic_print_on:
496
+ print(msg)
497
+
498
+ def deb(self, msg: str):
499
+ """Print an error message to the console, if debug is enabled for this agent (it reuses the agent-out-function).
500
+
501
+ Args:
502
+ msg: The error message string to print.
503
+ """
504
+ if AgentBasics.DEBUG:
505
+ self.out("[DEBUG " + ("AGENT" if not self.is_world else "WORLD") + "] " + msg)
506
+
507
+ def get_name(self) -> str:
508
+ """Returns the name of the agent or world from the node's profile.
509
+
510
+ Args:
511
+ None.
512
+
513
+ Returns:
514
+ The name of the agent or world.
515
+ """
516
+ return self._node_name
517
+
518
+ def get_profile(self) -> NodeProfile:
519
+ """Returns the profile of the node hosting this agent/world.
520
+
521
+ Returns:
522
+ The NodeProfile of this node.
523
+ """
524
+ return self._node_profile
525
+
526
+ def get_current_role(self, return_int: bool = False, ignore_base_role: bool = True) -> str | int | None:
527
+ """Returns the current role of the agent.
528
+
529
+ Args:
530
+ return_int: If True, returns the integer representation of the role.
531
+ ignore_base_role: If True, returns only the specific role part, not the base.
532
+
533
+ Returns:
534
+ The role as a string or integer, or None if the agent is not living in any worlds.
535
+ """
536
+ if self.in_world():
537
+ role_str = self._node_profile.get_dynamic_profile()['connections']['role']
538
+ if ignore_base_role:
539
+ role_str = role_str.split("~")[-1]
540
+ if not return_int:
541
+ return role_str
542
+ else:
543
+ return self.ROLE_STR_TO_BITS[role_str]
544
+ else:
545
+ return None
546
+
547
+ async def add_agent(self, peer_id: str, profile: NodeProfile) -> bool:
548
+ """Add a new known agent (async).
549
+
550
+ Args:
551
+ peer_id: The unique identifier of the peer.
552
+ profile: The NodeProfile object containing the peer's/agent's information.
553
+
554
+ Returns:
555
+ True if the agent was successfully added, False otherwise.
556
+ """
557
+
558
+ # If the agent was already there, we remove it and add it again (in case of changes)
559
+ await self.remove_agent(peer_id) # It has no effects if the agent is not existing
560
+
561
+ # Guessing the type of agent to add (accordingly to the default roles shared by every agent)
562
+ role = self._node_conn.get_role(peer_id)
563
+ self.all_agents[peer_id] = profile
564
+ if role & 1 == self.ROLE_PUBLIC:
565
+ self.public_agents[peer_id] = profile
566
+ public = True
567
+ elif role & 3 == self.ROLE_WORLD_AGENT:
568
+ self.world_agents[peer_id] = profile
569
+ public = False
570
+ elif role & 3 == self.ROLE_WORLD_MASTER:
571
+ self.world_masters[peer_id] = profile
572
+ public = False
573
+ else:
574
+ self.err(f"Cannot add agent with peer ID {peer_id} - unknown role: {role}")
575
+ return False
576
+
577
+ # Human or artificial?
578
+ if profile.get_static_profile()["node_type"] == AgentBasics.HUMAN:
579
+ self.human_agents[peer_id] = profile
580
+ else:
581
+ self.artificial_agents[peer_id] = profile
582
+
583
+ # Check compatibility of the streams owned by the agent we are adding with our-agent's processor
584
+ if self.proc_outputs is not None and self.proc_inputs is not None:
585
+
586
+ # Check compatibility of the environmental streams of the agent we are adding with our-agent's processor
587
+ environmental_streams = profile.get_dynamic_profile()['streams']
588
+ if (environmental_streams is not None and
589
+ not (await self.add_compatible_streams(peer_id, environmental_streams,
590
+ buffered=False,
591
+ public=public))): # This will also "add" the stream
592
+ return False
593
+
594
+ # Check compatibility of the generated streams of the agent we are adding with our-agent's processor
595
+ proc_streams = profile.get_dynamic_profile()['proc_outputs']
596
+ if (proc_streams is not None and
597
+ not (await self.add_compatible_streams(peer_id, profile.get_dynamic_profile()['proc_outputs'],
598
+ buffered=False,
599
+ public=public))): # This will also "add" the stream
600
+ return False
601
+
602
+ self.out(f"Successfully added agent with peer ID {peer_id} (public: {public})")
603
+ return True
604
+
605
+ async def remove_agent(self, peer_id: str):
606
+ """Remove an agent (async).
607
+
608
+ Args:
609
+ peer_id: The unique identifier of the peer to remove.
610
+ """
611
+ if peer_id in self.all_agents:
612
+
613
+ # Removing from agent list
614
+ del self.all_agents[peer_id]
615
+ if peer_id in self.world_agents:
616
+ del self.world_agents[peer_id]
617
+ elif peer_id in self.world_masters:
618
+ del self.world_masters[peer_id]
619
+ elif peer_id in self.public_agents:
620
+ del self.public_agents[peer_id]
621
+
622
+ if peer_id in self.artificial_agents:
623
+ del self.artificial_agents[peer_id]
624
+ elif peer_id in self.human_agents:
625
+ del self.human_agents[peer_id]
626
+
627
+ # Clearing from the list of processor-input-compatible-streams
628
+ if self.compat_in_streams is not None:
629
+ for i, _ in enumerate(self.compat_in_streams):
630
+ to_remove = []
631
+ for net_hash_name in self.compat_in_streams[i]:
632
+ if DataProps.peer_id_from_net_hash(net_hash_name[0]) == peer_id:
633
+ to_remove.append(net_hash_name)
634
+ for net_hash_name in to_remove:
635
+ self.compat_in_streams[i].remove(net_hash_name)
636
+
637
+ # Clearing from the list of processor-output-compatible-streams
638
+ if self.compat_out_streams is not None:
639
+ for i, _ in enumerate(self.compat_out_streams):
640
+ to_remove = []
641
+ for net_hash_name in self.compat_out_streams[i]:
642
+ if DataProps.peer_id_from_net_hash(net_hash_name[0]) == peer_id:
643
+ to_remove.append(net_hash_name)
644
+ for net_hash_name in to_remove:
645
+ self.compat_out_streams[i].remove(net_hash_name)
646
+
647
+ # Clearing streams owned by the removed agent from the list of known streams
648
+ await self.remove_streams(peer_id)
649
+
650
+ # Removing from the status variables
651
+ self.remove_peer_from_agent_status_attrs(peer_id)
652
+
653
+ # Updating buffered stream index
654
+ if peer_id in self.last_buffered_peer_id_to_info:
655
+ del self.last_buffered_peer_id_to_info[peer_id] # Only if present
656
+
657
+ # Clearing pending requests in the HSMs
658
+ behaviors = [self.behav_lone_wolf, self.behav]
659
+ for behav in behaviors:
660
+ if behav is not None and isinstance(behav, HybridStateMachine):
661
+ actions = behav.get_all_actions()
662
+ for action in actions:
663
+ if action.requests.is_requester_known(peer_id):
664
+ requests = action.requests.get_requests(peer_id)
665
+ for req in requests:
666
+ action.requests.remove(req)
667
+
668
+ self.out(f"Successfully removed agent with peer ID {peer_id}")
669
+
670
+ def remove_all_agents(self):
671
+ """Remove all known agents."""
672
+
673
+ # Clearing all agents
674
+ self.all_agents = {}
675
+ self.public_agents = {}
676
+ self.world_masters = {}
677
+ self.world_agents = {}
678
+ self.human_agents = {}
679
+ self.artificial_agents = {}
680
+
681
+ # Clearing the list of processor-output-compatible-streams
682
+ if self.compat_in_streams is not None and self.proc_inputs is not None:
683
+ self.compat_in_streams = [set() for _ in range(len(self.proc_inputs))]
684
+ if self.compat_out_streams is not None and self.proc_outputs is not None:
685
+ self.compat_out_streams = [set() for _ in range(len(self.proc_outputs))]
686
+
687
+ # Clearing the list of known streams (not our own streams!)
688
+ self.remove_all_streams(owned_too=False)
689
+ self.out(f"Successfully removed all agents")
690
+
691
+ def add_behav_wildcard(self, wildcard_from: str, wildcard_to: object):
692
+ """Adds a wildcard mapping for the agent's behavior state machine.
693
+
694
+ Args:
695
+ wildcard_from: The string to be used as a wildcard.
696
+ wildcard_to: The object to replace the wildcard.
697
+ """
698
+ self.behav_wildcards[wildcard_from] = wildcard_to
699
+
700
+ def add_stream(self, stream: DataStream, owned: bool = True, net_hash: str | None = None) -> dict[str, DataStream]:
701
+ """Add a new stream to the set of known streams.
702
+
703
+ Args:
704
+ stream: The DataStream object to add.
705
+ owned: If True, the streams are considered owned by this agent.
706
+ net_hash: Optional network hash for the streams. If None, it will be generated.
707
+
708
+ Returns:
709
+ A dictionary containing the added stream and the possibly already present streams belonging to the same
710
+ group (stream name -> stream object).
711
+ """
712
+
713
+ # Forcing clock
714
+ stream.clock = self._node_clock
715
+
716
+ # Stream net hash
717
+ if net_hash is None:
718
+ public_peer_id, private_peer_id = self.get_peer_ids()
719
+ peer_id = public_peer_id if stream.is_public() else private_peer_id
720
+ net_hash = stream.net_hash(peer_id)
721
+
722
+ # Adding the new stream
723
+ if net_hash not in self.known_streams:
724
+ self.known_streams[net_hash] = {}
725
+ else:
726
+ for _stream in self.known_streams[net_hash].values():
727
+ public = _stream.get_props().is_public()
728
+ pubsub = _stream.get_props().is_pubsub()
729
+ if public and not stream.get_props().is_public():
730
+ self.err(f"Cannot add a stream to a group with different properties (public): "
731
+ f"hash: {net_hash}, name: {stream.get_props().get_name()}, "
732
+ f"public: {stream.get_props().is_public()}")
733
+ return {}
734
+ if pubsub and not stream.get_props().is_pubsub():
735
+ self.err(f"Cannot add a stream to a group with different properties (pubsub): "
736
+ f"hash: {net_hash}, name: {stream.get_props().get_name()}, "
737
+ f"public: {stream.get_props().is_public()}")
738
+ return {}
739
+ break
740
+ self.known_streams[net_hash][stream.get_props().get_name()] = stream
741
+
742
+ if owned:
743
+
744
+ # Adding an 'owned' processor output stream (i.e., the stream coming from OUR OWN processor)
745
+ is_proc_outputs_stream = False
746
+ if self.proc_outputs is not None:
747
+ proc_outputs_name_and_group = set()
748
+ for props in self.proc_outputs:
749
+ proc_outputs_name_and_group.add((props.get_name(), props.get_group()))
750
+ if (stream.get_props().get_name(), stream.get_props().get_group()) in proc_outputs_name_and_group:
751
+ if net_hash not in self.proc_streams:
752
+ self.proc_streams[net_hash] = {}
753
+ self.proc_streams[net_hash][stream.get_props().get_name()] = stream
754
+ is_proc_outputs_stream = True
755
+
756
+ if net_hash not in self.owned_streams:
757
+ self.owned_streams[net_hash] = {}
758
+ self.owned_streams[net_hash][stream.get_props().get_name()] = stream
759
+
760
+ if not is_proc_outputs_stream:
761
+ if net_hash not in self.env_streams:
762
+ self.env_streams[net_hash] = {}
763
+ self.env_streams[net_hash][stream.get_props().get_name()] = stream
764
+
765
+ # Adding empty recipients slot
766
+ if net_hash not in self.recipients:
767
+ self.recipients[net_hash] = None
768
+
769
+ # If needed, merging descriptor labels (attribute labels) and sharing them with all streams
770
+ if self.merge_flat_stream_labels:
771
+ self.merge_flat_data_stream_props()
772
+
773
+ return self.known_streams[net_hash]
774
+
775
+ def add_streams(self, streams: list[DataStream], owned: bool = True, net_hash: str | None = None) \
776
+ -> list[dict[str, DataStream]]:
777
+ """Add a list of new streams to this environment.
778
+
779
+ Args:
780
+ streams: A list of DataStream objects to add.
781
+ owned: If True, the streams are considered owned by this agent.
782
+ net_hash: Optional network hash for the streams. If None, it will be generated for each.
783
+
784
+ Returns:
785
+ A list of dictionaries (it could be empty in case of issues), where each dictionary is what
786
+ is returned by add_stream().
787
+ """
788
+
789
+ # Adding the new stream
790
+ ret = []
791
+ for stream in streams:
792
+ stream_dict = self.add_stream(stream, owned, net_hash)
793
+ if len(stream_dict) == 0:
794
+ return []
795
+ ret.append(stream_dict)
796
+ return ret
797
+
798
+ async def remove_streams(self, peer_id: str, name: str | None = None, owned_too: bool = False):
799
+ """Remove a known stream (async).
800
+
801
+ Args:
802
+ peer_id: The hash of each stream included the peer ID of the owner, so this is the peer ID associated with
803
+ the stream(s) to remove.
804
+ name: The optional name of the stream to remove. If None, all streams with this peer_id are removed.
805
+ owned_too: If True, also removes streams from the owned stream dict (so also environmental and processor).
806
+ """
807
+
808
+ # Identifying what to remove
809
+ to_remove = []
810
+ for net_hash in self.known_streams.keys():
811
+ if DataProps.peer_id_from_net_hash(net_hash) == peer_id:
812
+ for _name, _stream in self.known_streams[net_hash].items():
813
+ if name is None or name == _name:
814
+ to_remove.append((net_hash, _name))
815
+
816
+ # Removing
817
+ for (net_hash, name) in to_remove:
818
+ if not owned_too and net_hash in self.owned_streams:
819
+ continue
820
+
821
+ del self.known_streams[net_hash][name]
822
+ if len(self.known_streams[net_hash]) == 0:
823
+ del self.known_streams[net_hash]
824
+
825
+ # Unsubscribing to pubsub
826
+ if DataProps.is_pubsub_from_net_hash(net_hash):
827
+ if peer_id != "<private_peer_id>" and peer_id != "<public_peer_id>":
828
+ if not (await self._node_conn.unsubscribe(peer_id, channel=net_hash)):
829
+ self.err(f"Failed in unsubscribing from pubsub, peer_id: {peer_id}, channel: {net_hash}")
830
+ else:
831
+ self.out(f"Successfully unsubscribed from pubsub, peer_id: {peer_id}, channel: {net_hash}")
832
+
833
+ # Removing all the owned streams (environment and processor streams are of course "owned")
834
+ if net_hash in self.owned_streams:
835
+ if name in self.owned_streams[net_hash]:
836
+ del self.owned_streams[net_hash][name]
837
+ if len(self.owned_streams[net_hash]) == 0:
838
+ del self.owned_streams[net_hash]
839
+ if net_hash in self.env_streams:
840
+ if name in self.env_streams[net_hash]:
841
+ del self.env_streams[net_hash][name]
842
+ if len(self.env_streams[net_hash]) == 0:
843
+ del self.env_streams[net_hash]
844
+ if net_hash in self.proc_streams:
845
+ if name in self.proc_streams[net_hash]:
846
+ del self.proc_streams[net_hash][name]
847
+ if len(self.proc_streams[net_hash]) == 0:
848
+ del self.proc_streams[net_hash]
849
+ self.out(f"Successfully removed known stream with network hash {net_hash}, stream name: {name}")
850
+
851
+ def remove_all_streams(self, owned_too: bool = False):
852
+ """Remove all not-owned streams.
853
+
854
+ Args:
855
+ owned_too: If True, also removes the owned streams of this agent (so also environmental and processor ones).
856
+ """
857
+ if not owned_too:
858
+ self.known_streams = {k: v for k, v in self.owned_streams}
859
+ else:
860
+ self.known_streams = {}
861
+ self.owned_streams = {}
862
+ self.env_streams = {}
863
+ self.proc_streams = {}
864
+ self.out(f"Successfully removed all streams!")
865
+
866
+ def find_streams(self, peer_id: str, name_or_group: str | None = None) -> dict[str, dict[str, DataStream]]:
867
+ """Find streams associated with a given peer ID and optionally by name or group.
868
+
869
+ Args:
870
+ peer_id: The peer ID of the (owner of the) streams to find.
871
+ name_or_group: Optional name or group of the streams to find.
872
+
873
+ Returns:
874
+ A dictionary where keys are network hashes and values are dictionaries of streams
875
+ (stream name to DataStream object) matching the criteria.
876
+ """
877
+ ret = {}
878
+ for net_hash, streams_dict in self.known_streams.items():
879
+ _peer_id = DataStream.peer_id_from_net_hash(net_hash)
880
+ _name_or_group = DataStream.name_or_group_from_net_hash(net_hash)
881
+ if peer_id == _peer_id:
882
+ if name_or_group is None or name_or_group == _name_or_group:
883
+ ret[net_hash] = streams_dict
884
+ else:
885
+ for _name, _stream in streams_dict.items():
886
+ if name_or_group == _name:
887
+ if net_hash not in ret:
888
+ ret[net_hash] = {}
889
+ ret[net_hash][name_or_group] = _stream
890
+ return ret
891
+
892
+ def get_last_streamed_data(self, agent_name: str):
893
+ """Find streams associated with a given peer ID and optionally by name or group.
894
+
895
+ Args:
896
+ agent_name: The name of the agent.
897
+
898
+ Returns:
899
+ A list of data samples taken from all the known streams associated to the provided agent.
900
+ """
901
+ data_list = []
902
+ for peer_id, profile in self.all_agents.items():
903
+ if profile.get_static_profile()['node_name'] == agent_name:
904
+ net_hash_to_stream_dict = self.find_streams(peer_id, name_or_group="processor")
905
+ for net_hash, streams_dict in net_hash_to_stream_dict.items():
906
+ for stream_name, stream_obj in streams_dict.items():
907
+ data_list.append(stream_obj.get())
908
+ return data_list
909
+
910
+ def merge_flat_data_stream_props(self):
911
+ """Merge the labels of the descriptor components, across all streams, sharing them."""
912
+
913
+ # Set of pivot labels
914
+ superset_labels = []
915
+
916
+ # Checking the whole list of streams, but considering only the ones with generic data, flat, and labels
917
+ considered_streams = []
918
+
919
+ for stream_dict in self.owned_streams.values():
920
+ for stream in stream_dict.values():
921
+
922
+ # Skipping not flat, or not generic, or unlabeled streams
923
+ if not stream.props.is_flat_tensor_with_labels():
924
+ continue
925
+
926
+ # Saving list of considered streams
927
+ considered_streams.append(stream)
928
+
929
+ # Adding the current stream-labels to the pivot labels
930
+ for label in stream.props.tensor_labels:
931
+ if label not in superset_labels:
932
+ superset_labels.append(label)
933
+
934
+ # Telling each stream in which positions their labels fall, given the pivot labels
935
+ for stream in considered_streams:
936
+
937
+ # In the case of BufferedDataStream, we have to update the data buffer by clearing previously applied
938
+ # adaptation first (I know it looks similar to what is done below, but we must clear first!)
939
+ if isinstance(stream, BufferedDataStream):
940
+ for i, (data, data_tag) in enumerate(stream.data_buffer):
941
+ stream.data_buffer[i] = (stream.props.clear_label_adaptation(data), data_tag)
942
+
943
+ # Updating labels
944
+ stream.props.tensor_labels.interleave_with(superset_labels)
945
+
946
+ # In the case of BufferedDataStream, we have to update the data buffer with the new labels
947
+ if isinstance(stream, BufferedDataStream):
948
+ for i, (data, data_tag) in enumerate(stream.data_buffer):
949
+ stream.data_buffer[i] = (stream.props.adapt_tensor_to_tensor_labels(data), data_tag)
950
+
951
+ def user_stream_hash_to_net_hash(self, user_stream_hash: str) -> str | None:
952
+ """Converts a user-defined stream hash (peer_id:name_or_group) to a network hash
953
+ (peer_id::dm:... or peer_id::ps:name_or_group) by searching the known hashes in the known streams.
954
+
955
+ Args:
956
+ user_stream_hash: The user-defined stream hash string (peer_id:name_or_group).
957
+
958
+ Returns:
959
+ The corresponding network hash string (peer_id::dm:... or peer_id::ps:name_or_group), or None if not found.
960
+ """
961
+ if user_stream_hash is None:
962
+ return None
963
+ if "::" in user_stream_hash:
964
+ return user_stream_hash # It was already fine
965
+ components = user_stream_hash.split(":")
966
+ peer_id = components[0]
967
+ name_or_group = components[-1]
968
+ for net_hash in self.known_streams.keys():
969
+ _peer_id = DataStream.peer_id_from_net_hash(net_hash)
970
+ _name_or_group = DataStream.name_or_group_from_net_hash(net_hash)
971
+ if _peer_id == peer_id and _name_or_group == name_or_group:
972
+ return net_hash
973
+ return None
974
+
975
+ def create_proc_input_streams(self, buffered: bool = False):
976
+ """Creates the processor input streams based on the `proc_inputs` defined for the agent.
977
+
978
+ Args:
979
+ buffered: If True, the created streams will be of type BufferedDataStream.
980
+ """
981
+
982
+ # Adding input streams (grouped together), passing the node clock
983
+ if self.proc_inputs is not None:
984
+ for i, procs in enumerate(self.proc_inputs):
985
+ procs.set_group("processor_in") # Adding default group info, forced, do not change this!
986
+
987
+ # Creating the streams
988
+ for props in procs.props:
989
+ if not buffered:
990
+ stream = DataStream(props=props.clone(), clock=self._node_clock)
991
+ else:
992
+ stream = BufferedDataStream(props=props.clone(), clock=self._node_clock)
993
+
994
+ self.add_stream(stream, owned=True)
995
+
996
+ public_peer_id, private_peer_id = self.get_peer_ids()
997
+ peer_id = public_peer_id if stream.is_public() else private_peer_id
998
+ net_hash = stream.net_hash(peer_id)
999
+ if stream.is_public():
1000
+ self.proc_in_net_hash['public'] = net_hash
1001
+ else:
1002
+ self.proc_in_net_hash['private'] = net_hash
1003
+
1004
+ # forcing the input stream to be compatible with proc inputs
1005
+ self.compat_in_streams[i].add((net_hash, props.get_name()))
1006
+
1007
+ def create_proc_output_streams(self, buffered: bool = False):
1008
+ """Creates the processor output streams based on the `proc_outputs` defined for the agent.
1009
+
1010
+ Args:
1011
+ buffered: If True, the created streams will be of type BufferedDataStream.
1012
+ """
1013
+
1014
+ # Adding generated streams (grouped together), passing the node clock
1015
+ if self.proc_outputs is not None:
1016
+ for i, procs in enumerate(self.proc_outputs):
1017
+ procs.set_group("processor") # Adding default group info, forced, do not change this!
1018
+
1019
+ # Creating the streams
1020
+ for props in procs.props:
1021
+ if not buffered:
1022
+ stream = DataStream(props=props.clone(), clock=self._node_clock)
1023
+ else:
1024
+ stream = BufferedDataStream(props=props.clone(), clock=self._node_clock)
1025
+
1026
+ self.add_stream(stream, owned=True)
1027
+
1028
+ public_peer_id, private_peer_id = self.get_peer_ids()
1029
+ peer_id = public_peer_id if stream.is_public() else private_peer_id
1030
+ net_hash = stream.net_hash(peer_id)
1031
+ if stream.is_public():
1032
+ self.proc_net_hash['public'] = net_hash
1033
+ else:
1034
+ self.proc_net_hash['private'] = net_hash
1035
+
1036
+ async def add_compatible_streams(self, peer_id: str,
1037
+ streams_in_profile: list[DataProps], buffered: bool = False,
1038
+ add_all: bool = False, public: bool = True) -> bool:
1039
+ """Add to the list of processor-compatible-streams those streams provided as arguments that are actually
1040
+ found to be compatible with the processor (if they are pubsub, it also subscribes to them) (async).
1041
+
1042
+ Args:
1043
+ peer_id: The peer ID of the agent providing the streams.
1044
+ streams_in_profile: A list of DataProps objects representing the streams from the peer's profile.
1045
+ buffered: If True, the added streams will be of type BufferedDataStream.
1046
+ add_all: If True, all streams from the profile are added, regardless of processor compatibility.
1047
+ public: Consider public streams only (or private streams only).
1048
+
1049
+ Returns:
1050
+ True if compatible streams were successfully added and subscribed to, False otherwise.
1051
+ """
1052
+ added_streams = []
1053
+
1054
+ if add_all:
1055
+
1056
+ # This is the case in which we add all streams, storing all pairs (DataProps, net_hash)
1057
+ for j in streams_in_profile:
1058
+ jj = DataProps.from_dict(j)
1059
+ if public == jj.is_public():
1060
+ net_hash = jj.net_hash(peer_id)
1061
+ added_streams.append((jj, net_hash))
1062
+ else:
1063
+
1064
+ # This is the case in which a processor is present, hence storing pairs (DataProps, net_hash)
1065
+ # of the found compatible streams
1066
+ added_net_hash_to_prop_name = {}
1067
+
1068
+ # Find streams that are compatible with our 'proc_inputs'
1069
+ for i, in_proc in enumerate(self.proc_inputs):
1070
+ for j in streams_in_profile:
1071
+ jj = DataProps.from_dict(j)
1072
+ if public == jj.is_public() and in_proc.is_compatible(jj):
1073
+ net_hash = jj.net_hash(peer_id)
1074
+
1075
+ if net_hash not in added_net_hash_to_prop_name:
1076
+ added_net_hash_to_prop_name[net_hash] = set()
1077
+ if jj.name not in added_net_hash_to_prop_name[net_hash]:
1078
+ added_net_hash_to_prop_name[net_hash].add(jj.name)
1079
+ added_streams.append((jj, net_hash))
1080
+
1081
+ # Saving the position in the proc_input list
1082
+ self.compat_in_streams[i].add((net_hash, jj.get_name()))
1083
+
1084
+ # Find streams that are compatible with our 'proc_outputs'
1085
+ has_cross_entropy = []
1086
+ if 'losses' in self.proc_opts:
1087
+ for i in range(0, len(self.proc_outputs)):
1088
+ if self.proc_opts['losses'][i] is not None and \
1089
+ (self.proc_opts['losses'][i] == torch.nn.functional.cross_entropy or
1090
+ isinstance(self.proc_opts['losses'][i], torch.nn.CrossEntropyLoss) or
1091
+ "cross_entropy" in self.proc_opts['losses'][i].__name__):
1092
+ has_cross_entropy.append(True)
1093
+ else:
1094
+ has_cross_entropy.append(False)
1095
+
1096
+ for i, out_proc in enumerate(self.proc_outputs):
1097
+ for j in streams_in_profile:
1098
+ jj = DataProps.from_dict(j)
1099
+ if (public == jj.is_public() and
1100
+ (out_proc.is_compatible(jj) or (jj.is_tensor_target_id() and has_cross_entropy[i]))):
1101
+ net_hash = jj.net_hash(peer_id)
1102
+
1103
+ if net_hash not in added_net_hash_to_prop_name:
1104
+ added_net_hash_to_prop_name[net_hash] = set()
1105
+ if jj.name not in added_net_hash_to_prop_name[net_hash]:
1106
+ added_net_hash_to_prop_name[net_hash].add(jj.name)
1107
+ added_streams.append((jj, net_hash))
1108
+
1109
+ # Saving the position in the proc_output list
1110
+ self.compat_out_streams[i].add((net_hash, jj.get_name()))
1111
+
1112
+ net_hashes_to_subscribe = set()
1113
+
1114
+ # For each compatible stream found...
1115
+ for (props, net_hash) in added_streams:
1116
+
1117
+ # Check if it is a new stream or a data stream to add to an already known stream
1118
+ already_known_stream = net_hash in self.known_streams
1119
+
1120
+ # Creating the stream object
1121
+ if not buffered:
1122
+ stream = DataStream(props=props.clone(), clock=self._node_clock)
1123
+ else:
1124
+ stream = BufferedDataStream(props=props.clone(), clock=self._node_clock)
1125
+
1126
+ # Add the data stream to the list of known streams
1127
+ # if the stream already exists it will be overwritten (which is fine in case of changes)
1128
+ self.add_stream(stream, owned=False, net_hash=net_hash)
1129
+
1130
+ # If the stream is over PubSub, and we are not already subscribed, we will subscribe
1131
+ if props.is_pubsub() and not already_known_stream:
1132
+ net_hashes_to_subscribe.add(net_hash)
1133
+
1134
+ # Opening PubSubs
1135
+ for net_hash in net_hashes_to_subscribe:
1136
+ self.out(f"Opening channel for the not-owned but processor-compatible stream {net_hash}")
1137
+ if not (await self._node_conn.subscribe(peer_id, channel=net_hash)):
1138
+ self.err(f"Error subscribing to {net_hash}")
1139
+ return False
1140
+
1141
+ return True
1142
+
1143
+ async def subscribe_to_pubsub_owned_streams(self) -> bool:
1144
+ """Subscribes to all owned streams that are marked as PubSub (async).
1145
+
1146
+ Returns:
1147
+ True if all subscriptions were successful, False otherwise.
1148
+ """
1149
+
1150
+ # Opening channels for all the (groups of) owned streams (generated and not)
1151
+ for net_hash in self.owned_streams.keys():
1152
+ is_pubsub = DataStream.is_pubsub_from_net_hash(net_hash)
1153
+
1154
+ if is_pubsub:
1155
+ self.out(f"Opening channel for the owned stream {net_hash}")
1156
+ peer_id = DataStream.peer_id_from_net_hash(net_hash) # Guessing peer ID from the net hash
1157
+
1158
+ if not (await self._node_conn.subscribe(peer_id, channel=net_hash)):
1159
+ self.err(f"Cannot open a channel for owned stream hash {net_hash}")
1160
+ return False
1161
+ return True
1162
+
1163
+ def update_streams_in_profile(self):
1164
+ """Updates the agent's profile with information about its owned (environmental and processor) streams."""
1165
+
1166
+ # Filling the information about the streams that can be generated and handled
1167
+ dynamic_profile = self._node_profile.get_dynamic_profile()
1168
+ if hasattr(self, 'proc_outputs') and hasattr(self, 'proc_inputs'):
1169
+ dynamic_profile['proc_outputs'] = \
1170
+ [dct for d in self.proc_outputs for dct in d.to_list_of_dicts()] # List of dict of DataProp
1171
+ dynamic_profile['proc_inputs'] = \
1172
+ [dct for d in self.proc_inputs for dct in d.to_list_of_dicts()] # List of dict of DataProp
1173
+
1174
+ # Adding the list of locally-created ("environmental") streams to the profile
1175
+ list_of_props = []
1176
+ public_peer_id, private_peer_id = self.get_peer_ids()
1177
+ for net_hash, streams_dict in self.owned_streams.items():
1178
+ if net_hash not in self.proc_streams.keys():
1179
+ if (DataProps.peer_id_from_net_hash(net_hash) == public_peer_id or
1180
+ DataProps.peer_id_from_net_hash(net_hash) == private_peer_id):
1181
+ for stream in streams_dict.values():
1182
+ list_of_props.append(stream.get_props().to_dict()) # DataProp
1183
+ if len(list_of_props) > 0:
1184
+ dynamic_profile['streams'] = list_of_props
1185
+
1186
+ async def send_profile_to_all(self):
1187
+ """Sends the agent's profile to all known agents (async)."""
1188
+
1189
+ for peer_id in self.all_agents.keys():
1190
+ self.out(f"Sending profile to {peer_id}")
1191
+ if not (await self._node_conn.send(peer_id, channel_trail=None,
1192
+ content=self._node_profile.get_all_profile(),
1193
+ content_type=Msg.PROFILE)):
1194
+ self.err("Failed to send profile, removing (disconnecting) " + peer_id)
1195
+ await self.remove_agent(peer_id)
1196
+
1197
+ def generate(self, input_net_hashes: list[str] | None = None,
1198
+ inputs: list[str | torch.Tensor | Image] | None = None,
1199
+ first: bool = False, last: bool = False, ref_uuid: str | None = None) -> (
1200
+ tuple[tuple[torch.Tensor] | None, int]):
1201
+ """Generate new signals.
1202
+
1203
+ Args:
1204
+ input_net_hashes: A list of network hashes to be considered as input streams (they will be sub-selected).
1205
+ inputs: A list of data to be directly provided as input to the processor (if not None, input_net_hashes is
1206
+ ignored).
1207
+ first: If True, indicates this is the first generation call in a sequence.
1208
+ last: If True, indicates this is the last generation call in a sequence.
1209
+ ref_uuid: An optional UUID to match against input stream UUIDs (it can be None).
1210
+
1211
+ Returns:
1212
+ A tuple containing:
1213
+ - A tuple of torch.Tensor objects representing the generated output, or None if generation failed.
1214
+ - An integer representing a data tag or status.
1215
+ """
1216
+
1217
+ # Preparing processor input
1218
+ if inputs is None:
1219
+ inputs = [None] * len(self.proc_inputs)
1220
+ matched = set()
1221
+ data_tag = None
1222
+
1223
+ if input_net_hashes is None:
1224
+ input_net_hashes = []
1225
+
1226
+ # Checking UUIDs and searching the provided input streams: we look to match them with the processor input
1227
+ for net_hash in input_net_hashes:
1228
+ stream_dict = self.known_streams[net_hash]
1229
+ for stream_name, stream in stream_dict.items():
1230
+
1231
+ # Checking the UUID in our known streams, comparing it with the UUID provided as input:
1232
+ # if they are not compatible, we don't generate at all
1233
+ if ref_uuid is not None and stream.get_uuid(expected=False) != ref_uuid:
1234
+ self.deb(f"[generate] The UUID ({stream.get_uuid(expected=False)}, expected: "
1235
+ f"{stream.get_uuid(expected=True)}) of stream {net_hash} is not the one we were "
1236
+ f"looking for ({ref_uuid}), skipping this data stream")
1237
+ continue
1238
+
1239
+ # Matching the currently checked input stream with one of the processor inputs
1240
+ stream_sample = stream.get(requested_by="generate")
1241
+ for i in range(len(self.proc_inputs)):
1242
+
1243
+ # If the current input stream is compatible with the i-th input slot...
1244
+ if (net_hash, stream_name) in self.compat_in_streams[i]:
1245
+
1246
+ # If the current input stream was already assigned to another input slot
1247
+ # (different from "i") we skip the generation
1248
+ if (net_hash, stream_name) in matched:
1249
+ self.err("Cannot generate: ambiguous input streams provided "
1250
+ "(they can match multiple processor inputs)")
1251
+ return None, -1
1252
+
1253
+ # Found a valid assignment: getting stream sample
1254
+ self.deb(f"[generate] Setting the {i}-th network input to stream with "
1255
+ f"net_hash: {net_hash}, name: {stream_name}")
1256
+ if stream_sample is None:
1257
+ self.deb(f"[generate] Failed setting the {i}-th input, got a None sample")
1258
+ else:
1259
+ self.deb(f"[generate] Going ahead setting the {i}-th input, got a valid sample")
1260
+
1261
+ # Found a valid assignment: associating it to the i-th input slot
1262
+ try:
1263
+ inputs[i] = self.proc_inputs[i].check_and_preprocess(stream_sample,
1264
+ device=self.proc.device)
1265
+ except Exception as e:
1266
+ self.err(f"Error while checking and preprocessing the {i}-th input [{e}]")
1267
+ continue
1268
+
1269
+ self.deb(f"[generate] Finished setting the {i}-th input, preprocessing complete")
1270
+
1271
+ # Found a valid assignment: saving match
1272
+ matched.add((net_hash, stream_name))
1273
+
1274
+ # If all the inputs share the same data tag, we will return it,
1275
+ # otherwise we set it at -1 (meaning no tag)
1276
+ if data_tag is None:
1277
+ data_tag = stream.get_tag()
1278
+ elif data_tag != stream.get_tag():
1279
+ data_tag = -1
1280
+
1281
+ if AgentBasics.DEBUG:
1282
+ if stream.props.is_text():
1283
+ self.deb(f"[generate] Input {i} of the network: {stream_sample}")
1284
+ break
1285
+
1286
+ # Checking if we were able to match some data for each input slot of the network (processor)
1287
+ for i in range(len(self.proc_inputs)):
1288
+ if inputs[i] is None:
1289
+ if self.proc_optional_inputs[i]["has_default"]:
1290
+ inputs[i] = self.proc_optional_inputs[i]["default_value"]
1291
+ else:
1292
+ self.err(
1293
+ f"Cannot generate: couldn't find a valid input for the "
1294
+ f"{i}-th input position of the processor (and no default values are present)")
1295
+ return None, -1
1296
+ else:
1297
+ data_tag = -1
1298
+
1299
+ if AgentBasics.DEBUG:
1300
+ if inputs is not None:
1301
+ input_shapes = []
1302
+ for x in inputs:
1303
+ if isinstance(x, torch.Tensor):
1304
+ input_shapes.append(x.shape)
1305
+ else:
1306
+ input_shapes.append("<non-tensor>")
1307
+ self.deb(f"[generate] Input shapes: {input_shapes}")
1308
+ self.deb(f"[generate] Input data tag: {data_tag}")
1309
+
1310
+ # Calling processor (inference) passing the collected inputs
1311
+ inputs = self.proc_callback_inputs(inputs)
1312
+ try:
1313
+ outputs = self.proc(*inputs, first=first, last=last)
1314
+
1315
+ # Ensuring the output is a tuple, even if composed by a single tensor
1316
+ if not isinstance(outputs, tuple):
1317
+ outputs = (outputs, )
1318
+ except Exception as e:
1319
+ self.err(f"Error while calling the processor [{e}]")
1320
+ outputs = (None, ) * len(self.proc_outputs)
1321
+ outputs = self.proc_callback_outputs(outputs)
1322
+
1323
+ # Saving
1324
+ self.last_ref_uuid = ref_uuid
1325
+
1326
+ if AgentBasics.DEBUG:
1327
+ if outputs is not None:
1328
+ i = 0
1329
+ for net_hash, stream_dict in self.proc_streams.items():
1330
+ for stream in stream_dict.values():
1331
+ if self.behaving_in_world() != stream.props.is_public():
1332
+ if outputs[i] is not None:
1333
+ if stream.props.is_tensor() or stream.props.is_text():
1334
+ self.deb(f"[generate] outputs[{i}]: {str(stream.props.to_text(outputs[i]))}")
1335
+ else:
1336
+ self.deb(f"[generate] outputs[{i}]: not None, but it cannot be converted to text")
1337
+ else:
1338
+ self.deb(f"[generate] outputs[{i}]: None")
1339
+ i += 1
1340
+ self.deb(f"[generate] Output shapes: {[x.shape for x in outputs if isinstance(x, torch.Tensor)]}")
1341
+
1342
+ return outputs, data_tag
1343
+
1344
+ def learn_generate(self,
1345
+ outputs: tuple[torch.Tensor],
1346
+ targets_net_hashes: list[str] | None) -> tuple[list[float] | None, list[float] | None]:
1347
+ """Learn (i.e., update model params) by matching the given processor outputs with a set of targets (if any).
1348
+
1349
+ Args:
1350
+ outputs: A tuple of torch.Tensor representing the outputs generated by the agent's processor.
1351
+ targets_net_hashes: An optional list of network hashes identifying the streams
1352
+ from which target data should be retrieved for learning.
1353
+ If None, losses are evaluated without explicit targets.
1354
+
1355
+ Returns:
1356
+ A tuple containing:
1357
+ - A list of float values representing the individual loss values for each output.
1358
+ Returns None if targets are specified but cannot be found.
1359
+ - A list of integers representing the data tags of the given target streams (None if no targets were given).
1360
+ """
1361
+
1362
+ # Cannot learn without optimizer and losses
1363
+ if (self.proc_opts['optimizer'] is None or self.proc_opts['losses'] is None or
1364
+ len(self.proc_opts['losses']) == 0):
1365
+ return None, None
1366
+
1367
+ # Matching targets with the output slots of the processor
1368
+ at_least_one_target_found = False
1369
+ if targets_net_hashes is not None:
1370
+ targets = [None] * len(self.proc_outputs)
1371
+ matched = set()
1372
+ data_tags = [-1] * len(self.proc_outputs)
1373
+
1374
+ # For each target stream group...
1375
+ for net_hash in targets_net_hashes:
1376
+ stream_dict = self.known_streams[net_hash]
1377
+
1378
+ # For each stream of the current target group....
1379
+ for stream_name, stream in stream_dict.items():
1380
+ stream_sample = None
1381
+
1382
+ # For each output slot of our processor... (index "i")
1383
+ for i in range(len(self.proc_outputs)):
1384
+
1385
+ # Check if the i-th target was already assigned or if the i-th output is not a tensor
1386
+ if targets[i] is not None or not isinstance(outputs[i], torch.Tensor):
1387
+ continue
1388
+
1389
+ # If the target stream is compatible with the i-th output of the processor...
1390
+ if (net_hash, stream_name) in self.compat_out_streams[i]:
1391
+
1392
+ # If the current target was already assigned to another output slot (different from "i)"
1393
+ # we skip learning
1394
+ if (net_hash, stream_name) in matched:
1395
+ self.err("Cannot generate: ambiguous target streams provided "
1396
+ "(they can match multiple processor outputs)")
1397
+ return None, None
1398
+
1399
+ # Found a valid assignment: getting stream sample
1400
+ if stream_sample is None:
1401
+ stream_sample = stream.get(requested_by="learn_generate")
1402
+ if stream_sample is None:
1403
+ return None, None
1404
+
1405
+ # Found a valid assignment: associating target to the i-th output slot
1406
+ try:
1407
+ targets[i] = self.proc_outputs[i].check_and_preprocess(stream_sample,
1408
+ allow_class_ids=True,
1409
+ targets=True,
1410
+ device=self.proc.device)
1411
+ except Exception as e:
1412
+ self.err(f"Error while checking and preprocessing the {i}-th targets [{e}]")
1413
+
1414
+ # Found a valid assignment: saving match
1415
+ matched.add((net_hash, stream_name))
1416
+
1417
+ # Saving tag
1418
+ data_tags[i] = stream.get_tag()
1419
+
1420
+ # Confirming
1421
+ at_least_one_target_found = True
1422
+
1423
+ if AgentBasics.DEBUG:
1424
+ if stream.props.is_tensor():
1425
+ self.deb("[generate] Target of the network: " +
1426
+ str(stream.props.to_text(targets[i])))
1427
+ elif stream.props.is_text():
1428
+ self.deb("[generate] Target of the network: " + stream_sample)
1429
+ break
1430
+
1431
+ # If no targets were matched, we skip learning
1432
+ if not at_least_one_target_found:
1433
+ self.err(f"Cannot learn: cannot find a valid target for any output positions of the processor")
1434
+ return None, None
1435
+ else:
1436
+
1437
+ # If no targets were provided, it is expected to be the case of fully unsupervised learning
1438
+ data_tags = None
1439
+ targets = None
1440
+
1441
+ # Retrieving custom elements from the option dictionary
1442
+ loss_functions: list = self.proc_opts['losses']
1443
+ optimizer: torch.optim.optimizer.Optimizer | None = self.proc_opts['optimizer']
1444
+
1445
+ # Evaluating loss function(s), one for each processor output slot (they are set to 0. if no targets are there)
1446
+ if targets_net_hashes is not None:
1447
+
1448
+ # Supervised or partly supervised learning
1449
+ loss_values = [loss_fcn(outputs[i], targets[i]) if targets[i] is not None else
1450
+ torch.tensor(0., device=self.proc.device)
1451
+ for i, loss_fcn in enumerate(loss_functions)]
1452
+ loss = torch.stack(loss_values).sum() # Sum of losses
1453
+ else:
1454
+
1455
+ # Unsupervised learning
1456
+ loss_values = [loss_fcn(outputs[i]) for i, loss_fcn in enumerate(loss_functions)]
1457
+ loss = torch.stack(loss_values).sum() # Sum of losses
1458
+
1459
+ # Learning step
1460
+ optimizer.zero_grad()
1461
+ loss.backward()
1462
+ optimizer.step()
1463
+ # This is where parameters are actually updated, but the flag is set
1464
+ # upon success in do_learn which is the outer method calling this one
1465
+ # self.proc_updated_since_last_save = True
1466
+
1467
+ # Teaching (for autoregressive models, expected to have attribute "y")
1468
+ if hasattr(self.proc, 'y'):
1469
+ self.proc.y = targets[0]
1470
+
1471
+ # Returning a list of float values and the data tags of the targets
1472
+ return [loss_value.item() for loss_value in loss_values], data_tags
1473
+
1474
+ async def behave(self):
1475
+ """Behave in the current environment, calling the state-machines of the public and private networks (async)."""
1476
+
1477
+ if self.in_world():
1478
+ self.out("Behaving (world)...")
1479
+ if self.behav is None:
1480
+ self.err("No behaviour specified")
1481
+ else:
1482
+ self.behav_lone_wolf.enable(False)
1483
+ self.behav.enable(True)
1484
+ await self.behav.act()
1485
+ self.behav.enable(False)
1486
+
1487
+ self.out("Behaving (public)...")
1488
+ if self.behav_lone_wolf is None:
1489
+ self.err("No behaviour specified")
1490
+ else:
1491
+ self.behav.enable(False)
1492
+ self.behav_lone_wolf.enable(True)
1493
+ await self.behav_lone_wolf.act()
1494
+ self.behav_lone_wolf.enable(False)
1495
+
1496
+ def learn_behave(self, state: int, last_action: int, prev_state: int):
1497
+ """A placeholder method for behavioral learning, intended to be implemented by child classes.
1498
+ It receives state and action information to update a behavioral model.
1499
+
1500
+ Args:
1501
+ state: The current state of the agent.
1502
+ last_action: The last action taken.
1503
+ prev_state: The previous state of the agent.
1504
+
1505
+ Returns:
1506
+ An integer representing a new state, or similar feedback.
1507
+ """
1508
+ pass
1509
+
1510
+ def get_peer_ids(self):
1511
+ """Retrieve the public and private peer IDs of the agent, from the underlying node's dynamic profile.
1512
+
1513
+ Returns:
1514
+ A tuple containing the public peer ID and the private peer ID.
1515
+ If either ID is not available, a placeholder string is returned <public_peer_id>, <private_peer_id>.
1516
+ """
1517
+ public_peer_id = None
1518
+ private_peer_id = None
1519
+ if self._node_profile is not None:
1520
+ dynamic_profile = self._node_profile.get_dynamic_profile()
1521
+ public_peer_id = dynamic_profile['peer_id'] # Public
1522
+ private_peer_id = dynamic_profile['private_peer_id'] # Private
1523
+ public_peer_id = '<public_peer_id>' if public_peer_id is None else public_peer_id
1524
+ private_peer_id = '<private_peer_id>' if private_peer_id is None else private_peer_id
1525
+ return public_peer_id, private_peer_id
1526
+
1527
+ def evaluate_profile(self, role: int, profile: NodeProfile) -> bool:
1528
+ """Evaluate if a given profile is valid for this agent based on its role. It helps in identifying and filtering
1529
+ out invalid or 'cheating' profiles.
1530
+
1531
+ Args:
1532
+ role: The expected integer role (e.g., ROLE_PUBLIC, ROLE_WORLD_MASTER) for the profile.
1533
+ profile: The NodeProfile object to be evaluated.
1534
+
1535
+ Returns:
1536
+ True if the profile is considered valid for the specified role, False otherwise.
1537
+ """
1538
+
1539
+ # If the role in the profile is not the provided role, a profile-cheater was found
1540
+ if (profile.get_dynamic_profile()['connections']['role'] in self.ROLE_STR_TO_BITS and
1541
+ self.ROLE_STR_TO_BITS[profile.get_dynamic_profile()['connections']['role']] != role):
1542
+ self.out(f"Cheater found: "
1543
+ f"{profile.get_dynamic_profile()['connections']['role']} != {self.ROLE_BITS_TO_STR[role]}")
1544
+ return False # Cheater found
1545
+
1546
+ # These are just examples: you are expected to reimplement this method in your custom agent file
1547
+ if (role & 1 == self.ROLE_PUBLIC and
1548
+ profile.get_dynamic_profile()['guessed_location'] == 'Some Dummy Location, Just An Example Here'):
1549
+ return False
1550
+ elif (role & 3 == self.ROLE_WORLD_MASTER and
1551
+ profile.get_dynamic_profile()['guessed_location'] == 'Some Other Location, Just Another Example Here'):
1552
+ return False
1553
+ else:
1554
+ return True
1555
+
1556
+ def accept_new_role(self, role: int):
1557
+ """Set the agent's role and optionally load a default behavior (private/world behaviour).
1558
+
1559
+ Args:
1560
+ role: The integer role to assign to the agent (e.g., ROLE_PUBLIC, ROLE_WORLD_MASTER).
1561
+ """
1562
+ base_role_str = self.ROLE_BITS_TO_STR[(role >> 2) << 2]
1563
+ full_role_str = self.ROLE_BITS_TO_STR[role]
1564
+
1565
+ self._node_profile.get_dynamic_profile()['connections']['role'] = full_role_str
1566
+
1567
+ base_role_to_behav = self.world_profile.get_dynamic_profile()['world_roles_fsm']
1568
+ if base_role_str in base_role_to_behav:
1569
+ default_behav = self.world_profile.get_dynamic_profile()['world_roles_fsm'][base_role_str]
1570
+ else:
1571
+ default_behav = None # A public role will not be found in the map
1572
+
1573
+ if default_behav is not None and len(default_behav) > 0:
1574
+ default_behav_hsm = HybridStateMachine(self)
1575
+ default_behav_hsm.load(default_behav)
1576
+ self.behav = HybridStateMachine(self, policy=self.policy_default)
1577
+ self.behav.include(default_behav_hsm, make_a_copy=True)
1578
+ self.behav.set_role(base_role_str)
1579
+ self.set_policy_filter(self.policy_filter, public=False)
1580
+
1581
+ def in_world(self):
1582
+ """Check if the agent is currently operating within a 'world'.
1583
+
1584
+ Returns:
1585
+ True if the agent is in a world, False otherwise.
1586
+ """
1587
+ if self._node_profile is not None:
1588
+ return self.ROLE_STR_TO_BITS[self._node_profile.get_dynamic_profile()['connections']['role']] & 1 == 1
1589
+ else:
1590
+ return False
1591
+
1592
+ def behaving_in_world(self):
1593
+ """Checks if the agent's world-specific behavior state machine is currently active.
1594
+
1595
+ Returns:
1596
+ True if the world behavior is active, False otherwise.
1597
+ """
1598
+ return self.behav.is_enabled()
1599
+
1600
+ def get_stream_sample(self, net_hash: str, sample_dict: dict[str, dict[str, torch.Tensor | None | int | str]]):
1601
+ """Receive and process stream samples that were provided by another agent.
1602
+
1603
+ Args:
1604
+ net_hash: The network hash identifying the source of the stream samples.
1605
+ sample_dict: A dictionary where keys are stream names and values are dictionaries
1606
+ containing 'data', 'data_tag', and 'data_uuid' for each sample.
1607
+
1608
+ Returns:
1609
+ True if the stream samples were successfully processed and stored, False otherwise
1610
+ (e.g., if the stream is unknown, not compatible, or data is None/stale).
1611
+ """
1612
+
1613
+ # Let's be sure that the net hash is converted from the user's perspective to the one of the code here
1614
+ net_hash = DataProps.normalize_net_hash(net_hash)
1615
+
1616
+ self.out(f"Got a stream sample from {net_hash}...")
1617
+ if sample_dict is None: # hasattr(sample_dict, "keys") //// or not isinstance(sample_dict, dict):
1618
+ self.err(f"Invalid sample (expected a dictionary, got {type(sample_dict)})")
1619
+ return False
1620
+
1621
+ if net_hash in self.known_streams:
1622
+ for name, data_and_tag_and_uuid in sample_dict.items():
1623
+ if ('data' not in data_and_tag_and_uuid or
1624
+ 'data_tag' not in data_and_tag_and_uuid or
1625
+ 'data_uuid' not in data_and_tag_and_uuid):
1626
+ self.err(f"Invalid sample in data stream named {name} (missing one or more keys)")
1627
+ return False
1628
+
1629
+ if AgentBasics.DEBUG:
1630
+ if net_hash in self.known_streams and name in self.known_streams[net_hash]:
1631
+ self.deb(f"[get_stream_sample] Local data stream {name} status: tag="
1632
+ f"{self.known_streams[net_hash][name].get_tag()}, uuid="
1633
+ f"{self.known_streams[net_hash][name].get_uuid(expected=False)}, uuid-expected="
1634
+ f"{self.known_streams[net_hash][name].get_uuid(expected=True)}")
1635
+
1636
+ data, data_tag, data_uuid = (data_and_tag_and_uuid['data'],
1637
+ data_and_tag_and_uuid['data_tag'],
1638
+ data_and_tag_and_uuid['data_uuid'])
1639
+
1640
+ # - data must be not None
1641
+ # - the stream name must be known
1642
+ # - if the UUID associated to our local stream is the same of the data, then we check tag order
1643
+ # - if the UUID associated to our local stream is the expected one, we don't check tag order
1644
+ skip = False
1645
+ reason = None
1646
+ if not skip:
1647
+ if data is None:
1648
+ skip = True
1649
+ reason = "Data is None"
1650
+ if not skip:
1651
+ if net_hash not in self.known_streams:
1652
+ skip = True
1653
+ reason = f"The net hash {net_hash} is not a known stream hash"
1654
+ if not skip:
1655
+ if name not in self.known_streams[net_hash]:
1656
+ skip = True
1657
+ reason = f"The data stream named {name} is present for net hash {net_hash}"
1658
+ if not skip:
1659
+ if (self.known_streams[net_hash][name].get_uuid(expected=True) is not None and
1660
+ data_uuid != self.known_streams[net_hash][name].get_uuid(expected=True)):
1661
+ skip = True
1662
+ reason = (f"The data UUID {data_uuid} is not the expected one "
1663
+ f"{self.known_streams[net_hash][name].get_uuid(expected=True)}")
1664
+ if not skip:
1665
+ if (self.known_streams[net_hash][name].get_uuid(expected=True) is None and
1666
+ self.known_streams[net_hash][name].get_uuid(expected=False) is not None and
1667
+ data_uuid != self.known_streams[net_hash][name].get_uuid(expected=False)):
1668
+ skip = True
1669
+ reason = (f"The data UUID {data_uuid} is not the one of the stream, which is "
1670
+ f"{self.known_streams[net_hash][name].get_uuid(expected=False)}")
1671
+ if not skip:
1672
+ if (self.known_streams[net_hash][name].get_uuid(expected=True) is None and
1673
+ self.known_streams[net_hash][name].get_uuid(expected=False) is not None and
1674
+ data_uuid == self.known_streams[net_hash][name].get_uuid(expected=False) and
1675
+ data_tag <= self.known_streams[net_hash][name].get_tag()):
1676
+ skip = True
1677
+ reason = (f"The data tag {data_tag} is less or equal to the already present one "
1678
+ f"({self.known_streams[net_hash][name].get_tag()})")
1679
+
1680
+ # If we sample can be accepted...
1681
+ if not skip:
1682
+ self.out(f"Accepted sample named {name}: tag={data_tag}, uuid={data_uuid}")
1683
+
1684
+ # Saving the data sample on the known stream objects
1685
+ if AgentBasics.DEBUG:
1686
+ self.deb(f"data={self.known_streams[net_hash][name].props.to_text(data)}")
1687
+ self.known_streams[net_hash][name].set(data, data_tag)
1688
+
1689
+ # If the local stream was expecting data with a certain UUID, and we got it ...
1690
+ # OR
1691
+ # if the local stream was not expecting anything and was also not set to any UUID, and we got data
1692
+ # with some UUID ...
1693
+ # THEN
1694
+ # we clear expectations and set the current UUID to the one of the data.
1695
+ # (the second part of the OR above is the case of data that arrives before an action request,
1696
+ # since action requests set expectations only)
1697
+ if ((self.known_streams[net_hash][name].get_uuid(expected=True) is not None and
1698
+ data_uuid == self.known_streams[net_hash][name].get_uuid(expected=True)) or
1699
+ (self.known_streams[net_hash][name].get_uuid(expected=True) is None and
1700
+ self.known_streams[net_hash][name].get_uuid(expected=False) is None and
1701
+ data_uuid is not None)):
1702
+
1703
+ # Setting what was the expected UUID as the local UUID from now on
1704
+ self.known_streams[net_hash][name].set_uuid(data_uuid, expected=False) # Setting current
1705
+ self.known_streams[net_hash][name].set_uuid(None, expected=True) # Clearing expected
1706
+
1707
+ if AgentBasics.DEBUG:
1708
+ self.deb(f"[get_stream_sample] Switched uuid of the local data stream!")
1709
+ self.deb(f"[get_stream_sample] New local data stream status: tag="
1710
+ f"{self.known_streams[net_hash][name].get_tag()}, uuid="
1711
+ f"{self.known_streams[net_hash][name].get_uuid(expected=False)}, uuid-expected="
1712
+ f"{self.known_streams[net_hash][name].get_uuid(expected=True)}")
1713
+
1714
+ # Clearing UUID if marked as such (useful for single shot actions with no "done"-like feedback)
1715
+ stream_obj = self.known_streams[net_hash][name]
1716
+ if stream_obj.data_uuid_clearable:
1717
+ self.deb(f"[get_stream_sample] Clearing marked data stream {net_hash}.{name}")
1718
+ stream_obj.clear_uuid_if_marked_as_clearable()
1719
+
1720
+ # Buffering data, if it was requested and if this sample comes from somebody's processor
1721
+ if (self.buffer_generated_by_others != "none" and
1722
+ DataProps.name_or_group_from_net_hash(net_hash) == "processor"):
1723
+ self.deb(f"[get_stream_sample] Buffering others' processor generated data...")
1724
+
1725
+ # Getting the streams of the processor of the source agent
1726
+ _processor_stream_dict = self.known_streams[net_hash]
1727
+ _peer_id = DataProps.peer_id_from_net_hash(net_hash)
1728
+
1729
+ # Setting buffered stream counter
1730
+ clear = False
1731
+ if _peer_id in self.last_buffered_peer_id_to_info:
1732
+ if self.buffer_generated_by_others == "one":
1733
+ _buffered_uuid_to_id = self.last_buffered_peer_id_to_info[_peer_id]["uuid_to_id"]
1734
+ if data_uuid not in _buffered_uuid_to_id:
1735
+ _id = next(iter(_buffered_uuid_to_id.values()))
1736
+ _buffered_uuid_to_id.clear()
1737
+ _buffered_uuid_to_id[data_uuid] = _id
1738
+ clear = True
1739
+ else:
1740
+ self.last_buffered_peer_id_to_info[_peer_id] = {"uuid_to_id": {}, "net_hash": None}
1741
+ _buffered_uuid_to_id = self.last_buffered_peer_id_to_info[_peer_id]["uuid_to_id"]
1742
+ if data_uuid not in _buffered_uuid_to_id:
1743
+ _buffered_uuid_to_id[data_uuid] = sum(
1744
+ len(v["uuid_to_id"]) for v in self.last_buffered_peer_id_to_info.values()) + 1
1745
+ _buffered_id = _buffered_uuid_to_id[data_uuid]
1746
+
1747
+ # Building net hash to retrieve the buffered stream
1748
+ _net_hash = DataProps.build_net_hash(
1749
+ _peer_id,
1750
+ pubsub=False,
1751
+ name_or_group=("buffered" + str(_buffered_id)))
1752
+
1753
+ # If the buffered stream was not created before
1754
+ if _net_hash not in self.known_streams:
1755
+ self.deb(f"[get_stream_sample] Adding a new buffered stream to the list of known "
1756
+ f"streams, hash: {_net_hash}")
1757
+ for stream_obj in _processor_stream_dict.values():
1758
+
1759
+ # Same properties of the stream of the processor of the source agent
1760
+ props = stream_obj.get_props().clone()
1761
+ props.set_group("buffered" + str(_buffered_id))
1762
+
1763
+ # Adding the newly created stream
1764
+ self.add_stream(BufferedDataStream(props=props, clock=self._node_clock),
1765
+ owned=False,
1766
+ net_hash=_net_hash)
1767
+
1768
+ # Saving hash of the new buffered stream
1769
+ self.last_buffered_peer_id_to_info[_peer_id]["net_hash"] = _net_hash
1770
+ else:
1771
+ if clear:
1772
+ for stream_obj in self.known_streams[_net_hash].values():
1773
+ stream_obj.clear_buffer()
1774
+
1775
+ # Saving sample
1776
+ self.known_streams[_net_hash][name].set(data, data_tag)
1777
+
1778
+ # Clearing all UUID of the locally buffered stream
1779
+ self.known_streams[_net_hash][name].set_uuid(None, expected=False)
1780
+ self.known_streams[_net_hash][name].set_uuid(None, expected=True)
1781
+
1782
+ # If we decided to skip this sample...
1783
+ else:
1784
+ self.out(f"Skipping sample named {name} in net hash {net_hash}: tag={data_tag}, uuid={data_uuid}" +
1785
+ (", data is None!" if data is None else ""))
1786
+
1787
+ behav = self.behav if self.behav.is_enabled() else self.behav_lone_wolf
1788
+ if behav.are_debug_messages_active():
1789
+ behav.action_out_fcn(behav.print_start +
1790
+ f"Skipping sample named {name} received in net hash {net_hash}, "
1791
+ f"tag={data_tag}, uuid={data_uuid}: {reason}")
1792
+
1793
+ if AgentBasics.DEBUG:
1794
+ if net_hash not in self.known_streams:
1795
+ self.deb(f"[get_stream_sample] "
1796
+ f"The net hash {net_hash} was not found in the set of known streams")
1797
+ else:
1798
+ if name not in self.known_streams[net_hash]:
1799
+ self.deb(f"[get_stream_sample] The net hash was known, but the data stream "
1800
+ f"named {name} is not known")
1801
+ else:
1802
+ self.deb(f"[get_stream_sample] "
1803
+ f"data={self.known_streams[net_hash][name].props.to_text(data)}")
1804
+ return True
1805
+
1806
+ # If this stream is not known at all...
1807
+ else:
1808
+ self.out(f"Skipping sample from {net_hash} (data stream is unknown)")
1809
+ return False
1810
+
1811
+ async def send_stream_samples(self):
1812
+ """Collect and send stream samples from all owned streams to appropriate recipients (async)."""
1813
+
1814
+ # Get samples from all the owned streams
1815
+ for net_hash, streams_dict in self.owned_streams.items():
1816
+
1817
+ # Skipping our processor input
1818
+ if DataProps.name_or_group_from_net_hash(net_hash) == "processor_in":
1819
+ continue
1820
+
1821
+ # Preparing content to send
1822
+ something_to_send = False
1823
+ content = {name: {} for name in streams_dict.keys()}
1824
+ content_data = {name: None for name in streams_dict.keys()}
1825
+ for name, stream in streams_dict.items():
1826
+ data = stream.get(requested_by="send_stream_samples")
1827
+
1828
+ if data is not None:
1829
+ something_to_send = True
1830
+ self.deb(f"[send_stream_samples] Preparing to send stream samples from {net_hash}, named {name} "
1831
+ f"(tag={stream.get_tag()}, uuid={stream.get_uuid()})")
1832
+
1833
+ content[name] = {'data': data, 'data_tag': stream.get_tag(), 'data_uuid': stream.get_uuid()}
1834
+ content_data[name] = data
1835
+
1836
+ stream.clear_uuid_if_marked_as_clearable()
1837
+
1838
+ # Checking if there is something valid in this group of streams
1839
+ if not something_to_send:
1840
+ continue
1841
+
1842
+ # Guessing recipients of direct message (if None, then PubSub)
1843
+ recipients = self.recipients[net_hash]
1844
+
1845
+ # Debug: force pubsub to be sent as direct message to the first agent
1846
+ # if self._recipients[net_hash] is None:
1847
+ # for peer_id in self.all_agents.keys():
1848
+ # recipient = peer_id
1849
+ # break
1850
+
1851
+ # If pubsub...
1852
+ if recipients is None:
1853
+ if DataStream.is_pubsub_from_net_hash(net_hash):
1854
+ self.deb(f"[send_stream_samples] Sending stream samples of the whole {net_hash} by pubsub")
1855
+
1856
+ for name in content.keys():
1857
+ content[name]['data'] = self.callback_before_sending_sample(content_data[name],
1858
+ content[name]['data_tag'],
1859
+ net_hash, name, None)
1860
+ self.deb(f"[send_stream_samples] - Sending {content[name]['data']}")
1861
+
1862
+ peer_id = DataStream.peer_id_from_net_hash(net_hash) # Guessing agent peer ID from the net hash
1863
+ ret = await self._node_conn.publish(peer_id, channel=net_hash,
1864
+ content_type=Msg.STREAM_SAMPLE,
1865
+ content=content)
1866
+
1867
+ self.deb(f"[send_stream_samples] Sending returned: " + str(ret))
1868
+
1869
+ # If direct message...
1870
+ else:
1871
+ if not DataStream.is_pubsub_from_net_hash(net_hash):
1872
+ _recipients = list(recipients.keys())
1873
+ for i, _recipient in enumerate(_recipients):
1874
+ self.deb(f"[send_stream_samples] Sending samples by direct message, to {_recipient}")
1875
+
1876
+ peer_id = _recipient # Peer ID from the recipient information
1877
+ name_or_group = DataProps.name_or_group_from_net_hash(net_hash)
1878
+ for name in content.keys():
1879
+ content[name]['data'] = self.callback_before_sending_sample(content_data[name],
1880
+ content[name]['data_tag'],
1881
+ net_hash, name, _recipient)
1882
+ self.deb(f"[send_stream_samples] - Sending {content[name]['data']}")
1883
+
1884
+ ret = await self._node_conn.send(peer_id, channel_trail=name_or_group,
1885
+ content_type=Msg.STREAM_SAMPLE,
1886
+ content=content)
1887
+
1888
+ self.recipient_got_one(net_hash, _recipient)
1889
+ self.deb(f"[send_stream_samples] Sending returned: " + str(ret))
1890
+ else:
1891
+ raise ValueError(f"Unexpected scenario: recipients set ({list(recipients.keys())}) "
1892
+ f"and sending on a pubsub stream")
1893
+
1894
+ def disable_proc_input(self, public: bool):
1895
+ stream_dict = self.owned_streams[self.get_proc_input_net_hash(public=public)]
1896
+ for stream_obj in stream_dict.values():
1897
+ if stream_obj.is_public() == public:
1898
+ stream_obj.disable()
1899
+
1900
+ def enable_proc_input(self, public: bool):
1901
+ stream_dict = self.owned_streams[self.get_proc_input_net_hash(public=public)]
1902
+ for stream_obj in stream_dict.values():
1903
+ if stream_obj.is_public() == public:
1904
+ stream_obj.enable()
1905
+
1906
+ def set_proc_input(self, data: str | Image | torch.Tensor | None, public: bool = False,
1907
+ uuid: str | None = None, data_type: str = "auto", data_tag: int = -1):
1908
+ peer_id = self.get_peer_ids()[0] if public else self.get_peer_ids()[1]
1909
+ proc_in = self.find_streams(peer_id, "processor_in")
1910
+ if proc_in is None or len(proc_in) == 0:
1911
+ return False
1912
+ for net_hash, stream_dict in proc_in.items():
1913
+ if not DataProps.is_pubsub_from_net_hash(net_hash):
1914
+ for stream_name, stream_obj in stream_dict.items():
1915
+ if stream_obj.props.is_public() == public:
1916
+ if (data is None or
1917
+ ((data_type == "text" or isinstance(data, str)) and stream_obj.props.is_text()) or
1918
+ ((data_type == "img" or isinstance(data, Image)) and stream_obj.props.is_img()) or
1919
+ ((data_type == "tensor" or isinstance(data, torch.Tensor))
1920
+ and stream_obj.props.is_tensor())):
1921
+ stream_obj.set(data) # This might fail if the stream is disabled
1922
+ stream_obj.set_uuid(None, expected=True)
1923
+ stream_obj.set_uuid(uuid, expected=False)
1924
+ stream_obj.set_tag(data_tag)
1925
+ return True
1926
+ return False
1927
+
1928
+ def get_tag(self, net_hash: str):
1929
+ if net_hash in self.known_streams:
1930
+ data_tag = -1
1931
+ stream_dict = self.known_streams[net_hash]
1932
+ for stream_obj in stream_dict.values():
1933
+ data_tag = max(data_tag, stream_obj.get_tag())
1934
+ return data_tag
1935
+ return -1
1936
+
1937
+ def set_tag(self, net_hash: str, data_tag: int):
1938
+ if net_hash in self.known_streams:
1939
+ stream_dict = self.known_streams[net_hash]
1940
+ for stream_obj in stream_dict.values():
1941
+ stream_obj.set_tag(data_tag)
1942
+
1943
+ def set_uuid(self, net_hash: str, uuid: int | None, expected: bool = False):
1944
+ if net_hash in self.known_streams:
1945
+ stream_dict = self.known_streams[net_hash]
1946
+ for stream_obj in stream_dict.values():
1947
+ stream_obj.set_uuid(uuid, expected=expected)
1948
+
1949
+ def force_action_step(self, step: int):
1950
+ self.overridden_action_step = step if step >= 0 else None
1951
+
1952
+ def get_action_step(self):
1953
+ """Retrieve the current action step from the agent's private/world behavior.
1954
+
1955
+ Returns:
1956
+ The current action step object from the HybridStateMachine's active action, or None if no action.
1957
+ """
1958
+ behav = self.behav if self.behav.is_enabled() else self.behav_lone_wolf
1959
+ return behav.get_action_step() if self.overridden_action_step is None else self.overridden_action_step
1960
+
1961
+ def is_last_action_step(self):
1962
+ """Check if the agent's current action (private/world behaviour) is on its last step.
1963
+
1964
+ Returns:
1965
+ True if the current action was its last step, False otherwise. Returns None if there is no active action.
1966
+ """
1967
+ behav = self.behav if self.behav.is_enabled() else self.behav_lone_wolf
1968
+ action = behav.get_action()
1969
+ if action is not None:
1970
+ return action.was_last_step_done()
1971
+ else:
1972
+ return None
1973
+
1974
+ def is_multi_steps_action(self):
1975
+ """Determines if the current action is a multistep action.
1976
+
1977
+ Returns:
1978
+ True if the action is multistep, False otherwise.
1979
+ """
1980
+ behav = self.behav if self.behav.is_enabled() else self.behav_lone_wolf
1981
+ action = behav.get_action()
1982
+ return action.is_multi_steps() if action is not None else False
1983
+
1984
+ async def set_policy(self,
1985
+ policy_method_name_or_policy_fcn: str | Callable[[list[Action]], [int, ActionRequest | None]],
1986
+ public: bool = False) -> bool:
1987
+ """Sets the policy to be used in selecting what action to perform in the current state (async).
1988
+
1989
+ Args:
1990
+ policy_method_name_or_policy_fcn: The name of a method of the Agent class that implements a policy function.
1991
+ It is a function that takes a list of `Action` objects that are candidates for execution, and returns
1992
+ the index of the selected action and an ActionRequest object with the action-requester details
1993
+ (requester, arguments, time, and UUID), or -1 and None if no action is selected.
1994
+ By design, every agent implements a basic policy function named "policy_default".
1995
+ public: If True, the policy will be applied to the public HSM, otherwise to the private/world one.
1996
+ """
1997
+ if isinstance(policy_method_name_or_policy_fcn, str):
1998
+ policy_fcn = getattr(self, policy_method_name_or_policy_fcn, None)
1999
+ if not callable(policy_fcn):
2000
+ return False
2001
+ behav = self.behav if not public else self.behav_lone_wolf
2002
+ behav.set_policy(policy_fcn)
2003
+ return True
2004
+ elif callable(policy_method_name_or_policy_fcn):
2005
+ policy_fcn = policy_method_name_or_policy_fcn
2006
+ behav = self.behav if not public else self.behav_lone_wolf
2007
+ behav.set_policy(policy_fcn)
2008
+ return True
2009
+ return False
2010
+
2011
+ def set_policy_filter(self,
2012
+ filter_method_name_or_policy_fcn: str | Callable[
2013
+ [int, ActionRequest | None, list[Action], dict], [int, ActionRequest | None]],
2014
+ public: bool = False) -> bool:
2015
+ """Sets the policy to be used in selecting what action to perform in the current state (async).
2016
+
2017
+ Args:
2018
+ filter_method_name_or_policy_fcn: The name of a method of the Agent class or a function that implements a
2019
+ policy filtering function, overriding what the policy decided.
2020
+ It is a function that takes what the policy decided, a list of `Action` objects that are candidates
2021
+ for execution, and a dictionary with customizable field (always including the "agent" key, with a ref
2022
+ to the current agent) and returns the index of the selected action and an ActionRequest object with the
2023
+ action-requester details (requester, arguments, time, and UUID), or -1 and None
2024
+ if no action is selected.
2025
+ By design, every agent comes with no filtering active.
2026
+ public: If True, the filter will be applied to the public HSM, otherwise to the private/world one.
2027
+ """
2028
+ if isinstance(filter_method_name_or_policy_fcn, str):
2029
+ filter_fcn = getattr(self, filter_method_name_or_policy_fcn, None)
2030
+ if not callable(filter_fcn):
2031
+ return False
2032
+ if public:
2033
+ self.policy_filter_lone_wolf = filter_fcn
2034
+ self.behav_lone_wolf.set_policy_filter(self.policy_filter_lone_wolf, self.policy_filter_lone_wolf_opts)
2035
+ self.policy_filter_lone_wolf_opts['agent'] = self # Forced (do it *after* set_policy_filter)
2036
+ self.policy_filter_lone_wolf_opts['public'] = True
2037
+ else:
2038
+ self.policy_filter = filter_fcn
2039
+ self.behav.set_policy_filter(self.policy_filter, self.policy_filter_opts)
2040
+ self.policy_filter_opts['agent'] = self # Forced (do it *after* set_policy_filter)
2041
+ self.policy_filter_opts['public'] = False
2042
+ return True
2043
+ elif callable(filter_method_name_or_policy_fcn):
2044
+ if public:
2045
+ self.policy_filter_lone_wolf = filter_method_name_or_policy_fcn
2046
+ self.behav_lone_wolf.set_policy_filter(self.policy_filter_lone_wolf, self.policy_filter_lone_wolf_opts)
2047
+ self.policy_filter_lone_wolf_opts['agent'] = self # Forced (do it *after* set_policy_filter)
2048
+ self.policy_filter_lone_wolf_opts['public'] = True
2049
+ else:
2050
+ self.policy_filter = filter_method_name_or_policy_fcn
2051
+ self.behav.set_policy_filter(self.policy_filter, self.policy_filter_opts)
2052
+ self.policy_filter_opts['agent'] = self # Forced (do it *after* set_policy_filter)
2053
+ self.policy_filter_opts['public'] = False
2054
+ return True
2055
+ return False
2056
+
2057
+ def policy_default(self, actions_list: list[Action]) -> tuple[int, ActionRequest | None]:
2058
+ """This is the default policy for selecting which action to execute from a list of feasible actions.
2059
+ It prioritizes actions that have been explicitly requested (i.e., have pending requests) on a first-come,
2060
+ first-served basis. If no requested actions are found, it then selects the first action in the list that is
2061
+ marked as `ready`.
2062
+
2063
+ Args:
2064
+ actions_list: A list of `Action` objects that are candidates for execution.
2065
+
2066
+ Returns:
2067
+ The index of the selected action and an ActionRequest object with the requester details (requester,
2068
+ arguments, time, and UUID), or -1 and None if no action is selected.
2069
+ """
2070
+ for i, action in enumerate(actions_list):
2071
+ _list_of_requests = action.get_list_of_requests()
2072
+ if len(_list_of_requests) > 0:
2073
+ _selected_action_idx = i
2074
+ _selected_request = _list_of_requests.get_oldest_request()
2075
+ return _selected_action_idx, _selected_request
2076
+ for i, action in enumerate(actions_list):
2077
+ if action.is_ready(consider_requests=False):
2078
+ _selected_action_idx = i
2079
+ _selected_request = None
2080
+ return _selected_action_idx, _selected_request
2081
+ _selected_action_idx = -1
2082
+ _selected_request = None
2083
+ return _selected_action_idx, _selected_request
2084
+
2085
+ def add_recipient(self, net_hash: str, peer_id: str | list | tuple | set, samples: int = 1):
2086
+ if net_hash in self.recipients:
2087
+ if self.recipients[net_hash] is None:
2088
+ self.recipients[net_hash] = {}
2089
+ if not isinstance(peer_id, (list, tuple, set)):
2090
+ if peer_id is not None:
2091
+ self.recipients[net_hash][peer_id] = samples
2092
+ else:
2093
+ for _peer_id in peer_id:
2094
+ if _peer_id is not None:
2095
+ self.recipients[net_hash][_peer_id] = samples
2096
+
2097
+ def remove_recipient(self, net_hash: str, peer_id: str):
2098
+ if net_hash in self.recipients and self.recipients[net_hash] is not None:
2099
+ if peer_id in self.recipients[net_hash]:
2100
+ del self.recipients[net_hash][peer_id]
2101
+ if len(self.recipients[net_hash]) == 0:
2102
+ self.recipients[net_hash] = None
2103
+
2104
+ def clear_recipients(self, net_hash: str):
2105
+ if net_hash in self.recipients:
2106
+ self.recipients[net_hash] = None
2107
+
2108
+ def recipient_got_one(self, net_hash: str, peer_id: str):
2109
+ if net_hash in self.recipients and self.recipients[net_hash] is not None:
2110
+ if peer_id in self.recipients[net_hash]:
2111
+ if self.recipients[net_hash][peer_id] == 1 or self.recipients[net_hash][peer_id] < 0:
2112
+ self.remove_recipient(net_hash, peer_id)
2113
+
2114
+ def mark_recipient_as_removable(self, net_hash: str, peer_id: str | list | tuple | set):
2115
+ if net_hash in self.recipients:
2116
+ if isinstance(peer_id, (list, tuple, set)):
2117
+ for _peer_id in peer_id:
2118
+ if _peer_id in self.recipients[net_hash]:
2119
+ if self.recipients[net_hash][_peer_id] > 0:
2120
+ self.recipients[net_hash][_peer_id] = -self.recipients[net_hash][_peer_id]
2121
+ else:
2122
+ if peer_id in self.recipients[net_hash]:
2123
+ if self.recipients[net_hash][peer_id] > 0:
2124
+ self.recipients[net_hash][peer_id] = -self.recipients[net_hash][peer_id]
2125
+
2126
+ def proc_callback_inputs(self, inputs):
2127
+ """A callback method that saves the inputs to the processor right before execution.
2128
+
2129
+ Args:
2130
+ inputs: The data inputs for the processor.
2131
+
2132
+ Returns:
2133
+ The same inputs passed to the function.
2134
+ """
2135
+ self.proc_last_inputs = inputs
2136
+ return inputs
2137
+
2138
+ def proc_callback_outputs(self, outputs):
2139
+ """A callback method that saves the outputs from the processor right after execution.
2140
+
2141
+ Args:
2142
+ outputs: The data outputs from the processor.
2143
+
2144
+ Returns:
2145
+ The same outputs passed to the function.
2146
+ """
2147
+ self.proc_last_outputs = outputs
2148
+ return outputs
2149
+
2150
+ def callback_before_sending_sample(self, data, data_tag: int,
2151
+ net_hash: str, stream_name: str, recipient: str | None):
2152
+ """A callback method that handles the steam data right before sending it through the network.
2153
+
2154
+ Args:
2155
+ data: The stream data sample.
2156
+ data_tag: The tag of the sample.
2157
+ stream_name: The name of the data stream.
2158
+ net_hash: The net hash of the whole stream.
2159
+ recipient: The (planned) recipient of this sample (or None in case of pubsub).
2160
+
2161
+ Returns:
2162
+ The same data passed to the function.
2163
+ """
2164
+ return data
2165
+
2166
+ def agent_state_dict(self):
2167
+ """Returns a dictionary containing an instance of the agent's state that can be saved."""
2168
+ save_in_state = ['world_profile',]
2169
+ return {k: getattr(self, k) for k in save_in_state}
2170
+
2171
+ def save(self, where: str = "") -> bool:
2172
+ """Save the agent's state, including its processor and other attributes, to a specified location.
2173
+
2174
+ Args:
2175
+ where: The directory path where the agent's state should be saved. Defaults to "".
2176
+
2177
+ Returns:
2178
+ True upon successful saving.
2179
+
2180
+ Raises:
2181
+ IOError: If there is an issue with file operations (e.g., directory creation, writing files).
2182
+ TypeError, ValueError, RuntimeError: For other potential issues during serialization or saving.
2183
+ """
2184
+
2185
+ if where == '':
2186
+ if self._node_identity_dir is None or len(self._node_identity_dir) == 0:
2187
+ return False
2188
+ where = os.path.join(self._node_identity_dir, "agent_state") # Default save path
2189
+
2190
+ os.makedirs(where, exist_ok=True)
2191
+
2192
+ # Saving the processor
2193
+ if self.proc is not None and self.proc_updated_since_last_save:
2194
+ pt_final = os.path.join(where, f"{self._node_name}.pt")
2195
+ pt_tmp = pt_final + ".tmp"
2196
+ try:
2197
+ checkpoint = {
2198
+ 'model_state_dict': self.proc.state_dict(),
2199
+ }
2200
+
2201
+ # If your agent has an optimizer, save its state too
2202
+ if self.proc_opts.get('optimizer') is not None:
2203
+ checkpoint['optimizer_state_dict'] = self.proc_opts['optimizer'].state_dict()
2204
+
2205
+ torch.save(checkpoint, pt_tmp)
2206
+ os.replace(pt_tmp, pt_final) # Atomic move
2207
+ self.proc_updated_since_last_save = False
2208
+ except Exception as e:
2209
+ if os.path.exists(pt_tmp):
2210
+ os.remove(pt_tmp)
2211
+ self.out(f"Error saving processor: {e}")
2212
+ raise e
2213
+
2214
+ # Save Agent State
2215
+ pkl_final = os.path.join(where, f"{self._node_name}.pkl")
2216
+ pkl_tmp = pkl_final + ".tmp"
2217
+ try:
2218
+ state = self.agent_state_dict()
2219
+ with open(pkl_tmp, "wb") as f:
2220
+ pickle.dump(state, f)
2221
+ os.replace(pkl_tmp, pkl_final)
2222
+ except Exception as e:
2223
+ self.out(f"Could not save " + ("agent" if not self.is_world else "world") + f": {e}")
2224
+ if os.path.exists(pkl_tmp):
2225
+ os.remove(pkl_tmp)
2226
+ raise e
2227
+
2228
+ return True
2229
+
2230
+ def load(self, where: str = "") -> bool:
2231
+ """Load the agent's state from a specified location.
2232
+
2233
+ Args:
2234
+ where: The directory path from which the agent's state should be loaded. Defaults to "".
2235
+
2236
+ Returns:
2237
+ True if loading succeeded.
2238
+ """
2239
+
2240
+ if where == '':
2241
+ if self._node_identity_dir is None or len(self._node_identity_dir) == 0:
2242
+ return False
2243
+ where = os.path.join(self._node_identity_dir, "agent_state") # Default save path
2244
+
2245
+ # Check if directory exists
2246
+ if not os.path.exists(where):
2247
+ self.out("No state folder found for " + ("agent" if not self.is_world else "world") +
2248
+ f" {self._node_name}.")
2249
+ return False
2250
+
2251
+ # Check if the specific pickle file exists
2252
+ pkl_path = os.path.join(where, f"{self._node_name}.pkl")
2253
+ if not os.path.exists(pkl_path):
2254
+ self.out("No saved state found for " + ("agent" if not self.is_world else "world") +
2255
+ f" {self._node_name}.")
2256
+ return False
2257
+
2258
+ # Loading the agent state dictionary
2259
+ try:
2260
+ with open(pkl_path, "rb") as f:
2261
+ agent_state_dict = pickle.load(f)
2262
+ except Exception as e:
2263
+ raise Exception(f"Error loading pickle file at {pkl_path}: {e}")
2264
+
2265
+ # Update self's attributes with the loaded object's attributes
2266
+ self.__dict__.update(agent_state_dict)
2267
+
2268
+ # Check if we also need to load the processor state
2269
+ pt_path = os.path.join(where, f"{self._node_name}.pt")
2270
+ load_proc = self.proc is not None and os.path.exists(pt_path)
2271
+ if load_proc:
2272
+ try:
2273
+ checkpoint = torch.load(pt_path)
2274
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
2275
+ self.proc.load_state_dict(checkpoint['model_state_dict'])
2276
+
2277
+ # Restore Optimizer to proc_opts
2278
+ if 'optimizer_state_dict' in checkpoint and self.proc_opts.get('optimizer') is not None:
2279
+ self.proc_opts['optimizer'].load_state_dict(checkpoint['optimizer_state_dict'])
2280
+ else:
2281
+ self.proc.load_state_dict(checkpoint)
2282
+ except Exception as e:
2283
+ raise Exception(f"Error loading processor state: {e}")
2284
+
2285
+ return True
2286
+
2287
+ def __str__(self):
2288
+ """String representation of an agent.
2289
+
2290
+ Returns:
2291
+ A formatted string describing the agent's current state and relationships.
2292
+ """
2293
+ s = ("[" + ("Agent" if not self.is_world else "World") + "]"
2294
+ + f" {self._node_name} (role: {self._node_profile.get_dynamic_profile()['connections']['role']})")
2295
+ if len(self.world_masters) > 0:
2296
+ s += "\n\t- known world masters:"
2297
+ for _s in self.world_masters.keys():
2298
+ s += "\n\t\t" + str(_s)
2299
+ if len(self.world_agents) > 0:
2300
+ s += "\n\t- known agents living in the same world (non-world-masters):"
2301
+ for _s in self.world_agents.keys():
2302
+ s += "\n\t\t" + str(_s)
2303
+ if len(self.public_agents) > 0:
2304
+ s += "\n\t- known lone wolves:"
2305
+ for _s in self.public_agents.keys():
2306
+ s += "\n\t\t" + str(_s)
2307
+ if len(self.known_streams) > 0:
2308
+ s += "\n\t- known_streams:"
2309
+ for _s in self.known_streams:
2310
+ s += "\n\t\t" + str(_s)
2311
+ s += "\n\t- behaviour (public):"
2312
+ s += "\n\t\t" + (str(self.behav_lone_wolf).replace("\n", "\n\t\t")
2313
+ if self.behav_lone_wolf is not None else "none")
2314
+ s += "\n\t- behaviour (private):"
2315
+ s += "\n\t\t" + (str(self.behav).replace("\n", "\n\t\t") if self.behav is not None else "none")
2316
+ s += "\n\t- processor:"
2317
+ s += "\n\t\t" + (str(self.proc).replace("\n", "\n\t\t") if self.proc is not None else "none")
2318
+ return s
2319
+
2320
+ async def __remove_all_world_related_agents(self):
2321
+ """Remove all world-related agents (masters and regular agents) from the agent's known lists (async)."""
2322
+
2323
+ to_remove = list(self.world_masters.keys())
2324
+ for peer_id in to_remove:
2325
+ await self.remove_agent(peer_id)
2326
+
2327
+ to_remove = list(self.world_agents.keys())
2328
+ for peer_id in to_remove:
2329
+ await self.remove_agent(peer_id)
2330
+
2331
+ async def __remove_all_world_private_streams(self):
2332
+ """Remove all known streams that are flagged as not-public and are not owned by this agent (async)."""
2333
+
2334
+ # Find what to remove
2335
+ to_remove = []
2336
+ for net_hash, stream_dict in self.known_streams.items():
2337
+ for name, stream_obj in stream_dict.items():
2338
+ if not stream_obj.get_props().is_public() and net_hash not in self.owned_streams:
2339
+ to_remove.append((DataProps.peer_id_from_net_hash(net_hash), name))
2340
+
2341
+ # Remove it
2342
+ for (peer_id, name) in to_remove:
2343
+ await self.remove_streams(peer_id, name)
2344
+
2345
+ # Nuke recipients and recipient slots associated to these streams
2346
+ recipient_net_hashes = list(self.recipients.keys())
2347
+ for net_hash in recipient_net_hashes:
2348
+ if net_hash not in self.known_streams:
2349
+ del self.recipients[net_hash]
2350
+
2351
+ def remove_peer_from_agent_status_attrs(self, peer_id):
2352
+ """Remove a peer ID from the status of the agent, assuming it to be the represented by attributes that start
2353
+ with '_'."""
2354
+ for attr_name in dir(self):
2355
+ if attr_name.startswith("_") and (not attr_name.startswith("__") and not attr_name.startswith("_Agent")
2356
+ and not attr_name.startswith("_WAgent")):
2357
+ try:
2358
+ value = getattr(self, attr_name)
2359
+ if isinstance(value, list):
2360
+ setattr(self, attr_name, [v for v in value if v != peer_id])
2361
+ elif isinstance(value, set):
2362
+ value.discard(peer_id)
2363
+ elif isinstance(value, dict):
2364
+ if peer_id in value:
2365
+ del value[peer_id]
2366
+ except AttributeError:
2367
+ continue # Skip read-only attributes
2368
+
2369
+ def reset_agent_status_attrs(self):
2370
+ """Resets attributes that represent the status of the agent, assuming to be the ones that start with '_'."""
2371
+ for attr_name in dir(self):
2372
+ if attr_name.startswith("_") and (not attr_name.startswith("__") and not attr_name.startswith("_Agent")
2373
+ and not attr_name.startswith("_WAgent")):
2374
+ try:
2375
+ value = getattr(self, attr_name)
2376
+ if isinstance(value, list):
2377
+ setattr(self, attr_name, [])
2378
+ elif isinstance(value, set):
2379
+ setattr(self, attr_name, set())
2380
+ elif isinstance(value, dict):
2381
+ setattr(self, attr_name, {})
2382
+ elif isinstance(value, int):
2383
+ setattr(self, attr_name, 0)
2384
+ elif isinstance(value, float):
2385
+ setattr(self, attr_name, 0.)
2386
+ elif isinstance(value, bool):
2387
+ setattr(self, attr_name, False)
2388
+ except AttributeError:
2389
+ continue # Skip read-only attributes