unaiverse 0.1.11__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unaiverse might be problematic. Click here for more details.

Files changed (50) hide show
  1. unaiverse/__init__.py +19 -0
  2. unaiverse/agent.py +2090 -0
  3. unaiverse/agent_basics.py +1948 -0
  4. unaiverse/clock.py +221 -0
  5. unaiverse/dataprops.py +1236 -0
  6. unaiverse/hsm.py +1892 -0
  7. unaiverse/modules/__init__.py +18 -0
  8. unaiverse/modules/cnu/__init__.py +17 -0
  9. unaiverse/modules/cnu/cnus.py +536 -0
  10. unaiverse/modules/cnu/layers.py +261 -0
  11. unaiverse/modules/cnu/psi.py +60 -0
  12. unaiverse/modules/hl/__init__.py +15 -0
  13. unaiverse/modules/hl/hl_utils.py +411 -0
  14. unaiverse/modules/networks.py +1509 -0
  15. unaiverse/modules/utils.py +710 -0
  16. unaiverse/networking/__init__.py +16 -0
  17. unaiverse/networking/node/__init__.py +18 -0
  18. unaiverse/networking/node/connpool.py +1308 -0
  19. unaiverse/networking/node/node.py +2499 -0
  20. unaiverse/networking/node/profile.py +446 -0
  21. unaiverse/networking/node/tokens.py +79 -0
  22. unaiverse/networking/p2p/__init__.py +187 -0
  23. unaiverse/networking/p2p/go.mod +127 -0
  24. unaiverse/networking/p2p/go.sum +548 -0
  25. unaiverse/networking/p2p/golibp2p.py +18 -0
  26. unaiverse/networking/p2p/golibp2p.pyi +135 -0
  27. unaiverse/networking/p2p/lib.go +2662 -0
  28. unaiverse/networking/p2p/lib.go.sha256 +1 -0
  29. unaiverse/networking/p2p/lib_types.py +312 -0
  30. unaiverse/networking/p2p/message_pb2.py +50 -0
  31. unaiverse/networking/p2p/messages.py +362 -0
  32. unaiverse/networking/p2p/mylogger.py +77 -0
  33. unaiverse/networking/p2p/p2p.py +871 -0
  34. unaiverse/networking/p2p/proto-go/message.pb.go +846 -0
  35. unaiverse/networking/p2p/unailib.cpython-311-darwin.so +0 -0
  36. unaiverse/stats.py +1481 -0
  37. unaiverse/streamlib/__init__.py +15 -0
  38. unaiverse/streamlib/streamlib.py +210 -0
  39. unaiverse/streams.py +776 -0
  40. unaiverse/utils/__init__.py +16 -0
  41. unaiverse/utils/lone_wolf.json +24 -0
  42. unaiverse/utils/misc.py +310 -0
  43. unaiverse/utils/sandbox.py +293 -0
  44. unaiverse/utils/server.py +435 -0
  45. unaiverse/world.py +335 -0
  46. unaiverse-0.1.11.dist-info/METADATA +367 -0
  47. unaiverse-0.1.11.dist-info/RECORD +50 -0
  48. unaiverse-0.1.11.dist-info/WHEEL +6 -0
  49. unaiverse-0.1.11.dist-info/licenses/LICENSE +43 -0
  50. unaiverse-0.1.11.dist-info/top_level.txt +1 -0
