unaiverse 0.1.8__cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unaiverse might be problematic. Click here for more details.

Files changed (50) hide show
  1. unaiverse/__init__.py +19 -0
  2. unaiverse/agent.py +2008 -0
  3. unaiverse/agent_basics.py +2041 -0
  4. unaiverse/clock.py +191 -0
  5. unaiverse/dataprops.py +1209 -0
  6. unaiverse/hsm.py +1889 -0
  7. unaiverse/modules/__init__.py +18 -0
  8. unaiverse/modules/cnu/__init__.py +17 -0
  9. unaiverse/modules/cnu/cnus.py +536 -0
  10. unaiverse/modules/cnu/layers.py +261 -0
  11. unaiverse/modules/cnu/psi.py +60 -0
  12. unaiverse/modules/hl/__init__.py +15 -0
  13. unaiverse/modules/hl/hl_utils.py +411 -0
  14. unaiverse/modules/networks.py +1509 -0
  15. unaiverse/modules/utils.py +710 -0
  16. unaiverse/networking/__init__.py +16 -0
  17. unaiverse/networking/node/__init__.py +18 -0
  18. unaiverse/networking/node/connpool.py +1261 -0
  19. unaiverse/networking/node/node.py +2299 -0
  20. unaiverse/networking/node/profile.py +447 -0
  21. unaiverse/networking/node/tokens.py +79 -0
  22. unaiverse/networking/p2p/__init__.py +188 -0
  23. unaiverse/networking/p2p/go.mod +127 -0
  24. unaiverse/networking/p2p/go.sum +548 -0
  25. unaiverse/networking/p2p/golibp2p.py +18 -0
  26. unaiverse/networking/p2p/golibp2p.pyi +135 -0
  27. unaiverse/networking/p2p/lib.go +2527 -0
  28. unaiverse/networking/p2p/lib.go.sha256 +1 -0
  29. unaiverse/networking/p2p/lib_types.py +312 -0
  30. unaiverse/networking/p2p/message_pb2.py +63 -0
  31. unaiverse/networking/p2p/messages.py +268 -0
  32. unaiverse/networking/p2p/mylogger.py +77 -0
  33. unaiverse/networking/p2p/p2p.py +929 -0
  34. unaiverse/networking/p2p/proto-go/message.pb.go +616 -0
  35. unaiverse/networking/p2p/unailib.cpython-312-aarch64-linux-gnu.so +0 -0
  36. unaiverse/streamlib/__init__.py +15 -0
  37. unaiverse/streamlib/streamlib.py +210 -0
  38. unaiverse/streams.py +770 -0
  39. unaiverse/utils/__init__.py +16 -0
  40. unaiverse/utils/ask_lone_wolf.json +27 -0
  41. unaiverse/utils/lone_wolf.json +19 -0
  42. unaiverse/utils/misc.py +492 -0
  43. unaiverse/utils/sandbox.py +293 -0
  44. unaiverse/utils/server.py +435 -0
  45. unaiverse/world.py +353 -0
  46. unaiverse-0.1.8.dist-info/METADATA +365 -0
  47. unaiverse-0.1.8.dist-info/RECORD +50 -0
  48. unaiverse-0.1.8.dist-info/WHEEL +7 -0
  49. unaiverse-0.1.8.dist-info/licenses/LICENSE +43 -0
  50. unaiverse-0.1.8.dist-info/top_level.txt +1 -0
