unaiverse 0.1.6__cp314-cp314t-macosx_10_15_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unaiverse might be problematic. Click here for more details.

Files changed (50) hide show
  1. unaiverse/__init__.py +19 -0
  2. unaiverse/agent.py +2008 -0
  3. unaiverse/agent_basics.py +1846 -0
  4. unaiverse/clock.py +191 -0
  5. unaiverse/dataprops.py +1209 -0
  6. unaiverse/hsm.py +1880 -0
  7. unaiverse/modules/__init__.py +18 -0
  8. unaiverse/modules/cnu/__init__.py +17 -0
  9. unaiverse/modules/cnu/cnus.py +536 -0
  10. unaiverse/modules/cnu/layers.py +261 -0
  11. unaiverse/modules/cnu/psi.py +60 -0
  12. unaiverse/modules/hl/__init__.py +15 -0
  13. unaiverse/modules/hl/hl_utils.py +411 -0
  14. unaiverse/modules/networks.py +1509 -0
  15. unaiverse/modules/utils.py +680 -0
  16. unaiverse/networking/__init__.py +16 -0
  17. unaiverse/networking/node/__init__.py +18 -0
  18. unaiverse/networking/node/connpool.py +1261 -0
  19. unaiverse/networking/node/node.py +2223 -0
  20. unaiverse/networking/node/profile.py +446 -0
  21. unaiverse/networking/node/tokens.py +79 -0
  22. unaiverse/networking/p2p/__init__.py +198 -0
  23. unaiverse/networking/p2p/go.mod +127 -0
  24. unaiverse/networking/p2p/go.sum +548 -0
  25. unaiverse/networking/p2p/golibp2p.py +18 -0
  26. unaiverse/networking/p2p/golibp2p.pyi +135 -0
  27. unaiverse/networking/p2p/lib.go +2714 -0
  28. unaiverse/networking/p2p/lib.go.sha256 +1 -0
  29. unaiverse/networking/p2p/lib_types.py +312 -0
  30. unaiverse/networking/p2p/message_pb2.py +63 -0
  31. unaiverse/networking/p2p/messages.py +265 -0
  32. unaiverse/networking/p2p/mylogger.py +77 -0
  33. unaiverse/networking/p2p/p2p.py +929 -0
  34. unaiverse/networking/p2p/proto-go/message.pb.go +616 -0
  35. unaiverse/networking/p2p/unailib.cpython-314t-darwin.so +0 -0
  36. unaiverse/streamlib/__init__.py +15 -0
  37. unaiverse/streamlib/streamlib.py +210 -0
  38. unaiverse/streams.py +770 -0
  39. unaiverse/utils/__init__.py +16 -0
  40. unaiverse/utils/ask_lone_wolf.json +27 -0
  41. unaiverse/utils/lone_wolf.json +19 -0
  42. unaiverse/utils/misc.py +305 -0
  43. unaiverse/utils/sandbox.py +293 -0
  44. unaiverse/utils/server.py +435 -0
  45. unaiverse/world.py +175 -0
  46. unaiverse-0.1.6.dist-info/METADATA +365 -0
  47. unaiverse-0.1.6.dist-info/RECORD +50 -0
  48. unaiverse-0.1.6.dist-info/WHEEL +6 -0
  49. unaiverse-0.1.6.dist-info/licenses/LICENSE +43 -0
  50. unaiverse-0.1.6.dist-info/top_level.txt +1 -0