@@ -0,0 +1,710 @@
1
+ """
2
+ █████ █████ ██████ █████ █████ █████ █████ ██████████ ███████████ █████████ ██████████
3
+ ░░███ ░░███ ░░██████ ░░███ ░░███ ░░███ ░░███ ░░███░░░░░█░░███░░░░░███ ███░░░░░███░░███░░░░░█
4
+ ░███ ░███ ░███░███ ░███ ██████ ░███ ░███ ░███ ░███ █ ░ ░███ ░███ ░███ ░░░ ░███ █ ░
5
+ ░███ ░███ ░███░░███░███ ░░░░░███ ░███ ░███ ░███ ░██████ ░██████████ ░░█████████ ░██████
6
+ ░███ ░███ ░███ ░░██████ ███████ ░███ ░░███ ███ ░███░░█ ░███░░░░░███ ░░░░░░░░███ ░███░░█
7
+ ░███ ░███ ░███ ░░█████ ███░░███ ░███ ░░░█████░ ░███ ░ █ ░███ ░███ ███ ░███ ░███ ░ █
8
+ ░░████████ █████ ░░█████░░████████ █████ ░░███ ██████████ █████ █████░░█████████ ██████████
9
+ ░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░ ░░░░░ ░░░ ░░░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░░ ░░░░░░░░░░
10
+ A Collectionless AI Project (https://collectionless.ai)
11
+ Registration/Login: https://unaiverse.io
12
+ Code Repositories: https://github.com/collectionlessai/
13
+ Main Developers: Stefano Melacci (Project Leader), Christian Di Maio, Tommaso Guidi
14
+ """
15
+ import io
16
+ import os
17
+ import io
18
+ import torch
19
+ import random
20
+ import inspect
21
+ import threading
22
+ import numpy as np
23
+ from PIL import Image
24
+ from torch.utils.data import DataLoader
25
+ from unaiverse.dataprops import Data4Proc
26
+ from torchvision import datasets, transforms
27
+
28
+
29
+ def transforms_factory(trans_type: str, add_batch_dim: bool = True, return_inverse: bool = False):
30
+ supported_types = {"rgb*", "gray*",
31
+ "rgb-no_norm*", "gray-no_norm*",
32
+ "rgb", "gray",
33
+ "rgb-no_norm", "gray-no_norm",
34
+ "gray_mnist"}
35
+
36
+ found = False
37
+ num = -1
38
+ for _type in supported_types:
39
+ if _type.endswith("*"):
40
+ has_star = True
41
+ __type = _type[0:-1]
42
+ else:
43
+ has_star = False
44
+ __type = _type
45
+
46
+ if has_star and trans_type.startswith(__type) and len(trans_type) > len(__type):
47
+ try:
48
+ num = int(trans_type[len(__type):])
49
+ trans_type = _type
50
+ found = True
51
+ break
52
+ except ValueError:
53
+ pass
54
+ elif trans_type == _type:
55
+ found = True
56
+ break
57
+
58
+ if not found:
59
+ raise ValueError(f"Invalid transformation type '{trans_type}': must be one of {supported_types}, "
60
+ f"where * is an integer number")
61
+
62
+ trans = None
63
+ inverse_trans = None
64
+
65
+ if trans_type == "rgb*":
66
+ trans = transforms.Compose([
67
+ transforms.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), # Ensure 3 channels
68
+ transforms.Resize(num),
69
+ transforms.CenterCrop(num),
70
+ transforms.ToTensor(), # Convert PIL to tensor (3, H, W), float [0,1]
71
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
72
+ std=[0.229, 0.224, 0.225]),
73
+ ])
74
+ inverse_trans = transforms.Compose([
75
+ transforms.Normalize(mean=[0., 0., 0.],
76
+ std=[1. / 0.229, 1. / 0.224, 1. / 0.225]),
77
+ transforms.Lambda(lambda x: x + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)),
78
+ transforms.ToPILImage()
79
+ ])
80
+ elif trans_type == "gray*":
81
+ trans = transforms.Compose([
82
+ transforms.Lambda(lambda img: img.convert("L") if img.mode != "L" else img), # Ensure 1 channel
83
+ transforms.Resize(num),
84
+ transforms.CenterCrop(num),
85
+ transforms.ToTensor(), # Convert PIL to tensor (1, H, W), float [0,1]
86
+ transforms.Normalize(mean=[0.45],
87
+ std=[0.225])
88
+ ])
89
+ inverse_trans = transforms.Compose([
90
+ transforms.Normalize(mean=[0.],
91
+ std=[1. / 0.225]),
92
+ transforms.Lambda(lambda x: x + 0.45),
93
+ transforms.ToPILImage()
94
+ ])
95
+ elif trans_type == "rgb-no_norm*":
96
+ trans = transforms.Compose([
97
+ transforms.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), # Ensure 3 channels
98
+ transforms.Resize(num),
99
+ transforms.CenterCrop(num),
100
+ transforms.PILToTensor(), # Convert PIL to tensor (3, H, W), uint [0,255]
101
+ ])
102
+ inverse_trans = transforms.Compose([transforms.ToPILImage()])
103
+ elif trans_type == "gray-no_norm*":
104
+ trans = transforms.Compose([
105
+ transforms.Lambda(lambda img: img.convert("L") if img.mode != "L" else img), # Ensure 1 channel
106
+ transforms.Resize(num),
107
+ transforms.CenterCrop(num),
108
+ transforms.PILToTensor(), # Convert PIL to tensor (1, H, W), uint [0,255]
109
+ ])
110
+ inverse_trans = transforms.ToPILImage()
111
+ elif trans_type == "rgb":
112
+ trans = transforms.Compose([
113
+ transforms.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), # Ensure 3 channels
114
+ transforms.ToTensor(), # Convert PIL to tensor (3, H, W), float [0,1]
115
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
116
+ std=[0.229, 0.224, 0.225])
117
+ ])
118
+ inverse_trans = transforms.Compose([
119
+ transforms.Normalize(mean=[0., 0., 0.],
120
+ std=[1. / 0.229, 1. / 0.224, 1. / 0.225]),
121
+ transforms.Lambda(lambda x: x + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)),
122
+ transforms.ToPILImage()
123
+ ])
124
+ elif trans_type == "gray":
125
+ trans = transforms.Compose([
126
+ transforms.Lambda(lambda img: img.convert("L") if img.mode != "L" else img), # Ensure 1 channel
127
+ transforms.ToTensor(), # Convert PIL to tensor (1, H, W), float [0,1]
128
+ transforms.Normalize(mean=[0.45],
129
+ std=[0.225])
130
+ ])
131
+ inverse_trans = transforms.Compose([
132
+ transforms.Normalize(mean=[0.],
133
+ std=[1. / 0.225]),
134
+ transforms.Lambda(lambda x: x + 0.45),
135
+ transforms.ToPILImage()
136
+ ])
137
+ elif trans_type == "rgb-no_norm":
138
+ trans = transforms.Compose([
139
+ transforms.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), # Ensure 3 channels
140
+ transforms.PILToTensor(), # Convert PIL to tensor (3, H, W), uint [0,255]
141
+ ])
142
+ inverse_trans = transforms.Compose([transforms.ToPILImage()])
143
+ elif trans_type == "gray-no_norm":
144
+ trans = transforms.Compose([
145
+ transforms.Lambda(lambda img: img.convert("L") if img.mode != "L" else img), # Ensure 1 channel
146
+ transforms.PILToTensor(), # Convert PIL to tensor (1, H, W), uint [0,255]
147
+ ])
148
+ inverse_trans = transforms.Compose([transforms.ToPILImage()])
149
+ elif trans_type == "gray_mnist":
150
+ trans = transforms.Compose([
151
+ transforms.Lambda(lambda img: img.convert("L") if img.mode != "L" else img), # Ensure 1 channel
152
+ transforms.Resize(28),
153
+ transforms.CenterCrop(28),
154
+ transforms.ToTensor(), # Convert PIL to tensor (1, H, W), float [0,1]
155
+ transforms.Normalize(mean=[0.1307], # MNIST
156
+ std=[0.3081]) # MNIST
157
+ ])
158
+ inverse_trans = transforms.Compose([
159
+ transforms.Normalize(mean=[0.],
160
+ std=[1. / 0.3081]),
161
+ transforms.Lambda(lambda x: x + 0.1307),
162
+ transforms.ToPILImage()
163
+ ])
164
+
165
+ if add_batch_dim:
166
+ trans.transforms.append(transforms.Lambda(lambda x: x.unsqueeze(0)))
167
+ inverse_trans.transforms.insert(0, transforms.Lambda(lambda x: x.squeeze(0)))
168
+
169
+ return trans if not return_inverse else inverse_trans
170
+
171
+
172
+ def hard_tanh(x: torch.Tensor) -> torch.Tensor:
173
+ return torch.clamp(x, min=-1., max=1.)
174
+
175
+
176
+ def target_shape_fixed_cross_entropy(output, target, *args, **kwargs):
177
+ if len(target.shape) > 1:
178
+ target = target.squeeze(0)
179
+ return torch.nn.functional.cross_entropy(output, target, *args, **kwargs)
180
+
181
+
182
+ def set_seed(seed: int) -> None:
183
+ if seed >= 0:
184
+ torch.manual_seed(seed)
185
+ random.seed(seed)
186
+ np.random.seed(0)
187
+
188
+
189
+ def get_proc_inputs_and_proc_outputs_for_rnn(u_shape: torch.Size | tuple, du_dim: int, y_dim: int):
190
+ if isinstance(u_shape, torch.Size):
191
+ u_shape = tuple(u_shape)
192
+ proc_inputs = [
193
+ Data4Proc(data_type="tensor", tensor_shape=(None,) + u_shape, tensor_dtype=torch.float32,
194
+ pubsub=False, private_only=True),
195
+ Data4Proc(data_type="tensor", tensor_shape=(None, du_dim,), tensor_dtype=torch.float32,
196
+ pubsub=False, private_only=True)
197
+ ]
198
+ proc_outputs = [
199
+ Data4Proc(data_type="tensor", tensor_shape=(None, y_dim), tensor_dtype=torch.float32,
200
+ pubsub=False, private_only=True)
201
+ ]
202
+ return proc_inputs, proc_outputs
203
+
204
+
205
+ def get_proc_inputs_and_proc_outputs_for_image_classification(y_dim: int):
206
+ if y_dim == -1:
207
+ y_dim = 1000 # Assuming ImageNet-trained models
208
+ proc_inputs = [Data4Proc(data_type="img", pubsub=False, private_only=True)]
209
+ proc_outputs = [Data4Proc(data_type="tensor", tensor_shape=(None, y_dim), tensor_dtype=torch.float32,
210
+ pubsub=False, private_only=True)]
211
+ return proc_inputs, proc_outputs
212
+
213
+
214
+ def isinstance_fcn(obj, class_to_check):
215
+ return isinstance(obj, class_to_check)
216
+
217
+
218
+ def error_rate_mnist_test_set(network: torch.nn.Module, mnist_data_save_path: str):
219
+
220
+ # Getting MNIST test set
221
+ mnist_test = datasets.MNIST(root=mnist_data_save_path,
222
+ train=False, download=True,
223
+ transform=transforms.Compose([
224
+ transforms.ToTensor(),
225
+ transforms.Normalize((0.1307,), (0.3081,))
226
+ ]))
227
+ mnist_test = DataLoader(mnist_test, batch_size=200, shuffle=False)
228
+
229
+ # Checking error rate
230
+ error_rate = 0.
231
+ n = 0
232
+ training_flag_backup = network.training
233
+ network.eval()
234
+ device = next(network.parameters()).device
235
+ for x, y in mnist_test:
236
+ x = x.to(device)
237
+ y = y.to(device)
238
+ o = network(x)
239
+ c = torch.argmax(o, dim=1)
240
+ error_rate += float(torch.sum(c != y).item())
241
+ n += x.shape[0]
242
+ error_rate /= n
243
+ network.training = training_flag_backup
244
+
245
+ return error_rate
246
+
247
+
248
+ class MultiIdentity(torch.nn.Module):
249
+ def __init__(self):
250
+ super().__init__()
251
+
252
+ def forward(self, *args):
253
+ if len(args) == 1:
254
+ return args[0]
255
+ return args
256
+
257
+
258
+ class HumanModule(torch.nn.Module):
259
+ def __init__(self):
260
+ super().__init__()
261
+
262
+ def forward(self, text: str = None, img: Image = None):
263
+ return text, img
264
+
265
+
266
+ class ModuleWrapper(torch.nn.Module):
267
+ def __init__(self,
268
+ module: torch.nn.Module | None = None,
269
+ proc_inputs: list[Data4Proc] | None = None,
270
+ proc_outputs: list[Data4Proc] | None = None,
271
+ seed: int = -1):
272
+ super(ModuleWrapper, self).__init__()
273
+ self.device = None # The device which is supposed to host the module
274
+ self.module = None # The module itself
275
+ self.proc_inputs = proc_inputs # The list of Data4Proc objects describing the input types of the module
276
+ self.proc_outputs = proc_outputs # The list of Data4Proc objects describing the output types of the module
277
+
278
+ # Working
279
+ set_seed(seed)
280
+ device_env = os.getenv("PROC_DEVICE", None)
281
+ self.device = torch.device("cpu") if device_env is None else torch.device(device_env)
282
+ self.module = module.to(self.device) if module is not None else None
283
+
284
+ def forward(self, *args, **kwargs):
285
+
286
+ # The forward signature expected by who calls this method is:
287
+ # forward(self, *args, first: bool, last: bool, **kwargs)
288
+ # so we have to discard 'first' and 'last' that are not used by an external module not designed for this library
289
+ del kwargs['first']
290
+ del kwargs['last']
291
+
292
+ # Calling the module
293
+ return self.module(*args, **kwargs)
294
+
295
+
296
+ class AgentProcessorChecker:
297
+
298
+ def __init__(self, processor_container: object):
299
+ assert hasattr(processor_container, 'proc'), "Invalid processor container object"
300
+ assert hasattr(processor_container, 'proc_inputs'), "Invalid processor container object"
301
+ assert hasattr(processor_container, 'proc_outputs'), "Invalid processor container object"
302
+ assert hasattr(processor_container, 'proc_opts'), "Invalid processor container object"
303
+ assert hasattr(processor_container, 'proc_optional_inputs'), "Invalid processor container object"
304
+
305
+ # Getting processor-related info from the main object which collects processor and its properties
306
+ proc: torch.nn.Module = processor_container.proc
307
+ proc_inputs: list[Data4Proc] | None = processor_container.proc_inputs
308
+ proc_outputs: list[Data4Proc] | None = processor_container.proc_outputs
309
+ proc_opts: dict | None = processor_container.proc_opts
310
+ proc_optional_inputs: list | None = processor_container.proc_optional_inputs
311
+
312
+ assert proc is None or isinstance(proc, torch.nn.Module), "Processor (proc) must be a torch.nn.Module"
313
+ assert (proc_inputs is None or (
314
+ isinstance_fcn(proc_inputs, list) and (len(proc_inputs) == 0 or
315
+ (len(proc_inputs) > 0 and
316
+ isinstance_fcn(proc_inputs[0], Data4Proc))))), \
317
+ "Invalid proc_inputs: it must be None or a list of Data4Proc"
318
+ assert (proc_outputs is None or (
319
+ isinstance_fcn(proc_inputs, list) and (len(proc_inputs) == 0 or
320
+ (len(proc_inputs) > 0 and
321
+ isinstance_fcn(proc_inputs[0], Data4Proc))))), \
322
+ "Invalid proc_inputs: it must be None or a list of Data4Proc"
323
+ assert (proc_opts is None or isinstance_fcn(proc_opts, dict)), \
324
+ "Invalid proc_opts: it must be None or a dictionary"
325
+
326
+ # Saving as attributes
327
+ self.proc = proc
328
+ self.proc_inputs = proc_inputs
329
+ self.proc_outputs = proc_outputs
330
+ self.proc_opts = proc_opts
331
+ self.proc_optional_inputs = proc_optional_inputs
332
+
333
+ # Dummy processor (if no processor was provided)
334
+ if self.proc is None:
335
+ self.proc = ModuleWrapper(module=MultiIdentity())
336
+ self.proc.device = torch.device("cpu")
337
+ if self.proc_inputs is None:
338
+ self.proc_inputs = [Data4Proc(data_type="all", pubsub=False, private_only=False)]
339
+ if self.proc_outputs is None:
340
+ self.proc_outputs = [Data4Proc(data_type="all", pubsub=False, private_only=False)]
341
+ self.proc_opts = {'optimizer': None, 'losses': [None] * len(self.proc_outputs)}
342
+ else:
343
+
344
+ # String telling it is a human
345
+ if isinstance(self.proc, str) and self.proc.lower() == "human":
346
+ self.proc = ModuleWrapper(module=HumanModule())
347
+ self.proc.device = torch.device("cpu")
348
+ self.proc_inputs = [Data4Proc(data_type="text", pubsub=False, private_only=False),
349
+ Data4Proc(data_type="img", pubsub=False, private_only=False)]
350
+ self.proc_outputs = [Data4Proc(data_type="text", pubsub=False, private_only=False),
351
+ Data4Proc(data_type="img", pubsub=False, private_only=False)]
352
+
353
+ # Wrapping to have the basic attributes (device)
354
+ elif not isinstance(self.proc, ModuleWrapper):
355
+ self.proc = ModuleWrapper(module=self.proc)
356
+ self.proc.device = torch.device("cpu")
357
+
358
+ # Guessing inputs, fixing attributes
359
+ if self.proc_inputs is None:
360
+ self.__guess_proc_inputs()
361
+
362
+ for j in range(len(self.proc_inputs)):
363
+ if self.proc_inputs[j].get_name() == "unk":
364
+ self.proc_inputs[j].set_name("proc_input_" + str(j))
365
+
366
+ # Guessing outputs, fixing attributes
367
+ if self.proc_outputs is None:
368
+ self.__guess_proc_outputs()
369
+
370
+ for j in range(len(self.proc_outputs)):
371
+ if self.proc_outputs[j].get_name() == "unk":
372
+ self.proc_outputs[j].set_name("proc_output_" + str(j))
373
+
374
+ # Guessing optimization-related options and stuff, fixing attributes
375
+ if (self.proc_opts is None or len(self.proc_opts) == 0 or
376
+ 'optimizer' not in self.proc_opts or 'losses' not in self.proc_opts):
377
+ self.__guess_proc_opts()
378
+ self.__fix_proc_opts()
379
+
380
+ # Ensuring all is OK
381
+ if self.proc is not None:
382
+ assert "optimizer" in self.proc_opts, "Missing 'optimizer' key in proc_opts (required)"
383
+ assert "losses" in self.proc_opts, "Missing 'losses' key in proc_opts (required)"
384
+
385
+ # Checking inputs with default values
386
+ if self.proc_optional_inputs is None:
387
+ self.__guess_proc_optional_inputs()
388
+
389
+ # Updating processor container object
390
+ processor_container.proc = self.proc
391
+ processor_container.proc_inputs = self.proc_inputs
392
+ processor_container.proc_outputs = self.proc_outputs
393
+ processor_container.proc_opts = self.proc_opts
394
+ processor_container.proc_optional_inputs = self.proc_optional_inputs
395
+
396
+ def __guess_proc_inputs(self):
397
+ if hasattr(self.proc, "proc_inputs"):
398
+ if self.proc.proc_inputs is not None:
399
+ self.proc_inputs = []
400
+ for p in self.proc.proc_inputs:
401
+ self.proc_inputs.append(p.clone())
402
+ return
403
+
404
+ first_layer = None
405
+
406
+ # Traverse modules to find the first real layer (skip containers like Sequential)
407
+ for layer in self.proc.modules():
408
+ if (not isinstance(layer, (torch.nn.Sequential,
409
+ torch.nn.ModuleList,
410
+ torch.nn.ModuleDict))
411
+ and not isinstance(layer, torch.nn.Module)
412
+ and hasattr(layer, 'weight')):
413
+ continue # Skip non-leaf layers
414
+ if isinstance(layer, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.Conv1d, torch.nn.Embedding)):
415
+ first_layer = layer
416
+ break
417
+
418
+ if first_layer is None:
419
+ raise ValueError("Cannot automatically guess the shape of the input data, "
420
+ "please explicitly provide it (proc_input)")
421
+
422
+ # Infer input properties
423
+ data_desc = "automatically guessed"
424
+ tensor_shape = None
425
+ tensor_labels = None
426
+ tensor_dtype = None
427
+ stream_to_proc_transforms = None
428
+ proc_to_stream_transforms = None
429
+
430
+ if isinstance(first_layer, torch.nn.Conv2d):
431
+
432
+ if first_layer.in_channels == 3 or first_layer.in_channels == 1:
433
+ data_type = "img"
434
+
435
+ # Creating dummy PIL images
436
+ rgb_input_img = Image.new('RGB', (224, 224))
437
+ pixels = rgb_input_img.load()
438
+ for x in range(28):
439
+ for y in range(28):
440
+ pixels[x, y] = (random.randint(0, 255),
441
+ random.randint(0, 255),
442
+ random.randint(0, 255))
443
+ gray_input_img = rgb_input_img.convert('L')
444
+
445
+ # Checking if the model supports PIL images as input
446
+ # noinspection PyBroadException
447
+ try:
448
+ _ = self.proc(rgb_input_img)
449
+ can_handle_rgb_img = True
450
+ except Exception:
451
+ can_handle_rgb_img = False
452
+
453
+ # Noinspection PyBroadException
454
+ try:
455
+ _ = self.proc(gray_input_img)
456
+ can_handle_gray_img = True
457
+ except Exception:
458
+ can_handle_gray_img = False
459
+
460
+ if can_handle_gray_img and can_handle_rgb_img:
461
+ stream_to_proc_transforms = None
462
+ elif can_handle_rgb_img:
463
+ stream_to_proc_transforms = transforms.Grayscale(num_output_channels=3)
464
+ elif can_handle_gray_img:
465
+ stream_to_proc_transforms = transforms.Grayscale()
466
+ else:
467
+ if first_layer.in_channels == 1:
468
+ stream_to_proc_transforms = transforms_factory("gray-no_norm")
469
+ else:
470
+ stream_to_proc_transforms = transforms_factory("rgb-no_norm")
471
+ else:
472
+
473
+ # If the number of input channels is not 1 and not 3...
474
+ data_type = "tensor"
475
+ tensor_shape = (first_layer.in_channels, None, None)
476
+ tensor_dtype = torch.float32
477
+
478
+ elif isinstance(first_layer, torch.nn.Conv1d):
479
+ data_type = "tensor"
480
+ tensor_shape = (first_layer.in_channels, None)
481
+ tensor_dtype = torch.float32
482
+ elif isinstance(first_layer, torch.nn.Linear):
483
+ data_type = "tensor"
484
+ tensor_dtype = torch.float32
485
+ tensor_shape = (first_layer.in_features,)
486
+ elif isinstance(first_layer, torch.nn.Embedding):
487
+
488
+ # Noinspection PyBroadException
489
+ try:
490
+ input_text = "testing if tokenizer is present"
491
+ _ = self.proc(input_text)
492
+ can_handle_text = True
493
+ can_handle_more_than_one_token = True # Unused
494
+ except Exception:
495
+ can_handle_text = False
496
+
497
+ # Noinspection PyBroadException
498
+ try:
499
+ device = torch.device("cpu")
500
+ for param in self.proc.parameters():
501
+ device = param.device
502
+ break
503
+ input_tokens = torch.tensor([[0, 1, 2, 3]], dtype=torch.long, device=device)
504
+ _ = self.proc(input_tokens)
505
+ can_handle_more_than_one_token = True
506
+ except Exception:
507
+ can_handle_more_than_one_token = False
508
+
509
+ if can_handle_text:
510
+ data_type = "text"
511
+ stream_to_proc_transforms = None
512
+ else:
513
+ data_type = "tensor"
514
+ if can_handle_more_than_one_token:
515
+ tensor_shape = (None,)
516
+ else:
517
+ tensor_shape = (1,)
518
+ tensor_dtype = torch.long
519
+ tensor_labels = ["token" + str(i) for i in range(0, first_layer.num_embeddings)]
520
+ else:
521
+ raise ValueError("Cannot automatically guess the shape of the input data, "
522
+ "please explicitly provide it (proc_input)")
523
+
524
+ # Setting the input attribute
525
+ self.proc_inputs = [Data4Proc(name="proc_input_0",
526
+ data_type=data_type,
527
+ data_desc=data_desc,
528
+ tensor_shape=tensor_shape,
529
+ tensor_labels=tensor_labels,
530
+ tensor_dtype=tensor_dtype,
531
+ stream_to_proc_transforms=stream_to_proc_transforms,
532
+ proc_to_stream_transforms=proc_to_stream_transforms,
533
+ pubsub=False,
534
+ private_only=True)]
535
+
536
+ def __guess_proc_outputs(self):
537
+ if hasattr(self.proc, "proc_outputs"):
538
+ if self.proc.proc_outputs is not None:
539
+ self.proc_outputs = []
540
+ for p in self.proc.proc_outputs:
541
+ self.proc_outputs.append(p.clone())
542
+ return
543
+
544
+ proc = self.proc
545
+ device = self.proc.device
546
+ inputs = []
547
+
548
+ for i, proc_input in enumerate(self.proc_inputs):
549
+ if proc_input.is_tensor():
550
+ inputs.append(proc_input.check_and_preprocess(
551
+ torch.randn([1] + list(proc_input.tensor_shape), # Adding batch size here
552
+ dtype=proc_input.tensor_dtype).to(device)))
553
+ elif proc_input.is_img():
554
+ rgb_input_img = Image.new('RGB', (224, 224))
555
+ pixels = rgb_input_img.load()
556
+ for x in range(224):
557
+ for y in range(224):
558
+ pixels[x, y] = (random.randint(0, 255),
559
+ random.randint(0, 255),
560
+ random.randint(0, 255))
561
+ inputs.append(proc_input.check_and_preprocess(rgb_input_img))
562
+ elif proc_input.is_text():
563
+ inputs.append(proc_input.check_and_preprocess("test text as input"))
564
+
565
+ # Forward
566
+ with torch.no_grad():
567
+ outputs = proc(*inputs)
568
+ if not isinstance(outputs, tuple | list):
569
+ outputs = [outputs]
570
+ if isinstance(outputs, tuple):
571
+ outputs = list(outputs)
572
+
573
+ # This will be filled below
574
+ self.proc_outputs = []
575
+
576
+ for j, output in enumerate(outputs):
577
+
578
+ # Infer output properties
579
+ data_desc = "automatically guessed"
580
+ tensor_shape = None
581
+ tensor_labels = None
582
+ tensor_dtype = None
583
+ stream_to_proc_transforms = None
584
+ proc_to_stream_transforms = None
585
+
586
+ if isinstance(output, Image.Image): # PIL Image
587
+ data_type = "img"
588
+ elif isinstance(output, torch.Tensor): # Tensor
589
+ output_shape = list(output.shape[1:]) # Removing batch size here
590
+ if len(output_shape) == 3 and (output_shape[0] == 3 or output_shape[0] == 1):
591
+ data_type = "img"
592
+ if output_shape[0] == 3:
593
+ proc_to_stream_transforms = transforms_factory("rgb", return_inverse=True)
594
+ else:
595
+ proc_to_stream_transforms = transforms_factory("gray", return_inverse=True)
596
+ else:
597
+ data_type = "tensor"
598
+ tensor_dtype = str(output.dtype)
599
+ tensor_shape = output_shape
600
+ tensor_labels = None
601
+ elif isinstance(output, str):
602
+ data_type = "text"
603
+ else:
604
+ raise ValueError(f"Unsupported output type {type(output)}")
605
+
606
+ # Setting the output attribute
607
+ self.proc_outputs.append(Data4Proc(name="proc_output_" + str(j),
608
+ data_type=data_type,
609
+ data_desc=data_desc,
610
+ tensor_shape=tensor_shape,
611
+ tensor_labels=tensor_labels,
612
+ tensor_dtype=tensor_dtype,
613
+ stream_to_proc_transforms=stream_to_proc_transforms,
614
+ proc_to_stream_transforms=proc_to_stream_transforms,
615
+ pubsub=False,
616
+ private_only=True))
617
+
618
+ def __guess_proc_opts(self):
619
+ if self.proc_opts is None:
620
+ if isinstance(self.proc.module, MultiIdentity) or len(list(self.proc.parameters())) == 0:
621
+ self.proc_opts = {"optimizer": None,
622
+ "losses": [None] * len(self.proc_outputs)}
623
+ else:
624
+ self.proc_opts = {"optimizer": torch.optim.SGD(self.proc.parameters(), lr=1e-5),
625
+ "losses": [torch.nn.functional.mse_loss] * len(self.proc_outputs)}
626
+ else:
627
+ if "optimizer" not in self.proc_opts:
628
+ self.proc_opts["optimizer"] = None
629
+ if "losses" not in self.proc_opts:
630
+ self.proc_opts["losses"] = [None] * len(self.proc_outputs)
631
+
632
+ def __fix_proc_opts(self):
633
+ opts = {}
634
+ found_optimizer = False
635
+ found_loss = False
636
+ cannot_fix = False
637
+
638
+ if "optimizer" in self.proc_opts:
639
+ found_optimizer = True
640
+ if "losses" in self.proc_opts:
641
+ found_loss = True
642
+
643
+ if not found_loss:
644
+ opts['losses'] = [torch.nn.functional.mse_loss] * len(self.proc_opts)
645
+
646
+ for k, v in self.proc_opts.items():
647
+ if isinstance(v, torch.optim.Optimizer):
648
+ if k == "optimizer":
649
+ opts["optimizer"] = v
650
+ continue
651
+ else:
652
+ if not found_optimizer:
653
+ opts["optimizer"] = v
654
+ found_optimizer = True
655
+ else:
656
+ cannot_fix = True
657
+ break
658
+ elif k == "losses" and isinstance(v, list) or isinstance(v, tuple):
659
+ opts["losses"] = v
660
+ continue
661
+ elif (v == torch.nn.functional.mse_loss or isinstance(v, torch.nn.MSELoss)
662
+ or v == torch.nn.functional.binary_cross_entropy or isinstance(v, torch.nn.BCELoss)
663
+ or isinstance(v, torch.nn.CrossEntropyLoss) or v == torch.nn.functional.cross_entropy):
664
+ if not found_loss:
665
+ opts["losses"] = [v]
666
+ found_loss = True
667
+ else:
668
+ cannot_fix = True
669
+ break
670
+ else:
671
+ opts[k] = v
672
+
673
+ if not found_optimizer:
674
+ if 'lr' in opts:
675
+ opts['optimizer'] = torch.optim.SGD(self.proc.parameters(), lr=opts['lr'])
676
+
677
+ assert not cannot_fix, \
678
+ "About proc_opts: cannot find required keys ('optimizer', 'losses') and/or cannot automatically guess them"
679
+
680
+ # Removing batch dim from targets in case of cross-entropy
681
+ fixed_list = []
682
+ for _loss_fcn in opts['losses']:
683
+ if _loss_fcn == torch.nn.functional.cross_entropy or isinstance(_loss_fcn, torch.nn.CrossEntropyLoss):
684
+ fixed_list.append(target_shape_fixed_cross_entropy)
685
+ else:
686
+ fixed_list.append(_loss_fcn)
687
+ opts['losses'] = fixed_list
688
+
689
+ # Updating
690
+ self.proc_opts = opts
691
+
692
+ def __guess_proc_optional_inputs(self):
693
+ self.proc_optional_inputs = []
694
+ if isinstance(self.proc, ModuleWrapper):
695
+ if hasattr(self.proc.module, "forward"):
696
+ sig = inspect.signature(self.proc.module.forward)
697
+ else:
698
+ sig = inspect.signature(self.proc.forward)
699
+ else:
700
+ sig = inspect.signature(self.proc.forward)
701
+
702
+ i = 0
703
+ for name, param in sig.parameters.items():
704
+ if i >= len(self.proc_inputs):
705
+ break
706
+ if param.default is not inspect.Parameter.empty:
707
+ self.proc_optional_inputs.append({"has_default": True, "default_value": param.default})
708
+ else:
709
+ self.proc_optional_inputs.append({"has_default": False, "default_value": None})
710
+ i += 1