unaiverse 0.1.6__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unaiverse might be problematic. Click here for more details.

Files changed (50) hide show
  1. unaiverse/__init__.py +19 -0
  2. unaiverse/agent.py +2008 -0
  3. unaiverse/agent_basics.py +1846 -0
  4. unaiverse/clock.py +191 -0
  5. unaiverse/dataprops.py +1209 -0
  6. unaiverse/hsm.py +1880 -0
  7. unaiverse/modules/__init__.py +18 -0
  8. unaiverse/modules/cnu/__init__.py +17 -0
  9. unaiverse/modules/cnu/cnus.py +536 -0
  10. unaiverse/modules/cnu/layers.py +261 -0
  11. unaiverse/modules/cnu/psi.py +60 -0
  12. unaiverse/modules/hl/__init__.py +15 -0
  13. unaiverse/modules/hl/hl_utils.py +411 -0
  14. unaiverse/modules/networks.py +1509 -0
  15. unaiverse/modules/utils.py +680 -0
  16. unaiverse/networking/__init__.py +16 -0
  17. unaiverse/networking/node/__init__.py +18 -0
  18. unaiverse/networking/node/connpool.py +1261 -0
  19. unaiverse/networking/node/node.py +2223 -0
  20. unaiverse/networking/node/profile.py +446 -0
  21. unaiverse/networking/node/tokens.py +79 -0
  22. unaiverse/networking/p2p/__init__.py +198 -0
  23. unaiverse/networking/p2p/go.mod +127 -0
  24. unaiverse/networking/p2p/go.sum +548 -0
  25. unaiverse/networking/p2p/golibp2p.py +18 -0
  26. unaiverse/networking/p2p/golibp2p.pyi +135 -0
  27. unaiverse/networking/p2p/lib.go +2714 -0
  28. unaiverse/networking/p2p/lib.go.sha256 +1 -0
  29. unaiverse/networking/p2p/lib_types.py +312 -0
  30. unaiverse/networking/p2p/message_pb2.py +63 -0
  31. unaiverse/networking/p2p/messages.py +265 -0
  32. unaiverse/networking/p2p/mylogger.py +77 -0
  33. unaiverse/networking/p2p/p2p.py +929 -0
  34. unaiverse/networking/p2p/proto-go/message.pb.go +616 -0
  35. unaiverse/networking/p2p/unailib.cp310-win_amd64.pyd +0 -0
  36. unaiverse/streamlib/__init__.py +15 -0
  37. unaiverse/streamlib/streamlib.py +210 -0
  38. unaiverse/streams.py +770 -0
  39. unaiverse/utils/__init__.py +16 -0
  40. unaiverse/utils/ask_lone_wolf.json +27 -0
  41. unaiverse/utils/lone_wolf.json +19 -0
  42. unaiverse/utils/misc.py +305 -0
  43. unaiverse/utils/sandbox.py +293 -0
  44. unaiverse/utils/server.py +435 -0
  45. unaiverse/world.py +175 -0
  46. unaiverse-0.1.6.dist-info/METADATA +365 -0
  47. unaiverse-0.1.6.dist-info/RECORD +50 -0
  48. unaiverse-0.1.6.dist-info/WHEEL +5 -0
  49. unaiverse-0.1.6.dist-info/licenses/LICENSE +43 -0
  50. unaiverse-0.1.6.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1509 @@