@@ -0,0 +1,680 @@
1
+ """
2
+ █████ █████ ██████ █████ █████ █████ █████ ██████████ ███████████ █████████ ██████████
3
+ ░░███ ░░███ ░░██████ ░░███ ░░███ ░░███ ░░███ ░░███░░░░░█░░███░░░░░███ ███░░░░░███░░███░░░░░█
4
+ ░███ ░███ ░███░███ ░███ ██████ ░███ ░███ ░███ ░███ █ ░ ░███ ░███ ░███ ░░░ ░███ █ ░
5
+ ░███ ░███ ░███░░███░███ ░░░░░███ ░███ ░███ ░███ ░██████ ░██████████ ░░█████████ ░██████
6
+ ░███ ░███ ░███ ░░██████ ███████ ░███ ░░███ ███ ░███░░█ ░███░░░░░███ ░░░░░░░░███ ░███░░█
7
+ ░███ ░███ ░███ ░░█████ ███░░███ ░███ ░░░█████░ ░███ ░ █ ░███ ░███ ███ ░███ ░███ ░ █
8
+ ░░████████ █████ ░░█████░░████████ █████ ░░███ ██████████ █████ █████░░█████████ ██████████
9
+ ░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░ ░░░░░ ░░░ ░░░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░░ ░░░░░░░░░░
10
+ A Collectionless AI Project (https://collectionless.ai)
11
+ Registration/Login: https://unaiverse.io
12
+ Code Repositories: https://github.com/collectionlessai/
13
+ Main Developers: Stefano Melacci (Project Leader), Christian Di Maio, Tommaso Guidi
14
+ """
15
+ import os
16
+ import torch
17
+ import random
18
+ import inspect
19
+ import numpy as np
20
+ from PIL import Image
21
+ from torch.utils.data import DataLoader
22
+ from unaiverse.dataprops import Data4Proc
23
+ from torchvision import datasets, transforms
24
+
25
+
26
+ def transforms_factory(trans_type: str, add_batch_dim: bool = True, return_inverse: bool = False):
27
+ supported_types = {"rgb*", "gray*",
28
+ "rgb-no_norm*", "gray-no_norm*",
29
+ "rgb", "gray",
30
+ "rgb-no_norm", "gray-no_norm",
31
+ "gray_mnist"}
32
+
33
+ found = False
34
+ num = -1
35
+ for _type in supported_types:
36
+ if _type.endswith("*"):
37
+ has_star = True
38
+ __type = _type[0:-1]
39
+ else:
40
+ has_star = False
41
+ __type = _type
42
+
43
+ if has_star and trans_type.startswith(__type) and len(trans_type) > len(__type):
44
+ try:
45
+ num = int(trans_type[len(__type):])
46
+ trans_type = _type
47
+ found = True
48
+ break
49
+ except ValueError:
50
+ pass
51
+ elif trans_type == _type:
52
+ found = True
53
+ break
54
+
55
+ if not found:
56
+ raise ValueError(f"Invalid transformation type '{trans_type}': must be one of {supported_types}, "
57
+ f"where * is an integer number")
58
+
59
+ trans = None
60
+ inverse_trans = None
61
+
62
+ if trans_type == "rgb*":
63
+ trans = transforms.Compose([
64
+ transforms.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), # Ensure 3 channels
65
+ transforms.Resize(num),
66
+ transforms.CenterCrop(num),
67
+ transforms.ToTensor(), # Convert PIL to tensor (3, H, W), float [0,1]
68
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
69
+ std=[0.229, 0.224, 0.225]),
70
+ ])
71
+ inverse_trans = transforms.Compose([
72
+ transforms.Normalize(mean=[0., 0., 0.],
73
+ std=[1. / 0.229, 1. / 0.224, 1. / 0.225]),
74
+ transforms.Lambda(lambda x: x + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)),
75
+ transforms.ToPILImage()
76
+ ])
77
+ elif trans_type == "gray*":
78
+ trans = transforms.Compose([
79
+ transforms.Lambda(lambda img: img.convert("L") if img.mode != "L" else img), # Ensure 1 channel
80
+ transforms.Resize(num),
81
+ transforms.CenterCrop(num),
82
+ transforms.ToTensor(), # Convert PIL to tensor (1, H, W), float [0,1]
83
+ transforms.Normalize(mean=[0.45],
84
+ std=[0.225])
85
+ ])
86
+ inverse_trans = transforms.Compose([
87
+ transforms.Normalize(mean=[0.],
88
+ std=[1. / 0.225]),
89
+ transforms.Lambda(lambda x: x + 0.45),
90
+ transforms.ToPILImage()
91
+ ])
92
+ elif trans_type == "rgb-no_norm*":
93
+ trans = transforms.Compose([
94
+ transforms.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), # Ensure 3 channels
95
+ transforms.Resize(num),
96
+ transforms.CenterCrop(num),
97
+ transforms.PILToTensor(), # Convert PIL to tensor (3, H, W), uint [0,255]
98
+ ])
99
+ inverse_trans = transforms.Compose([transforms.ToPILImage()])
100
+ elif trans_type == "gray-no_norm*":
101
+ trans = transforms.Compose([
102
+ transforms.Lambda(lambda img: img.convert("L") if img.mode != "L" else img), # Ensure 1 channel
103
+ transforms.Resize(num),
104
+ transforms.CenterCrop(num),
105
+ transforms.PILToTensor(), # Convert PIL to tensor (1, H, W), uint [0,255]
106
+ ])
107
+ inverse_trans = transforms.ToPILImage()
108
+ elif trans_type == "rgb":
109
+ trans = transforms.Compose([
110
+ transforms.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), # Ensure 3 channels
111
+ transforms.ToTensor(), # Convert PIL to tensor (3, H, W), float [0,1]
112
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
113
+ std=[0.229, 0.224, 0.225])
114
+ ])
115
+ inverse_trans = transforms.Compose([
116
+ transforms.Normalize(mean=[0., 0., 0.],
117
+ std=[1. / 0.229, 1. / 0.224, 1. / 0.225]),
118
+ transforms.Lambda(lambda x: x + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)),
119
+ transforms.ToPILImage()
120
+ ])
121
+ elif trans_type == "gray":
122
+ trans = transforms.Compose([
123
+ transforms.Lambda(lambda img: img.convert("L") if img.mode != "L" else img), # Ensure 1 channel
124
+ transforms.ToTensor(), # Convert PIL to tensor (1, H, W), float [0,1]
125
+ transforms.Normalize(mean=[0.45],
126
+ std=[0.225])
127
+ ])
128
+ inverse_trans = transforms.Compose([
129
+ transforms.Normalize(mean=[0.],
130
+ std=[1. / 0.225]),
131
+ transforms.Lambda(lambda x: x + 0.45),
132
+ transforms.ToPILImage()
133
+ ])
134
+ elif trans_type == "rgb-no_norm":
135
+ trans = transforms.Compose([
136
+ transforms.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), # Ensure 3 channels
137
+ transforms.PILToTensor(), # Convert PIL to tensor (3, H, W), uint [0,255]
138
+ ])
139
+ inverse_trans = transforms.Compose([transforms.ToPILImage()])
140
+ elif trans_type == "gray-no_norm":
141
+ trans = transforms.Compose([
142
+ transforms.Lambda(lambda img: img.convert("L") if img.mode != "L" else img), # Ensure 1 channel
143
+ transforms.PILToTensor(), # Convert PIL to tensor (1, H, W), uint [0,255]
144
+ ])
145
+ inverse_trans = transforms.Compose([transforms.ToPILImage()])
146
+ elif trans_type == "gray_mnist":
147
+ trans = transforms.Compose([
148
+ transforms.Lambda(lambda img: img.convert("L") if img.mode != "L" else img), # Ensure 1 channel
149
+ transforms.Resize(28),
150
+ transforms.CenterCrop(28),
151
+ transforms.ToTensor(), # Convert PIL to tensor (1, H, W), float [0,1]
152
+ transforms.Normalize(mean=[0.1307], # MNIST
153
+ std=[0.3081]) # MNIST
154
+ ])
155
+ inverse_trans = transforms.Compose([
156
+ transforms.Normalize(mean=[0.],
157
+ std=[1. / 0.3081]),
158
+ transforms.Lambda(lambda x: x + 0.1307),
159
+ transforms.ToPILImage()
160
+ ])
161
+
162
+ if add_batch_dim:
163
+ trans.transforms.append(transforms.Lambda(lambda x: x.unsqueeze(0)))
164
+ inverse_trans.transforms.insert(0, transforms.Lambda(lambda x: x.squeeze(0)))
165
+
166
+ return trans if not return_inverse else inverse_trans
167
+
168
+
169
+ def hard_tanh(x: torch.Tensor) -> torch.Tensor:
170
+ return torch.clamp(x, min=-1., max=1.)
171
+
172
+
173
+ def target_shape_fixed_cross_entropy(output, target, *args, **kwargs):
174
+ if len(target.shape) > 1:
175
+ target = target.squeeze(0)
176
+ return torch.nn.functional.cross_entropy(output, target, *args, **kwargs)
177
+
178
+
179
+ def set_seed(seed: int) -> None:
180
+ if seed >= 0:
181
+ torch.manual_seed(seed)
182
+ random.seed(seed)
183
+ np.random.seed(0)
184
+
185
+
186
+ def get_proc_inputs_and_proc_outputs_for_rnn(u_shape: torch.Size | tuple, du_dim: int, y_dim: int):
187
+ if isinstance(u_shape, torch.Size):
188
+ u_shape = tuple(u_shape)
189
+ proc_inputs = [
190
+ Data4Proc(data_type="tensor", tensor_shape=(None,) + u_shape, tensor_dtype=torch.float32,
191
+ pubsub=False, private_only=True),
192
+ Data4Proc(data_type="tensor", tensor_shape=(None, du_dim,), tensor_dtype=torch.float32,
193
+ pubsub=False, private_only=True)
194
+ ]
195
+ proc_outputs = [
196
+ Data4Proc(data_type="tensor", tensor_shape=(None, y_dim), tensor_dtype=torch.float32,
197
+ pubsub=False, private_only=True)
198
+ ]
199
+ return proc_inputs, proc_outputs
200
+
201
+
202
+ def get_proc_inputs_and_proc_outputs_for_image_classification(y_dim: int):
203
+ if y_dim == -1:
204
+ y_dim = 1000 # Assuming ImageNet-trained models
205
+ proc_inputs = [Data4Proc(data_type="img", pubsub=False, private_only=True)]
206
+ proc_outputs = [Data4Proc(data_type="tensor", tensor_shape=(None, y_dim), tensor_dtype=torch.float32,
207
+ pubsub=False, private_only=True)]
208
+ return proc_inputs, proc_outputs
209
+
210
+
211
+ def error_rate_mnist_test_set(network: torch.nn.Module, mnist_data_save_path: str):
212
+
213
+ # Getting MNIST test set
214
+ mnist_test = datasets.MNIST(root=mnist_data_save_path,
215
+ train=False, download=True,
216
+ transform=transforms.Compose([
217
+ transforms.ToTensor(),
218
+ transforms.Normalize((0.1307,), (0.3081,))
219
+ ]))
220
+ mnist_test = DataLoader(mnist_test, batch_size=200, shuffle=False)
221
+
222
+ # Checking error rate
223
+ error_rate = 0.
224
+ n = 0
225
+ training_flag_backup = network.training
226
+ network.eval()
227
+ for x, y in mnist_test:
228
+ o = network(x)
229
+ c = torch.argmax(o, dim=1)
230
+ error_rate += float(torch.sum(c != y).item())
231
+ n += x.shape[0]
232
+ error_rate /= n
233
+ network.training = training_flag_backup
234
+
235
+ return error_rate
236
+
237
+
238
+ class MultiIdentity(torch.nn.Module):
239
+ def __init__(self):
240
+ super().__init__()
241
+
242
+ def forward(self, *args):
243
+ if len(args) == 1:
244
+ return args[0]
245
+ return args
246
+
247
+
248
+ class ModuleWrapper(torch.nn.Module):
249
+ def __init__(self,
250
+ module: torch.nn.Module | None = None,
251
+ proc_inputs: list[Data4Proc] | None = None,
252
+ proc_outputs: list[Data4Proc] | None = None,
253
+ seed: int = -1):
254
+ super(ModuleWrapper, self).__init__()
255
+ self.device = None # The device which is supposed to host the module
256
+ self.module = None # The module itself
257
+ self.proc_inputs = proc_inputs # The list of Data4Proc objects describing the input types of the module
258
+ self.proc_outputs = proc_outputs # The list of Data4Proc objects describing the output types of the module
259
+
260
+ # Working
261
+ set_seed(seed)
262
+ device_env = os.getenv("PROC_DEVICE", None)
263
+ self.device = torch.device("cpu") if device_env is None else torch.device(device_env)
264
+ self.module = module.to(self.device) if module is not None else None
265
+
266
+ def forward(self, *args, **kwargs):
267
+
268
+ # The forward signature expected by who calls this method is:
269
+ # forward(self, *args, first: bool, last: bool, **kwargs)
270
+ # so we have to discard 'first' and 'last' that are not used by an external module not designed for this library
271
+ del kwargs['first']
272
+ del kwargs['last']
273
+
274
+ # Calling the module
275
+ return self.module(*args, **kwargs)
276
+
277
+
278
+ class AgentProcessorChecker:
279
+
280
+ def __init__(self, processor_container: object):
281
+ assert hasattr(processor_container, 'proc'), "Invalid processor container object"
282
+ assert hasattr(processor_container, 'proc_inputs'), "Invalid processor container object"
283
+ assert hasattr(processor_container, 'proc_outputs'), "Invalid processor container object"
284
+ assert hasattr(processor_container, 'proc_opts'), "Invalid processor container object"
285
+ assert hasattr(processor_container, 'proc_optional_inputs'), "Invalid processor container object"
286
+
287
+ # Getting processor-related info from the main object which collects processor and its properties
288
+ proc: torch.nn.Module = processor_container.proc
289
+ proc_inputs: list[Data4Proc] | None = processor_container.proc_inputs
290
+ proc_outputs: list[Data4Proc] | None = processor_container.proc_outputs
291
+ proc_opts: dict | None = processor_container.proc_opts
292
+ proc_optional_inputs: list | None = processor_container.proc_optional_inputs
293
+
294
+ assert proc is None or isinstance(proc, torch.nn.Module), "Processor (proc) must be a torch.nn.Module"
295
+ assert (proc_inputs is None or (
296
+ isinstance(proc_inputs, list) and (len(proc_inputs) == 0 or
297
+ len(proc_inputs) > 0 and isinstance(proc_inputs[0], Data4Proc)))), \
298
+ "Invalid proc_inputs: it must be None or a list of Data4Proc"
299
+ assert (proc_outputs is None or (
300
+ isinstance(proc_inputs, list) and (len(proc_inputs) == 0 or
301
+ len(proc_inputs) > 0 and isinstance(proc_inputs[0], Data4Proc)))), \
302
+ "Invalid proc_inputs: it must be None or a list of Data4Proc"
303
+ assert (proc_opts is None or isinstance(proc_opts, dict)), "Invalid proc_opts: it must be None or a dictionary"
304
+
305
+ # Saving as attributes
306
+ self.proc = proc
307
+ self.proc_inputs = proc_inputs
308
+ self.proc_outputs = proc_outputs
309
+ self.proc_opts = proc_opts
310
+ self.proc_optional_inputs = proc_optional_inputs
311
+
312
+ # Dummy processor (if no processor was provided)
313
+ if self.proc is None:
314
+ self.proc = ModuleWrapper(module=MultiIdentity())
315
+ self.proc.device = torch.device("cpu")
316
+ if self.proc_inputs is None:
317
+ self.proc_inputs = [Data4Proc(data_type="all", pubsub=False, private_only=False)]
318
+ if self.proc_outputs is None:
319
+ self.proc_outputs = [Data4Proc(data_type="all", pubsub=False, private_only=False)]
320
+ self.proc_opts = {'optimizer': None, 'losses': [None] * len(self.proc_outputs)}
321
+ else:
322
+
323
+ # Wrapping to have the basic attributes (device)
324
+ if not isinstance(self.proc, ModuleWrapper):
325
+ self.proc = ModuleWrapper(module=self.proc)
326
+ self.proc.device = torch.device("cpu")
327
+
328
+ # Guessing inputs, fixing attributes
329
+ if self.proc_inputs is None:
330
+ self.__guess_proc_inputs()
331
+
332
+ for j in range(len(self.proc_inputs)):
333
+ if self.proc_inputs[j].get_name() == "unk":
334
+ self.proc_inputs[j].set_name("proc_input_" + str(j))
335
+
336
+ # Guessing outputs, fixing attributes
337
+ if self.proc_outputs is None:
338
+ self.__guess_proc_outputs()
339
+
340
+ for j in range(len(self.proc_outputs)):
341
+ if self.proc_outputs[j].get_name() == "unk":
342
+ self.proc_outputs[j].set_name("proc_output_" + str(j))
343
+
344
+ # Guessing optimization-related options and stuff, fixing attributes
345
+ if (self.proc_opts is None or len(self.proc_opts) == 0 or
346
+ 'optimizer' not in self.proc_opts or 'losses' not in self.proc_opts):
347
+ self.__guess_proc_opts()
348
+ self.__fix_proc_opts()
349
+
350
+ # Ensuring all is OK
351
+ if self.proc is not None:
352
+ assert "optimizer" in self.proc_opts, "Missing 'optimizer' key in proc_opts (required)"
353
+ assert "losses" in self.proc_opts, "Missing 'losses' key in proc_opts (required)"
354
+
355
+ # Checking inputs with default values
356
+ if self.proc_optional_inputs is None:
357
+ self.__guess_proc_optional_inputs()
358
+
359
+ # Updating processor container object
360
+ processor_container.proc = self.proc
361
+ processor_container.proc_inputs = self.proc_inputs
362
+ processor_container.proc_outputs = self.proc_outputs
363
+ processor_container.proc_opts = self.proc_opts
364
+ processor_container.proc_optional_inputs = self.proc_optional_inputs
365
+
366
+ def __guess_proc_inputs(self):
367
+ if hasattr(self.proc, "proc_inputs"):
368
+ if self.proc.proc_inputs is not None:
369
+ self.proc_inputs = []
370
+ for p in self.proc.proc_inputs:
371
+ self.proc_inputs.append(p.clone())
372
+ return
373
+
374
+ first_layer = None
375
+
376
+ # Traverse modules to find the first real layer (skip containers like Sequential)
377
+ for layer in self.proc.modules():
378
+ if (not isinstance(layer, (torch.nn.Sequential,
379
+ torch.nn.ModuleList,
380
+ torch.nn.ModuleDict))
381
+ and not isinstance(layer, torch.nn.Module)
382
+ and hasattr(layer, 'weight')):
383
+ continue # Skip non-leaf layers
384
+ if isinstance(layer, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.Conv1d, torch.nn.Embedding)):
385
+ first_layer = layer
386
+ break
387
+
388
+ if first_layer is None:
389
+ raise ValueError("Cannot automatically guess the shape of the input data, "
390
+ "please explicitly provide it (proc_input)")
391
+
392
+ # Infer input properties
393
+ data_desc = "automatically guessed"
394
+ tensor_shape = None
395
+ tensor_labels = None
396
+ tensor_dtype = None
397
+ stream_to_proc_transforms = None
398
+ proc_to_stream_transforms = None
399
+
400
+ if isinstance(first_layer, torch.nn.Conv2d):
401
+
402
+ if first_layer.in_channels == 3 or first_layer.in_channels == 1:
403
+ data_type = "img"
404
+
405
+ # Creating dummy PIL images
406
+ rgb_input_img = Image.new('RGB', (224, 224))
407
+ pixels = rgb_input_img.load()
408
+ for x in range(28):
409
+ for y in range(28):
410
+ pixels[x, y] = (random.randint(0, 255),
411
+ random.randint(0, 255),
412
+ random.randint(0, 255))
413
+ gray_input_img = rgb_input_img.convert('L')
414
+
415
+ # Checking if the model supports PIL images as input
416
+ # noinspection PyBroadException
417
+ try:
418
+ _ = self.proc(rgb_input_img)
419
+ can_handle_rgb_img = True
420
+ except Exception:
421
+ can_handle_rgb_img = False
422
+
423
+ # Noinspection PyBroadException
424
+ try:
425
+ _ = self.proc(gray_input_img)
426
+ can_handle_gray_img = True
427
+ except Exception:
428
+ can_handle_gray_img = False
429
+
430
+ if can_handle_gray_img and can_handle_rgb_img:
431
+ stream_to_proc_transforms = None
432
+ elif can_handle_rgb_img:
433
+ stream_to_proc_transforms = transforms.Grayscale(num_output_channels=3)
434
+ elif can_handle_gray_img:
435
+ stream_to_proc_transforms = transforms.Grayscale()
436
+ else:
437
+ if first_layer.in_channels == 1:
438
+ stream_to_proc_transforms = transforms_factory("gray-no_norm")
439
+ else:
440
+ stream_to_proc_transforms = transforms_factory("rgb-no_norm")
441
+ else:
442
+
443
+ # If the number of input channels is not 1 and not 3...
444
+ data_type = "tensor"
445
+ tensor_shape = (first_layer.in_channels, None, None)
446
+ tensor_dtype = torch.float32
447
+
448
+ elif isinstance(first_layer, torch.nn.Conv1d):
449
+ data_type = "tensor"
450
+ tensor_shape = (first_layer.in_channels, None)
451
+ tensor_dtype = torch.float32
452
+ elif isinstance(first_layer, torch.nn.Linear):
453
+ data_type = "tensor"
454
+ tensor_dtype = torch.float32
455
+ tensor_shape = (first_layer.in_features,)
456
+ elif isinstance(first_layer, torch.nn.Embedding):
457
+
458
+ # Noinspection PyBroadException
459
+ try:
460
+ input_text = "testing if tokenizer is present"
461
+ _ = self.proc(input_text)
462
+ can_handle_text = True
463
+ can_handle_more_than_one_token = True # Unused
464
+ except Exception:
465
+ can_handle_text = False
466
+
467
+ # Noinspection PyBroadException
468
+ try:
469
+ device = torch.device("cpu")
470
+ for param in self.proc.parameters():
471
+ device = param.device
472
+ break
473
+ input_tokens = torch.tensor([[0, 1, 2, 3]], dtype=torch.long, device=device)
474
+ _ = self.proc(input_tokens)
475
+ can_handle_more_than_one_token = True
476
+ except Exception:
477
+ can_handle_more_than_one_token = False
478
+
479
+ if can_handle_text:
480
+ data_type = "text"
481
+ stream_to_proc_transforms = None
482
+ else:
483
+ data_type = "tensor"
484
+ if can_handle_more_than_one_token:
485
+ tensor_shape = (None,)
486
+ else:
487
+ tensor_shape = (1,)
488
+ tensor_dtype = torch.long
489
+ tensor_labels = ["token" + str(i) for i in range(0, first_layer.num_embeddings)]
490
+ else:
491
+ raise ValueError("Cannot automatically guess the shape of the input data, "
492
+ "please explicitly provide it (proc_input)")
493
+
494
+ # Setting the input attribute
495
+ self.proc_inputs = [Data4Proc(name="proc_input_0",
496
+ data_type=data_type,
497
+ data_desc=data_desc,
498
+ tensor_shape=tensor_shape,
499
+ tensor_labels=tensor_labels,
500
+ tensor_dtype=tensor_dtype,
501
+ stream_to_proc_transforms=stream_to_proc_transforms,
502
+ proc_to_stream_transforms=proc_to_stream_transforms,
503
+ pubsub=False,
504
+ private_only=True)]
505
+
506
+ def __guess_proc_outputs(self):
507
+ if hasattr(self.proc, "proc_outputs"):
508
+ if self.proc.proc_outputs is not None:
509
+ self.proc_outputs = []
510
+ for p in self.proc.proc_outputs:
511
+ self.proc_outputs.append(p.clone())
512
+ return
513
+
514
+ proc = self.proc
515
+ device = self.proc.device
516
+ inputs = []
517
+
518
+ for i, proc_input in enumerate(self.proc_inputs):
519
+ if proc_input.is_tensor():
520
+ inputs.append(proc_input.check_and_preprocess(
521
+ torch.randn([1] + list(proc_input.tensor_shape), # Adding batch size here
522
+ dtype=proc_input.tensor_dtype).to(device)))
523
+ elif proc_input.is_img():
524
+ rgb_input_img = Image.new('RGB', (224, 224))
525
+ pixels = rgb_input_img.load()
526
+ for x in range(224):
527
+ for y in range(224):
528
+ pixels[x, y] = (random.randint(0, 255),
529
+ random.randint(0, 255),
530
+ random.randint(0, 255))
531
+ inputs.append(proc_input.check_and_preprocess(rgb_input_img))
532
+ elif proc_input.is_text():
533
+ inputs.append(proc_input.check_and_preprocess("test text as input"))
534
+
535
+ # Forward
536
+ with torch.no_grad():
537
+ outputs = proc(*inputs)
538
+ if not isinstance(outputs, tuple | list):
539
+ outputs = [outputs]
540
+ if isinstance(outputs, tuple):
541
+ outputs = list(outputs)
542
+
543
+ # This will be filled below
544
+ self.proc_outputs = []
545
+
546
+ for j, output in enumerate(outputs):
547
+
548
+ # Infer output properties
549
+ data_desc = "automatically guessed"
550
+ tensor_shape = None
551
+ tensor_labels = None
552
+ tensor_dtype = None
553
+ stream_to_proc_transforms = None
554
+ proc_to_stream_transforms = None
555
+
556
+ if isinstance(output, Image.Image): # PIL Image
557
+ data_type = "img"
558
+ elif isinstance(output, torch.Tensor): # Tensor
559
+ output_shape = list(output.shape[1:]) # Removing batch size here
560
+ if len(output_shape) == 3 and (output_shape[0] == 3 or output_shape[0] == 1):
561
+ data_type = "img"
562
+ if output_shape[0] == 3:
563
+ proc_to_stream_transforms = transforms_factory("rgb", return_inverse=True)
564
+ else:
565
+ proc_to_stream_transforms = transforms_factory("gray", return_inverse=True)
566
+ else:
567
+ data_type = "tensor"
568
+ tensor_dtype = str(output.dtype)
569
+ tensor_shape = output_shape
570
+ tensor_labels = None
571
+ elif isinstance(output, str):
572
+ data_type = "text"
573
+ else:
574
+ raise ValueError(f"Unsupported output type {type(output)}")
575
+
576
+ # Setting the output attribute
577
+ self.proc_outputs.append(Data4Proc(name="proc_output_" + str(j),
578
+ data_type=data_type,
579
+ data_desc=data_desc,
580
+ tensor_shape=tensor_shape,
581
+ tensor_labels=tensor_labels,
582
+ tensor_dtype=tensor_dtype,
583
+ stream_to_proc_transforms=stream_to_proc_transforms,
584
+ proc_to_stream_transforms=proc_to_stream_transforms,
585
+ pubsub=False,
586
+ private_only=True))
587
+
588
+ def __guess_proc_opts(self):
589
+ if self.proc_opts is None:
590
+ if isinstance(self.proc.module, MultiIdentity) or len(list(self.proc.parameters())) == 0:
591
+ self.proc_opts = {"optimizer": None,
592
+ "losses": [None] * len(self.proc_outputs)}
593
+ else:
594
+ self.proc_opts = {"optimizer": torch.optim.SGD(self.proc.parameters(), lr=1e-5),
595
+ "losses": [torch.nn.functional.mse_loss] * len(self.proc_outputs)}
596
+ else:
597
+ if "optimizer" not in self.proc_opts:
598
+ self.proc_opts["optimizer"] = None
599
+ if "losses" not in self.proc_opts:
600
+ self.proc_opts["losses"] = [None] * len(self.proc_outputs)
601
+
602
+ def __fix_proc_opts(self):
603
+ opts = {}
604
+ found_optimizer = False
605
+ found_loss = False
606
+ cannot_fix = False
607
+
608
+ if "optimizer" in self.proc_opts:
609
+ found_optimizer = True
610
+ if "losses" in self.proc_opts:
611
+ found_loss = True
612
+
613
+ if not found_loss:
614
+ opts['losses'] = [torch.nn.functional.mse_loss] * len(self.proc_opts)
615
+
616
+ for k, v in self.proc_opts.items():
617
+ if isinstance(v, torch.optim.Optimizer):
618
+ if k == "optimizer":
619
+ opts["optimizer"] = v
620
+ continue
621
+ else:
622
+ if not found_optimizer:
623
+ opts["optimizer"] = v
624
+ found_optimizer = True
625
+ else:
626
+ cannot_fix = True
627
+ break
628
+ elif k == "losses" and isinstance(v, list) or isinstance(v, tuple):
629
+ opts["losses"] = v
630
+ continue
631
+ elif (v == torch.nn.functional.mse_loss or isinstance(v, torch.nn.MSELoss)
632
+ or v == torch.nn.functional.binary_cross_entropy or isinstance(v, torch.nn.BCELoss)
633
+ or isinstance(v, torch.nn.CrossEntropyLoss) or v == torch.nn.functional.cross_entropy):
634
+ if not found_loss:
635
+ opts["losses"] = [v]
636
+ found_loss = True
637
+ else:
638
+ cannot_fix = True
639
+ break
640
+ else:
641
+ opts[k] = v
642
+
643
+ if not found_optimizer:
644
+ if 'lr' in opts:
645
+ opts['optimizer'] = torch.optim.SGD(self.proc.parameters(), lr=opts['lr'])
646
+
647
+ assert not cannot_fix, \
648
+ "About proc_opts: cannot find required keys ('optimizer', 'losses') and/or cannot automatically guess them"
649
+
650
+ # Removing batch dim from targets in case of cross-entropy
651
+ fixed_list = []
652
+ for _loss_fcn in opts['losses']:
653
+ if _loss_fcn == torch.nn.functional.cross_entropy or isinstance(_loss_fcn, torch.nn.CrossEntropyLoss):
654
+ fixed_list.append(target_shape_fixed_cross_entropy)
655
+ else:
656
+ fixed_list.append(_loss_fcn)
657
+ opts['losses'] = fixed_list
658
+
659
+ # Updating
660
+ self.proc_opts = opts
661
+
662
+ def __guess_proc_optional_inputs(self):
663
+ self.proc_optional_inputs = []
664
+ if isinstance(self.proc, ModuleWrapper):
665
+ if hasattr(self.proc.module, "forward"):
666
+ sig = inspect.signature(self.proc.module.forward)
667
+ else:
668
+ sig = inspect.signature(self.proc.forward)
669
+ else:
670
+ sig = inspect.signature(self.proc.forward)
671
+
672
+ i = 0
673
+ for name, param in sig.parameters.items():
674
+ if i >= len(self.proc_inputs):
675
+ break
676
+ if param.default is not inspect.Parameter.empty:
677
+ self.proc_optional_inputs.append({"has_default": True, "default_value": param.default})
678
+ else:
679
+ self.proc_optional_inputs.append({"has_default": False, "default_value": None})
680
+ i += 1