unaiverse 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. unaiverse/__init__.py +19 -0
  2. unaiverse/agent.py +2226 -0
  3. unaiverse/agent_basics.py +2389 -0
  4. unaiverse/clock.py +234 -0
  5. unaiverse/dataprops.py +1282 -0
  6. unaiverse/hsm.py +2471 -0
  7. unaiverse/modules/__init__.py +18 -0
  8. unaiverse/modules/cnu/__init__.py +17 -0
  9. unaiverse/modules/cnu/cnus.py +536 -0
  10. unaiverse/modules/cnu/layers.py +261 -0
  11. unaiverse/modules/cnu/psi.py +60 -0
  12. unaiverse/modules/hl/__init__.py +15 -0
  13. unaiverse/modules/hl/hl_utils.py +411 -0
  14. unaiverse/modules/networks.py +1509 -0
  15. unaiverse/modules/utils.py +748 -0
  16. unaiverse/networking/__init__.py +16 -0
  17. unaiverse/networking/node/__init__.py +18 -0
  18. unaiverse/networking/node/connpool.py +1332 -0
  19. unaiverse/networking/node/node.py +2752 -0
  20. unaiverse/networking/node/profile.py +446 -0
  21. unaiverse/networking/node/tokens.py +79 -0
  22. unaiverse/networking/p2p/__init__.py +188 -0
  23. unaiverse/networking/p2p/go.mod +127 -0
  24. unaiverse/networking/p2p/go.sum +548 -0
  25. unaiverse/networking/p2p/golibp2p.py +18 -0
  26. unaiverse/networking/p2p/golibp2p.pyi +136 -0
  27. unaiverse/networking/p2p/lib.go +2765 -0
  28. unaiverse/networking/p2p/lib_types.py +311 -0
  29. unaiverse/networking/p2p/message_pb2.py +50 -0
  30. unaiverse/networking/p2p/messages.py +360 -0
  31. unaiverse/networking/p2p/mylogger.py +78 -0
  32. unaiverse/networking/p2p/p2p.py +900 -0
  33. unaiverse/networking/p2p/proto-go/message.pb.go +846 -0
  34. unaiverse/stats.py +1506 -0
  35. unaiverse/streamlib/__init__.py +15 -0
  36. unaiverse/streamlib/streamlib.py +210 -0
  37. unaiverse/streams.py +804 -0
  38. unaiverse/utils/__init__.py +16 -0
  39. unaiverse/utils/lone_wolf.json +28 -0
  40. unaiverse/utils/misc.py +441 -0
  41. unaiverse/utils/sandbox.py +292 -0
  42. unaiverse/world.py +384 -0
  43. unaiverse-0.1.12.dist-info/METADATA +366 -0
  44. unaiverse-0.1.12.dist-info/RECORD +47 -0
  45. unaiverse-0.1.12.dist-info/WHEEL +5 -0
  46. unaiverse-0.1.12.dist-info/licenses/LICENSE +177 -0
  47. unaiverse-0.1.12.dist-info/top_level.txt +1 -0