1
+ """
2
+ █████ █████ ██████ █████ █████ █████ █████ ██████████ ███████████ █████████ ██████████
3
+ ░░███ ░░███ ░░██████ ░░███ ░░███ ░░███ ░░███ ░░███░░░░░█░░███░░░░░███ ███░░░░░███░░███░░░░░█
4
+ ░███ ░███ ░███░███ ░███ ██████ ░███ ░███ ░███ ░███ █ ░ ░███ ░███ ░███ ░░░ ░███ █ ░
5
+ ░███ ░███ ░███░░███░███ ░░░░░███ ░███ ░███ ░███ ░██████ ░██████████ ░░█████████ ░██████
6
+ ░███ ░███ ░███ ░░██████ ███████ ░███ ░░███ ███ ░███░░█ ░███░░░░░███ ░░░░░░░░███ ░███░░█
7
+ ░███ ░███ ░███ ░░█████ ███░░███ ░███ ░░░█████░ ░███ ░ █ ░███ ░███ ███ ░███ ░███ ░ █
8
+ ░░████████ █████ ░░█████░░████████ █████ ░░███ ██████████ █████ █████░░█████████ ██████████
9
+ ░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░ ░░░░░ ░░░ ░░░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░░ ░░░░░░░░░░
10
+ A Collectionless AI Project (https://collectionless.ai)
11
+ Registration/Login: https://unaiverse.io
12
+ Code Repositories: https://github.com/collectionlessai/
13
+ Main Developers: Stefano Melacci (Project Leader), Christian Di Maio, Tommaso Guidi
14
+ """
15
+ import os
16
+ import torch
17
+ import shutil
18
+ import numpy as np
19
+ import torchvision
20
+ from PIL import Image
21
+ import urllib.request
22
+ from typing import Callable
23
+ import torch.nn.functional as F
24
+ from unaiverse.dataprops import Data4Proc
25
+ from unaiverse.modules.cnu.cnus import CNUs
26
+ from unaiverse.modules.cnu.layers import LinearCNU
27
+ from transformers import pipeline, AutoProcessor, AutoModelForCausalLM, AutoTokenizer
28
+ from unaiverse.modules.utils import get_proc_inputs_and_proc_outputs_for_image_classification
29
+ from unaiverse.modules.utils import ModuleWrapper, transforms_factory, get_proc_inputs_and_proc_outputs_for_rnn
30
+
31
+
32
+ class RNNTokenLM(ModuleWrapper):
33
+
34
+ def __init__(self, num_emb: int, emb_dim: int, y_dim: int, h_dim: int, batch_size: int = 1, seed: int = -1):
35
+ super(RNNTokenLM, self).__init__(seed=seed)
36
+ device = self.device
37
+ u_dim = emb_dim
38
+ self.embeddings = torch.nn.Embedding(num_emb, emb_dim)
39
+
40
+ self.proc_inputs = [
41
+ Data4Proc(data_type="tensor", tensor_shape=(u_dim, ), tensor_dtype=torch.float32,
42
+ pubsub=False, private_only=True)
43
+ ]
44
+ self.proc_outputs = [
45
+ Data4Proc(data_type="tensor", tensor_shape=(y_dim, ), tensor_dtype=torch.float32,
46
+ pubsub=False, private_only=True)
47
+ ]
48
+
49
+ self.A = torch.nn.Linear(h_dim, h_dim, bias=False, device=device)
50
+ self.B = torch.nn.Linear(u_dim, h_dim, bias=False, device=device)
51
+ self.C = torch.nn.Linear(h_dim, y_dim, bias=False, device=device)
52
+ self.h_init = torch.randn((batch_size, h_dim), device=device)
53
+ self.u_init = torch.zeros((batch_size, u_dim), device=device)
54
+ self.h = None
55
+ self.y = None
56
+
57
+ def forward(self, u: torch.Tensor | None = None, first: bool = True, last: bool = False):
58
+ if first:
59
+ h = self.h_init
60
+ u = self.u_init if u is None else u
61
+ else:
62
+ h = self.h.detach()
63
+ u = self.embeddings((torch.argmax(self.y.detach(), dim=1) if self.y.shape[1] > 1
64
+ else self.y.squeeze(1).detach()).to(self.device))
65
+
66
+ self.h = torch.tanh(self.A(h) + self.B(u))
67
+ self.y = self.C(self.h)
68
+ return self.y
69
+
70
+
71
+ class RNN(ModuleWrapper):
72
+
73
+ def __init__(self, u_shape: tuple[int], d_dim: int, y_dim: int, h_dim: int, batch_size: int = 1, seed: int = -1):
74
+ super(RNN, self).__init__(seed=seed)
75
+ device = self.device
76
+ u_shape = torch.Size(u_shape)
77
+ u_dim = u_shape.numel()
78
+ du_dim = d_dim
79
+ self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_rnn(u_shape, du_dim, y_dim)
80
+
81
+ self.A = torch.nn.Linear(h_dim, h_dim, bias=False, device=device)
82
+ self.B = torch.nn.Linear(u_dim + du_dim, h_dim, bias=False, device=device)
83
+ self.C = torch.nn.Linear(h_dim, y_dim, bias=False, device=device)
84
+ self.register_buffer('h_init', torch.randn((batch_size, h_dim), device=device))
85
+ self.h = None
86
+ self.u_dim = u_dim
87
+ self.du_dim = du_dim
88
+
89
+ def forward(self, u: torch.Tensor, du: torch.Tensor, first: bool = True, last: bool = False):
90
+ if first:
91
+ h = self.h_init.data
92
+ else:
93
+ h = self.h.detach()
94
+ if u is None:
95
+ u = torch.zeros((h.shape[0], self.u_dim), dtype=torch.float32, device=self.device)
96
+ else:
97
+ u = u.to(self.device)
98
+ if du is None:
99
+ du = torch.zeros((h.shape[0], self.du_dim), dtype=torch.float32, device=self.device)
100
+ else:
101
+ du = du.to(self.device)
102
+
103
+ self.h = torch.tanh(self.A(h) + self.B(torch.cat([du, u], dim=1)))
104
+ y = self.C(self.h)
105
+ return y
106
+
107
+
108
+ class CSSM(ModuleWrapper):
109
+
110
+ def __init__(self, u_shape: tuple[int], d_dim: int, y_dim: int, h_dim: int, sigma: Callable = F.tanh,
111
+ project_every: int = 0, local: bool = False, batch_size: int = 1, seed: int = -1):
112
+ super(CSSM, self).__init__(seed=seed)
113
+ device = self.device
114
+ u_shape = torch.Size(u_shape)
115
+ u_dim = u_shape.numel()
116
+ du_dim = d_dim
117
+ self.batch_size = batch_size
118
+ self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_rnn(u_shape, du_dim, y_dim)
119
+
120
+ # Define linear transformation matrices for state update and output mapping
121
+ self.A = torch.nn.Linear(h_dim, h_dim, bias=False, device=device) # Recurrent weight matrix
122
+ self.B = torch.nn.Linear(u_dim + du_dim, h_dim, bias=False, device=device) # Input-to-hidden mapping
123
+ self.C = torch.nn.Linear(h_dim, y_dim, bias=False, device=device) # Hidden-to-output mapping
124
+
125
+ # Hidden state initialization
126
+ self.register_buffer('h_init', torch.randn((batch_size, h_dim), device=device))
127
+ self.register_buffer('h_next', torch.randn((batch_size, h_dim), device=device))
128
+ self.h = None
129
+ self.dh = None
130
+ self.sigma = sigma # The non-linear activation function
131
+
132
+ # Store input dimensions and device
133
+ self.u_dim = u_dim
134
+ self.du_dim = du_dim
135
+ self.delta = 1. # Discrete time step
136
+ self.local = local # If True the state update is computed locally in time (i.e., kept out from the graph)
137
+ self.forward_count = 0
138
+ self.project_every = project_every
139
+
140
+ @torch.no_grad()
141
+ def adjust_eigs(self):
142
+ """Placeholder for eigenvalue adjustment method."""
143
+ pass
144
+
145
+ def init_h(self, udu: torch.Tensor) -> torch.Tensor:
146
+ return self.h_init.data
147
+
148
+ @staticmethod
149
+ def handle_inputs(du, u):
150
+ return du, u
151
+
152
+ def forward(self, u: torch.Tensor, du: torch.Tensor, first: bool = True, last: bool = False):
153
+ """Forward pass that updates the hidden state and computes the output."""
154
+
155
+ # Handle missing inputs
156
+ u = u.flatten(1).to(self.device) if u is not None else (
157
+ torch.zeros((self.batch_size, self.u_dim), device=self.device))
158
+ du = du.to(self.device) if du is not None else (
159
+ torch.zeros((self.batch_size, self.du_dim), device=self.device))
160
+
161
+ # Reset hidden state if first step
162
+ if first:
163
+ h = self.init_h(torch.cat([du, u], dim=1))
164
+ self.forward_count = 0
165
+ else:
166
+ h = self.h_next.data
167
+
168
+ # Track the gradients on h from here on
169
+ h.requires_grad_()
170
+
171
+ # Check if it's time to project the eigenvalues
172
+ if self.project_every:
173
+ if self.forward_count % self.project_every == 0:
174
+ self.adjust_eigs()
175
+
176
+ # Handle inputs
177
+ du, u = self.handle_inputs(du, u)
178
+
179
+ # Update hidden state based on input and previous hidden state
180
+ h_new = self.A(h) + self.B(torch.cat([du, u], dim=1))
181
+
182
+ if self.local:
183
+
184
+ # In the local version we keep track in self.h of the old value of the state
185
+ self.h = h
186
+ self.dh = (h_new - self.h) / self.delta # (h_new - h_old) / delta
187
+ else:
188
+
189
+ # In the non-local version we keep track in self.h of the new value of the state
190
+ self.h = h_new
191
+ self.dh = (self.h - h) / self.delta # (h_new - h_old) / delta
192
+
193
+ # Compute output using a nonlinear activation function
194
+ y = self.C(self.sigma(self.h))
195
+
196
+ # Store the new state for the next iteration
197
+ self.h_next.data = h_new.detach()
198
+ self.forward_count += 1
199
+
200
+ return y
201
+
202
+
203
+ class CDiagR(ModuleWrapper):
204
+ """Diagonal matrix-based generator with real-valued transformations."""
205
+ def __init__(self, u_shape: tuple[int], d_dim: int, y_dim: int, h_dim: int, sigma: Callable = lambda x: x,
206
+ project_every: int = 0, local: bool = False, batch_size: int = 1, seed: int = -1):
207
+ super(CDiagR, self).__init__(seed=seed)
208
+ device = self.device
209
+ u_shape = torch.Size(u_shape)
210
+ u_dim = u_shape.numel()
211
+ du_dim = d_dim
212
+ self.batch_size = batch_size
213
+ self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_rnn(u_shape, du_dim, y_dim)
214
+
215
+ # Define diagonal transformation and linear layers
216
+ self.diag = torch.nn.Linear(in_features=1, out_features=h_dim, bias=False, device=device, dtype=torch.float32)
217
+ self.B = torch.nn.Linear(u_dim + du_dim, h_dim, bias=False, device=device)
218
+ self.C = torch.nn.Linear(h_dim, y_dim, bias=False, device=device)
219
+
220
+ # Hidden state initialization
221
+ self.register_buffer('h_init', torch.randn((batch_size, h_dim), device=device))
222
+ self.register_buffer('h_next', torch.randn((batch_size, h_dim), device=device))
223
+ self.h = None
224
+ self.dh = None
225
+ self.sigma = sigma # The non-linear activation function
226
+
227
+ # Store input dimensions and device
228
+ self.u_dim = u_dim
229
+ self.du_dim = du_dim
230
+ self.delta = 1.
231
+ self.local = local # If True the state update is computed locally in time (i.e., kept out from the graph)
232
+ self.forward_count = 0
233
+ self.project_every = project_every
234
+
235
+ @torch.no_grad()
236
+ def adjust_eigs(self):
237
+ """Normalize the diagonal weight matrix by setting signs."""
238
+ self.diag.weight.copy_(torch.sign(self.diag.weight))
239
+
240
+ def init_h(self, udu: torch.Tensor) -> torch.Tensor:
241
+ return self.h_init.data
242
+
243
+ @staticmethod
244
+ def handle_inputs(du, u):
245
+ return du, u
246
+
247
+ def forward(self, u: torch.Tensor, du: torch.Tensor, first: bool = True, last: bool = False):
248
+ """Forward pass with diagonal transformation."""
249
+
250
+ # Handle missing inputs
251
+ u = u.flatten(1).to(self.device) if u is not None else (
252
+ torch.zeros((self.batch_size, self.u_dim), device=self.device))
253
+ du = du.to(self.device) if du is not None else (
254
+ torch.zeros((self.batch_size, self.du_dim), device=self.device))
255
+
256
+ # Reset hidden state if first step
257
+ if first:
258
+ h = self.init_h(torch.cat([du, u], dim=1))
259
+ self.forward_count = 0
260
+ else:
261
+ h = self.h_next.data
262
+
263
+ # Track the gradients on h from here on
264
+ h.requires_grad_()
265
+
266
+ # Check if it's time to project the eigenvalues
267
+ if self.project_every:
268
+ if self.forward_count % self.project_every == 0:
269
+ self.adjust_eigs()
270
+
271
+ # Handle inputs
272
+ du, u = self.handle_inputs(du, u)
273
+
274
+ # Apply diagonal transformation to hidden state
275
+ h_new = self.diag.weight.view(self.diag.out_features) * h + self.B(torch.cat([du, u], dim=1))
276
+
277
+ if self.local:
278
+
279
+ # In the local version we keep track in self.h of the old value of the state
280
+ self.h = h
281
+ self.dh = (h_new - self.h) / self.delta # (h_new - h_old) / delta
282
+ else:
283
+
284
+ # In the non-local version we keep track in self.h of the new value of the state
285
+ self.h = h_new
286
+ self.dh = (self.h - h) / self.delta # (h_new - h_old) / delta
287
+
288
+ # Compute output using a nonlinear activation function
289
+ y = self.C(self.sigma(self.h))
290
+
291
+ # Store the new state for the next iteration
292
+ self.h_next.data = h_new.detach()
293
+ self.forward_count += 1
294
+
295
+ return y
296
+
297
+
298
+ class CDiagC(ModuleWrapper):
299
+ """Diagonal matrix-based generator with complex-valued transformations."""
300
+ def __init__(self, u_shape: tuple[int], d_dim: int, y_dim: int, h_dim: int, sigma: Callable = lambda x: x,
301
+ project_every: int = 0, local: bool = False, batch_size: int = 1, seed: int = -1):
302
+ super(CDiagC, self).__init__(seed=seed)
303
+ device = self.device
304
+ u_shape = torch.Size(u_shape)
305
+ u_dim = u_shape.numel()
306
+ du_dim = d_dim
307
+ self.batch_size = batch_size
308
+ self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_rnn(u_shape, du_dim, y_dim)
309
+
310
+ # Define diagonal transformation with complex numbers
311
+ self.diag = torch.nn.Linear(in_features=1, out_features=h_dim, bias=False, device=device, dtype=torch.cfloat)
312
+ self.B = torch.nn.Linear(u_dim + du_dim, h_dim, bias=False, device=device, dtype=torch.cfloat)
313
+ self.C = torch.nn.Linear(h_dim, y_dim, bias=False, device=device, dtype=torch.cfloat)
314
+
315
+ # Hidden state initialization
316
+ self.register_buffer('h_init', torch.randn((batch_size, h_dim), device=device))
317
+ self.register_buffer('h_next', torch.randn((batch_size, h_dim), device=device))
318
+ self.h = None
319
+ self.dh = None
320
+ self.sigma = sigma # The non-linear activation function
321
+
322
+ # Store input dimensions and device
323
+ self.u_dim = u_dim
324
+ self.du_dim = du_dim
325
+ self.delta = 1.
326
+ self.local = local # If True the state update is computed locally in time (i.e., kept out from the graph)
327
+ self.forward_count = 0
328
+ self.project_every = project_every
329
+
330
+ @torch.no_grad()
331
+ def adjust_eigs(self):
332
+ """ Normalize the diagonal weight matrix by dividing by its magnitude. """
333
+ self.diag.weight.div_(self.diag.weight.abs())
334
+
335
+ def init_h(self, udu: torch.Tensor) -> torch.Tensor:
336
+ return self.h_init.data
337
+
338
+ @staticmethod
339
+ def handle_inputs(du, u):
340
+ return du, u
341
+
342
+ def forward(self, u: torch.Tensor, du: torch.Tensor, first: bool = True, last: bool = False):
343
+ """Forward pass with complex-valued transformation."""
344
+
345
+ # Handle missing inputs
346
+ u = u.flatten(1).to(self.device) if u is not None else torch.zeros((self.batch_size, self.u_dim),
347
+ device=self.device, dtype=torch.cfloat)
348
+ du = du.to(self.device) if du is not None else torch.zeros((self.batch_size, self.du_dim),
349
+ device=self.device, dtype=torch.cfloat)
350
+
351
+ # Reset hidden state if first step
352
+ if first:
353
+ h = self.init_h(torch.cat([du, u], dim=1))
354
+ self.forward_count = 0
355
+ else:
356
+ h = self.h_next.data
357
+
358
+ # Track the gradients on h from here on
359
+ h.requires_grad_()
360
+
361
+ # Check if it's time to project the eigenvalues
362
+ if self.project_every:
363
+ if self.forward_count % self.project_every == 0:
364
+ self.adjust_eigs()
365
+
366
+ # Handle inputs
367
+ du, u = self.handle_inputs(du, u)
368
+
369
+ # Apply complex diagonal transformation
370
+ h_new = self.diag.weight.view(self.diag.out_features) * h + self.B(torch.cat([du, u], dim=1))
371
+
372
+ if self.local:
373
+
374
+ # In the local version we keep track in self.h of the old value of the state
375
+ self.h = h
376
+ self.dh = (h_new - self.h) / self.delta # (h_new - h_old) / delta
377
+ else:
378
+
379
+ # In the non-local version we keep track in self.h of the new value of the state
380
+ self.h = h_new
381
+ self.dh = (self.h - h) / self.delta # (h_new - h_old) / delta
382
+
383
+ # Compute output using a nonlinear activation function
384
+ y = self.C(self.sigma(self.h))
385
+
386
+ # Store the new state for the next iteration
387
+ self.h_next.data = h_new.detach()
388
+ self.forward_count += 1
389
+
390
+ return y.real
391
+
392
+
393
+ class CTE(ModuleWrapper):
394
+ """Antisymmetric Matrix Exponential Generator implementing continuous-time dynamics.
395
+
396
+ Uses antisymmetric weight matrix with matrix exponential for stable hidden state evolution.
397
+
398
+ Args:
399
+ u_shape: Input shape (tuple of integers)
400
+ d_dim: Input descriptor dimension
401
+ y_dim: Output dimension
402
+ h_dim: Hidden state dimension
403
+ delta: Time step for discrete approximation
404
+ local: Local computations (bool)
405
+ seed: Random seed (positive int)
406
+ """
407
+
408
+ def __init__(self, u_shape: tuple[int], d_dim: int, y_dim: int, h_dim: int, delta: float,
409
+ sigma: Callable = lambda x: x, project_every: int = 0, local: bool = False,
410
+ cnu_memories: int = 0, batch_size: int = 1, seed: int = -1):
411
+ super(CTE, self).__init__(seed=seed)
412
+ device = self.device
413
+ u_shape = torch.Size(u_shape)
414
+ u_dim = u_shape.numel()
415
+ du_dim = d_dim
416
+ self.batch_size = batch_size
417
+ self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_rnn(u_shape, du_dim, y_dim)
418
+
419
+ # Antisymmetric weight matrix (W - W^T)
420
+ self.W = torch.nn.Linear(h_dim, h_dim, bias=False, device=device)
421
+ self.Id = torch.eye(h_dim, device=device) # Identity matrix
422
+
423
+ # Input projection matrix
424
+ self.B = torch.nn.Linear(u_dim + du_dim, h_dim, bias=False, device=device)
425
+
426
+ # Output projection matrix
427
+ if cnu_memories <= 0:
428
+ self.C = torch.nn.Linear(h_dim, y_dim, bias=False, device=device)
429
+ else:
430
+ self.C = LinearCNU(h_dim, y_dim, bias=False, device=device, key_size=u_dim + du_dim,
431
+ delta=1, beta_k=delta, scramble=False, key_mem_units=cnu_memories, shared_keys=True)
432
+
433
+ # Hidden state initialization
434
+ self.register_buffer('h_init', torch.randn((batch_size, h_dim), device=device))
435
+ self.register_buffer('h_next', torch.randn((batch_size, h_dim), device=device))
436
+ self.h = None
437
+ self.dh = None
438
+ self.sigma = sigma # The non-linear activation function
439
+
440
+ # System parameters
441
+ self.u_dim = u_dim
442
+ self.du_dim = du_dim
443
+ self.delta = delta
444
+ self.local = local
445
+ self.forward_count = 0
446
+ self.project_every = project_every
447
+
448
+ @torch.no_grad()
449
+ def adjust_eigs(self):
450
+ """Placeholder for eigenvalue adjustment method"""
451
+ pass
452
+
453
+ def init_h(self, udu: torch.Tensor) -> torch.Tensor:
454
+ return self.h_init.data
455
+
456
+ @staticmethod
457
+ def handle_inputs(du, u):
458
+ return du, u
459
+
460
+ def forward(self, u: torch.Tensor, du: torch.Tensor, first: bool = True, last: bool = False) -> torch.Tensor:
461
+ """Forward pass through the system dynamics.
462
+
463
+ Args:
464
+ u: Input tensor of shape (batch_size, u_dim)
465
+ du: Input descriptor tensor of shape (batch_size, du_dim)
466
+ first: Flag indicating first step (resets hidden state)
467
+ last: Flag indicating last step (does nothing)
468
+
469
+ Returns:
470
+ y: Output tensor of shape (batch_size, y_dim)
471
+ """
472
+
473
+ # Handle missing inputs
474
+ u = u.flatten(1).to(self.device) if u is not None else (
475
+ torch.zeros((self.batch_size, self.u_dim), device=self.device))
476
+ du = du.to(self.device) if du is not None else (
477
+ torch.zeros((self.batch_size, self.du_dim), device=self.device))
478
+
479
+ # Reset hidden state if first step
480
+ if first:
481
+ h = self.init_h(torch.cat([du, u], dim=1))
482
+ self.forward_count = 0
483
+ else:
484
+ h = self.h_next.data
485
+
486
+ # Track the gradients on h from here on
487
+ h.requires_grad_()
488
+
489
+ # Check if it's time to project the eigenvalues
490
+ if self.project_every:
491
+ if self.forward_count % self.project_every == 0:
492
+ self.adjust_eigs()
493
+
494
+ if not isinstance(self.C, LinearCNU):
495
+ C = self.C
496
+ else:
497
+ udu = torch.cat([du, u], dim=1)
498
+ weight_C = self.C.compute_weights(udu).view(self.C.out_features, self.C.in_features)
499
+
500
+ def C(x):
501
+ return torch.nn.functional.linear(x, weight_C)
502
+
503
+ # Handle inputs
504
+ du, u = self.handle_inputs(du, u)
505
+
506
+ # Antisymmetric matrix construction
507
+ A = 0.5 * (self.W.weight - self.W.weight.t())
508
+ A_expm = torch.linalg.matrix_exp(A * self.delta) # Matrix exponential
509
+ rec = F.linear(h, A_expm, self.W.bias) # Recurrent component
510
+
511
+ # Input processing component
512
+ A_inv = torch.linalg.inv(A)
513
+ inp = A_inv @ (A_expm - self.Id) @ self.B(torch.cat([du, u], dim=1)).unsqueeze(-1)
514
+
515
+ # Handle locality
516
+ h_new = rec + inp.squeeze(-1) # Updated hidden state
517
+ if self.local:
518
+
519
+ # In the local version we keep track in self.h of the old value of the state
520
+ self.h = h
521
+ self.dh = (h_new - self.h) / self.delta # (h_new - h_old) / delta
522
+ else:
523
+
524
+ # In the non-local version we keep track in self.h of the new value of the state
525
+ self.h = h_new
526
+ self.dh = (self.h - h) / self.delta # (h_new - h_old) / delta
527
+
528
+ # Compute output using a nonlinear activation function
529
+ y = C(self.sigma(self.h))
530
+
531
+ # Store the new state for the next iteration
532
+ self.h_next.data = h_new
533
+ self.forward_count += 1
534
+
535
+ return y
536
+
537
+
538
+ class CTEInitStateBZeroInput(CTE):
539
+
540
+ def __init__(self, u_shape: tuple[int], d_dim: int, y_dim: int, h_dim: int, delta: float,
541
+ sigma: Callable = lambda x: x, project_every: int = 0, local: bool = False,
542
+ cnu_memories: int = 0, batch_size: int = 1, seed: int = -1):
543
+ super(CTEInitStateBZeroInput, self).__init__(u_shape, d_dim, y_dim, h_dim, delta, sigma, project_every,
544
+ local, cnu_memories, batch_size, seed)
545
+
546
+ @torch.no_grad()
547
+ def init_h(self, udu: torch.Tensor) -> torch.Tensor:
548
+ return self.B(udu).detach() / torch.sum(udu, dim=1)
549
+
550
+ @staticmethod
551
+ def handle_inputs(du, u):
552
+ return torch.zeros_like(du), torch.zeros_like(u)
553
+
554
+
555
+ class CTEToken(CTE):
556
+
557
+ def __init__(self, num_emb: int, emb_dim: int, d_dim: int, y_dim: int, h_dim: int, seed: int = -1):
558
+ super(CTEToken, self).__init__((emb_dim,), d_dim, y_dim, h_dim, delta=1.0, local=False, seed=seed)
559
+ self.embeddings = torch.nn.Embedding(num_emb, emb_dim)
560
+
561
+ def forward(self, u: torch.Tensor, du: torch.Tensor, first: bool = True, last: bool = False):
562
+ if u is not None:
563
+ u = self.embeddings(u.to(self.device))
564
+ y = super().forward(u, du, first=first, last=last)
565
+ return y
566
+
567
+
568
+ class CTB(ModuleWrapper):
569
+ """Block Antisymmetric Generator using 2x2 parameterized rotation blocks.
570
+
571
+ Implements structured antisymmetric dynamics through learnable rotational frequencies.
572
+
573
+ Args:
574
+ u_shape: Input shape (tuple of integers)
575
+ d_dim: Input descriptor dimension
576
+ y_dim: Output dimension
577
+ h_dim: Hidden state dimension
578
+ delta: Time step for discrete approximation
579
+ alpha: Dissipation added on the diagonal (also controls the eigenvalue projections method)
580
+ """
581
+
582
+ def __init__(self, u_shape: tuple[int], d_dim: int, y_dim: int, h_dim: int, delta: float = None,
583
+ alpha: float = 0., sigma: Callable = lambda x: x, project_every: int = 0, local: bool = False,
584
+ batch_size: int = 1, seed: int = -1):
585
+ super(CTB, self).__init__(seed=seed)
586
+ device = self.device
587
+ u_shape = torch.Size(u_shape)
588
+ u_dim = u_shape.numel()
589
+ du_dim = d_dim
590
+ self.batch_size = batch_size
591
+ self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_rnn(u_shape, du_dim, y_dim)
592
+
593
+ assert h_dim % 2 == 0, "Hidden dimension must be even for 2x2 blocks"
594
+ self.order = h_dim // 2 # Number of 2x2 blocks
595
+
596
+ # Learnable rotational frequencies
597
+ self.omega = torch.nn.Parameter(torch.empty(self.order, device=device))
598
+ self.register_buffer('ones', torch.ones(self.order, requires_grad=False, device=device))
599
+
600
+ # Projection matrices
601
+ self.B = torch.nn.Linear(u_dim + du_dim, h_dim, bias=False, device=device)
602
+ self.C = torch.nn.Linear(h_dim, y_dim, bias=False, device=device)
603
+
604
+ # Damping configuration
605
+ if alpha > 0.:
606
+
607
+ # In this case we want to add the feedback parameter alpha and use it to move eigenvalues on the unit circle
608
+ self.project_method = 'const'
609
+ self.register_buffer('alpha', torch.full_like(self.omega.data, alpha, device=device))
610
+ elif alpha == 0.:
611
+
612
+ # This is the case in which we want to divide by the modulus
613
+ self.project_method = 'modulus'
614
+ self.register_buffer('alpha', torch.zeros_like(self.omega.data, device=device))
615
+ elif alpha == -1.:
616
+ self.project_method = 'alpha'
617
+ self.register_buffer('alpha', torch.zeros_like(self.omega.data, device=device))
618
+
619
+ # Hidden state initialization
620
+ self.register_buffer('h_init', torch.randn((batch_size, h_dim), device=device))
621
+ self.register_buffer('h_next', torch.randn((batch_size, h_dim), device=device))
622
+ self.h = None
623
+ self.dh = None
624
+ self.sigma = sigma # The non-linear activation function
625
+
626
+ # System parameters
627
+ self.u_dim = u_dim
628
+ self.du_dim = du_dim
629
+ self.delta = delta
630
+ self.local = local # If True the state update is computed locally in time (i.e., kept out from the graph)
631
+ self.reset_parameters()
632
+ self.forward_count = 0
633
+ self.project_every = project_every
634
+
635
+ def reset_parameters(self) -> None:
636
+ """Initialize rotational frequencies with uniform distribution"""
637
+ torch.nn.init.uniform_(self.omega)
638
+
639
+ @torch.no_grad()
640
+ def adjust_eigs(self):
641
+ """Adjust eigenvalues to maintain stability"""
642
+ with torch.no_grad():
643
+ if self.project_method == 'alpha':
644
+
645
+ # Compute damping to maintain eigenvalues on unit circle
646
+ self.alpha.copy_((1. - torch.sqrt(1. - (self.delta * self.omega) ** 2) / self.delta))
647
+ elif self.project_method == 'modulus':
648
+
649
+ # Normalize by modulus for unit circle stability
650
+ module = torch.sqrt(self.ones ** 2 + (self.delta * self.omega) ** 2)
651
+ self.omega.div_(module)
652
+ self.ones.div_(module)
653
+
654
+ def init_h(self, udu: torch.Tensor) -> torch.Tensor:
655
+ return self.h_init.data
656
+
657
+ @staticmethod
658
+ def handle_inputs(du, u):
659
+ return du, u
660
+
661
+ def forward(self, u: torch.Tensor, du: torch.Tensor, first: bool = True, last: bool = False) -> torch.Tensor:
662
+ """Forward pass through block-structured dynamics"""
663
+
664
+ # Handle missing inputs
665
+ u = u.flatten(1).to(self.device) if u is not None \
666
+ else torch.zeros((self.batch_size, self.u_dim), device=self.device)
667
+ du = du.to(self.device) if du is not None \
668
+ else torch.zeros((self.batch_size, self.du_dim), device=self.device)
669
+
670
+ # Reset hidden state if first step
671
+ if first:
672
+ h = self.init_h(torch.cat([du, u], dim=1))
673
+ self.forward_count = 0
674
+ else:
675
+ h = self.h_next.data
676
+
677
+ # Track the gradients on h from here on
678
+ h.requires_grad_()
679
+ h_pair = h.view(-1, self.order, 2) # Reshape to (batch, blocks, 2)
680
+
681
+ # Check if it's time to project the eigenvalues
682
+ if self.project_every:
683
+ if self.forward_count % self.project_every == 0:
684
+ self.adjust_eigs()
685
+
686
+ # Handle inputs
687
+ du, u = self.handle_inputs(du, u)
688
+
689
+ # Block-wise rotation with damping
690
+ h1 = (self.ones - self.delta * self.alpha) * h_pair[..., 0] + self.delta * self.omega * h_pair[..., 1]
691
+ h2 = -self.delta * self.omega * h_pair[..., 0] + (self.ones - self.delta * self.alpha) * h_pair[..., 1]
692
+
693
+ # Recurrent and input components
694
+ rec = torch.stack([h1, h2], dim=-1).flatten(start_dim=1)
695
+ inp = self.delta * self.B(torch.cat([du, u], dim=1))
696
+
697
+ # Handle locality
698
+ h_new = rec + inp # Updated hidden state
699
+ if self.local:
700
+
701
+ # In the local version we keep track in self.h of the old value of the state
702
+ self.h = h
703
+ self.dh = (h_new - self.h) / self.delta # (h_new - h_old) / delta
704
+ else:
705
+
706
+ # In the non-local version we keep track in self.h of the new value of the state
707
+ self.h = h_new
708
+ self.dh = (self.h - h) / self.delta # (h_new - h_old) / delta
709
+
710
+ # Compute output using a nonlinear activation function
711
+ y = self.C(self.sigma(self.h))
712
+
713
+ # Store the new state for the next iteration
714
+ self.h_next.data = h_new.detach()
715
+ self.forward_count += 1
716
+
717
+ return y
718
+
719
+
720
+ class CTBE(ModuleWrapper):
721
+ """Antisymmetric Generator with Exact Matrix Exponential Blocks.
722
+
723
+ Implements precise rotational dynamics using trigonometric parameterization.
724
+
725
+ Args:
726
+ u_shape: Input shape (tuple of integers)
727
+ d_dim: Input descriptor dimension
728
+ y_dim: Output dimension
729
+ h_dim: Hidden state dimension
730
+ delta: Time step for discrete approximation
731
+ """
732
+
733
+ def __init__(self, u_shape: tuple[int], d_dim: int, y_dim: int, h_dim: int, delta: float,
734
+ sigma: Callable = lambda x: x, project_every: int = 0, local: bool = False,
735
+ cnu_memories: int = 0, batch_size: int = 1, seed: int = -1):
736
+ super(CTBE, self).__init__(seed=seed)
737
+ device = self.device
738
+ u_shape = torch.Size(u_shape)
739
+ u_dim = u_shape.numel()
740
+ du_dim = d_dim
741
+ self.batch_size = batch_size
742
+ self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_rnn(u_shape, du_dim, y_dim)
743
+
744
+ assert h_dim % 2 == 0, "Hidden dimension must be even for 2x2 blocks"
745
+ self.order = h_dim // 2
746
+
747
+ # Learnable rotational frequencies
748
+ self.omega = torch.nn.Parameter(torch.empty(self.order, device=device))
749
+ self.B = torch.nn.Linear(u_dim + du_dim, h_dim, bias=False, device=device)
750
+ if cnu_memories <= 0:
751
+ self.C = torch.nn.Linear(h_dim, y_dim, bias=False, device=device)
752
+ else:
753
+ self.C = LinearCNU(h_dim, y_dim, bias=False, device=device, key_size=u_dim + du_dim,
754
+ delta=1, beta_k=delta, scramble=False, key_mem_units=cnu_memories, shared_keys=True)
755
+
756
+ # Hidden state initialization
757
+ self.register_buffer('h_init', torch.randn((batch_size, h_dim), device=device))
758
+ self.register_buffer('h_next', torch.randn((batch_size, h_dim), device=device))
759
+ self.h = None
760
+ self.dh = None
761
+ self.sigma = sigma # The non-linear activation function
762
+
763
+ # System parameters
764
+ self.u_dim = u_dim
765
+ self.du_dim = du_dim
766
+ self.delta = delta
767
+ self.local = local # If True the state update is computed locally in time (i.e., kept out from the graph)
768
+ self.reset_parameters()
769
+ self.forward_count = 0
770
+ self.project_every = project_every
771
+
772
+ def reset_parameters(self) -> None:
773
+ """Initialize rotational frequencies"""
774
+ if not isinstance(self.omega, CNUs):
775
+ torch.nn.init.uniform_(self.omega)
776
+ else:
777
+ torch.nn.init.uniform_(self.omega.M)
778
+
779
+ @torch.no_grad()
780
+ def adjust_eigs(self):
781
+ """Placeholder for eigenvalue adjustment"""
782
+ pass
783
+
784
+ def init_h(self, udu: torch.Tensor) -> torch.Tensor:
785
+ return self.h_init.data
786
+
787
+ @staticmethod
788
+ def handle_inputs(du, u):
789
+ return du, u
790
+
791
+ def forward(self, u: torch.Tensor, du: torch.Tensor, first: bool = True, last: bool = False) -> torch.Tensor:
792
+ """Exact matrix exponential forward pass"""
793
+
794
+ # Handle missing inputs
795
+ u = u.flatten(1).to(self.device) if u is not None \
796
+ else torch.zeros((self.batch_size, self.u_dim), device=self.device)
797
+ du = du.to(self.device) if du is not None \
798
+ else torch.zeros((self.batch_size, self.du_dim), device=self.device)
799
+
800
+ # Reset hidden state if first step
801
+ if first:
802
+ h = self.init_h(torch.cat([du, u], dim=1))
803
+ self.forward_count = 0
804
+ else:
805
+ h = self.h_next.data
806
+
807
+ # Track the gradients on h from here on
808
+ h.requires_grad_()
809
+ h_pair = h.view(-1, self.order, 2)
810
+
811
+ # Check if it's time to project the eigenvalues
812
+ if self.project_every:
813
+ if self.forward_count % self.project_every == 0:
814
+ self.adjust_eigs()
815
+
816
+ if not isinstance(self.C, LinearCNU):
817
+ C = self.C
818
+ else:
819
+ udu = torch.cat([du, u], dim=1)
820
+ weight_C = self.C.compute_weights(udu).view(self.C.out_features, self.C.in_features)
821
+
822
+ def C(x):
823
+ return torch.nn.functional.linear(x, weight_C)
824
+
825
+ # Handle inputs
826
+ du, u = self.handle_inputs(du, u)
827
+ udu = torch.cat([du, u], dim=1)
828
+
829
+ # Trigonometric terms for exact rotation
830
+ cos_t = torch.cos(self.omega * self.delta)
831
+ sin_t = torch.sin(self.omega * self.delta)
832
+
833
+ # Rotational update
834
+ h1 = cos_t * h_pair[..., 0] + sin_t * h_pair[..., 1]
835
+ h2 = -sin_t * h_pair[..., 0] + cos_t * h_pair[..., 1]
836
+ rec = torch.stack([h1, h2], dim=-1).flatten(start_dim=1)
837
+
838
+ # Input processing
839
+ u_hat = self.B(udu).view(-1, self.order, 2)
840
+ inp1 = (sin_t * u_hat[..., 0] - (cos_t - 1) * u_hat[..., 1]) / self.omega
841
+ inp2 = ((cos_t - 1) * u_hat[..., 0] + sin_t * u_hat[..., 1]) / self.omega
842
+ inp = torch.stack([inp1, inp2], dim=-1).flatten(start_dim=1)
843
+
844
+ # Handle locality
845
+ h_new = rec + inp # Updated hidden state
846
+ if self.local:
847
+
848
+ # In the local version we keep track in self.h of the old value of the state
849
+ self.h = h
850
+ self.dh = (h_new - self.h) / self.delta # (h_new - h_old) / delta
851
+ else:
852
+
853
+ # In the non-local version we keep track in self.h of the new value of the state
854
+ self.h = h_new
855
+ self.dh = (self.h - h) / self.delta # (h_new - h_old) / delta
856
+
857
+ # Compute output using a nonlinear activation function
858
+ y = C(self.sigma(self.h))
859
+
860
+ # Store the new state for the next iteration
861
+ self.h_next.data = h_new.detach()
862
+ self.forward_count += 1
863
+
864
+ return y
865
+
866
+
867
+ class CTBEInitStateBZeroInput(CTBE):
868
+ def __init__(self, u_shape, d_dim, y_dim, h_dim, delta, local, cnu_memories: int = 0,
869
+ batch_size: int = 1, seed: int = -1):
870
+ super().__init__(u_shape=u_shape, d_dim=d_dim, y_dim=y_dim, h_dim=h_dim, delta=delta, local=local,
871
+ cnu_memories=cnu_memories, batch_size=batch_size, seed=seed)
872
+
873
+ @torch.no_grad()
874
+ def init_h(self, udu: torch.Tensor) -> torch.Tensor:
875
+ return self.B(udu).detach() / torch.sum(udu)
876
+
877
+ @staticmethod
878
+ def handle_inputs(du, u):
879
+ return torch.zeros_like(du), torch.zeros_like(u)
880
+
881
+
882
+ class CNN(ModuleWrapper):
883
+
884
+ def __init__(self, d_dim: int, in_channels: int = 3, in_res: int = 32, return_input: bool = False,
885
+ seed: int = -1):
886
+ super(CNN, self).__init__(seed=seed)
887
+ self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_image_classification(d_dim)
888
+ self.return_input = return_input
889
+ if self.return_input:
890
+ self.proc_outputs.insert(0, Data4Proc(data_type="img", pubsub=False, private_only=True))
891
+ self.transforms = transforms_factory("rgb" + str(in_res) if in_channels == 3 else "gray" + str(in_res))
892
+
893
+ self.module = torch.nn.Sequential(
894
+ torch.nn.Conv2d(in_channels, 64, kernel_size=5, padding=2),
895
+ torch.nn.ReLU(inplace=True),
896
+ torch.nn.AvgPool2d(kernel_size=3, stride=2),
897
+ torch.nn.Conv2d(64, 128, kernel_size=5, padding=2),
898
+ torch.nn.ReLU(inplace=True),
899
+ torch.nn.AvgPool2d(kernel_size=3, stride=2),
900
+ torch.nn.Conv2d(128, 256, kernel_size=3, padding=1),
901
+ torch.nn.ReLU(inplace=True),
902
+ torch.nn.AvgPool2d(kernel_size=3, stride=2),
903
+ torch.nn.Flatten(),
904
+ torch.nn.LazyLinear(2048),
905
+ torch.nn.ReLU(inplace=True),
906
+ torch.nn.Linear(2048, d_dim),
907
+ torch.nn.Sigmoid()
908
+ ).to(self.device)
909
+
910
+ def forward(self, y: Image.Image, first: bool = True, last: bool = False):
911
+ o = self.module(self.transforms(y).to(self.device))
912
+ if not self.return_input:
913
+ return o
914
+ else:
915
+ return y, o
916
+
917
+
918
+ class CNNCNU(ModuleWrapper):
919
+
920
+ def __init__(self, d_dim: int, cnu_memories: int, in_channels: int = 3, in_res: int = 32,
921
+ delta: int = 1, scramble: bool = False, return_input: bool = False, seed: int = -1):
922
+ super(CNNCNU, self).__init__(seed=seed)
923
+ self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_image_classification(d_dim)
924
+ self.return_input = return_input
925
+ if self.return_input:
926
+ self.proc_outputs.insert(0, Data4Proc(data_type="img", pubsub=False, private_only=True))
927
+ self.transforms = transforms_factory("rgb" + str(in_res) if in_channels == 3 else "gray" + str(in_res))
928
+
929
+ self.module = torch.nn.Sequential(
930
+ torch.nn.Conv2d(in_channels, 64, kernel_size=5, padding=2),
931
+ torch.nn.ReLU(inplace=True),
932
+ torch.nn.AvgPool2d(kernel_size=3, stride=2),
933
+ torch.nn.Conv2d(64, 128, kernel_size=5, padding=2),
934
+ torch.nn.ReLU(inplace=True),
935
+ torch.nn.AvgPool2d(kernel_size=3, stride=2),
936
+ torch.nn.Conv2d(128, 256, kernel_size=3, padding=1),
937
+ torch.nn.ReLU(inplace=True),
938
+ torch.nn.AvgPool2d(kernel_size=3, stride=2),
939
+ torch.nn.Flatten(),
940
+ torch.nn.LazyLinear(2048),
941
+ torch.nn.ReLU(inplace=True),
942
+ LinearCNU(2048, d_dim, key_mem_units=cnu_memories, delta=delta, scramble=scramble),
943
+ torch.nn.Sigmoid()
944
+ ).to(self.device)
945
+
946
+ def forward(self, y: Image.Image, first: bool = True, last: bool = False):
947
+ o = self.module(self.transforms(y).to(self.device))
948
+ if not self.return_input:
949
+ return o
950
+ else:
951
+ return y, o
952
+
953
+
954
+ class SingleLayerCNU(ModuleWrapper):
955
+
956
+ def __init__(self, d_dim: int, cnu_memories: int, in_channels: int = 3, in_res: int = 32,
957
+ delta: int = 1, scramble: bool = False, return_input: bool = False, seed: int = -1):
958
+ super(SingleLayerCNU, self).__init__(seed=seed)
959
+ self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_image_classification(d_dim)
960
+ self.return_input = return_input
961
+ if self.return_input:
962
+ self.proc_outputs.insert(0, Data4Proc(data_type="img", pubsub=False, private_only=True))
963
+ self.transforms = transforms_factory("rgb" + str(in_res) if in_channels == 3 else "gray" + str(in_res))
964
+
965
+ self.module = torch.nn.Sequential(
966
+ torch.nn.Flatten(),
967
+ LinearCNU(in_res * in_res * in_channels, d_dim, key_mem_units=cnu_memories, delta=delta, scramble=scramble),
968
+ torch.nn.Sigmoid()
969
+ ).to(self.device)
970
+
971
+ def forward(self, y: Image.Image, first: bool = True, last: bool = False):
972
+ o = self.module(self.transforms(y).to(self.device))
973
+ if not self.return_input:
974
+ return o
975
+ else:
976
+ return y, o
977
+
978
+
979
+ class CNNMNIST(CNN):
980
+
981
+ def __init__(self, *args, **kwargs):
982
+ kwargs['in_channels'] = 1
983
+ kwargs['in_res'] = 28
984
+ super(CNNMNIST, self).__init__(*args, **kwargs)
985
+ self.transforms = transforms_factory("gray_mnist")
986
+
987
+
988
+ class CNNCNUMNIST(CNNCNU):
989
+
990
+ def __init__(self, *args, **kwargs):
991
+ kwargs['in_channels'] = 1
992
+ kwargs['in_res'] = 28
993
+ super(CNNCNUMNIST, self).__init__(*args, **kwargs)
994
+ self.transforms = transforms_factory("gray_mnist")
995
+
996
+
997
+ class SingleLayerCNUMNIST(SingleLayerCNU):
998
+
999
+ def __init__(self, *args, **kwargs):
1000
+ kwargs['in_channels'] = 1
1001
+ kwargs['in_res'] = 28
1002
+ super(SingleLayerCNUMNIST, self).__init__(*args, **kwargs)
1003
+ self.transforms = transforms_factory("gray_mnist")
1004
+
1005
+
1006
+ class ResNet(ModuleWrapper):
1007
+ def __init__(self, d_dim: int = -1, return_input: bool = False, seed: int = -1, freeze_backbone: bool = True):
1008
+ super(ResNet, self).__init__(seed=seed)
1009
+ self.return_input = return_input
1010
+ if self.return_input:
1011
+ self.proc_outputs.insert(0, Data4Proc(data_type="img", pubsub=False, private_only=True))
1012
+ self.transforms = transforms_factory("rgb224")
1013
+ self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_image_classification(d_dim)
1014
+ resnet = torchvision.models.resnet50(weights="IMAGENET1K_V1")
1015
+
1016
+ if freeze_backbone:
1017
+ for layer in resnet.parameters():
1018
+ if layer != resnet.fc:
1019
+ layer.requires_grad = False
1020
+
1021
+ if d_dim > 0:
1022
+ resnet.fc = torch.nn.Sequential(
1023
+ torch.nn.Linear(resnet.fc.in_features, d_dim),
1024
+ torch.nn.Sigmoid()
1025
+ )
1026
+
1027
+ self.module = resnet.to(self.device)
1028
+
1029
+ def forward(self, y: Image.Image, first: bool = True, last: bool = False):
1030
+ o = self.module(self.transforms(y).to(self.device))
1031
+ if not self.return_input:
1032
+ return o
1033
+ else:
1034
+ return y, o
1035
+
1036
+
1037
+ class ResNetCNU(ModuleWrapper):
1038
+ def __init__(self, d_dim: int, cnu_memories: int,
1039
+ delta: int = 1, scramble: bool = False, return_input: bool = False, seed: int = -1):
1040
+ super(ResNetCNU, self).__init__(seed=seed)
1041
+ self.return_input = return_input
1042
+ if self.return_input:
1043
+ self.proc_outputs.insert(0, Data4Proc(data_type="img", pubsub=False, private_only=True))
1044
+ self.transforms = transforms_factory("rgb224")
1045
+ self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_image_classification(d_dim)
1046
+ resnet = torchvision.models.resnet50(weights="IMAGENET1K_V1")
1047
+
1048
+ resnet.fc = torch.nn.Sequential(
1049
+ LinearCNU(resnet.fc.in_features, d_dim, key_mem_units=cnu_memories, delta=delta, scramble=scramble),
1050
+ torch.nn.Sigmoid()
1051
+ )
1052
+
1053
+ self.module = resnet.to(self.device)
1054
+
1055
+ def forward(self, y: Image.Image, first: bool = True, last: bool = False):
1056
+ o = self.module(self.transforms(y).to(self.device))
1057
+ if not self.return_input:
1058
+ return o
1059
+ else:
1060
+ return y, o
1061
+
1062
+
1063
+ class ViT(ModuleWrapper):
1064
+ def __init__(self, d_dim: int = -1, return_input: bool = False, seed: int = -1):
1065
+ super(ViT, self).__init__(seed=seed)
1066
+ self.return_input = return_input
1067
+ if self.return_input:
1068
+ self.proc_outputs.insert(0, Data4Proc(data_type="img", pubsub=False, private_only=True))
1069
+ weights = torchvision.models.ViT_B_16_Weights.IMAGENET1K_V1
1070
+ self.transforms = torchvision.transforms.Compose([
1071
+ weights.transforms(),
1072
+ torchvision.transforms.Lambda(lambda x: x.unsqueeze(0)) # Add batch dimension
1073
+ ])
1074
+ self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_image_classification(d_dim)
1075
+ vit = torchvision.models.vit_b_16(weights=weights)
1076
+
1077
+ if d_dim > 0:
1078
+ vit.heads = torch.nn.Sequential(
1079
+ torch.nn.Linear(vit.heads.head.in_features, 2048),
1080
+ torch.nn.ReLU(inplace=True),
1081
+ torch.nn.Linear(2048, d_dim),
1082
+ torch.nn.Sigmoid()
1083
+ )
1084
+ self.labels = ["unk"] * d_dim
1085
+ else:
1086
+ url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
1087
+ self.labels = []
1088
+ with urllib.request.urlopen(url) as f:
1089
+ self.labels = [line.strip().decode('utf-8') for line in f.readlines()]
1090
+
1091
+ self.module = vit.to(self.device)
1092
+
1093
+ def forward(self, y: Image.Image, first: bool = True, last: bool = False):
1094
+ o = self.module(self.transforms(y).to(self.device))
1095
+ if not self.return_input:
1096
+ return o
1097
+ else:
1098
+ return y, o
1099
+
1100
+
1101
+ class DenseNet(ModuleWrapper):
1102
+ def __init__(self, d_dim: int = -1, return_input: bool = False, seed: int = -1):
1103
+ super(DenseNet, self).__init__(seed=seed)
1104
+ self.return_input = return_input
1105
+ if self.return_input:
1106
+ self.proc_outputs.insert(0, Data4Proc(data_type="img", pubsub=False, private_only=True))
1107
+ self.transforms = transforms_factory("rgb224")
1108
+ self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_image_classification(d_dim)
1109
+ densenet = torchvision.models.densenet121(weights=None)
1110
+
1111
+ if d_dim > 0:
1112
+ densenet.classifier = torch.nn.Sequential(
1113
+ torch.nn.Linear(densenet.classifier.in_features, 2048),
1114
+ torch.nn.ReLU(inplace=True),
1115
+ torch.nn.Linear(2048, d_dim),
1116
+ torch.nn.Sigmoid()
1117
+ )
1118
+
1119
+ self.module = densenet.to(self.device)
1120
+
1121
+ def forward(self, y: Image.Image, first: bool = True, last: bool = False):
1122
+ o = self.module(self.transforms(y).to(self.device))
1123
+ if not self.return_input:
1124
+ return o
1125
+ else:
1126
+ return y, o
1127
+
1128
+
1129
+ class EfficientNet(ModuleWrapper):
1130
+ def __init__(self, d_dim: int = -1, return_input: bool = False, seed: int = -1):
1131
+ super(EfficientNet, self).__init__(seed=seed)
1132
+ self.return_input = return_input
1133
+ if self.return_input:
1134
+ self.proc_outputs.insert(0, Data4Proc(data_type="img", pubsub=False, private_only=True))
1135
+ weights = torchvision.models.EfficientNet_B0_Weights.IMAGENET1K_V1
1136
+ self.transforms = weights.transforms
1137
+ self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_image_classification(d_dim)
1138
+ effnet = torchvision.models.efficientnet_b0(weights=weights)
1139
+
1140
+ if d_dim > 0:
1141
+ effnet.classifier = torch.nn.Sequential(
1142
+ torch.nn.Linear(effnet.classifier[1].in_features, 2048),
1143
+ torch.nn.ReLU(inplace=True),
1144
+ torch.nn.Linear(2048, d_dim),
1145
+ torch.nn.Sigmoid()
1146
+ )
1147
+
1148
+ self.module = effnet.to(self.device)
1149
+
1150
+ def forward(self, y: Image.Image, first: bool = True, last: bool = False):
1151
+ o = self.module(self.transforms(y).to(self.device))
1152
+ if o.dim() == 1:
1153
+ o = o.unsqueeze(0)
1154
+ if not self.return_input:
1155
+ return o
1156
+ else:
1157
+ return y, o
1158
+
1159
+
1160
+ class FasterRCNN(ModuleWrapper):
1161
+ def __init__(self, seed: int = -1):
1162
+ self.labels = ['__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
1163
+ 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
1164
+ 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
1165
+ 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
1166
+ 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
1167
+ 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
1168
+ 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
1169
+ 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
1170
+ 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
1171
+ 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
1172
+ 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
1173
+ 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
1174
+ ]
1175
+
1176
+ weights = torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights.DEFAULT
1177
+ faster_rcnn = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=weights)
1178
+ faster_rcnn.eval()
1179
+ self.transforms = torchvision.transforms.Compose([transforms_factory("rgb-no_norm"),
1180
+ torchvision.transforms.Lambda(lambda x: x.squeeze(0)),
1181
+ weights.transforms()])
1182
+
1183
+ super(FasterRCNN, self).__init__(
1184
+ module=faster_rcnn,
1185
+ proc_inputs=[Data4Proc(data_type="img", pubsub=False, private_only=True)],
1186
+ proc_outputs=[Data4Proc(data_type="tensor", tensor_dtype=torch.long, tensor_shape=(None,),
1187
+ pubsub=False, private_only=True),
1188
+ Data4Proc(data_type="tensor", tensor_dtype=torch.float32, tensor_shape=(None,),
1189
+ pubsub=False, private_only=True),
1190
+ Data4Proc(data_type="tensor", tensor_dtype=torch.float32, tensor_shape=(None, 4),
1191
+ pubsub=False, private_only=True),
1192
+ Data4Proc(data_type="text",
1193
+ pubsub=False, private_only=True)],
1194
+ seed=seed)
1195
+
1196
+ def forward(self, y: Image.Image, first: bool = True, last: bool = False):
1197
+ o = self.module([self.transforms(y).to(self.device)]) # List with 1 image per element (no batch dim)
1198
+
1199
+ found_class_indices = o[0]['labels']
1200
+ found_class_scores = o[0]['scores']
1201
+ found_class_boxes = o[0]['boxes']
1202
+ valid = found_class_scores > 0.8
1203
+
1204
+ found_class_indices = found_class_indices[valid]
1205
+ found_class_scores = found_class_scores[valid]
1206
+ found_class_boxes = found_class_boxes[valid]
1207
+ found_class_names = [self.labels[i.item()] for i in found_class_indices]
1208
+
1209
+ return found_class_indices, found_class_scores, found_class_boxes, ", ".join(found_class_names)
1210
+
1211
+
1212
+ class TinyLLama(ModuleWrapper):
1213
+ def __init__(self):
1214
+ super(TinyLLama, self).__init__(
1215
+ proc_inputs=[Data4Proc(data_type="text", pubsub=False, private_only=True)],
1216
+ proc_outputs=[Data4Proc(data_type="text", pubsub=False, private_only=True)]
1217
+ )
1218
+ self.module = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
1219
+ torch_dtype=torch.bfloat16, device=self.device)
1220
+
1221
+ def forward(self, msg: str, first: bool = False, last: bool = False):
1222
+ msg_struct = [{"role": "system", "content": "You are a helpful assistant"},
1223
+ {"role": "user", "content": msg}]
1224
+ prompt = self.module.tokenizer.apply_chat_template(msg_struct, tokenize=False, add_generation_prompt=True)
1225
+
1226
+ out = self.module(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
1227
+ out = out[0]["generated_text"] if (out is not None and len(out) > 0 and "generated_text" in out[0])\
1228
+ else "Error!"
1229
+ if "<|assistant|>\n" in out:
1230
+ out = out.split("<|assistant|>\n")[1]
1231
+ return out.strip()
1232
+
1233
+
1234
+ class LLama(ModuleWrapper):
1235
+ def __init__(self):
1236
+ super(LLama, self).__init__(
1237
+ proc_inputs=[Data4Proc(data_type="text", pubsub=False, private_only=True)],
1238
+ proc_outputs=[Data4Proc(data_type="text", pubsub=False, private_only=True)]
1239
+ )
1240
+ self.module = pipeline("text-generation", model="meta-llama/Llama-3.2-3B",
1241
+ torch_dtype=torch.bfloat16, device=self.device)
1242
+
1243
+ def forward(self, msg: str, first: bool = False, last: bool = False):
1244
+ msg_struct = [{"role": "system", "content": "You are a helpful assistant"},
1245
+ {"role": "user", "content": msg}]
1246
+ prompt = self.module.tokenizer.apply_chat_template(msg_struct, tokenize=False, add_generation_prompt=True)
1247
+
1248
+ out = self.module(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
1249
+ out = out[0]["generated_text"] if (out is not None and len(out) > 0 and "generated_text" in out[0])\
1250
+ else "Error!"
1251
+ if "<|assistant|>\n" in out:
1252
+ out = out.split("<|assistant|>\n")[1]
1253
+ return out.strip()
1254
+
1255
+
1256
+ class Phi(ModuleWrapper):
1257
+ def __init__(self):
1258
+ super(Phi, self).__init__(
1259
+ proc_inputs=[Data4Proc(data_type="text", pubsub=False, private_only=True)],
1260
+ proc_outputs=[Data4Proc(data_type="text", pubsub=False, private_only=True)]
1261
+ )
1262
+ self.module = pipeline("text-generation", model="microsoft/Phi-3.5-mini-instruct",
1263
+ torch_dtype="auto", device=self.device)
1264
+
1265
+ def forward(self, msg: str, first: bool = False, last: bool = False):
1266
+ msg_struct = [{"role": "system", "content": "You are a helpful assistant"},
1267
+ {"role": "user", "content": msg}]
1268
+ prompt = self.module.tokenizer.apply_chat_template(msg_struct, tokenize=False, add_generation_prompt=True)
1269
+
1270
+ out = self.module(prompt, max_new_tokens=256, do_sample=True, return_full_text=False)
1271
+ out = out[0]["generated_text"] if (out is not None and len(out) > 0 and "generated_text" in out[0])\
1272
+ else "Error!"
1273
+ if "<|assistant|>\n" in out:
1274
+ out = out.split("<|assistant|>\n")[1]
1275
+ return out.strip()
1276
+
1277
+
1278
+ class LangSegmentAnything(ModuleWrapper):
1279
+ def __init__(self):
1280
+ super(LangSegmentAnything, self).__init__(
1281
+ proc_inputs=[Data4Proc(data_type="img", pubsub=False, private_only=True),
1282
+ Data4Proc(data_type="text", pubsub=False, private_only=True)],
1283
+ proc_outputs=[Data4Proc(data_type="img", pubsub=False, private_only=True)]
1284
+ )
1285
+ from lang_sam import LangSAM
1286
+ self.module = LangSAM(device=self.device)
1287
+
1288
+ # Generate a 64x64 error image (with text "Error" on it)
1289
+ from PIL import ImageDraw, ImageFont
1290
+ self.error_img = Image.new("RGB", (64, 64), color="white")
1291
+ draw = ImageDraw.Draw(self.error_img)
1292
+ font = ImageFont.load_default()
1293
+ text = "Error"
1294
+ bbox = draw.textbbox((0, 0), text, font=font)
1295
+ text_width = bbox[2] - bbox[0]
1296
+ text_height = bbox[3] - bbox[1]
1297
+ position = ((64 - text_width) // 2, (64 - text_height) // 2)
1298
+ draw.text(position, text, fill="black", font=font)
1299
+
1300
+ def forward(self, image_pil: Image, msg: str, first: bool = False, last: bool = False):
1301
+ try:
1302
+ image_pil = image_pil.convert("RGB") if image_pil.mode != "RGB" else image_pil # Forcing RGB
1303
+ out = self.module.predict([image_pil], [msg])
1304
+
1305
+ if (out is None or not isinstance(out, list) or len(out) < 1 or
1306
+ not isinstance(out[0], dict) or 'masks' not in out[0]) or out[0]['masks'].ndim != 3:
1307
+ return image_pil
1308
+ else:
1309
+ return LangSegmentAnything.__highlight_masks_on_image(image_pil, out[0]['masks'])
1310
+ except Exception as e:
1311
+ return self.error_img
1312
+
1313
+ @staticmethod
1314
+ def __highlight_masks_on_image(image_pil: Image.Image, masks: np.ndarray, alpha: float = 0.75):
1315
+ img_np = np.array(image_pil, dtype=np.float32) / 255.0
1316
+ height, width, _ = img_np.shape
1317
+ num_masks = masks.shape[0]
1318
+
1319
+ overlay_np = np.zeros((height, width, 3), dtype=np.float32)
1320
+ alpha_mask_combined = np.zeros((height, width, 1), dtype=np.float32)
1321
+
1322
+ color_palette = [
1323
+ (255, 102, 102), # Light Red
1324
+ (102, 255, 102), # Light Green
1325
+ (102, 102, 255), # Light Blue
1326
+ (255, 255, 102), # Light Yellow
1327
+ (255, 102, 255), # Light Magenta
1328
+ (102, 255, 255), # Light Cyan
1329
+ (255, 178, 102), # Orange
1330
+ (178, 102, 255), # Purple
1331
+ (102, 178, 255), # Sky Blue
1332
+ ]
1333
+
1334
+ for i in range(num_masks):
1335
+ mask = masks[i, :, :].astype(np.bool)
1336
+
1337
+ color_rgb_int = color_palette[i % len(color_palette)]
1338
+ color = np.array(color_rgb_int, dtype=np.float32) / 255.0
1339
+ overlay_np[mask] = (1 - alpha) * overlay_np[mask] + alpha * color
1340
+ alpha_mask_combined[mask] = np.maximum(alpha_mask_combined[mask], alpha)
1341
+
1342
+ # Final blending and conversion ...
1343
+ final_np = (1 - alpha_mask_combined) * img_np + alpha_mask_combined * overlay_np
1344
+ final_np = (final_np * 255).astype(np.uint8)
1345
+ final_image = Image.fromarray(final_np)
1346
+ return final_image
1347
+
1348
+
1349
+ class SmolVLM(ModuleWrapper):
1350
+ def __init__(self):
1351
+ super(SmolVLM, self).__init__(
1352
+ proc_inputs=[Data4Proc(data_type="img", pubsub=False, private_only=True),
1353
+ Data4Proc(data_type="text", pubsub=False, private_only=True)],
1354
+ proc_outputs=[Data4Proc(data_type="text", pubsub=False, private_only=True)]
1355
+ )
1356
+ model_id = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct"
1357
+ self.pre_post_processor = AutoProcessor.from_pretrained(model_id, device_map=self.device)
1358
+
1359
+ from transformers import AutoModelForImageTextToText
1360
+ self.module = AutoModelForImageTextToText.from_pretrained(model_id, torch_dtype=torch.bfloat16,
1361
+ device_map=self.device)
1362
+ self.module = self.module.to(self.device)
1363
+
1364
+ def forward(self, image_pil: Image, msg: str = "what is this?", first: bool = False, last: bool = False):
1365
+ image_pil = image_pil.convert("RGB") if image_pil.mode != "RGB" else image_pil # Forcing RGB
1366
+
1367
+ msg_struct = [{"role": "user", "content": [{"type": "text", "text": f"{msg}"},
1368
+ {"type": "image", "image": image_pil}]}]
1369
+
1370
+ prompt = self.pre_post_processor.apply_chat_template(msg_struct,
1371
+ tokenize=True,
1372
+ add_generation_prompt=True,
1373
+ return_dict=True,
1374
+ return_tensors="pt").to(self.device, dtype=torch.bfloat16)
1375
+
1376
+ out = self.module.generate(**prompt, do_sample=False, max_new_tokens=128)
1377
+ out = self.pre_post_processor.batch_decode(out, skip_special_tokens=True)[0] if out is not None else "Error!"
1378
+ if "Assistant:" in out:
1379
+ out = out.split("Assistant:")[1]
1380
+ return out.strip()
1381
+
1382
+
1383
+ class SiteRAG(ModuleWrapper):
1384
+
1385
+ def __init__(self,
1386
+ site_url: str,
1387
+ site_folder: str = os.path.join("rag", "downloaded_site"),
1388
+ db_folder: str = os.path.join("rag", "chroma_db")):
1389
+ super(SiteRAG, self).__init__(
1390
+ proc_inputs=[Data4Proc(data_type="text", pubsub=False, private_only=True)],
1391
+ proc_outputs=[Data4Proc(data_type="text", pubsub=False, private_only=True)],
1392
+ )
1393
+
1394
+ # Saving options
1395
+ self.site_url = site_url
1396
+ self.site_folder = site_folder
1397
+ self.db_folder = db_folder
1398
+
1399
+ # Loading neural model
1400
+ model_id = "TheBloke/vicuna-7b-1.1-HF"
1401
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
1402
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16,
1403
+ device_map=self.device, offload_folder="offload")
1404
+ self.module = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=200)
1405
+
1406
+ # Embedder
1407
+ from langchain.embeddings import SentenceTransformerEmbeddings
1408
+ self.embedder = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2",
1409
+ model_kwargs={"device": self.device.type})
1410
+
1411
+ # Crawling site
1412
+ self.crawl_website()
1413
+ self.crawled_site_to_rag_knowledge_base()
1414
+
1415
+ # Setting up RAG stuff
1416
+ from langchain.vectorstores import Chroma
1417
+ db = Chroma(persist_directory=db_folder, embedding_function=self.embedder)
1418
+ self.retriever = db.as_retriever(search_kwargs={"k": 3})
1419
+
1420
+ def forward(self, msg: str, first: bool = False, last: bool = False):
1421
+
1422
+ # Build context
1423
+ docs = self.retriever.get_relevant_documents(msg)
1424
+ context = "\n\n".join(doc.page_content for doc in docs)
1425
+ prompt = f"Answer the question based on the following context:\n\n{context}\n\nQuestion: {msg}\nAnswer:"
1426
+
1427
+ # Generate answer
1428
+ out = self.module(prompt, max_new_tokens=256, do_sample=True, temperature=0.7)
1429
+ out = out[0]['generated_text'][len(prompt):].strip() if (out is not None and len(out) > 0 and
1430
+ "generated_text" in out[0]) else "Error!"
1431
+
1432
+ # Append source URLs
1433
+ best_doc_with_score = self.retriever.vectorstore.similarity_search_with_score(msg, k=1)
1434
+ best_doc, _ = best_doc_with_score[0]
1435
+ docs = [best_doc]
1436
+ sources = set("<a href='" +
1437
+ doc.metadata['source'] +
1438
+ "' onclick='window.open(this.href); return false;' style='color: blue;'>" +
1439
+ doc.metadata['source'] + "</a>" for doc in docs)
1440
+ sources_text = "<br/><br/>\nURLs:\n" + "\n".join(sources)
1441
+
1442
+ return out.strip() + sources_text
1443
+
1444
+ def crawl_website(self, max_pages=300):
1445
+ import requests
1446
+ from bs4 import BeautifulSoup
1447
+ from urllib.parse import urljoin, urlparse
1448
+
1449
+ if os.path.exists(self.site_folder):
1450
+ shutil.rmtree(self.site_folder)
1451
+ os.makedirs(self.site_folder)
1452
+ visited = set()
1453
+ to_visit = [self.site_url]
1454
+
1455
+ while to_visit and len(visited) < max_pages:
1456
+ url = to_visit.pop(0)
1457
+ if url in visited:
1458
+ continue
1459
+ visited.add(url)
1460
+
1461
+ try:
1462
+ r = requests.get(url, timeout=10)
1463
+ if "text/html" not in r.headers.get("Content-Type", ""):
1464
+ continue
1465
+
1466
+ parsed = urlparse(url)
1467
+ filename = parsed.path.strip("/") or "index.html"
1468
+ filename += ".crawled"
1469
+ file_path = os.path.join(self.site_folder, filename.replace("/", "__"))
1470
+ with open(file_path, "w", encoding="utf-8") as f:
1471
+ f.write(r.text)
1472
+
1473
+ soup = BeautifulSoup(r.text, "html.parser")
1474
+ for link in soup.find_all("a", href=True):
1475
+ full_url = urljoin(url, link["href"])
1476
+ if full_url.startswith(self.site_url) and full_url not in visited:
1477
+ to_visit.append(full_url)
1478
+ except Exception as e:
1479
+ print(f"Error fetching {url}: {e}")
1480
+
1481
+ print(f"Crawled {len(visited)} pages.")
1482
+
1483
+ def crawled_site_to_rag_knowledge_base(self):
1484
+ from bs4 import BeautifulSoup
1485
+ from urllib.parse import urljoin
1486
+ from langchain.vectorstores import Chroma
1487
+ from langchain.docstore.document import Document
1488
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
1489
+
1490
+ docs = []
1491
+ for filename in os.listdir(self.site_folder):
1492
+ if filename.endswith(".crawled"):
1493
+ file_path = os.path.join(self.site_folder, filename)
1494
+ with open(file_path, encoding="utf-8") as f:
1495
+ html = f.read()
1496
+
1497
+ soup: BeautifulSoup = BeautifulSoup(html, "html.parser")
1498
+ text = soup.get_text(separator=" ", strip=True) # Type: ignore
1499
+
1500
+ page_path = filename.replace("__", "/").replace(".crawled", "")
1501
+ url = urljoin(self.site_url, page_path)
1502
+
1503
+ docs.append(Document(page_content=text, metadata={"source": url}))
1504
+
1505
+ splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
1506
+ split_docs = splitter.split_documents(docs)
1507
+
1508
+ chroma_db = Chroma.from_documents(split_docs, self.embedder, persist_directory=self.db_folder)
1509
+ chroma_db.persist()