unaiverse 0.1.11__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unaiverse might be problematic. Click here for more details.
- unaiverse/__init__.py +19 -0
- unaiverse/agent.py +2090 -0
- unaiverse/agent_basics.py +1948 -0
- unaiverse/clock.py +221 -0
- unaiverse/dataprops.py +1236 -0
- unaiverse/hsm.py +1892 -0
- unaiverse/modules/__init__.py +18 -0
- unaiverse/modules/cnu/__init__.py +17 -0
- unaiverse/modules/cnu/cnus.py +536 -0
- unaiverse/modules/cnu/layers.py +261 -0
- unaiverse/modules/cnu/psi.py +60 -0
- unaiverse/modules/hl/__init__.py +15 -0
- unaiverse/modules/hl/hl_utils.py +411 -0
- unaiverse/modules/networks.py +1509 -0
- unaiverse/modules/utils.py +710 -0
- unaiverse/networking/__init__.py +16 -0
- unaiverse/networking/node/__init__.py +18 -0
- unaiverse/networking/node/connpool.py +1308 -0
- unaiverse/networking/node/node.py +2499 -0
- unaiverse/networking/node/profile.py +446 -0
- unaiverse/networking/node/tokens.py +79 -0
- unaiverse/networking/p2p/__init__.py +187 -0
- unaiverse/networking/p2p/go.mod +127 -0
- unaiverse/networking/p2p/go.sum +548 -0
- unaiverse/networking/p2p/golibp2p.py +18 -0
- unaiverse/networking/p2p/golibp2p.pyi +135 -0
- unaiverse/networking/p2p/lib.go +2662 -0
- unaiverse/networking/p2p/lib.go.sha256 +1 -0
- unaiverse/networking/p2p/lib_types.py +312 -0
- unaiverse/networking/p2p/message_pb2.py +50 -0
- unaiverse/networking/p2p/messages.py +362 -0
- unaiverse/networking/p2p/mylogger.py +77 -0
- unaiverse/networking/p2p/p2p.py +871 -0
- unaiverse/networking/p2p/proto-go/message.pb.go +846 -0
- unaiverse/networking/p2p/unailib.cpython-311-darwin.so +0 -0
- unaiverse/stats.py +1481 -0
- unaiverse/streamlib/__init__.py +15 -0
- unaiverse/streamlib/streamlib.py +210 -0
- unaiverse/streams.py +776 -0
- unaiverse/utils/__init__.py +16 -0
- unaiverse/utils/lone_wolf.json +24 -0
- unaiverse/utils/misc.py +310 -0
- unaiverse/utils/sandbox.py +293 -0
- unaiverse/utils/server.py +435 -0
- unaiverse/world.py +335 -0
- unaiverse-0.1.11.dist-info/METADATA +367 -0
- unaiverse-0.1.11.dist-info/RECORD +50 -0
- unaiverse-0.1.11.dist-info/WHEEL +6 -0
- unaiverse-0.1.11.dist-info/licenses/LICENSE +43 -0
- unaiverse-0.1.11.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,435 @@
|
|
|
1
|
+
"""
|
|
2
|
+
█████ █████ ██████ █████ █████ █████ █████ ██████████ ███████████ █████████ ██████████
|
|
3
|
+
░░███ ░░███ ░░██████ ░░███ ░░███ ░░███ ░░███ ░░███░░░░░█░░███░░░░░███ ███░░░░░███░░███░░░░░█
|
|
4
|
+
░███ ░███ ░███░███ ░███ ██████ ░███ ░███ ░███ ░███ █ ░ ░███ ░███ ░███ ░░░ ░███ █ ░
|
|
5
|
+
░███ ░███ ░███░░███░███ ░░░░░███ ░███ ░███ ░███ ░██████ ░██████████ ░░█████████ ░██████
|
|
6
|
+
░███ ░███ ░███ ░░██████ ███████ ░███ ░░███ ███ ░███░░█ ░███░░░░░███ ░░░░░░░░███ ░███░░█
|
|
7
|
+
░███ ░███ ░███ ░░█████ ███░░███ ░███ ░░░█████░ ░███ ░ █ ░███ ░███ ███ ░███ ░███ ░ █
|
|
8
|
+
░░████████ █████ ░░█████░░████████ █████ ░░███ ██████████ █████ █████░░█████████ ██████████
|
|
9
|
+
░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░ ░░░░░ ░░░ ░░░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░░ ░░░░░░░░░░
|
|
10
|
+
A Collectionless AI Project (https://collectionless.ai)
|
|
11
|
+
Registration/Login: https://unaiverse.io
|
|
12
|
+
Code Repositories: https://github.com/collectionlessai/
|
|
13
|
+
Main Developers: Stefano Melacci (Project Leader), Christian Di Maio, Tommaso Guidi
|
|
14
|
+
"""
|
|
15
|
+
import io
|
|
16
|
+
import os
|
|
17
|
+
import json
|
|
18
|
+
import torch
|
|
19
|
+
import base64
|
|
20
|
+
from PIL import Image
|
|
21
|
+
from flask_cors import CORS
|
|
22
|
+
from threading import Thread
|
|
23
|
+
from unaiverse.dataprops import DataProps
|
|
24
|
+
import torchvision.transforms as transforms
|
|
25
|
+
from unaiverse.streams import BufferedDataStream
|
|
26
|
+
from unaiverse.networking.node.node import NodeSynchronizer
|
|
27
|
+
from flask import Flask, jsonify, request, send_from_directory
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class Server:
|
|
31
|
+
|
|
32
|
+
def __init__(self, node_synchronizer: NodeSynchronizer,
|
|
33
|
+
root: str = '../../../../zoo/debug_viewer/www',
|
|
34
|
+
port: int = 5001,
|
|
35
|
+
checkpoints: dict[str, list[dict] | int] | str | None = None,
|
|
36
|
+
y_range: list[float] | None = None):
|
|
37
|
+
self.node_synchronizer = node_synchronizer
|
|
38
|
+
self.node_synchronizer.using_server = True # Forcing
|
|
39
|
+
self.root = os.path.join(os.path.dirname(os.path.abspath(__file__)), root)
|
|
40
|
+
self.root_css = self.root + "/static/css"
|
|
41
|
+
self.root_js = self.root + "/static/js"
|
|
42
|
+
self.port = port
|
|
43
|
+
self.app = Flask(__name__, template_folder=self.root)
|
|
44
|
+
CORS(self.app) # To handle cross-origin requests (needed for development)
|
|
45
|
+
self.register_routes()
|
|
46
|
+
self.thumb_transforms = transforms.Compose([transforms.Resize(64), transforms.CenterCrop(64)])
|
|
47
|
+
self.y_range = y_range
|
|
48
|
+
self.visu_name_to_net_hash = {}
|
|
49
|
+
|
|
50
|
+
# Loading checkpoints, if needed
|
|
51
|
+
if checkpoints is not None and isinstance(checkpoints, str): # String: assumed to be a file name
|
|
52
|
+
file_name = checkpoints
|
|
53
|
+
checkpoints = {"checkpoints": None, "matched": -1, "current": 0}
|
|
54
|
+
with open(file_name, 'r') as file:
|
|
55
|
+
checkpoints["checkpoints"] = json.load(file) # From filename to dictionary
|
|
56
|
+
elif checkpoints is not None:
|
|
57
|
+
checkpoints = {"checkpoints": checkpoints, "matched": -1, "current": 0}
|
|
58
|
+
self.node_synchronizer.server_checkpoints = checkpoints
|
|
59
|
+
|
|
60
|
+
# Fixing y_range as needed
|
|
61
|
+
self.y_range = [None, None] if self.y_range is None else self.y_range
|
|
62
|
+
assert len(self.y_range) == 2, "Invalid y_range argument (it must be either None of a list of 2 floats)"
|
|
63
|
+
|
|
64
|
+
# Starting a new thread
|
|
65
|
+
thread = Thread(target=self.__run_server)
|
|
66
|
+
thread.start()
|
|
67
|
+
|
|
68
|
+
def __run_server(self):
|
|
69
|
+
self.app.run(host='0.0.0.0', port=self.port, threaded=True, debug=False) # Run Flask with threading enabled
|
|
70
|
+
|
|
71
|
+
def register_routes(self):
|
|
72
|
+
self.app.add_url_rule('/', view_func=self.serve_index, methods=['GET'])
|
|
73
|
+
self.app.add_url_rule('/<path:filename>', view_func=self.serve_root, methods=['GET'])
|
|
74
|
+
self.app.add_url_rule('/static/css/<path:filename>', view_func=self.serve_static_css, methods=['GET'])
|
|
75
|
+
self.app.add_url_rule('/static/js/<path:filename>', view_func=self.serve_static_js, methods=['GET'])
|
|
76
|
+
self.app.add_url_rule('/get_play_pause_status', view_func=self.get_play_pause_status, methods=['GET'])
|
|
77
|
+
self.app.add_url_rule('/ask_to_pause', view_func=self.ask_to_pause, methods=['GET'])
|
|
78
|
+
self.app.add_url_rule('/ask_to_play', view_func=self.ask_to_play, methods=['GET'])
|
|
79
|
+
self.app.add_url_rule('/get_env_name', view_func=self.get_env_name, methods=['GET'])
|
|
80
|
+
self.app.add_url_rule('/get_summary', view_func=self.get_summary, methods=['GET'])
|
|
81
|
+
self.app.add_url_rule('/get_authority', view_func=self.get_authority, methods=['GET'])
|
|
82
|
+
self.app.add_url_rule('/get_behav', view_func=self.get_behav, methods=['GET'])
|
|
83
|
+
self.app.add_url_rule('/get_behav_status', view_func=self.get_behav_status, methods=['GET'])
|
|
84
|
+
self.app.add_url_rule('/get_list_of_agents', view_func=self.get_list_of_agents, methods=['GET'])
|
|
85
|
+
self.app.add_url_rule('/get_list_of_streams', view_func=self.get_list_of_streams, methods=['GET'])
|
|
86
|
+
self.app.add_url_rule('/get_stream', view_func=self.get_stream, methods=['GET'])
|
|
87
|
+
self.app.add_url_rule('/get_console', view_func=self.get_console, methods=['GET'])
|
|
88
|
+
self.app.add_url_rule('/save', view_func=self.save, methods=['GET'])
|
|
89
|
+
|
|
90
|
+
@staticmethod
|
|
91
|
+
def pack_data(_data):
|
|
92
|
+
_type = type(_data).__name__ if _data is not None else "none"
|
|
93
|
+
|
|
94
|
+
def is_tensor_or_list_of_tensors(_d):
|
|
95
|
+
if isinstance(_d, list) and len(_d) > 0 and isinstance(_d[0], torch.Tensor):
|
|
96
|
+
return True
|
|
97
|
+
elif isinstance(_d, torch.Tensor):
|
|
98
|
+
return True
|
|
99
|
+
else:
|
|
100
|
+
return False
|
|
101
|
+
|
|
102
|
+
def is_pil_or_list_of_pils(_d):
|
|
103
|
+
if isinstance(_d, list) and len(_d) > 0 and isinstance(_d[0], Image.Image):
|
|
104
|
+
return True
|
|
105
|
+
elif isinstance(_d, Image.Image):
|
|
106
|
+
return True
|
|
107
|
+
else:
|
|
108
|
+
return False
|
|
109
|
+
|
|
110
|
+
# List of pytorch tensors (or nones)
|
|
111
|
+
def encode_tensor_or_list_of_tensors(__data):
|
|
112
|
+
__type = ""
|
|
113
|
+
|
|
114
|
+
if isinstance(__data, list) and len(__data) > 0 and isinstance(__data[0], torch.Tensor):
|
|
115
|
+
found_tensor = False
|
|
116
|
+
__data_b64 = []
|
|
117
|
+
for __tensor in __data:
|
|
118
|
+
if __tensor is not None:
|
|
119
|
+
if not found_tensor:
|
|
120
|
+
found_tensor = True
|
|
121
|
+
__type = "list_" + type(__data[0]).__name__ + "_" + __data[0].dtype.__str__().split('.')[-1]
|
|
122
|
+
|
|
123
|
+
__data_b64.append(base64.b64encode(__tensor.detach().cpu().numpy().tobytes()).decode('utf-8'))
|
|
124
|
+
else:
|
|
125
|
+
__data_b64.append(None) # There might be some None in some list elements...
|
|
126
|
+
if not found_tensor:
|
|
127
|
+
__type = "none"
|
|
128
|
+
__data = __data_b64
|
|
129
|
+
|
|
130
|
+
# Pytorch tensor
|
|
131
|
+
if isinstance(__data, torch.Tensor):
|
|
132
|
+
__type = __data.dtype.__str__().split('.')[-1]
|
|
133
|
+
__data = base64.b64encode(__data.detach().cpu().numpy()).decode('utf-8')
|
|
134
|
+
|
|
135
|
+
return __data, __type
|
|
136
|
+
|
|
137
|
+
# List of PIL images (or nones)
|
|
138
|
+
def encode_pil_or_list_of_pils(__data):
|
|
139
|
+
__type = ""
|
|
140
|
+
|
|
141
|
+
if isinstance(__data, list) and len(__data) > 0 and isinstance(__data[0], Image.Image):
|
|
142
|
+
found_image = False
|
|
143
|
+
_data_b64 = []
|
|
144
|
+
for __img in __data:
|
|
145
|
+
if __img is not None:
|
|
146
|
+
if not found_image:
|
|
147
|
+
found_image = True
|
|
148
|
+
__type = "list_png"
|
|
149
|
+
|
|
150
|
+
buffer = io.BytesIO()
|
|
151
|
+
__img.save(buffer, format="PNG", optimize=True, compress_level=9)
|
|
152
|
+
buffer.seek(0)
|
|
153
|
+
_data_b64.append(f"data:image/png;base64,{base64.b64encode(buffer.read()).decode('utf-8')}")
|
|
154
|
+
else:
|
|
155
|
+
_data_b64.append(None) # There might be some None in some list elements...
|
|
156
|
+
if not found_image:
|
|
157
|
+
__type = "none"
|
|
158
|
+
__data = _data_b64
|
|
159
|
+
|
|
160
|
+
# Pil image
|
|
161
|
+
if isinstance(__data, Image.Image):
|
|
162
|
+
__type = "png"
|
|
163
|
+
__buffer = io.BytesIO()
|
|
164
|
+
__data.save(__buffer, format="PNG", optimize=True, compress_level=9)
|
|
165
|
+
__data = f"data:image/png;base64,{base64.b64encode(__buffer.read()).decode('utf-8')}"
|
|
166
|
+
|
|
167
|
+
return __data, __type
|
|
168
|
+
|
|
169
|
+
# In the case of a dictionary, we look for values that are (list of) tensors/images and encode them;
|
|
170
|
+
# we augment the key name adding "-type", where "type" is the type of the packed data
|
|
171
|
+
if _type == "dict":
|
|
172
|
+
keys = list(_data.keys())
|
|
173
|
+
for k in keys:
|
|
174
|
+
v = _data[k]
|
|
175
|
+
if is_tensor_or_list_of_tensors(v):
|
|
176
|
+
v_encoded, v_type = encode_tensor_or_list_of_tensors(v)
|
|
177
|
+
del _data[k]
|
|
178
|
+
k = k + "-" + v_type
|
|
179
|
+
_data[k] = v_encoded
|
|
180
|
+
elif is_pil_or_list_of_pils(v):
|
|
181
|
+
v_encoded, v_type = encode_pil_or_list_of_pils(v)
|
|
182
|
+
del _data[k]
|
|
183
|
+
k = k + "-" + v_type
|
|
184
|
+
_data[k] = v_encoded
|
|
185
|
+
else:
|
|
186
|
+
if is_tensor_or_list_of_tensors(_data):
|
|
187
|
+
_data, _data_type = encode_tensor_or_list_of_tensors(_data)
|
|
188
|
+
_type += "_" + _data_type
|
|
189
|
+
elif is_pil_or_list_of_pils(_data):
|
|
190
|
+
_data, _data_type = encode_pil_or_list_of_pils(_data)
|
|
191
|
+
_type += "_" + _data_type
|
|
192
|
+
else:
|
|
193
|
+
pass
|
|
194
|
+
|
|
195
|
+
# Generate JSON for the whole data, where some of them might have been base64 encoded (tensors/images)
|
|
196
|
+
return jsonify({"data": _data, "type": _type})
|
|
197
|
+
|
|
198
|
+
def serve_index(self):
|
|
199
|
+
return send_from_directory(self.root, 'index.html')
|
|
200
|
+
|
|
201
|
+
def serve_root(self, filename):
|
|
202
|
+
return send_from_directory(self.root, filename)
|
|
203
|
+
|
|
204
|
+
def serve_static_js(self, filename):
|
|
205
|
+
return send_from_directory(self.root_js, filename)
|
|
206
|
+
|
|
207
|
+
def serve_static_css(self, filename):
|
|
208
|
+
return send_from_directory(self.root_css, filename)
|
|
209
|
+
|
|
210
|
+
def get_play_pause_status(self):
|
|
211
|
+
ret = {'status': None,
|
|
212
|
+
'still_to_play': self.node_synchronizer.skip_clear_for,
|
|
213
|
+
'time': self.node_synchronizer.clock.get_time(passed=True),
|
|
214
|
+
'y_range': self.y_range,
|
|
215
|
+
'matched_checkpoint_to_show': None,
|
|
216
|
+
'more_checkpoints_available': False}
|
|
217
|
+
if self.node_synchronizer.synch_cycle == self.node_synchronizer.synch_cycles:
|
|
218
|
+
ret['status'] = 'ended'
|
|
219
|
+
elif self.node_synchronizer.step_event.is_set():
|
|
220
|
+
ret['status'] = 'playing'
|
|
221
|
+
elif self.node_synchronizer.wait_event.is_set():
|
|
222
|
+
ret['status'] = 'paused'
|
|
223
|
+
if self.node_synchronizer.server_checkpoints is not None:
|
|
224
|
+
ret['more_checkpoints_available'] = self.node_synchronizer.server_checkpoints["current"] >= 0
|
|
225
|
+
if self.node_synchronizer.server_checkpoints["matched"] >= 0:
|
|
226
|
+
ret['matched_checkpoint_to_show'] = self.node_synchronizer.server_checkpoints["checkpoints"][
|
|
227
|
+
self.node_synchronizer.server_checkpoints["matched"]]["show"]
|
|
228
|
+
return Server.pack_data(ret)
|
|
229
|
+
|
|
230
|
+
def ask_to_play(self):
|
|
231
|
+
steps = int(request.args.get('steps'))
|
|
232
|
+
if steps >= 0:
|
|
233
|
+
self.node_synchronizer.skip_clear_for = steps - 1
|
|
234
|
+
else:
|
|
235
|
+
self.node_synchronizer.skip_clear_for = steps
|
|
236
|
+
self.node_synchronizer.step_event.set()
|
|
237
|
+
return Server.pack_data(self.node_synchronizer.synch_cycle)
|
|
238
|
+
|
|
239
|
+
def ask_to_pause(self):
|
|
240
|
+
self.node_synchronizer.skip_clear_for = 0
|
|
241
|
+
return Server.pack_data(self.node_synchronizer.synch_cycle)
|
|
242
|
+
|
|
243
|
+
def get_env_name(self):
|
|
244
|
+
return Server.pack_data({"name": self.node_synchronizer.world.get_name(),
|
|
245
|
+
"title": self.node_synchronizer.world.get_name()})
|
|
246
|
+
|
|
247
|
+
def get_summary(self):
|
|
248
|
+
agent_name = request.args.get('agent_name')
|
|
249
|
+
desc = str(self.node_synchronizer.agent_nodes[agent_name].agent) \
|
|
250
|
+
if agent_name != self.node_synchronizer.world.get_name() else str(self.node_synchronizer.world)
|
|
251
|
+
return Server.pack_data(desc)
|
|
252
|
+
|
|
253
|
+
def get_authority(self):
|
|
254
|
+
agent_name = request.args.get('agent_name')
|
|
255
|
+
role = self.node_synchronizer.agent_name_to_profile[agent_name].get_dynamic_profile()['connections']['role']
|
|
256
|
+
authority = 1.0 if "high_authority" in role else 0.0
|
|
257
|
+
return Server.pack_data(authority)
|
|
258
|
+
|
|
259
|
+
def get_behav(self):
|
|
260
|
+
agent_name = request.args.get('agent_name')
|
|
261
|
+
if agent_name == self.node_synchronizer.world.get_name():
|
|
262
|
+
behav = self.node_synchronizer.world.behav
|
|
263
|
+
else:
|
|
264
|
+
behav = self.node_synchronizer.agent_nodes[agent_name].agent.behav
|
|
265
|
+
return Server.pack_data(str(behav.to_graphviz().source))
|
|
266
|
+
|
|
267
|
+
def get_behav_status(self):
|
|
268
|
+
agent_name = request.args.get('agent_name')
|
|
269
|
+
if agent_name == self.node_synchronizer.world.get_name():
|
|
270
|
+
behav = self.node_synchronizer.world.behav
|
|
271
|
+
else:
|
|
272
|
+
behav = self.node_synchronizer.agent_nodes[agent_name].agent.behav
|
|
273
|
+
state = behav.get_state().id if behav.get_state() is not None else None
|
|
274
|
+
action = behav.get_action().id if behav.get_action() is not None else None
|
|
275
|
+
return Server.pack_data({'state': state, 'action': action,
|
|
276
|
+
'state_with_action': behav.get_state().has_action()
|
|
277
|
+
if (state is not None) else False})
|
|
278
|
+
|
|
279
|
+
def get_list_of_agents(self):
|
|
280
|
+
agents = self.node_synchronizer.agent_nodes
|
|
281
|
+
ret = {"agents": list(agents.keys()), "authorities": [
|
|
282
|
+
1.0 if "teacher" in self.node_synchronizer.agent_name_to_profile[x].
|
|
283
|
+
get_dynamic_profile()['connections']['role'] else 0.0 for x in agents.keys()]}
|
|
284
|
+
return Server.pack_data(ret)
|
|
285
|
+
|
|
286
|
+
def get_list_of_streams(self):
|
|
287
|
+
agent_name = request.args.get('agent_name')
|
|
288
|
+
agent = self.node_synchronizer.agent_nodes[agent_name].agent\
|
|
289
|
+
if agent_name != self.node_synchronizer.world.get_name() else (
|
|
290
|
+
self.node_synchronizer.world)
|
|
291
|
+
streams = agent.known_streams
|
|
292
|
+
decoupled_streams = []
|
|
293
|
+
for net_hash, stream_dict in streams.items():
|
|
294
|
+
assert len(stream_dict) <= 2, (f"Agent {agent_name}: "
|
|
295
|
+
f"unexpected size of a stream group ({len(stream_dict)}), expected 2. "
|
|
296
|
+
f"The net hash is {net_hash} and here is "
|
|
297
|
+
f"the corresponding dict: "
|
|
298
|
+
f"{str({k: str(v.get_props()) for k, v in stream_dict.items()})}")
|
|
299
|
+
group_name = DataProps.name_or_group_from_net_hash(net_hash)
|
|
300
|
+
|
|
301
|
+
found = False
|
|
302
|
+
peer_id = DataProps.peer_id_from_net_hash(net_hash)
|
|
303
|
+
for _agent_name, _agent_node in self.node_synchronizer.agent_nodes.items():
|
|
304
|
+
_agent = _agent_node.agent
|
|
305
|
+
public_peer_id, private_peer_id = _agent.get_peer_ids()
|
|
306
|
+
if peer_id == public_peer_id or peer_id == private_peer_id:
|
|
307
|
+
group_name = _agent_name.lower() + ":" + group_name
|
|
308
|
+
self.visu_name_to_net_hash[group_name] = net_hash
|
|
309
|
+
found = True
|
|
310
|
+
break
|
|
311
|
+
if not found:
|
|
312
|
+
public_peer_id, private_peer_id = self.node_synchronizer.world.get_peer_ids()
|
|
313
|
+
if peer_id == public_peer_id or peer_id == private_peer_id:
|
|
314
|
+
group_name = "world" + ":" + group_name
|
|
315
|
+
self.visu_name_to_net_hash[group_name] = net_hash
|
|
316
|
+
|
|
317
|
+
decoupled_streams.append(group_name + " [y]")
|
|
318
|
+
decoupled_streams.append(group_name + " [d]")
|
|
319
|
+
return Server.pack_data(decoupled_streams)
|
|
320
|
+
|
|
321
|
+
def get_stream(self):
|
|
322
|
+
agent_name = request.args.get('agent_name')
|
|
323
|
+
stream_name = request.args.get('stream_name')
|
|
324
|
+
since_step = int(request.args.get('since_step'))
|
|
325
|
+
data_id = 0
|
|
326
|
+
if stream_name.endswith(" [y]"):
|
|
327
|
+
group_name = stream_name[0:stream_name.find(" [y]")]
|
|
328
|
+
data_id = 0
|
|
329
|
+
elif stream_name.endswith(" [d]"):
|
|
330
|
+
group_name = stream_name[0:stream_name.find(" [d]")]
|
|
331
|
+
data_id = 1
|
|
332
|
+
else:
|
|
333
|
+
group_name = stream_name
|
|
334
|
+
|
|
335
|
+
if agent_name != self.node_synchronizer.world.get_name():
|
|
336
|
+
agent = self.node_synchronizer.agent_nodes[agent_name].agent
|
|
337
|
+
known_streams = self.node_synchronizer.agent_nodes[agent_name].agent.known_streams
|
|
338
|
+
else:
|
|
339
|
+
agent = self.node_synchronizer.world
|
|
340
|
+
known_streams = self.node_synchronizer.world.known_streams
|
|
341
|
+
|
|
342
|
+
net_hash = self.visu_name_to_net_hash[group_name]
|
|
343
|
+
stream_objs = list(known_streams[net_hash].values())
|
|
344
|
+
stream_obj = stream_objs[data_id] if data_id < len(stream_objs) else None
|
|
345
|
+
|
|
346
|
+
if stream_obj is None:
|
|
347
|
+
|
|
348
|
+
# Missing stream
|
|
349
|
+
ks = [agent._node_clock.get_cycle()]
|
|
350
|
+
data = None
|
|
351
|
+
last_k = agent._node_clock.get_cycle()
|
|
352
|
+
props = None
|
|
353
|
+
elif isinstance(stream_obj, BufferedDataStream):
|
|
354
|
+
|
|
355
|
+
# Buffered stream
|
|
356
|
+
ks, data, last_k, props = stream_obj.get_since_cycle(since_step)
|
|
357
|
+
else:
|
|
358
|
+
|
|
359
|
+
# Not-buffered stream
|
|
360
|
+
sample = stream_obj.get()
|
|
361
|
+
ks = [agent._node_clock.get_cycle()]
|
|
362
|
+
data = [sample] if sample is not None else None
|
|
363
|
+
last_k = agent._node_clock.get_cycle()
|
|
364
|
+
props = stream_obj.get_props()
|
|
365
|
+
|
|
366
|
+
# Data is None if the step index (k) of the stream is -1 (beginning), or if stream is disabled
|
|
367
|
+
if data is not None:
|
|
368
|
+
|
|
369
|
+
# If data has labeled components (and is not "img" and is not "token_ids"),
|
|
370
|
+
# then we take a decision and convert it to a text string
|
|
371
|
+
if props.is_flat_tensor_with_labels():
|
|
372
|
+
for _i, _data in enumerate(data):
|
|
373
|
+
data[_i] = props.to_text(_data)
|
|
374
|
+
|
|
375
|
+
# If data is of type image, we revert the possibly applied transformation and downscale it
|
|
376
|
+
elif props.is_img():
|
|
377
|
+
for _i, _data in enumerate(data):
|
|
378
|
+
data[_i] = self.thumb_transforms(_data)
|
|
379
|
+
|
|
380
|
+
return Server.pack_data({
|
|
381
|
+
"ks": ks,
|
|
382
|
+
"data": data,
|
|
383
|
+
"last_k": last_k
|
|
384
|
+
})
|
|
385
|
+
|
|
386
|
+
def get_console(self):
|
|
387
|
+
agent_name = request.args.get('agent_name')
|
|
388
|
+
last_only = request.args.get('last_only')
|
|
389
|
+
|
|
390
|
+
is_world = agent_name == self.node_synchronizer.world.get_name()
|
|
391
|
+
|
|
392
|
+
if is_world:
|
|
393
|
+
node = self.node_synchronizer.world_node # <-- ✅ this must be set in your code
|
|
394
|
+
agent = node.world
|
|
395
|
+
behav = agent.behav
|
|
396
|
+
else:
|
|
397
|
+
node = self.node_synchronizer.agent_nodes[agent_name]
|
|
398
|
+
agent = node.agent
|
|
399
|
+
behav = agent.behav
|
|
400
|
+
|
|
401
|
+
state = behav.get_state().id if behav.get_state() is not None else None
|
|
402
|
+
action = behav.get_action().id if behav.get_action() is not None else None
|
|
403
|
+
|
|
404
|
+
output_messages = node._output_messages
|
|
405
|
+
output_ids = node._output_messages_ids
|
|
406
|
+
count = node._output_messages_count
|
|
407
|
+
last_pos = node._output_messages_last_pos
|
|
408
|
+
|
|
409
|
+
if last_only is None or not last_only:
|
|
410
|
+
return Server.pack_data({
|
|
411
|
+
'output_messages': output_messages,
|
|
412
|
+
'output_messages_count': count,
|
|
413
|
+
'output_messages_last_pos': last_pos,
|
|
414
|
+
'output_messages_ids': output_ids,
|
|
415
|
+
'behav_status': {
|
|
416
|
+
'state': state,
|
|
417
|
+
'action': action,
|
|
418
|
+
'state_with_action': behav.get_state().has_action() if state is not None else False
|
|
419
|
+
}
|
|
420
|
+
})
|
|
421
|
+
else:
|
|
422
|
+
return Server.pack_data({
|
|
423
|
+
'output_messages': [output_messages[last_pos]],
|
|
424
|
+
'output_messages_count': 1,
|
|
425
|
+
'output_messages_last_pos': 0,
|
|
426
|
+
'output_messages_ids': [output_ids[last_pos]],
|
|
427
|
+
'behav_status': {
|
|
428
|
+
'state': state,
|
|
429
|
+
'action': action,
|
|
430
|
+
'state_with_action': behav.get_state().has_action() if state is not None else False
|
|
431
|
+
}
|
|
432
|
+
})
|
|
433
|
+
|
|
434
|
+
def save(self):
|
|
435
|
+
return Server.pack_data(self.node_synchronizer.world.env.save()) # TODO
|