@@ -0,0 +1,748 @@
1
+ """
2
+ █████ █████ ██████ █████ █████ █████ █████ ██████████ ███████████ █████████ ██████████
3
+ ░░███ ░░███ ░░██████ ░░███ ░░███ ░░███ ░░███ ░░███░░░░░█░░███░░░░░███ ███░░░░░███░░███░░░░░█
4
+ ░███ ░███ ░███░███ ░███ ██████ ░███ ░███ ░███ ░███ █ ░ ░███ ░███ ░███ ░░░ ░███ █ ░
5
+ ░███ ░███ ░███░░███░███ ░░░░░███ ░███ ░███ ░███ ░██████ ░██████████ ░░█████████ ░██████
6
+ ░███ ░███ ░███ ░░██████ ███████ ░███ ░░███ ███ ░███░░█ ░███░░░░░███ ░░░░░░░░███ ░███░░█
7
+ ░███ ░███ ░███ ░░█████ ███░░███ ░███ ░░░█████░ ░███ ░ █ ░███ ░███ ███ ░███ ░███ ░ █
8
+ ░░████████ █████ ░░█████░░████████ █████ ░░███ ██████████ █████ █████░░█████████ ██████████
9
+ ░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░ ░░░░░ ░░░ ░░░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░░ ░░░░░░░░░░
10
+ A Collectionless AI Project (https://collectionless.ai)
11
+ Registration/Login: https://unaiverse.io
12
+ Code Repositories: https://github.com/collectionlessai/
13
+ Main Developers: Stefano Melacci (Project Leader), Christian Di Maio, Tommaso Guidi
14
+ """
15
+ import os
16
+ import torch
17
+ import random
18
+ import inspect
19
+ import logging
20
+ import numpy as np
21
+ from PIL import Image
22
+ from torch.utils.data import DataLoader
23
+ from unaiverse.dataprops import Data4Proc
24
+ from torchvision import datasets, transforms
25
+
26
+
27
+ def transforms_factory(trans_type: str, add_batch_dim: bool = True, return_inverse: bool = False):
28
+ supported_types = {"rgb*", "gray*",
29
+ "rgb-no_norm*", "gray-no_norm*",
30
+ "rgb", "gray",
31
+ "rgb-no_norm", "gray-no_norm",
32
+ "gray_mnist"}
33
+
34
+ found = False
35
+ num = -1
36
+ for _type in supported_types:
37
+ if _type.endswith("*"):
38
+ has_star = True
39
+ __type = _type[0:-1]
40
+ else:
41
+ has_star = False
42
+ __type = _type
43
+
44
+ if has_star and trans_type.startswith(__type) and len(trans_type) > len(__type):
45
+ try:
46
+ num = int(trans_type[len(__type):])
47
+ trans_type = _type
48
+ found = True
49
+ break
50
+ except ValueError:
51
+ pass
52
+ elif trans_type == _type:
53
+ found = True
54
+ break
55
+
56
+ if not found:
57
+ raise ValueError(f"Invalid transformation type '{trans_type}': must be one of {supported_types}, "
58
+ f"where * is an integer number")
59
+
60
+ trans = None
61
+ inverse_trans = None
62
+
63
+ if trans_type == "rgb*":
64
+ trans = transforms.Compose([
65
+ transforms.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), # Ensure 3 channels
66
+ transforms.Resize(num),
67
+ transforms.CenterCrop(num),
68
+ transforms.ToTensor(), # Convert PIL to tensor (3, H, W), float [0,1]
69
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
70
+ std=[0.229, 0.224, 0.225]),
71
+ ])
72
+ inverse_trans = transforms.Compose([
73
+ transforms.Normalize(mean=[0., 0., 0.],
74
+ std=[1. / 0.229, 1. / 0.224, 1. / 0.225]),
75
+ transforms.Lambda(lambda x: x + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)),
76
+ transforms.ToPILImage()
77
+ ])
78
+ elif trans_type == "gray*":
79
+ trans = transforms.Compose([
80
+ transforms.Lambda(lambda img: img.convert("L") if img.mode != "L" else img), # Ensure 1 channel
81
+ transforms.Resize(num),
82
+ transforms.CenterCrop(num),
83
+ transforms.ToTensor(), # Convert PIL to tensor (1, H, W), float [0,1]
84
+ transforms.Normalize(mean=[0.45],
85
+ std=[0.225])
86
+ ])
87
+ inverse_trans = transforms.Compose([
88
+ transforms.Normalize(mean=[0.],
89
+ std=[1. / 0.225]),
90
+ transforms.Lambda(lambda x: x + 0.45),
91
+ transforms.ToPILImage()
92
+ ])
93
+ elif trans_type == "rgb-no_norm*":
94
+ trans = transforms.Compose([
95
+ transforms.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), # Ensure 3 channels
96
+ transforms.Resize(num),
97
+ transforms.CenterCrop(num),
98
+ transforms.PILToTensor(), # Convert PIL to tensor (3, H, W), uint [0,255]
99
+ ])
100
+ inverse_trans = transforms.Compose([transforms.ToPILImage()])
101
+ elif trans_type == "gray-no_norm*":
102
+ trans = transforms.Compose([
103
+ transforms.Lambda(lambda img: img.convert("L") if img.mode != "L" else img), # Ensure 1 channel
104
+ transforms.Resize(num),
105
+ transforms.CenterCrop(num),
106
+ transforms.PILToTensor(), # Convert PIL to tensor (1, H, W), uint [0,255]
107
+ ])
108
+ inverse_trans = transforms.ToPILImage()
109
+ elif trans_type == "rgb":
110
+ trans = transforms.Compose([
111
+ transforms.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), # Ensure 3 channels
112
+ transforms.ToTensor(), # Convert PIL to tensor (3, H, W), float [0,1]
113
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
114
+ std=[0.229, 0.224, 0.225])
115
+ ])
116
+ inverse_trans = transforms.Compose([
117
+ transforms.Normalize(mean=[0., 0., 0.],
118
+ std=[1. / 0.229, 1. / 0.224, 1. / 0.225]),
119
+ transforms.Lambda(lambda x: x + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)),
120
+ transforms.ToPILImage()
121
+ ])
122
+ elif trans_type == "gray":
123
+ trans = transforms.Compose([
124
+ transforms.Lambda(lambda img: img.convert("L") if img.mode != "L" else img), # Ensure 1 channel
125
+ transforms.ToTensor(), # Convert PIL to tensor (1, H, W), float [0,1]
126
+ transforms.Normalize(mean=[0.45],
127
+ std=[0.225])
128
+ ])
129
+ inverse_trans = transforms.Compose([
130
+ transforms.Normalize(mean=[0.],
131
+ std=[1. / 0.225]),
132
+ transforms.Lambda(lambda x: x + 0.45),
133
+ transforms.ToPILImage()
134
+ ])
135
+ elif trans_type == "rgb-no_norm":
136
+ trans = transforms.Compose([
137
+ transforms.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), # Ensure 3 channels
138
+ transforms.PILToTensor(), # Convert PIL to tensor (3, H, W), uint [0,255]
139
+ ])
140
+ inverse_trans = transforms.Compose([transforms.ToPILImage()])
141
+ elif trans_type == "gray-no_norm":
142
+ trans = transforms.Compose([
143
+ transforms.Lambda(lambda img: img.convert("L") if img.mode != "L" else img), # Ensure 1 channel
144
+ transforms.PILToTensor(), # Convert PIL to tensor (1, H, W), uint [0,255]
145
+ ])
146
+ inverse_trans = transforms.Compose([transforms.ToPILImage()])
147
+ elif trans_type == "gray_mnist":
148
+ trans = transforms.Compose([
149
+ transforms.Lambda(lambda img: img.convert("L") if img.mode != "L" else img), # Ensure 1 channel
150
+ transforms.Resize(28),
151
+ transforms.CenterCrop(28),
152
+ transforms.ToTensor(), # Convert PIL to tensor (1, H, W), float [0,1]
153
+ transforms.Normalize(mean=[0.1307], # MNIST
154
+ std=[0.3081]) # MNIST
155
+ ])
156
+ inverse_trans = transforms.Compose([
157
+ transforms.Normalize(mean=[0.],
158
+ std=[1. / 0.3081]),
159
+ transforms.Lambda(lambda x: x + 0.1307),
160
+ transforms.ToPILImage()
161
+ ])
162
+
163
+ if add_batch_dim:
164
+ trans.transforms.append(transforms.Lambda(lambda x: x.unsqueeze(0)))
165
+ inverse_trans.transforms.insert(0, transforms.Lambda(lambda x: x.squeeze(0)))
166
+
167
+ return trans if not return_inverse else inverse_trans
168
+
169
+
170
+ def hard_tanh(x: torch.Tensor) -> torch.Tensor:
171
+ return torch.clamp(x, min=-1., max=1.)
172
+
173
+
174
+ def target_shape_fixed_cross_entropy(output, target, *args, **kwargs):
175
+ if len(target.shape) > 1:
176
+ target = target.squeeze(0)
177
+ return torch.nn.functional.cross_entropy(output, target, *args, **kwargs)
178
+
179
+
180
+ def set_seed(seed: int) -> None:
181
+ if seed >= 0:
182
+ torch.manual_seed(seed)
183
+ random.seed(seed)
184
+ np.random.seed(0)
185
+
186
+
187
+ def get_proc_inputs_and_proc_outputs_for_rnn(u_shape: torch.Size | tuple, du_dim: int, y_dim: int):
188
+ if isinstance(u_shape, torch.Size):
189
+ u_shape = tuple(u_shape)
190
+ proc_inputs = [
191
+ Data4Proc(data_type="tensor", tensor_shape=(None,) + u_shape, tensor_dtype=torch.float32,
192
+ pubsub=False, private_only=True),
193
+ Data4Proc(data_type="tensor", tensor_shape=(None, du_dim,), tensor_dtype=torch.float32,
194
+ pubsub=False, private_only=True)
195
+ ]
196
+ proc_outputs = [
197
+ Data4Proc(data_type="tensor", tensor_shape=(None, y_dim), tensor_dtype=torch.float32,
198
+ pubsub=False, private_only=True)
199
+ ]
200
+ return proc_inputs, proc_outputs
201
+
202
+
203
+ def get_proc_inputs_and_proc_outputs_for_image_classification(y_dim: int):
204
+ if y_dim == -1:
205
+ y_dim = 1000 # Assuming ImageNet-trained models
206
+ proc_inputs = [Data4Proc(data_type="img", pubsub=False, private_only=True)]
207
+ proc_outputs = [Data4Proc(data_type="tensor", tensor_shape=(None, y_dim), tensor_dtype=torch.float32,
208
+ pubsub=False, private_only=True)]
209
+ return proc_inputs, proc_outputs
210
+
211
+
212
+ def isinstance_fcn(obj, class_to_check):
213
+ return isinstance(obj, class_to_check)
214
+
215
+
216
+ def error_rate_mnist_test_set(network: torch.nn.Module, mnist_data_save_path: str):
217
+
218
+ # Getting MNIST test set
219
+ mnist_test = datasets.MNIST(root=mnist_data_save_path,
220
+ train=False, download=True,
221
+ transform=transforms.Compose([
222
+ transforms.ToTensor(),
223
+ transforms.Normalize((0.1307,), (0.3081,))
224
+ ]))
225
+ mnist_test = DataLoader(mnist_test, batch_size=200, shuffle=False)
226
+
227
+ # Checking error rate
228
+ error_rate = 0.
229
+ n = 0
230
+ training_flag_backup = network.training
231
+ network.eval()
232
+ device = next(network.parameters()).device
233
+ for x, y in mnist_test:
234
+ x = x.to(device)
235
+ y = y.to(device)
236
+ o = network(x)
237
+ c = torch.argmax(o, dim=1)
238
+ error_rate += float(torch.sum(c != y).item())
239
+ n += x.shape[0]
240
+ error_rate /= n
241
+ network.training = training_flag_backup
242
+
243
+ return error_rate
244
+
245
+
246
+ class MultiIdentity(torch.nn.Module):
247
+ def __init__(self):
248
+ super().__init__()
249
+
250
+ def forward(self, *args):
251
+ if len(args) == 1:
252
+ return args[0]
253
+ return args
254
+
255
+
256
+ class HumanModule(torch.nn.Module):
257
+ def __init__(self):
258
+ super().__init__()
259
+
260
+ def forward(self, text: str | None = None, img: Image.Image | None = None, whatever: object | None = None):
261
+ return text, img
262
+
263
+
264
+ class LoggerModule(torch.nn.Module):
265
+ def __init__(self, log_file="app_log.txt"):
266
+ super().__init__()
267
+ self.log_file = log_file
268
+ self._initialized = False
269
+ self._logger = logging.getLogger("CallableLogger")
270
+ self._logger.setLevel(logging.INFO)
271
+ self.__handler = None
272
+ self._idx = 0
273
+ self._objects = ["telescope", "hammer", "compass", "anchor", "lantern", "keyboard", "cat", "dog", "tiger",
274
+ "zebra", "batman", "superman", "candy", "table", "chair", "balloon", "kitchen", "sofa", "lamp",
275
+ "arrow", "green", "red", "blue", "yellow", "magenta", "brown", "pink", "orange", "white",
276
+ "paris", "rome", "boston", "york", "berlin", "singapore", "taiwan", "japan", "china",
277
+ "turkey", "italy", "france", "germany", "spain", "madrid", "barcelona", "portugal",
278
+ "norway", "sweden", "belgium", "romania", "sunny", "snowy", "rainy"]
279
+ random.shuffle(self._objects)
280
+
281
+ def __setup_logger(self):
282
+ self.__handler = logging.FileHandler(self.log_file, mode='w') # 'w' mode overwrites the file
283
+ formatter = logging.Formatter('%(message)s')
284
+ self.__handler.setFormatter(formatter)
285
+ self._logger.addHandler(self.__handler)
286
+ self._initialized = True
287
+
288
+ def forward(self, text: str, img: Image = None):
289
+ if not self._initialized:
290
+ self.__setup_logger()
291
+ self._logger.info("-------------------------------------------------------------------------------")
292
+ self._logger.info(f"[INPUT] text={text if text is not None else None}, "
293
+ f"img={img.size if img is not None else None}")
294
+ # text = random.choice(objects)
295
+ text = f"{self._idx}_{self._objects[self._idx]}"
296
+ self._idx = (self._idx + 1) % len(self._objects)
297
+ img = None
298
+ self._logger.info(f"[OUTPUT] text={text}, img={None}")
299
+ self._logger.info("-------------------------------------------------------------------------------")
300
+ self.__handler.flush()
301
+ return text, img
302
+
303
+
304
+ class ModuleWrapper(torch.nn.Module):
305
+ def __init__(self,
306
+ module: torch.nn.Module | None = None,
307
+ proc_inputs: list[Data4Proc] | None = None,
308
+ proc_outputs: list[Data4Proc] | None = None,
309
+ seed: int = -1):
310
+ super(ModuleWrapper, self).__init__()
311
+ self.device = None # The device which is supposed to host the module
312
+ self.module = None # The module itself
313
+ self.proc_inputs = proc_inputs # The list of Data4Proc objects describing the input types of the module
314
+ self.proc_outputs = proc_outputs # The list of Data4Proc objects describing the output types of the module
315
+
316
+ # Working
317
+ set_seed(seed)
318
+ device_env = os.getenv("PROC_DEVICE", None)
319
+ self.device = torch.device("cpu") if device_env is None else torch.device(device_env)
320
+ self.module = module.to(self.device) if module is not None else None
321
+
322
+ def forward(self, *args, **kwargs):
323
+
324
+ # The forward signature expected by who calls this method is:
325
+ # forward(self, *args, first: bool, last: bool, **kwargs)
326
+ # so we have to discard 'first' and 'last' that are not used by an external module not designed for this library
327
+ del kwargs['first']
328
+ del kwargs['last']
329
+
330
+ # Calling the module
331
+ return self.module(*args, **kwargs)
332
+
333
+
334
+ class AgentProcessorChecker:
335
+
336
+ def __init__(self, processor_container: object):
337
+ assert hasattr(processor_container, 'proc'), "Invalid processor container object"
338
+ assert hasattr(processor_container, 'proc_inputs'), "Invalid processor container object"
339
+ assert hasattr(processor_container, 'proc_outputs'), "Invalid processor container object"
340
+ assert hasattr(processor_container, 'proc_opts'), "Invalid processor container object"
341
+ assert hasattr(processor_container, 'proc_optional_inputs'), "Invalid processor container object"
342
+
343
+ # Getting processor-related info from the main object which collects processor and its properties
344
+ proc: torch.nn.Module = processor_container.proc
345
+ proc_inputs: list[Data4Proc] | None = processor_container.proc_inputs
346
+ proc_outputs: list[Data4Proc] | None = processor_container.proc_outputs
347
+ proc_opts: dict | None = processor_container.proc_opts
348
+ proc_optional_inputs: list | None = processor_container.proc_optional_inputs
349
+
350
+ assert proc is None or isinstance(proc, torch.nn.Module), "Processor (proc) must be a torch.nn.Module"
351
+ assert (proc_inputs is None or (
352
+ isinstance_fcn(proc_inputs, list) and (len(proc_inputs) == 0 or
353
+ (len(proc_inputs) > 0 and
354
+ isinstance_fcn(proc_inputs[0], Data4Proc))))), \
355
+ "Invalid proc_inputs: it must be None or a list of Data4Proc"
356
+ assert (proc_outputs is None or (
357
+ isinstance_fcn(proc_inputs, list) and (len(proc_inputs) == 0 or
358
+ (len(proc_inputs) > 0 and
359
+ isinstance_fcn(proc_inputs[0], Data4Proc))))), \
360
+ "Invalid proc_inputs: it must be None or a list of Data4Proc"
361
+ assert (proc_opts is None or isinstance_fcn(proc_opts, dict)), \
362
+ "Invalid proc_opts: it must be None or a dictionary"
363
+
364
+ # Saving as attributes
365
+ self.proc = proc
366
+ self.proc_inputs = proc_inputs
367
+ self.proc_outputs = proc_outputs
368
+ self.proc_opts = proc_opts
369
+ self.proc_optional_inputs = proc_optional_inputs
370
+
371
+ # Dummy processor (if no processor was provided)
372
+ if self.proc is None:
373
+ self.proc = ModuleWrapper(module=MultiIdentity())
374
+ self.proc.device = torch.device("cpu")
375
+ if self.proc_inputs is None:
376
+ self.proc_inputs = [Data4Proc(data_type="all", pubsub=False, private_only=False)]
377
+ if self.proc_outputs is None:
378
+ self.proc_outputs = [Data4Proc(data_type="all", pubsub=False, private_only=False)]
379
+ self.proc_opts = {'optimizer': None, 'losses': [None] * len(self.proc_outputs)}
380
+ else:
381
+
382
+ # String telling it is a human
383
+ if isinstance(self.proc, str) and self.proc.lower() == "human":
384
+ self.proc = ModuleWrapper(module=HumanModule())
385
+ self.proc.device = torch.device("cpu")
386
+ self.proc_inputs = [Data4Proc(data_type="text", pubsub=False, private_only=False),
387
+ Data4Proc(data_type="img", pubsub=False, private_only=False)]
388
+ self.proc_outputs = [Data4Proc(data_type="text", pubsub=False, private_only=False),
389
+ Data4Proc(data_type="img", pubsub=False, private_only=False)]
390
+
391
+ # Wrapping to have the basic attributes (device)
392
+ elif not isinstance(self.proc, ModuleWrapper):
393
+ self.proc = ModuleWrapper(module=self.proc)
394
+ self.proc.device = torch.device("cpu")
395
+
396
+ # Guessing inputs, fixing attributes
397
+ if self.proc_inputs is None:
398
+ self.__guess_proc_inputs()
399
+
400
+ for j in range(len(self.proc_inputs)):
401
+ if self.proc_inputs[j].get_name() == "unk":
402
+ self.proc_inputs[j].set_name("proc_input_" + str(j))
403
+
404
+ # Guessing outputs, fixing attributes
405
+ if self.proc_outputs is None:
406
+ self.__guess_proc_outputs()
407
+
408
+ for j in range(len(self.proc_outputs)):
409
+ if self.proc_outputs[j].get_name() == "unk":
410
+ self.proc_outputs[j].set_name("proc_output_" + str(j))
411
+
412
+ # Guessing optimization-related options and stuff, fixing attributes
413
+ if (self.proc_opts is None or len(self.proc_opts) == 0 or
414
+ 'optimizer' not in self.proc_opts or 'losses' not in self.proc_opts):
415
+ self.__guess_proc_opts()
416
+ self.__fix_proc_opts()
417
+
418
+ # Ensuring all is OK
419
+ if self.proc is not None:
420
+ assert "optimizer" in self.proc_opts, "Missing 'optimizer' key in proc_opts (required)"
421
+ assert "losses" in self.proc_opts, "Missing 'losses' key in proc_opts (required)"
422
+
423
+ # Checking inputs with default values
424
+ if self.proc_optional_inputs is None:
425
+ self.__guess_proc_optional_inputs()
426
+
427
+ # Updating processor container object
428
+ processor_container.proc = self.proc
429
+ processor_container.proc_inputs = self.proc_inputs
430
+ processor_container.proc_outputs = self.proc_outputs
431
+ processor_container.proc_opts = self.proc_opts
432
+ processor_container.proc_optional_inputs = self.proc_optional_inputs
433
+
434
+ def __guess_proc_inputs(self):
435
+ if hasattr(self.proc, "proc_inputs"):
436
+ if self.proc.proc_inputs is not None:
437
+ self.proc_inputs = []
438
+ for p in self.proc.proc_inputs:
439
+ self.proc_inputs.append(p.clone())
440
+ return
441
+
442
+ first_layer = None
443
+
444
+ # Traverse modules to find the first real layer (skip containers like Sequential)
445
+ for layer in self.proc.modules():
446
+ if (not isinstance(layer, (torch.nn.Sequential,
447
+ torch.nn.ModuleList,
448
+ torch.nn.ModuleDict))
449
+ and not isinstance(layer, torch.nn.Module)
450
+ and hasattr(layer, 'weight')):
451
+ continue # Skip non-leaf layers
452
+ if isinstance(layer, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.Conv1d, torch.nn.Embedding)):
453
+ first_layer = layer
454
+ break
455
+
456
+ if first_layer is None:
457
+ raise ValueError("Cannot automatically guess the shape of the input data, "
458
+ "please explicitly provide it (proc_input)")
459
+
460
+ # Infer input properties
461
+ data_desc = "automatically guessed"
462
+ tensor_shape = None
463
+ tensor_labels = None
464
+ tensor_dtype = None
465
+ stream_to_proc_transforms = None
466
+ proc_to_stream_transforms = None
467
+
468
+ if isinstance(first_layer, torch.nn.Conv2d):
469
+
470
+ if first_layer.in_channels == 3 or first_layer.in_channels == 1:
471
+ data_type = "img"
472
+
473
+ # Creating dummy PIL images
474
+ rgb_input_img = Image.new('RGB', (224, 224))
475
+ pixels = rgb_input_img.load()
476
+ for x in range(28):
477
+ for y in range(28):
478
+ pixels[x, y] = (random.randint(0, 255),
479
+ random.randint(0, 255),
480
+ random.randint(0, 255))
481
+ gray_input_img = rgb_input_img.convert('L')
482
+
483
+ # Checking if the model supports PIL images as input
484
+ # noinspection PyBroadException
485
+ try:
486
+ _ = self.proc(rgb_input_img)
487
+ can_handle_rgb_img = True
488
+ except Exception:
489
+ can_handle_rgb_img = False
490
+
491
+ # Noinspection PyBroadException
492
+ try:
493
+ _ = self.proc(gray_input_img)
494
+ can_handle_gray_img = True
495
+ except Exception:
496
+ can_handle_gray_img = False
497
+
498
+ if can_handle_gray_img and can_handle_rgb_img:
499
+ stream_to_proc_transforms = None
500
+ elif can_handle_rgb_img:
501
+ stream_to_proc_transforms = transforms.Grayscale(num_output_channels=3)
502
+ elif can_handle_gray_img:
503
+ stream_to_proc_transforms = transforms.Grayscale()
504
+ else:
505
+ if first_layer.in_channels == 1:
506
+ stream_to_proc_transforms = transforms_factory("gray-no_norm")
507
+ else:
508
+ stream_to_proc_transforms = transforms_factory("rgb-no_norm")
509
+ else:
510
+
511
+ # If the number of input channels is not 1 and not 3...
512
+ data_type = "tensor"
513
+ tensor_shape = (first_layer.in_channels, None, None)
514
+ tensor_dtype = torch.float32
515
+
516
+ elif isinstance(first_layer, torch.nn.Conv1d):
517
+ data_type = "tensor"
518
+ tensor_shape = (first_layer.in_channels, None)
519
+ tensor_dtype = torch.float32
520
+ elif isinstance(first_layer, torch.nn.Linear):
521
+ data_type = "tensor"
522
+ tensor_dtype = torch.float32
523
+ tensor_shape = (first_layer.in_features,)
524
+ elif isinstance(first_layer, torch.nn.Embedding):
525
+
526
+ # Noinspection PyBroadException
527
+ try:
528
+ input_text = "testing if tokenizer is present"
529
+ _ = self.proc(input_text)
530
+ can_handle_text = True
531
+ can_handle_more_than_one_token = True # Unused
532
+ except Exception:
533
+ can_handle_text = False
534
+
535
+ # Noinspection PyBroadException
536
+ try:
537
+ device = torch.device("cpu")
538
+ for param in self.proc.parameters():
539
+ device = param.device
540
+ break
541
+ input_tokens = torch.tensor([[0, 1, 2, 3]], dtype=torch.long, device=device)
542
+ _ = self.proc(input_tokens)
543
+ can_handle_more_than_one_token = True
544
+ except Exception:
545
+ can_handle_more_than_one_token = False
546
+
547
+ if can_handle_text:
548
+ data_type = "text"
549
+ stream_to_proc_transforms = None
550
+ else:
551
+ data_type = "tensor"
552
+ if can_handle_more_than_one_token:
553
+ tensor_shape = (None,)
554
+ else:
555
+ tensor_shape = (1,)
556
+ tensor_dtype = torch.long
557
+ tensor_labels = ["token" + str(i) for i in range(0, first_layer.num_embeddings)]
558
+ else:
559
+ raise ValueError("Cannot automatically guess the shape of the input data, "
560
+ "please explicitly provide it (proc_input)")
561
+
562
+ # Setting the input attribute
563
+ self.proc_inputs = [Data4Proc(name="proc_input_0",
564
+ data_type=data_type,
565
+ data_desc=data_desc,
566
+ tensor_shape=tensor_shape,
567
+ tensor_labels=tensor_labels,
568
+ tensor_dtype=tensor_dtype,
569
+ stream_to_proc_transforms=stream_to_proc_transforms,
570
+ proc_to_stream_transforms=proc_to_stream_transforms,
571
+ pubsub=False,
572
+ private_only=True)]
573
+
574
+ def __guess_proc_outputs(self):
575
+ if hasattr(self.proc, "proc_outputs"):
576
+ if self.proc.proc_outputs is not None:
577
+ self.proc_outputs = []
578
+ for p in self.proc.proc_outputs:
579
+ self.proc_outputs.append(p.clone())
580
+ return
581
+
582
+ proc = self.proc
583
+ device = self.proc.device
584
+ inputs = []
585
+
586
+ for i, proc_input in enumerate(self.proc_inputs):
587
+ if proc_input.is_tensor():
588
+ inputs.append(proc_input.check_and_preprocess(
589
+ torch.randn([1] + list(proc_input.tensor_shape), # Adding batch size here
590
+ dtype=proc_input.tensor_dtype).to(device)))
591
+ elif proc_input.is_img():
592
+ rgb_input_img = Image.new('RGB', (224, 224))
593
+ pixels = rgb_input_img.load()
594
+ for x in range(224):
595
+ for y in range(224):
596
+ pixels[x, y] = (random.randint(0, 255),
597
+ random.randint(0, 255),
598
+ random.randint(0, 255))
599
+ inputs.append(proc_input.check_and_preprocess(rgb_input_img))
600
+ elif proc_input.is_text():
601
+ inputs.append(proc_input.check_and_preprocess("test text as input"))
602
+
603
+ # Forward
604
+ with torch.no_grad():
605
+ outputs = proc(*inputs)
606
+ if not isinstance(outputs, tuple | list):
607
+ outputs = [outputs]
608
+ if isinstance(outputs, tuple):
609
+ outputs = list(outputs)
610
+
611
+ # This will be filled below
612
+ self.proc_outputs = []
613
+
614
+ for j, output in enumerate(outputs):
615
+
616
+ # Infer output properties
617
+ data_desc = "automatically guessed"
618
+ tensor_shape = None
619
+ tensor_labels = None
620
+ tensor_dtype = None
621
+ stream_to_proc_transforms = None
622
+ proc_to_stream_transforms = None
623
+
624
+ if isinstance(output, Image.Image): # PIL Image
625
+ data_type = "img"
626
+ elif isinstance(output, torch.Tensor): # Tensor
627
+ output_shape = list(output.shape[1:]) # Removing batch size here
628
+ if len(output_shape) == 3 and (output_shape[0] == 3 or output_shape[0] == 1):
629
+ data_type = "img"
630
+ if output_shape[0] == 3:
631
+ proc_to_stream_transforms = transforms_factory("rgb", return_inverse=True)
632
+ else:
633
+ proc_to_stream_transforms = transforms_factory("gray", return_inverse=True)
634
+ else:
635
+ data_type = "tensor"
636
+ tensor_dtype = str(output.dtype)
637
+ tensor_shape = output_shape
638
+ tensor_labels = None
639
+ elif isinstance(output, str):
640
+ data_type = "text"
641
+ else:
642
+ raise ValueError(f"Unsupported output type {type(output)}")
643
+
644
+ # Setting the output attribute
645
+ self.proc_outputs.append(Data4Proc(name="proc_output_" + str(j),
646
+ data_type=data_type,
647
+ data_desc=data_desc,
648
+ tensor_shape=tensor_shape,
649
+ tensor_labels=tensor_labels,
650
+ tensor_dtype=tensor_dtype,
651
+ stream_to_proc_transforms=stream_to_proc_transforms,
652
+ proc_to_stream_transforms=proc_to_stream_transforms,
653
+ pubsub=False,
654
+ private_only=True))
655
+
656
+ def __guess_proc_opts(self):
657
+ if self.proc_opts is None:
658
+ if isinstance(self.proc.module, MultiIdentity) or len(list(self.proc.parameters())) == 0:
659
+ self.proc_opts = {"optimizer": None,
660
+ "losses": [None] * len(self.proc_outputs)}
661
+ else:
662
+ self.proc_opts = {"optimizer": torch.optim.SGD(self.proc.parameters(), lr=1e-5),
663
+ "losses": [torch.nn.functional.mse_loss] * len(self.proc_outputs)}
664
+ else:
665
+ if "optimizer" not in self.proc_opts:
666
+ self.proc_opts["optimizer"] = None
667
+ if "losses" not in self.proc_opts:
668
+ self.proc_opts["losses"] = [None] * len(self.proc_outputs)
669
+
670
+ def __fix_proc_opts(self):
671
+ opts = {}
672
+ found_optimizer = False
673
+ found_loss = False
674
+ cannot_fix = False
675
+
676
+ if "optimizer" in self.proc_opts:
677
+ found_optimizer = True
678
+ if "losses" in self.proc_opts:
679
+ found_loss = True
680
+
681
+ if not found_loss:
682
+ opts['losses'] = [torch.nn.functional.mse_loss] * len(self.proc_opts)
683
+
684
+ for k, v in self.proc_opts.items():
685
+ if isinstance(v, torch.optim.Optimizer):
686
+ if k == "optimizer":
687
+ opts["optimizer"] = v
688
+ continue
689
+ else:
690
+ if not found_optimizer:
691
+ opts["optimizer"] = v
692
+ found_optimizer = True
693
+ else:
694
+ cannot_fix = True
695
+ break
696
+ elif k == "losses" and isinstance(v, list) or isinstance(v, tuple):
697
+ opts["losses"] = v
698
+ continue
699
+ elif (v == torch.nn.functional.mse_loss or isinstance(v, torch.nn.MSELoss)
700
+ or v == torch.nn.functional.binary_cross_entropy or isinstance(v, torch.nn.BCELoss)
701
+ or isinstance(v, torch.nn.CrossEntropyLoss) or v == torch.nn.functional.cross_entropy):
702
+ if not found_loss:
703
+ opts["losses"] = [v]
704
+ found_loss = True
705
+ else:
706
+ cannot_fix = True
707
+ break
708
+ else:
709
+ opts[k] = v
710
+
711
+ if not found_optimizer:
712
+ if 'lr' in opts:
713
+ opts['optimizer'] = torch.optim.SGD(self.proc.parameters(), lr=opts['lr'])
714
+
715
+ assert not cannot_fix, \
716
+ "About proc_opts: cannot find required keys ('optimizer', 'losses') and/or cannot automatically guess them"
717
+
718
+ # Removing batch dim from targets in case of cross-entropy
719
+ fixed_list = []
720
+ for _loss_fcn in opts['losses']:
721
+ if _loss_fcn == torch.nn.functional.cross_entropy or isinstance(_loss_fcn, torch.nn.CrossEntropyLoss):
722
+ fixed_list.append(target_shape_fixed_cross_entropy)
723
+ else:
724
+ fixed_list.append(_loss_fcn)
725
+ opts['losses'] = fixed_list
726
+
727
+ # Updating
728
+ self.proc_opts = opts
729
+
730
+ def __guess_proc_optional_inputs(self):
731
+ self.proc_optional_inputs = []
732
+ if isinstance(self.proc, ModuleWrapper):
733
+ if hasattr(self.proc.module, "forward"):
734
+ sig = inspect.signature(self.proc.module.forward)
735
+ else:
736
+ sig = inspect.signature(self.proc.forward)
737
+ else:
738
+ sig = inspect.signature(self.proc.forward)
739
+
740
+ i = 0
741
+ for name, param in sig.parameters.items():
742
+ if i >= len(self.proc_inputs):
743
+ break
744
+ if param.default is not inspect.Parameter.empty:
745
+ self.proc_optional_inputs.append({"has_default": True, "default_value": param.default})
746
+ else:
747
+ self.proc_optional_inputs.append({"has_default": False, "default_value": None})
748
+ i += 1