@@ -0,0 +1,710 @@
1
+ """
2
+ █████ █████ ██████ █████ █████ █████ █████ ██████████ ███████████ █████████ ██████████
3
+ ░░███ ░░███ ░░██████ ░░███ ░░███ ░░███ ░░███ ░░███░░░░░█░░███░░░░░███ ███░░░░░███░░███░░░░░█
4
+ ░███ ░███ ░███░███ ░███ ██████ ░███ ░███ ░███ ░███ █ ░ ░███ ░███ ░███ ░░░ ░███ █ ░
5
+ ░███ ░███ ░███░░███░███ ░░░░░███ ░███ ░███ ░███ ░██████ ░██████████ ░░█████████ ░██████
6
+ ░███ ░███ ░███ ░░██████ ███████ ░███ ░░███ ███ ░███░░█ ░███░░░░░███ ░░░░░░░░███ ░███░░█
7
+ ░███ ░███ ░███ ░░█████ ███░░███ ░███ ░░░█████░ ░███ ░ █ ░███ ░███ ███ ░███ ░███ ░ █
8
+ ░░████████ █████ ░░█████░░████████ █████ ░░███ ██████████ █████ █████░░█████████ ██████████
9
+ ░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░ ░░░░░ ░░░ ░░░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░░ ░░░░░░░░░░
10
+ A Collectionless AI Project (https://collectionless.ai)
11
+ Registration/Login: https://unaiverse.io
12
+ Code Repositories: https://github.com/collectionlessai/
13
+ Main Developers: Stefano Melacci (Project Leader), Christian Di Maio, Tommaso Guidi
14
+ """
15
+ import os
16
+ import torch
17
+ import random
18
+ import inspect
19
+ import threading
20
+ import numpy as np
21
+ from PIL import Image
22
+ from torch.utils.data import DataLoader
23
+ from unaiverse.dataprops import Data4Proc
24
+ from torchvision import datasets, transforms
25
+
26
+
27
+ def transforms_factory(trans_type: str, add_batch_dim: bool = True, return_inverse: bool = False):
28
+ supported_types = {"rgb*", "gray*",
29
+ "rgb-no_norm*", "gray-no_norm*",
30
+ "rgb", "gray",
31
+ "rgb-no_norm", "gray-no_norm",
32
+ "gray_mnist"}
33
+
34
+ found = False
35
+ num = -1
36
+ for _type in supported_types:
37
+ if _type.endswith("*"):
38
+ has_star = True
39
+ __type = _type[0:-1]
40
+ else:
41
+ has_star = False
42
+ __type = _type
43
+
44
+ if has_star and trans_type.startswith(__type) and len(trans_type) > len(__type):
45
+ try:
46
+ num = int(trans_type[len(__type):])
47
+ trans_type = _type
48
+ found = True
49
+ break
50
+ except ValueError:
51
+ pass
52
+ elif trans_type == _type:
53
+ found = True
54
+ break
55
+
56
+ if not found:
57
+ raise ValueError(f"Invalid transformation type '{trans_type}': must be one of {supported_types}, "
58
+ f"where * is an integer number")
59
+
60
+ trans = None
61
+ inverse_trans = None
62
+
63
+ if trans_type == "rgb*":
64
+ trans = transforms.Compose([
65
+ transforms.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), # Ensure 3 channels
66
+ transforms.Resize(num),
67
+ transforms.CenterCrop(num),
68
+ transforms.ToTensor(), # Convert PIL to tensor (3, H, W), float [0,1]
69
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
70
+ std=[0.229, 0.224, 0.225]),
71
+ ])
72
+ inverse_trans = transforms.Compose([
73
+ transforms.Normalize(mean=[0., 0., 0.],
74
+ std=[1. / 0.229, 1. / 0.224, 1. / 0.225]),
75
+ transforms.Lambda(lambda x: x + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)),
76
+ transforms.ToPILImage()
77
+ ])
78
+ elif trans_type == "gray*":
79
+ trans = transforms.Compose([
80
+ transforms.Lambda(lambda img: img.convert("L") if img.mode != "L" else img), # Ensure 1 channel
81
+ transforms.Resize(num),
82
+ transforms.CenterCrop(num),
83
+ transforms.ToTensor(), # Convert PIL to tensor (1, H, W), float [0,1]
84
+ transforms.Normalize(mean=[0.45],
85
+ std=[0.225])
86
+ ])
87
+ inverse_trans = transforms.Compose([
88
+ transforms.Normalize(mean=[0.],
89
+ std=[1. / 0.225]),
90
+ transforms.Lambda(lambda x: x + 0.45),
91
+ transforms.ToPILImage()
92
+ ])
93
+ elif trans_type == "rgb-no_norm*":
94
+ trans = transforms.Compose([
95
+ transforms.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), # Ensure 3 channels
96
+ transforms.Resize(num),
97
+ transforms.CenterCrop(num),
98
+ transforms.PILToTensor(), # Convert PIL to tensor (3, H, W), uint [0,255]
99
+ ])
100
+ inverse_trans = transforms.Compose([transforms.ToPILImage()])
101
+ elif trans_type == "gray-no_norm*":
102
+ trans = transforms.Compose([
103
+ transforms.Lambda(lambda img: img.convert("L") if img.mode != "L" else img), # Ensure 1 channel
104
+ transforms.Resize(num),
105
+ transforms.CenterCrop(num),
106
+ transforms.PILToTensor(), # Convert PIL to tensor (1, H, W), uint [0,255]
107
+ ])
108
+ inverse_trans = transforms.ToPILImage()
109
+ elif trans_type == "rgb":
110
+ trans = transforms.Compose([
111
+ transforms.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), # Ensure 3 channels
112
+ transforms.ToTensor(), # Convert PIL to tensor (3, H, W), float [0,1]
113
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
114
+ std=[0.229, 0.224, 0.225])
115
+ ])
116
+ inverse_trans = transforms.Compose([
117
+ transforms.Normalize(mean=[0., 0., 0.],
118
+ std=[1. / 0.229, 1. / 0.224, 1. / 0.225]),
119
+ transforms.Lambda(lambda x: x + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)),
120
+ transforms.ToPILImage()
121
+ ])
122
+ elif trans_type == "gray":
123
+ trans = transforms.Compose([
124
+ transforms.Lambda(lambda img: img.convert("L") if img.mode != "L" else img), # Ensure 1 channel
125
+ transforms.ToTensor(), # Convert PIL to tensor (1, H, W), float [0,1]
126
+ transforms.Normalize(mean=[0.45],
127
+ std=[0.225])
128
+ ])
129
+ inverse_trans = transforms.Compose([
130
+ transforms.Normalize(mean=[0.],
131
+ std=[1. / 0.225]),
132
+ transforms.Lambda(lambda x: x + 0.45),
133
+ transforms.ToPILImage()
134
+ ])
135
+ elif trans_type == "rgb-no_norm":
136
+ trans = transforms.Compose([
137
+ transforms.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), # Ensure 3 channels
138
+ transforms.PILToTensor(), # Convert PIL to tensor (3, H, W), uint [0,255]
139
+ ])
140
+ inverse_trans = transforms.Compose([transforms.ToPILImage()])
141
+ elif trans_type == "gray-no_norm":
142
+ trans = transforms.Compose([
143
+ transforms.Lambda(lambda img: img.convert("L") if img.mode != "L" else img), # Ensure 1 channel
144
+ transforms.PILToTensor(), # Convert PIL to tensor (1, H, W), uint [0,255]
145
+ ])
146
+ inverse_trans = transforms.Compose([transforms.ToPILImage()])
147
+ elif trans_type == "gray_mnist":
148
+ trans = transforms.Compose([
149
+ transforms.Lambda(lambda img: img.convert("L") if img.mode != "L" else img), # Ensure 1 channel
150
+ transforms.Resize(28),
151
+ transforms.CenterCrop(28),
152
+ transforms.ToTensor(), # Convert PIL to tensor (1, H, W), float [0,1]
153
+ transforms.Normalize(mean=[0.1307], # MNIST
154
+ std=[0.3081]) # MNIST
155
+ ])
156
+ inverse_trans = transforms.Compose([
157
+ transforms.Normalize(mean=[0.],
158
+ std=[1. / 0.3081]),
159
+ transforms.Lambda(lambda x: x + 0.1307),
160
+ transforms.ToPILImage()
161
+ ])
162
+
163
+ if add_batch_dim:
164
+ trans.transforms.append(transforms.Lambda(lambda x: x.unsqueeze(0)))
165
+ inverse_trans.transforms.insert(0, transforms.Lambda(lambda x: x.squeeze(0)))
166
+
167
+ return trans if not return_inverse else inverse_trans
168
+
169
+
170
+ def hard_tanh(x: torch.Tensor) -> torch.Tensor:
171
+ return torch.clamp(x, min=-1., max=1.)
172
+
173
+
174
+ def target_shape_fixed_cross_entropy(output, target, *args, **kwargs):
175
+ if len(target.shape) > 1:
176
+ target = target.squeeze(0)
177
+ return torch.nn.functional.cross_entropy(output, target, *args, **kwargs)
178
+
179
+
180
+ def set_seed(seed: int) -> None:
181
+ if seed >= 0:
182
+ torch.manual_seed(seed)
183
+ random.seed(seed)
184
+ np.random.seed(0)
185
+
186
+
187
+ def get_proc_inputs_and_proc_outputs_for_rnn(u_shape: torch.Size | tuple, du_dim: int, y_dim: int):
188
+ if isinstance(u_shape, torch.Size):
189
+ u_shape = tuple(u_shape)
190
+ proc_inputs = [
191
+ Data4Proc(data_type="tensor", tensor_shape=(None,) + u_shape, tensor_dtype=torch.float32,
192
+ pubsub=False, private_only=True),
193
+ Data4Proc(data_type="tensor", tensor_shape=(None, du_dim,), tensor_dtype=torch.float32,
194
+ pubsub=False, private_only=True)
195
+ ]
196
+ proc_outputs = [
197
+ Data4Proc(data_type="tensor", tensor_shape=(None, y_dim), tensor_dtype=torch.float32,
198
+ pubsub=False, private_only=True)
199
+ ]
200
+ return proc_inputs, proc_outputs
201
+
202
+
203
+ def get_proc_inputs_and_proc_outputs_for_image_classification(y_dim: int):
204
+ if y_dim == -1:
205
+ y_dim = 1000 # Assuming ImageNet-trained models
206
+ proc_inputs = [Data4Proc(data_type="img", pubsub=False, private_only=True)]
207
+ proc_outputs = [Data4Proc(data_type="tensor", tensor_shape=(None, y_dim), tensor_dtype=torch.float32,
208
+ pubsub=False, private_only=True)]
209
+ return proc_inputs, proc_outputs
210
+
211
+
212
+ def error_rate_mnist_test_set(network: torch.nn.Module, mnist_data_save_path: str):
213
+
214
+ # Getting MNIST test set
215
+ mnist_test = datasets.MNIST(root=mnist_data_save_path,
216
+ train=False, download=True,
217
+ transform=transforms.Compose([
218
+ transforms.ToTensor(),
219
+ transforms.Normalize((0.1307,), (0.3081,))
220
+ ]))
221
+ mnist_test = DataLoader(mnist_test, batch_size=200, shuffle=False)
222
+
223
+ # Checking error rate
224
+ error_rate = 0.
225
+ n = 0
226
+ training_flag_backup = network.training
227
+ network.eval()
228
+ for x, y in mnist_test:
229
+ o = network(x)
230
+ c = torch.argmax(o, dim=1)
231
+ error_rate += float(torch.sum(c != y).item())
232
+ n += x.shape[0]
233
+ error_rate /= n
234
+ network.training = training_flag_backup
235
+
236
+ return error_rate
237
+
238
+
239
+ class MultiIdentity(torch.nn.Module):
240
+ def __init__(self):
241
+ super().__init__()
242
+
243
+ def forward(self, *args):
244
+ if len(args) == 1:
245
+ return args[0]
246
+ return args
247
+
248
+
249
+ class HumanModule(torch.nn.Module):
250
+ def __init__(self):
251
+ super().__init__()
252
+ self.__out_text = None
253
+ self.__out_img = None
254
+ self.__event = threading.Event()
255
+
256
+ def forward(self, text: str, img: Image):
257
+ self.__out_text = None
258
+ self.__out_img = None
259
+ self.__event.clear()
260
+ self.__event.wait() # Blocks until set (Waiting for human inputs)
261
+ return self.__out_text, self.__out_image
262
+
263
+ def go_ahead(self, text: str, img: Image):
264
+ self.__out_text = text
265
+ self.__out_img = img
266
+ self.__event.set() # Unblock forward(...)
267
+
268
+
269
+ class ModuleWrapper(torch.nn.Module):
270
+ def __init__(self,
271
+ module: torch.nn.Module | None = None,
272
+ proc_inputs: list[Data4Proc] | None = None,
273
+ proc_outputs: list[Data4Proc] | None = None,
274
+ seed: int = -1):
275
+ super(ModuleWrapper, self).__init__()
276
+ self.device = None # The device which is supposed to host the module
277
+ self.module = None # The module itself
278
+ self.proc_inputs = proc_inputs # The list of Data4Proc objects describing the input types of the module
279
+ self.proc_outputs = proc_outputs # The list of Data4Proc objects describing the output types of the module
280
+
281
+ # Working
282
+ set_seed(seed)
283
+ device_env = os.getenv("PROC_DEVICE", None)
284
+ self.device = torch.device("cpu") if device_env is None else torch.device(device_env)
285
+ self.module = module.to(self.device) if module is not None else None
286
+
287
+ def forward(self, *args, **kwargs):
288
+
289
+ # The forward signature expected by who calls this method is:
290
+ # forward(self, *args, first: bool, last: bool, **kwargs)
291
+ # so we have to discard 'first' and 'last' that are not used by an external module not designed for this library
292
+ del kwargs['first']
293
+ del kwargs['last']
294
+
295
+ # Calling the module
296
+ return self.module(*args, **kwargs)
297
+
298
+
299
+ class AgentProcessorChecker:
300
+
301
+ def __init__(self, processor_container: object):
302
+ assert hasattr(processor_container, 'proc'), "Invalid processor container object"
303
+ assert hasattr(processor_container, 'proc_inputs'), "Invalid processor container object"
304
+ assert hasattr(processor_container, 'proc_outputs'), "Invalid processor container object"
305
+ assert hasattr(processor_container, 'proc_opts'), "Invalid processor container object"
306
+ assert hasattr(processor_container, 'proc_optional_inputs'), "Invalid processor container object"
307
+
308
+ # Getting processor-related info from the main object which collects processor and its properties
309
+ proc: torch.nn.Module = processor_container.proc
310
+ proc_inputs: list[Data4Proc] | None = processor_container.proc_inputs
311
+ proc_outputs: list[Data4Proc] | None = processor_container.proc_outputs
312
+ proc_opts: dict | None = processor_container.proc_opts
313
+ proc_optional_inputs: list | None = processor_container.proc_optional_inputs
314
+
315
+ assert proc is None or isinstance(proc, torch.nn.Module), "Processor (proc) must be a torch.nn.Module"
316
+ assert (proc_inputs is None or (
317
+ isinstance(proc_inputs, list) and (len(proc_inputs) == 0 or
318
+ len(proc_inputs) > 0 and isinstance(proc_inputs[0], Data4Proc)))), \
319
+ "Invalid proc_inputs: it must be None or a list of Data4Proc"
320
+ assert (proc_outputs is None or (
321
+ isinstance(proc_inputs, list) and (len(proc_inputs) == 0 or
322
+ len(proc_inputs) > 0 and isinstance(proc_inputs[0], Data4Proc)))), \
323
+ "Invalid proc_inputs: it must be None or a list of Data4Proc"
324
+ assert (proc_opts is None or isinstance(proc_opts, dict)), "Invalid proc_opts: it must be None or a dictionary"
325
+
326
+ # Saving as attributes
327
+ self.proc = proc
328
+ self.proc_inputs = proc_inputs
329
+ self.proc_outputs = proc_outputs
330
+ self.proc_opts = proc_opts
331
+ self.proc_optional_inputs = proc_optional_inputs
332
+
333
+ # Dummy processor (if no processor was provided)
334
+ if self.proc is None:
335
+ self.proc = ModuleWrapper(module=MultiIdentity())
336
+ self.proc.device = torch.device("cpu")
337
+ if self.proc_inputs is None:
338
+ self.proc_inputs = [Data4Proc(data_type="all", pubsub=False, private_only=False)]
339
+ if self.proc_outputs is None:
340
+ self.proc_outputs = [Data4Proc(data_type="all", pubsub=False, private_only=False)]
341
+ self.proc_opts = {'optimizer': None, 'losses': [None] * len(self.proc_outputs)}
342
+ else:
343
+
344
+ # String telling it is a human
345
+ if isinstance(self.proc, str) and self.proc.lower() == "human":
346
+ self.proc = ModuleWrapper(module=HumanModule())
347
+ self.proc.device = torch.device("cpu")
348
+ self.proc_inputs = [Data4Proc(data_type="text", pubsub=False, private_only=False),
349
+ Data4Proc(data_type="img", pubsub=False, private_only=False)]
350
+ self.proc_outputs = [Data4Proc(data_type="text", pubsub=False, private_only=False),
351
+ Data4Proc(data_type="img", pubsub=False, private_only=False)]
352
+
353
+ # Wrapping to have the basic attributes (device)
354
+ elif not isinstance(self.proc, ModuleWrapper):
355
+ self.proc = ModuleWrapper(module=self.proc)
356
+ self.proc.device = torch.device("cpu")
357
+
358
+ # Guessing inputs, fixing attributes
359
+ if self.proc_inputs is None:
360
+ self.__guess_proc_inputs()
361
+
362
+ for j in range(len(self.proc_inputs)):
363
+ if self.proc_inputs[j].get_name() == "unk":
364
+ self.proc_inputs[j].set_name("proc_input_" + str(j))
365
+
366
+ # Guessing outputs, fixing attributes
367
+ if self.proc_outputs is None:
368
+ self.__guess_proc_outputs()
369
+
370
+ for j in range(len(self.proc_outputs)):
371
+ if self.proc_outputs[j].get_name() == "unk":
372
+ self.proc_outputs[j].set_name("proc_output_" + str(j))
373
+
374
+ # Guessing optimization-related options and stuff, fixing attributes
375
+ if (self.proc_opts is None or len(self.proc_opts) == 0 or
376
+ 'optimizer' not in self.proc_opts or 'losses' not in self.proc_opts):
377
+ self.__guess_proc_opts()
378
+ self.__fix_proc_opts()
379
+
380
+ # Ensuring all is OK
381
+ if self.proc is not None:
382
+ assert "optimizer" in self.proc_opts, "Missing 'optimizer' key in proc_opts (required)"
383
+ assert "losses" in self.proc_opts, "Missing 'losses' key in proc_opts (required)"
384
+
385
+ # Checking inputs with default values
386
+ if self.proc_optional_inputs is None:
387
+ self.__guess_proc_optional_inputs()
388
+
389
+ # Updating processor container object
390
+ processor_container.proc = self.proc
391
+ processor_container.proc_inputs = self.proc_inputs
392
+ processor_container.proc_outputs = self.proc_outputs
393
+ processor_container.proc_opts = self.proc_opts
394
+ processor_container.proc_optional_inputs = self.proc_optional_inputs
395
+
396
+ def __guess_proc_inputs(self):
397
+ if hasattr(self.proc, "proc_inputs"):
398
+ if self.proc.proc_inputs is not None:
399
+ self.proc_inputs = []
400
+ for p in self.proc.proc_inputs:
401
+ self.proc_inputs.append(p.clone())
402
+ return
403
+
404
+ first_layer = None
405
+
406
+ # Traverse modules to find the first real layer (skip containers like Sequential)
407
+ for layer in self.proc.modules():
408
+ if (not isinstance(layer, (torch.nn.Sequential,
409
+ torch.nn.ModuleList,
410
+ torch.nn.ModuleDict))
411
+ and not isinstance(layer, torch.nn.Module)
412
+ and hasattr(layer, 'weight')):
413
+ continue # Skip non-leaf layers
414
+ if isinstance(layer, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.Conv1d, torch.nn.Embedding)):
415
+ first_layer = layer
416
+ break
417
+
418
+ if first_layer is None:
419
+ raise ValueError("Cannot automatically guess the shape of the input data, "
420
+ "please explicitly provide it (proc_input)")
421
+
422
+ # Infer input properties
423
+ data_desc = "automatically guessed"
424
+ tensor_shape = None
425
+ tensor_labels = None
426
+ tensor_dtype = None
427
+ stream_to_proc_transforms = None
428
+ proc_to_stream_transforms = None
429
+
430
+ if isinstance(first_layer, torch.nn.Conv2d):
431
+
432
+ if first_layer.in_channels == 3 or first_layer.in_channels == 1:
433
+ data_type = "img"
434
+
435
+ # Creating dummy PIL images
436
+ rgb_input_img = Image.new('RGB', (224, 224))
437
+ pixels = rgb_input_img.load()
438
+ for x in range(28):
439
+ for y in range(28):
440
+ pixels[x, y] = (random.randint(0, 255),
441
+ random.randint(0, 255),
442
+ random.randint(0, 255))
443
+ gray_input_img = rgb_input_img.convert('L')
444
+
445
+ # Checking if the model supports PIL images as input
446
+ # noinspection PyBroadException
447
+ try:
448
+ _ = self.proc(rgb_input_img)
449
+ can_handle_rgb_img = True
450
+ except Exception:
451
+ can_handle_rgb_img = False
452
+
453
+ # Noinspection PyBroadException
454
+ try:
455
+ _ = self.proc(gray_input_img)
456
+ can_handle_gray_img = True
457
+ except Exception:
458
+ can_handle_gray_img = False
459
+
460
+ if can_handle_gray_img and can_handle_rgb_img:
461
+ stream_to_proc_transforms = None
462
+ elif can_handle_rgb_img:
463
+ stream_to_proc_transforms = transforms.Grayscale(num_output_channels=3)
464
+ elif can_handle_gray_img:
465
+ stream_to_proc_transforms = transforms.Grayscale()
466
+ else:
467
+ if first_layer.in_channels == 1:
468
+ stream_to_proc_transforms = transforms_factory("gray-no_norm")
469
+ else:
470
+ stream_to_proc_transforms = transforms_factory("rgb-no_norm")
471
+ else:
472
+
473
+ # If the number of input channels is not 1 and not 3...
474
+ data_type = "tensor"
475
+ tensor_shape = (first_layer.in_channels, None, None)
476
+ tensor_dtype = torch.float32
477
+
478
+ elif isinstance(first_layer, torch.nn.Conv1d):
479
+ data_type = "tensor"
480
+ tensor_shape = (first_layer.in_channels, None)
481
+ tensor_dtype = torch.float32
482
+ elif isinstance(first_layer, torch.nn.Linear):
483
+ data_type = "tensor"
484
+ tensor_dtype = torch.float32
485
+ tensor_shape = (first_layer.in_features,)
486
+ elif isinstance(first_layer, torch.nn.Embedding):
487
+
488
+ # Noinspection PyBroadException
489
+ try:
490
+ input_text = "testing if tokenizer is present"
491
+ _ = self.proc(input_text)
492
+ can_handle_text = True
493
+ can_handle_more_than_one_token = True # Unused
494
+ except Exception:
495
+ can_handle_text = False
496
+
497
+ # Noinspection PyBroadException
498
+ try:
499
+ device = torch.device("cpu")
500
+ for param in self.proc.parameters():
501
+ device = param.device
502
+ break
503
+ input_tokens = torch.tensor([[0, 1, 2, 3]], dtype=torch.long, device=device)
504
+ _ = self.proc(input_tokens)
505
+ can_handle_more_than_one_token = True
506
+ except Exception:
507
+ can_handle_more_than_one_token = False
508
+
509
+ if can_handle_text:
510
+ data_type = "text"
511
+ stream_to_proc_transforms = None
512
+ else:
513
+ data_type = "tensor"
514
+ if can_handle_more_than_one_token:
515
+ tensor_shape = (None,)
516
+ else:
517
+ tensor_shape = (1,)
518
+ tensor_dtype = torch.long
519
+ tensor_labels = ["token" + str(i) for i in range(0, first_layer.num_embeddings)]
520
+ else:
521
+ raise ValueError("Cannot automatically guess the shape of the input data, "
522
+ "please explicitly provide it (proc_input)")
523
+
524
+ # Setting the input attribute
525
+ self.proc_inputs = [Data4Proc(name="proc_input_0",
526
+ data_type=data_type,
527
+ data_desc=data_desc,
528
+ tensor_shape=tensor_shape,
529
+ tensor_labels=tensor_labels,
530
+ tensor_dtype=tensor_dtype,
531
+ stream_to_proc_transforms=stream_to_proc_transforms,
532
+ proc_to_stream_transforms=proc_to_stream_transforms,
533
+ pubsub=False,
534
+ private_only=True)]
535
+
536
+ def __guess_proc_outputs(self):
537
+ if hasattr(self.proc, "proc_outputs"):
538
+ if self.proc.proc_outputs is not None:
539
+ self.proc_outputs = []
540
+ for p in self.proc.proc_outputs:
541
+ self.proc_outputs.append(p.clone())
542
+ return
543
+
544
+ proc = self.proc
545
+ device = self.proc.device
546
+ inputs = []
547
+
548
+ for i, proc_input in enumerate(self.proc_inputs):
549
+ if proc_input.is_tensor():
550
+ inputs.append(proc_input.check_and_preprocess(
551
+ torch.randn([1] + list(proc_input.tensor_shape), # Adding batch size here
552
+ dtype=proc_input.tensor_dtype).to(device)))
553
+ elif proc_input.is_img():
554
+ rgb_input_img = Image.new('RGB', (224, 224))
555
+ pixels = rgb_input_img.load()
556
+ for x in range(224):
557
+ for y in range(224):
558
+ pixels[x, y] = (random.randint(0, 255),
559
+ random.randint(0, 255),
560
+ random.randint(0, 255))
561
+ inputs.append(proc_input.check_and_preprocess(rgb_input_img))
562
+ elif proc_input.is_text():
563
+ inputs.append(proc_input.check_and_preprocess("test text as input"))
564
+
565
+ # Forward
566
+ with torch.no_grad():
567
+ outputs = proc(*inputs)
568
+ if not isinstance(outputs, tuple | list):
569
+ outputs = [outputs]
570
+ if isinstance(outputs, tuple):
571
+ outputs = list(outputs)
572
+
573
+ # This will be filled below
574
+ self.proc_outputs = []
575
+
576
+ for j, output in enumerate(outputs):
577
+
578
+ # Infer output properties
579
+ data_desc = "automatically guessed"
580
+ tensor_shape = None
581
+ tensor_labels = None
582
+ tensor_dtype = None
583
+ stream_to_proc_transforms = None
584
+ proc_to_stream_transforms = None
585
+
586
+ if isinstance(output, Image.Image): # PIL Image
587
+ data_type = "img"
588
+ elif isinstance(output, torch.Tensor): # Tensor
589
+ output_shape = list(output.shape[1:]) # Removing batch size here
590
+ if len(output_shape) == 3 and (output_shape[0] == 3 or output_shape[0] == 1):
591
+ data_type = "img"
592
+ if output_shape[0] == 3:
593
+ proc_to_stream_transforms = transforms_factory("rgb", return_inverse=True)
594
+ else:
595
+ proc_to_stream_transforms = transforms_factory("gray", return_inverse=True)
596
+ else:
597
+ data_type = "tensor"
598
+ tensor_dtype = str(output.dtype)
599
+ tensor_shape = output_shape
600
+ tensor_labels = None
601
+ elif isinstance(output, str):
602
+ data_type = "text"
603
+ else:
604
+ raise ValueError(f"Unsupported output type {type(output)}")
605
+
606
+ # Setting the output attribute
607
+ self.proc_outputs.append(Data4Proc(name="proc_output_" + str(j),
608
+ data_type=data_type,
609
+ data_desc=data_desc,
610
+ tensor_shape=tensor_shape,
611
+ tensor_labels=tensor_labels,
612
+ tensor_dtype=tensor_dtype,
613
+ stream_to_proc_transforms=stream_to_proc_transforms,
614
+ proc_to_stream_transforms=proc_to_stream_transforms,
615
+ pubsub=False,
616
+ private_only=True))
617
+
618
+ def __guess_proc_opts(self):
619
+ if self.proc_opts is None:
620
+ if isinstance(self.proc.module, MultiIdentity) or len(list(self.proc.parameters())) == 0:
621
+ self.proc_opts = {"optimizer": None,
622
+ "losses": [None] * len(self.proc_outputs)}
623
+ else:
624
+ self.proc_opts = {"optimizer": torch.optim.SGD(self.proc.parameters(), lr=1e-5),
625
+ "losses": [torch.nn.functional.mse_loss] * len(self.proc_outputs)}
626
+ else:
627
+ if "optimizer" not in self.proc_opts:
628
+ self.proc_opts["optimizer"] = None
629
+ if "losses" not in self.proc_opts:
630
+ self.proc_opts["losses"] = [None] * len(self.proc_outputs)
631
+
632
+ def __fix_proc_opts(self):
633
+ opts = {}
634
+ found_optimizer = False
635
+ found_loss = False
636
+ cannot_fix = False
637
+
638
+ if "optimizer" in self.proc_opts:
639
+ found_optimizer = True
640
+ if "losses" in self.proc_opts:
641
+ found_loss = True
642
+
643
+ if not found_loss:
644
+ opts['losses'] = [torch.nn.functional.mse_loss] * len(self.proc_opts)
645
+
646
+ for k, v in self.proc_opts.items():
647
+ if isinstance(v, torch.optim.Optimizer):
648
+ if k == "optimizer":
649
+ opts["optimizer"] = v
650
+ continue
651
+ else:
652
+ if not found_optimizer:
653
+ opts["optimizer"] = v
654
+ found_optimizer = True
655
+ else:
656
+ cannot_fix = True
657
+ break
658
+ elif k == "losses" and isinstance(v, list) or isinstance(v, tuple):
659
+ opts["losses"] = v
660
+ continue
661
+ elif (v == torch.nn.functional.mse_loss or isinstance(v, torch.nn.MSELoss)
662
+ or v == torch.nn.functional.binary_cross_entropy or isinstance(v, torch.nn.BCELoss)
663
+ or isinstance(v, torch.nn.CrossEntropyLoss) or v == torch.nn.functional.cross_entropy):
664
+ if not found_loss:
665
+ opts["losses"] = [v]
666
+ found_loss = True
667
+ else:
668
+ cannot_fix = True
669
+ break
670
+ else:
671
+ opts[k] = v
672
+
673
+ if not found_optimizer:
674
+ if 'lr' in opts:
675
+ opts['optimizer'] = torch.optim.SGD(self.proc.parameters(), lr=opts['lr'])
676
+
677
+ assert not cannot_fix, \
678
+ "About proc_opts: cannot find required keys ('optimizer', 'losses') and/or cannot automatically guess them"
679
+
680
+ # Removing batch dim from targets in case of cross-entropy
681
+ fixed_list = []
682
+ for _loss_fcn in opts['losses']:
683
+ if _loss_fcn == torch.nn.functional.cross_entropy or isinstance(_loss_fcn, torch.nn.CrossEntropyLoss):
684
+ fixed_list.append(target_shape_fixed_cross_entropy)
685
+ else:
686
+ fixed_list.append(_loss_fcn)
687
+ opts['losses'] = fixed_list
688
+
689
+ # Updating
690
+ self.proc_opts = opts
691
+
692
+ def __guess_proc_optional_inputs(self):
693
+ self.proc_optional_inputs = []
694
+ if isinstance(self.proc, ModuleWrapper):
695
+ if hasattr(self.proc.module, "forward"):
696
+ sig = inspect.signature(self.proc.module.forward)
697
+ else:
698
+ sig = inspect.signature(self.proc.forward)
699
+ else:
700
+ sig = inspect.signature(self.proc.forward)
701
+
702
+ i = 0
703
+ for name, param in sig.parameters.items():
704
+ if i >= len(self.proc_inputs):
705
+ break
706
+ if param.default is not inspect.Parameter.empty:
707
+ self.proc_optional_inputs.append({"has_default": True, "default_value": param.default})
708
+ else:
709
+ self.proc_optional_inputs.append({"has_default": False, "default_value": None})
710
+ i += 1