unaiverse 0.1.6__cp311-cp311-musllinux_1_2_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unaiverse might be problematic. Click here for more details.
- unaiverse/__init__.py +19 -0
- unaiverse/agent.py +2008 -0
- unaiverse/agent_basics.py +1846 -0
- unaiverse/clock.py +191 -0
- unaiverse/dataprops.py +1209 -0
- unaiverse/hsm.py +1880 -0
- unaiverse/modules/__init__.py +18 -0
- unaiverse/modules/cnu/__init__.py +17 -0
- unaiverse/modules/cnu/cnus.py +536 -0
- unaiverse/modules/cnu/layers.py +261 -0
- unaiverse/modules/cnu/psi.py +60 -0
- unaiverse/modules/hl/__init__.py +15 -0
- unaiverse/modules/hl/hl_utils.py +411 -0
- unaiverse/modules/networks.py +1509 -0
- unaiverse/modules/utils.py +680 -0
- unaiverse/networking/__init__.py +16 -0
- unaiverse/networking/node/__init__.py +18 -0
- unaiverse/networking/node/connpool.py +1261 -0
- unaiverse/networking/node/node.py +2223 -0
- unaiverse/networking/node/profile.py +446 -0
- unaiverse/networking/node/tokens.py +79 -0
- unaiverse/networking/p2p/__init__.py +198 -0
- unaiverse/networking/p2p/go.mod +127 -0
- unaiverse/networking/p2p/go.sum +548 -0
- unaiverse/networking/p2p/golibp2p.py +18 -0
- unaiverse/networking/p2p/golibp2p.pyi +135 -0
- unaiverse/networking/p2p/lib.go +2714 -0
- unaiverse/networking/p2p/lib.go.sha256 +1 -0
- unaiverse/networking/p2p/lib_types.py +312 -0
- unaiverse/networking/p2p/message_pb2.py +63 -0
- unaiverse/networking/p2p/messages.py +265 -0
- unaiverse/networking/p2p/mylogger.py +77 -0
- unaiverse/networking/p2p/p2p.py +929 -0
- unaiverse/networking/p2p/proto-go/message.pb.go +616 -0
- unaiverse/networking/p2p/unailib.cpython-311-x86_64-linux-musl.so +0 -0
- unaiverse/streamlib/__init__.py +15 -0
- unaiverse/streamlib/streamlib.py +210 -0
- unaiverse/streams.py +770 -0
- unaiverse/utils/__init__.py +16 -0
- unaiverse/utils/ask_lone_wolf.json +27 -0
- unaiverse/utils/lone_wolf.json +19 -0
- unaiverse/utils/misc.py +305 -0
- unaiverse/utils/sandbox.py +293 -0
- unaiverse/utils/server.py +435 -0
- unaiverse/world.py +175 -0
- unaiverse-0.1.6.dist-info/METADATA +365 -0
- unaiverse-0.1.6.dist-info/RECORD +50 -0
- unaiverse-0.1.6.dist-info/WHEEL +5 -0
- unaiverse-0.1.6.dist-info/licenses/LICENSE +43 -0
- unaiverse-0.1.6.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1509 @@
|
|
|
1
|
+
"""
|
|
2
|
+
█████ █████ ██████ █████ █████ █████ █████ ██████████ ███████████ █████████ ██████████
|
|
3
|
+
░░███ ░░███ ░░██████ ░░███ ░░███ ░░███ ░░███ ░░███░░░░░█░░███░░░░░███ ███░░░░░███░░███░░░░░█
|
|
4
|
+
░███ ░███ ░███░███ ░███ ██████ ░███ ░███ ░███ ░███ █ ░ ░███ ░███ ░███ ░░░ ░███ █ ░
|
|
5
|
+
░███ ░███ ░███░░███░███ ░░░░░███ ░███ ░███ ░███ ░██████ ░██████████ ░░█████████ ░██████
|
|
6
|
+
░███ ░███ ░███ ░░██████ ███████ ░███ ░░███ ███ ░███░░█ ░███░░░░░███ ░░░░░░░░███ ░███░░█
|
|
7
|
+
░███ ░███ ░███ ░░█████ ███░░███ ░███ ░░░█████░ ░███ ░ █ ░███ ░███ ███ ░███ ░███ ░ █
|
|
8
|
+
░░████████ █████ ░░█████░░████████ █████ ░░███ ██████████ █████ █████░░█████████ ██████████
|
|
9
|
+
░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░ ░░░░░ ░░░ ░░░░░░░░░░ ░░░░░ ░░░░░ ░░░░░░░░░ ░░░░░░░░░░
|
|
10
|
+
A Collectionless AI Project (https://collectionless.ai)
|
|
11
|
+
Registration/Login: https://unaiverse.io
|
|
12
|
+
Code Repositories: https://github.com/collectionlessai/
|
|
13
|
+
Main Developers: Stefano Melacci (Project Leader), Christian Di Maio, Tommaso Guidi
|
|
14
|
+
"""
|
|
15
|
+
import os
|
|
16
|
+
import torch
|
|
17
|
+
import shutil
|
|
18
|
+
import numpy as np
|
|
19
|
+
import torchvision
|
|
20
|
+
from PIL import Image
|
|
21
|
+
import urllib.request
|
|
22
|
+
from typing import Callable
|
|
23
|
+
import torch.nn.functional as F
|
|
24
|
+
from unaiverse.dataprops import Data4Proc
|
|
25
|
+
from unaiverse.modules.cnu.cnus import CNUs
|
|
26
|
+
from unaiverse.modules.cnu.layers import LinearCNU
|
|
27
|
+
from transformers import pipeline, AutoProcessor, AutoModelForCausalLM, AutoTokenizer
|
|
28
|
+
from unaiverse.modules.utils import get_proc_inputs_and_proc_outputs_for_image_classification
|
|
29
|
+
from unaiverse.modules.utils import ModuleWrapper, transforms_factory, get_proc_inputs_and_proc_outputs_for_rnn
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class RNNTokenLM(ModuleWrapper):
|
|
33
|
+
|
|
34
|
+
def __init__(self, num_emb: int, emb_dim: int, y_dim: int, h_dim: int, batch_size: int = 1, seed: int = -1):
|
|
35
|
+
super(RNNTokenLM, self).__init__(seed=seed)
|
|
36
|
+
device = self.device
|
|
37
|
+
u_dim = emb_dim
|
|
38
|
+
self.embeddings = torch.nn.Embedding(num_emb, emb_dim)
|
|
39
|
+
|
|
40
|
+
self.proc_inputs = [
|
|
41
|
+
Data4Proc(data_type="tensor", tensor_shape=(u_dim, ), tensor_dtype=torch.float32,
|
|
42
|
+
pubsub=False, private_only=True)
|
|
43
|
+
]
|
|
44
|
+
self.proc_outputs = [
|
|
45
|
+
Data4Proc(data_type="tensor", tensor_shape=(y_dim, ), tensor_dtype=torch.float32,
|
|
46
|
+
pubsub=False, private_only=True)
|
|
47
|
+
]
|
|
48
|
+
|
|
49
|
+
self.A = torch.nn.Linear(h_dim, h_dim, bias=False, device=device)
|
|
50
|
+
self.B = torch.nn.Linear(u_dim, h_dim, bias=False, device=device)
|
|
51
|
+
self.C = torch.nn.Linear(h_dim, y_dim, bias=False, device=device)
|
|
52
|
+
self.h_init = torch.randn((batch_size, h_dim), device=device)
|
|
53
|
+
self.u_init = torch.zeros((batch_size, u_dim), device=device)
|
|
54
|
+
self.h = None
|
|
55
|
+
self.y = None
|
|
56
|
+
|
|
57
|
+
def forward(self, u: torch.Tensor | None = None, first: bool = True, last: bool = False):
|
|
58
|
+
if first:
|
|
59
|
+
h = self.h_init
|
|
60
|
+
u = self.u_init if u is None else u
|
|
61
|
+
else:
|
|
62
|
+
h = self.h.detach()
|
|
63
|
+
u = self.embeddings((torch.argmax(self.y.detach(), dim=1) if self.y.shape[1] > 1
|
|
64
|
+
else self.y.squeeze(1).detach()).to(self.device))
|
|
65
|
+
|
|
66
|
+
self.h = torch.tanh(self.A(h) + self.B(u))
|
|
67
|
+
self.y = self.C(self.h)
|
|
68
|
+
return self.y
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class RNN(ModuleWrapper):
|
|
72
|
+
|
|
73
|
+
def __init__(self, u_shape: tuple[int], d_dim: int, y_dim: int, h_dim: int, batch_size: int = 1, seed: int = -1):
|
|
74
|
+
super(RNN, self).__init__(seed=seed)
|
|
75
|
+
device = self.device
|
|
76
|
+
u_shape = torch.Size(u_shape)
|
|
77
|
+
u_dim = u_shape.numel()
|
|
78
|
+
du_dim = d_dim
|
|
79
|
+
self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_rnn(u_shape, du_dim, y_dim)
|
|
80
|
+
|
|
81
|
+
self.A = torch.nn.Linear(h_dim, h_dim, bias=False, device=device)
|
|
82
|
+
self.B = torch.nn.Linear(u_dim + du_dim, h_dim, bias=False, device=device)
|
|
83
|
+
self.C = torch.nn.Linear(h_dim, y_dim, bias=False, device=device)
|
|
84
|
+
self.register_buffer('h_init', torch.randn((batch_size, h_dim), device=device))
|
|
85
|
+
self.h = None
|
|
86
|
+
self.u_dim = u_dim
|
|
87
|
+
self.du_dim = du_dim
|
|
88
|
+
|
|
89
|
+
def forward(self, u: torch.Tensor, du: torch.Tensor, first: bool = True, last: bool = False):
|
|
90
|
+
if first:
|
|
91
|
+
h = self.h_init.data
|
|
92
|
+
else:
|
|
93
|
+
h = self.h.detach()
|
|
94
|
+
if u is None:
|
|
95
|
+
u = torch.zeros((h.shape[0], self.u_dim), dtype=torch.float32, device=self.device)
|
|
96
|
+
else:
|
|
97
|
+
u = u.to(self.device)
|
|
98
|
+
if du is None:
|
|
99
|
+
du = torch.zeros((h.shape[0], self.du_dim), dtype=torch.float32, device=self.device)
|
|
100
|
+
else:
|
|
101
|
+
du = du.to(self.device)
|
|
102
|
+
|
|
103
|
+
self.h = torch.tanh(self.A(h) + self.B(torch.cat([du, u], dim=1)))
|
|
104
|
+
y = self.C(self.h)
|
|
105
|
+
return y
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class CSSM(ModuleWrapper):
|
|
109
|
+
|
|
110
|
+
def __init__(self, u_shape: tuple[int], d_dim: int, y_dim: int, h_dim: int, sigma: Callable = F.tanh,
|
|
111
|
+
project_every: int = 0, local: bool = False, batch_size: int = 1, seed: int = -1):
|
|
112
|
+
super(CSSM, self).__init__(seed=seed)
|
|
113
|
+
device = self.device
|
|
114
|
+
u_shape = torch.Size(u_shape)
|
|
115
|
+
u_dim = u_shape.numel()
|
|
116
|
+
du_dim = d_dim
|
|
117
|
+
self.batch_size = batch_size
|
|
118
|
+
self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_rnn(u_shape, du_dim, y_dim)
|
|
119
|
+
|
|
120
|
+
# Define linear transformation matrices for state update and output mapping
|
|
121
|
+
self.A = torch.nn.Linear(h_dim, h_dim, bias=False, device=device) # Recurrent weight matrix
|
|
122
|
+
self.B = torch.nn.Linear(u_dim + du_dim, h_dim, bias=False, device=device) # Input-to-hidden mapping
|
|
123
|
+
self.C = torch.nn.Linear(h_dim, y_dim, bias=False, device=device) # Hidden-to-output mapping
|
|
124
|
+
|
|
125
|
+
# Hidden state initialization
|
|
126
|
+
self.register_buffer('h_init', torch.randn((batch_size, h_dim), device=device))
|
|
127
|
+
self.register_buffer('h_next', torch.randn((batch_size, h_dim), device=device))
|
|
128
|
+
self.h = None
|
|
129
|
+
self.dh = None
|
|
130
|
+
self.sigma = sigma # The non-linear activation function
|
|
131
|
+
|
|
132
|
+
# Store input dimensions and device
|
|
133
|
+
self.u_dim = u_dim
|
|
134
|
+
self.du_dim = du_dim
|
|
135
|
+
self.delta = 1. # Discrete time step
|
|
136
|
+
self.local = local # If True the state update is computed locally in time (i.e., kept out from the graph)
|
|
137
|
+
self.forward_count = 0
|
|
138
|
+
self.project_every = project_every
|
|
139
|
+
|
|
140
|
+
@torch.no_grad()
|
|
141
|
+
def adjust_eigs(self):
|
|
142
|
+
"""Placeholder for eigenvalue adjustment method."""
|
|
143
|
+
pass
|
|
144
|
+
|
|
145
|
+
def init_h(self, udu: torch.Tensor) -> torch.Tensor:
|
|
146
|
+
return self.h_init.data
|
|
147
|
+
|
|
148
|
+
@staticmethod
|
|
149
|
+
def handle_inputs(du, u):
|
|
150
|
+
return du, u
|
|
151
|
+
|
|
152
|
+
def forward(self, u: torch.Tensor, du: torch.Tensor, first: bool = True, last: bool = False):
|
|
153
|
+
"""Forward pass that updates the hidden state and computes the output."""
|
|
154
|
+
|
|
155
|
+
# Handle missing inputs
|
|
156
|
+
u = u.flatten(1).to(self.device) if u is not None else (
|
|
157
|
+
torch.zeros((self.batch_size, self.u_dim), device=self.device))
|
|
158
|
+
du = du.to(self.device) if du is not None else (
|
|
159
|
+
torch.zeros((self.batch_size, self.du_dim), device=self.device))
|
|
160
|
+
|
|
161
|
+
# Reset hidden state if first step
|
|
162
|
+
if first:
|
|
163
|
+
h = self.init_h(torch.cat([du, u], dim=1))
|
|
164
|
+
self.forward_count = 0
|
|
165
|
+
else:
|
|
166
|
+
h = self.h_next.data
|
|
167
|
+
|
|
168
|
+
# Track the gradients on h from here on
|
|
169
|
+
h.requires_grad_()
|
|
170
|
+
|
|
171
|
+
# Check if it's time to project the eigenvalues
|
|
172
|
+
if self.project_every:
|
|
173
|
+
if self.forward_count % self.project_every == 0:
|
|
174
|
+
self.adjust_eigs()
|
|
175
|
+
|
|
176
|
+
# Handle inputs
|
|
177
|
+
du, u = self.handle_inputs(du, u)
|
|
178
|
+
|
|
179
|
+
# Update hidden state based on input and previous hidden state
|
|
180
|
+
h_new = self.A(h) + self.B(torch.cat([du, u], dim=1))
|
|
181
|
+
|
|
182
|
+
if self.local:
|
|
183
|
+
|
|
184
|
+
# In the local version we keep track in self.h of the old value of the state
|
|
185
|
+
self.h = h
|
|
186
|
+
self.dh = (h_new - self.h) / self.delta # (h_new - h_old) / delta
|
|
187
|
+
else:
|
|
188
|
+
|
|
189
|
+
# In the non-local version we keep track in self.h of the new value of the state
|
|
190
|
+
self.h = h_new
|
|
191
|
+
self.dh = (self.h - h) / self.delta # (h_new - h_old) / delta
|
|
192
|
+
|
|
193
|
+
# Compute output using a nonlinear activation function
|
|
194
|
+
y = self.C(self.sigma(self.h))
|
|
195
|
+
|
|
196
|
+
# Store the new state for the next iteration
|
|
197
|
+
self.h_next.data = h_new.detach()
|
|
198
|
+
self.forward_count += 1
|
|
199
|
+
|
|
200
|
+
return y
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class CDiagR(ModuleWrapper):
|
|
204
|
+
"""Diagonal matrix-based generator with real-valued transformations."""
|
|
205
|
+
def __init__(self, u_shape: tuple[int], d_dim: int, y_dim: int, h_dim: int, sigma: Callable = lambda x: x,
|
|
206
|
+
project_every: int = 0, local: bool = False, batch_size: int = 1, seed: int = -1):
|
|
207
|
+
super(CDiagR, self).__init__(seed=seed)
|
|
208
|
+
device = self.device
|
|
209
|
+
u_shape = torch.Size(u_shape)
|
|
210
|
+
u_dim = u_shape.numel()
|
|
211
|
+
du_dim = d_dim
|
|
212
|
+
self.batch_size = batch_size
|
|
213
|
+
self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_rnn(u_shape, du_dim, y_dim)
|
|
214
|
+
|
|
215
|
+
# Define diagonal transformation and linear layers
|
|
216
|
+
self.diag = torch.nn.Linear(in_features=1, out_features=h_dim, bias=False, device=device, dtype=torch.float32)
|
|
217
|
+
self.B = torch.nn.Linear(u_dim + du_dim, h_dim, bias=False, device=device)
|
|
218
|
+
self.C = torch.nn.Linear(h_dim, y_dim, bias=False, device=device)
|
|
219
|
+
|
|
220
|
+
# Hidden state initialization
|
|
221
|
+
self.register_buffer('h_init', torch.randn((batch_size, h_dim), device=device))
|
|
222
|
+
self.register_buffer('h_next', torch.randn((batch_size, h_dim), device=device))
|
|
223
|
+
self.h = None
|
|
224
|
+
self.dh = None
|
|
225
|
+
self.sigma = sigma # The non-linear activation function
|
|
226
|
+
|
|
227
|
+
# Store input dimensions and device
|
|
228
|
+
self.u_dim = u_dim
|
|
229
|
+
self.du_dim = du_dim
|
|
230
|
+
self.delta = 1.
|
|
231
|
+
self.local = local # If True the state update is computed locally in time (i.e., kept out from the graph)
|
|
232
|
+
self.forward_count = 0
|
|
233
|
+
self.project_every = project_every
|
|
234
|
+
|
|
235
|
+
@torch.no_grad()
|
|
236
|
+
def adjust_eigs(self):
|
|
237
|
+
"""Normalize the diagonal weight matrix by setting signs."""
|
|
238
|
+
self.diag.weight.copy_(torch.sign(self.diag.weight))
|
|
239
|
+
|
|
240
|
+
def init_h(self, udu: torch.Tensor) -> torch.Tensor:
|
|
241
|
+
return self.h_init.data
|
|
242
|
+
|
|
243
|
+
@staticmethod
|
|
244
|
+
def handle_inputs(du, u):
|
|
245
|
+
return du, u
|
|
246
|
+
|
|
247
|
+
def forward(self, u: torch.Tensor, du: torch.Tensor, first: bool = True, last: bool = False):
|
|
248
|
+
"""Forward pass with diagonal transformation."""
|
|
249
|
+
|
|
250
|
+
# Handle missing inputs
|
|
251
|
+
u = u.flatten(1).to(self.device) if u is not None else (
|
|
252
|
+
torch.zeros((self.batch_size, self.u_dim), device=self.device))
|
|
253
|
+
du = du.to(self.device) if du is not None else (
|
|
254
|
+
torch.zeros((self.batch_size, self.du_dim), device=self.device))
|
|
255
|
+
|
|
256
|
+
# Reset hidden state if first step
|
|
257
|
+
if first:
|
|
258
|
+
h = self.init_h(torch.cat([du, u], dim=1))
|
|
259
|
+
self.forward_count = 0
|
|
260
|
+
else:
|
|
261
|
+
h = self.h_next.data
|
|
262
|
+
|
|
263
|
+
# Track the gradients on h from here on
|
|
264
|
+
h.requires_grad_()
|
|
265
|
+
|
|
266
|
+
# Check if it's time to project the eigenvalues
|
|
267
|
+
if self.project_every:
|
|
268
|
+
if self.forward_count % self.project_every == 0:
|
|
269
|
+
self.adjust_eigs()
|
|
270
|
+
|
|
271
|
+
# Handle inputs
|
|
272
|
+
du, u = self.handle_inputs(du, u)
|
|
273
|
+
|
|
274
|
+
# Apply diagonal transformation to hidden state
|
|
275
|
+
h_new = self.diag.weight.view(self.diag.out_features) * h + self.B(torch.cat([du, u], dim=1))
|
|
276
|
+
|
|
277
|
+
if self.local:
|
|
278
|
+
|
|
279
|
+
# In the local version we keep track in self.h of the old value of the state
|
|
280
|
+
self.h = h
|
|
281
|
+
self.dh = (h_new - self.h) / self.delta # (h_new - h_old) / delta
|
|
282
|
+
else:
|
|
283
|
+
|
|
284
|
+
# In the non-local version we keep track in self.h of the new value of the state
|
|
285
|
+
self.h = h_new
|
|
286
|
+
self.dh = (self.h - h) / self.delta # (h_new - h_old) / delta
|
|
287
|
+
|
|
288
|
+
# Compute output using a nonlinear activation function
|
|
289
|
+
y = self.C(self.sigma(self.h))
|
|
290
|
+
|
|
291
|
+
# Store the new state for the next iteration
|
|
292
|
+
self.h_next.data = h_new.detach()
|
|
293
|
+
self.forward_count += 1
|
|
294
|
+
|
|
295
|
+
return y
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
class CDiagC(ModuleWrapper):
|
|
299
|
+
"""Diagonal matrix-based generator with complex-valued transformations."""
|
|
300
|
+
def __init__(self, u_shape: tuple[int], d_dim: int, y_dim: int, h_dim: int, sigma: Callable = lambda x: x,
|
|
301
|
+
project_every: int = 0, local: bool = False, batch_size: int = 1, seed: int = -1):
|
|
302
|
+
super(CDiagC, self).__init__(seed=seed)
|
|
303
|
+
device = self.device
|
|
304
|
+
u_shape = torch.Size(u_shape)
|
|
305
|
+
u_dim = u_shape.numel()
|
|
306
|
+
du_dim = d_dim
|
|
307
|
+
self.batch_size = batch_size
|
|
308
|
+
self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_rnn(u_shape, du_dim, y_dim)
|
|
309
|
+
|
|
310
|
+
# Define diagonal transformation with complex numbers
|
|
311
|
+
self.diag = torch.nn.Linear(in_features=1, out_features=h_dim, bias=False, device=device, dtype=torch.cfloat)
|
|
312
|
+
self.B = torch.nn.Linear(u_dim + du_dim, h_dim, bias=False, device=device, dtype=torch.cfloat)
|
|
313
|
+
self.C = torch.nn.Linear(h_dim, y_dim, bias=False, device=device, dtype=torch.cfloat)
|
|
314
|
+
|
|
315
|
+
# Hidden state initialization
|
|
316
|
+
self.register_buffer('h_init', torch.randn((batch_size, h_dim), device=device))
|
|
317
|
+
self.register_buffer('h_next', torch.randn((batch_size, h_dim), device=device))
|
|
318
|
+
self.h = None
|
|
319
|
+
self.dh = None
|
|
320
|
+
self.sigma = sigma # The non-linear activation function
|
|
321
|
+
|
|
322
|
+
# Store input dimensions and device
|
|
323
|
+
self.u_dim = u_dim
|
|
324
|
+
self.du_dim = du_dim
|
|
325
|
+
self.delta = 1.
|
|
326
|
+
self.local = local # If True the state update is computed locally in time (i.e., kept out from the graph)
|
|
327
|
+
self.forward_count = 0
|
|
328
|
+
self.project_every = project_every
|
|
329
|
+
|
|
330
|
+
@torch.no_grad()
|
|
331
|
+
def adjust_eigs(self):
|
|
332
|
+
""" Normalize the diagonal weight matrix by dividing by its magnitude. """
|
|
333
|
+
self.diag.weight.div_(self.diag.weight.abs())
|
|
334
|
+
|
|
335
|
+
def init_h(self, udu: torch.Tensor) -> torch.Tensor:
|
|
336
|
+
return self.h_init.data
|
|
337
|
+
|
|
338
|
+
@staticmethod
|
|
339
|
+
def handle_inputs(du, u):
|
|
340
|
+
return du, u
|
|
341
|
+
|
|
342
|
+
def forward(self, u: torch.Tensor, du: torch.Tensor, first: bool = True, last: bool = False):
|
|
343
|
+
"""Forward pass with complex-valued transformation."""
|
|
344
|
+
|
|
345
|
+
# Handle missing inputs
|
|
346
|
+
u = u.flatten(1).to(self.device) if u is not None else torch.zeros((self.batch_size, self.u_dim),
|
|
347
|
+
device=self.device, dtype=torch.cfloat)
|
|
348
|
+
du = du.to(self.device) if du is not None else torch.zeros((self.batch_size, self.du_dim),
|
|
349
|
+
device=self.device, dtype=torch.cfloat)
|
|
350
|
+
|
|
351
|
+
# Reset hidden state if first step
|
|
352
|
+
if first:
|
|
353
|
+
h = self.init_h(torch.cat([du, u], dim=1))
|
|
354
|
+
self.forward_count = 0
|
|
355
|
+
else:
|
|
356
|
+
h = self.h_next.data
|
|
357
|
+
|
|
358
|
+
# Track the gradients on h from here on
|
|
359
|
+
h.requires_grad_()
|
|
360
|
+
|
|
361
|
+
# Check if it's time to project the eigenvalues
|
|
362
|
+
if self.project_every:
|
|
363
|
+
if self.forward_count % self.project_every == 0:
|
|
364
|
+
self.adjust_eigs()
|
|
365
|
+
|
|
366
|
+
# Handle inputs
|
|
367
|
+
du, u = self.handle_inputs(du, u)
|
|
368
|
+
|
|
369
|
+
# Apply complex diagonal transformation
|
|
370
|
+
h_new = self.diag.weight.view(self.diag.out_features) * h + self.B(torch.cat([du, u], dim=1))
|
|
371
|
+
|
|
372
|
+
if self.local:
|
|
373
|
+
|
|
374
|
+
# In the local version we keep track in self.h of the old value of the state
|
|
375
|
+
self.h = h
|
|
376
|
+
self.dh = (h_new - self.h) / self.delta # (h_new - h_old) / delta
|
|
377
|
+
else:
|
|
378
|
+
|
|
379
|
+
# In the non-local version we keep track in self.h of the new value of the state
|
|
380
|
+
self.h = h_new
|
|
381
|
+
self.dh = (self.h - h) / self.delta # (h_new - h_old) / delta
|
|
382
|
+
|
|
383
|
+
# Compute output using a nonlinear activation function
|
|
384
|
+
y = self.C(self.sigma(self.h))
|
|
385
|
+
|
|
386
|
+
# Store the new state for the next iteration
|
|
387
|
+
self.h_next.data = h_new.detach()
|
|
388
|
+
self.forward_count += 1
|
|
389
|
+
|
|
390
|
+
return y.real
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
class CTE(ModuleWrapper):
|
|
394
|
+
"""Antisymmetric Matrix Exponential Generator implementing continuous-time dynamics.
|
|
395
|
+
|
|
396
|
+
Uses antisymmetric weight matrix with matrix exponential for stable hidden state evolution.
|
|
397
|
+
|
|
398
|
+
Args:
|
|
399
|
+
u_shape: Input shape (tuple of integers)
|
|
400
|
+
d_dim: Input descriptor dimension
|
|
401
|
+
y_dim: Output dimension
|
|
402
|
+
h_dim: Hidden state dimension
|
|
403
|
+
delta: Time step for discrete approximation
|
|
404
|
+
local: Local computations (bool)
|
|
405
|
+
seed: Random seed (positive int)
|
|
406
|
+
"""
|
|
407
|
+
|
|
408
|
+
def __init__(self, u_shape: tuple[int], d_dim: int, y_dim: int, h_dim: int, delta: float,
|
|
409
|
+
sigma: Callable = lambda x: x, project_every: int = 0, local: bool = False,
|
|
410
|
+
cnu_memories: int = 0, batch_size: int = 1, seed: int = -1):
|
|
411
|
+
super(CTE, self).__init__(seed=seed)
|
|
412
|
+
device = self.device
|
|
413
|
+
u_shape = torch.Size(u_shape)
|
|
414
|
+
u_dim = u_shape.numel()
|
|
415
|
+
du_dim = d_dim
|
|
416
|
+
self.batch_size = batch_size
|
|
417
|
+
self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_rnn(u_shape, du_dim, y_dim)
|
|
418
|
+
|
|
419
|
+
# Antisymmetric weight matrix (W - W^T)
|
|
420
|
+
self.W = torch.nn.Linear(h_dim, h_dim, bias=False, device=device)
|
|
421
|
+
self.Id = torch.eye(h_dim, device=device) # Identity matrix
|
|
422
|
+
|
|
423
|
+
# Input projection matrix
|
|
424
|
+
self.B = torch.nn.Linear(u_dim + du_dim, h_dim, bias=False, device=device)
|
|
425
|
+
|
|
426
|
+
# Output projection matrix
|
|
427
|
+
if cnu_memories <= 0:
|
|
428
|
+
self.C = torch.nn.Linear(h_dim, y_dim, bias=False, device=device)
|
|
429
|
+
else:
|
|
430
|
+
self.C = LinearCNU(h_dim, y_dim, bias=False, device=device, key_size=u_dim + du_dim,
|
|
431
|
+
delta=1, beta_k=delta, scramble=False, key_mem_units=cnu_memories, shared_keys=True)
|
|
432
|
+
|
|
433
|
+
# Hidden state initialization
|
|
434
|
+
self.register_buffer('h_init', torch.randn((batch_size, h_dim), device=device))
|
|
435
|
+
self.register_buffer('h_next', torch.randn((batch_size, h_dim), device=device))
|
|
436
|
+
self.h = None
|
|
437
|
+
self.dh = None
|
|
438
|
+
self.sigma = sigma # The non-linear activation function
|
|
439
|
+
|
|
440
|
+
# System parameters
|
|
441
|
+
self.u_dim = u_dim
|
|
442
|
+
self.du_dim = du_dim
|
|
443
|
+
self.delta = delta
|
|
444
|
+
self.local = local
|
|
445
|
+
self.forward_count = 0
|
|
446
|
+
self.project_every = project_every
|
|
447
|
+
|
|
448
|
+
@torch.no_grad()
|
|
449
|
+
def adjust_eigs(self):
|
|
450
|
+
"""Placeholder for eigenvalue adjustment method"""
|
|
451
|
+
pass
|
|
452
|
+
|
|
453
|
+
def init_h(self, udu: torch.Tensor) -> torch.Tensor:
|
|
454
|
+
return self.h_init.data
|
|
455
|
+
|
|
456
|
+
@staticmethod
|
|
457
|
+
def handle_inputs(du, u):
|
|
458
|
+
return du, u
|
|
459
|
+
|
|
460
|
+
def forward(self, u: torch.Tensor, du: torch.Tensor, first: bool = True, last: bool = False) -> torch.Tensor:
|
|
461
|
+
"""Forward pass through the system dynamics.
|
|
462
|
+
|
|
463
|
+
Args:
|
|
464
|
+
u: Input tensor of shape (batch_size, u_dim)
|
|
465
|
+
du: Input descriptor tensor of shape (batch_size, du_dim)
|
|
466
|
+
first: Flag indicating first step (resets hidden state)
|
|
467
|
+
last: Flag indicating last step (does nothing)
|
|
468
|
+
|
|
469
|
+
Returns:
|
|
470
|
+
y: Output tensor of shape (batch_size, y_dim)
|
|
471
|
+
"""
|
|
472
|
+
|
|
473
|
+
# Handle missing inputs
|
|
474
|
+
u = u.flatten(1).to(self.device) if u is not None else (
|
|
475
|
+
torch.zeros((self.batch_size, self.u_dim), device=self.device))
|
|
476
|
+
du = du.to(self.device) if du is not None else (
|
|
477
|
+
torch.zeros((self.batch_size, self.du_dim), device=self.device))
|
|
478
|
+
|
|
479
|
+
# Reset hidden state if first step
|
|
480
|
+
if first:
|
|
481
|
+
h = self.init_h(torch.cat([du, u], dim=1))
|
|
482
|
+
self.forward_count = 0
|
|
483
|
+
else:
|
|
484
|
+
h = self.h_next.data
|
|
485
|
+
|
|
486
|
+
# Track the gradients on h from here on
|
|
487
|
+
h.requires_grad_()
|
|
488
|
+
|
|
489
|
+
# Check if it's time to project the eigenvalues
|
|
490
|
+
if self.project_every:
|
|
491
|
+
if self.forward_count % self.project_every == 0:
|
|
492
|
+
self.adjust_eigs()
|
|
493
|
+
|
|
494
|
+
if not isinstance(self.C, LinearCNU):
|
|
495
|
+
C = self.C
|
|
496
|
+
else:
|
|
497
|
+
udu = torch.cat([du, u], dim=1)
|
|
498
|
+
weight_C = self.C.compute_weights(udu).view(self.C.out_features, self.C.in_features)
|
|
499
|
+
|
|
500
|
+
def C(x):
|
|
501
|
+
return torch.nn.functional.linear(x, weight_C)
|
|
502
|
+
|
|
503
|
+
# Handle inputs
|
|
504
|
+
du, u = self.handle_inputs(du, u)
|
|
505
|
+
|
|
506
|
+
# Antisymmetric matrix construction
|
|
507
|
+
A = 0.5 * (self.W.weight - self.W.weight.t())
|
|
508
|
+
A_expm = torch.linalg.matrix_exp(A * self.delta) # Matrix exponential
|
|
509
|
+
rec = F.linear(h, A_expm, self.W.bias) # Recurrent component
|
|
510
|
+
|
|
511
|
+
# Input processing component
|
|
512
|
+
A_inv = torch.linalg.inv(A)
|
|
513
|
+
inp = A_inv @ (A_expm - self.Id) @ self.B(torch.cat([du, u], dim=1)).unsqueeze(-1)
|
|
514
|
+
|
|
515
|
+
# Handle locality
|
|
516
|
+
h_new = rec + inp.squeeze(-1) # Updated hidden state
|
|
517
|
+
if self.local:
|
|
518
|
+
|
|
519
|
+
# In the local version we keep track in self.h of the old value of the state
|
|
520
|
+
self.h = h
|
|
521
|
+
self.dh = (h_new - self.h) / self.delta # (h_new - h_old) / delta
|
|
522
|
+
else:
|
|
523
|
+
|
|
524
|
+
# In the non-local version we keep track in self.h of the new value of the state
|
|
525
|
+
self.h = h_new
|
|
526
|
+
self.dh = (self.h - h) / self.delta # (h_new - h_old) / delta
|
|
527
|
+
|
|
528
|
+
# Compute output using a nonlinear activation function
|
|
529
|
+
y = C(self.sigma(self.h))
|
|
530
|
+
|
|
531
|
+
# Store the new state for the next iteration
|
|
532
|
+
self.h_next.data = h_new
|
|
533
|
+
self.forward_count += 1
|
|
534
|
+
|
|
535
|
+
return y
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
class CTEInitStateBZeroInput(CTE):
|
|
539
|
+
|
|
540
|
+
def __init__(self, u_shape: tuple[int], d_dim: int, y_dim: int, h_dim: int, delta: float,
|
|
541
|
+
sigma: Callable = lambda x: x, project_every: int = 0, local: bool = False,
|
|
542
|
+
cnu_memories: int = 0, batch_size: int = 1, seed: int = -1):
|
|
543
|
+
super(CTEInitStateBZeroInput, self).__init__(u_shape, d_dim, y_dim, h_dim, delta, sigma, project_every,
|
|
544
|
+
local, cnu_memories, batch_size, seed)
|
|
545
|
+
|
|
546
|
+
@torch.no_grad()
|
|
547
|
+
def init_h(self, udu: torch.Tensor) -> torch.Tensor:
|
|
548
|
+
return self.B(udu).detach() / torch.sum(udu, dim=1)
|
|
549
|
+
|
|
550
|
+
@staticmethod
|
|
551
|
+
def handle_inputs(du, u):
|
|
552
|
+
return torch.zeros_like(du), torch.zeros_like(u)
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
class CTEToken(CTE):
|
|
556
|
+
|
|
557
|
+
def __init__(self, num_emb: int, emb_dim: int, d_dim: int, y_dim: int, h_dim: int, seed: int = -1):
|
|
558
|
+
super(CTEToken, self).__init__((emb_dim,), d_dim, y_dim, h_dim, delta=1.0, local=False, seed=seed)
|
|
559
|
+
self.embeddings = torch.nn.Embedding(num_emb, emb_dim)
|
|
560
|
+
|
|
561
|
+
def forward(self, u: torch.Tensor, du: torch.Tensor, first: bool = True, last: bool = False):
|
|
562
|
+
if u is not None:
|
|
563
|
+
u = self.embeddings(u.to(self.device))
|
|
564
|
+
y = super().forward(u, du, first=first, last=last)
|
|
565
|
+
return y
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
class CTB(ModuleWrapper):
|
|
569
|
+
"""Block Antisymmetric Generator using 2x2 parameterized rotation blocks.
|
|
570
|
+
|
|
571
|
+
Implements structured antisymmetric dynamics through learnable rotational frequencies.
|
|
572
|
+
|
|
573
|
+
Args:
|
|
574
|
+
u_shape: Input shape (tuple of integers)
|
|
575
|
+
d_dim: Input descriptor dimension
|
|
576
|
+
y_dim: Output dimension
|
|
577
|
+
h_dim: Hidden state dimension
|
|
578
|
+
delta: Time step for discrete approximation
|
|
579
|
+
alpha: Dissipation added on the diagonal (also controls the eigenvalue projections method)
|
|
580
|
+
"""
|
|
581
|
+
|
|
582
|
+
def __init__(self, u_shape: tuple[int], d_dim: int, y_dim: int, h_dim: int, delta: float = None,
|
|
583
|
+
alpha: float = 0., sigma: Callable = lambda x: x, project_every: int = 0, local: bool = False,
|
|
584
|
+
batch_size: int = 1, seed: int = -1):
|
|
585
|
+
super(CTB, self).__init__(seed=seed)
|
|
586
|
+
device = self.device
|
|
587
|
+
u_shape = torch.Size(u_shape)
|
|
588
|
+
u_dim = u_shape.numel()
|
|
589
|
+
du_dim = d_dim
|
|
590
|
+
self.batch_size = batch_size
|
|
591
|
+
self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_rnn(u_shape, du_dim, y_dim)
|
|
592
|
+
|
|
593
|
+
assert h_dim % 2 == 0, "Hidden dimension must be even for 2x2 blocks"
|
|
594
|
+
self.order = h_dim // 2 # Number of 2x2 blocks
|
|
595
|
+
|
|
596
|
+
# Learnable rotational frequencies
|
|
597
|
+
self.omega = torch.nn.Parameter(torch.empty(self.order, device=device))
|
|
598
|
+
self.register_buffer('ones', torch.ones(self.order, requires_grad=False, device=device))
|
|
599
|
+
|
|
600
|
+
# Projection matrices
|
|
601
|
+
self.B = torch.nn.Linear(u_dim + du_dim, h_dim, bias=False, device=device)
|
|
602
|
+
self.C = torch.nn.Linear(h_dim, y_dim, bias=False, device=device)
|
|
603
|
+
|
|
604
|
+
# Damping configuration
|
|
605
|
+
if alpha > 0.:
|
|
606
|
+
|
|
607
|
+
# In this case we want to add the feedback parameter alpha and use it to move eigenvalues on the unit circle
|
|
608
|
+
self.project_method = 'const'
|
|
609
|
+
self.register_buffer('alpha', torch.full_like(self.omega.data, alpha, device=device))
|
|
610
|
+
elif alpha == 0.:
|
|
611
|
+
|
|
612
|
+
# This is the case in which we want to divide by the modulus
|
|
613
|
+
self.project_method = 'modulus'
|
|
614
|
+
self.register_buffer('alpha', torch.zeros_like(self.omega.data, device=device))
|
|
615
|
+
elif alpha == -1.:
|
|
616
|
+
self.project_method = 'alpha'
|
|
617
|
+
self.register_buffer('alpha', torch.zeros_like(self.omega.data, device=device))
|
|
618
|
+
|
|
619
|
+
# Hidden state initialization
|
|
620
|
+
self.register_buffer('h_init', torch.randn((batch_size, h_dim), device=device))
|
|
621
|
+
self.register_buffer('h_next', torch.randn((batch_size, h_dim), device=device))
|
|
622
|
+
self.h = None
|
|
623
|
+
self.dh = None
|
|
624
|
+
self.sigma = sigma # The non-linear activation function
|
|
625
|
+
|
|
626
|
+
# System parameters
|
|
627
|
+
self.u_dim = u_dim
|
|
628
|
+
self.du_dim = du_dim
|
|
629
|
+
self.delta = delta
|
|
630
|
+
self.local = local # If True the state update is computed locally in time (i.e., kept out from the graph)
|
|
631
|
+
self.reset_parameters()
|
|
632
|
+
self.forward_count = 0
|
|
633
|
+
self.project_every = project_every
|
|
634
|
+
|
|
635
|
+
def reset_parameters(self) -> None:
|
|
636
|
+
"""Initialize rotational frequencies with uniform distribution"""
|
|
637
|
+
torch.nn.init.uniform_(self.omega)
|
|
638
|
+
|
|
639
|
+
@torch.no_grad()
|
|
640
|
+
def adjust_eigs(self):
|
|
641
|
+
"""Adjust eigenvalues to maintain stability"""
|
|
642
|
+
with torch.no_grad():
|
|
643
|
+
if self.project_method == 'alpha':
|
|
644
|
+
|
|
645
|
+
# Compute damping to maintain eigenvalues on unit circle
|
|
646
|
+
self.alpha.copy_((1. - torch.sqrt(1. - (self.delta * self.omega) ** 2) / self.delta))
|
|
647
|
+
elif self.project_method == 'modulus':
|
|
648
|
+
|
|
649
|
+
# Normalize by modulus for unit circle stability
|
|
650
|
+
module = torch.sqrt(self.ones ** 2 + (self.delta * self.omega) ** 2)
|
|
651
|
+
self.omega.div_(module)
|
|
652
|
+
self.ones.div_(module)
|
|
653
|
+
|
|
654
|
+
def init_h(self, udu: torch.Tensor) -> torch.Tensor:
|
|
655
|
+
return self.h_init.data
|
|
656
|
+
|
|
657
|
+
@staticmethod
|
|
658
|
+
def handle_inputs(du, u):
|
|
659
|
+
return du, u
|
|
660
|
+
|
|
661
|
+
def forward(self, u: torch.Tensor, du: torch.Tensor, first: bool = True, last: bool = False) -> torch.Tensor:
|
|
662
|
+
"""Forward pass through block-structured dynamics"""
|
|
663
|
+
|
|
664
|
+
# Handle missing inputs
|
|
665
|
+
u = u.flatten(1).to(self.device) if u is not None \
|
|
666
|
+
else torch.zeros((self.batch_size, self.u_dim), device=self.device)
|
|
667
|
+
du = du.to(self.device) if du is not None \
|
|
668
|
+
else torch.zeros((self.batch_size, self.du_dim), device=self.device)
|
|
669
|
+
|
|
670
|
+
# Reset hidden state if first step
|
|
671
|
+
if first:
|
|
672
|
+
h = self.init_h(torch.cat([du, u], dim=1))
|
|
673
|
+
self.forward_count = 0
|
|
674
|
+
else:
|
|
675
|
+
h = self.h_next.data
|
|
676
|
+
|
|
677
|
+
# Track the gradients on h from here on
|
|
678
|
+
h.requires_grad_()
|
|
679
|
+
h_pair = h.view(-1, self.order, 2) # Reshape to (batch, blocks, 2)
|
|
680
|
+
|
|
681
|
+
# Check if it's time to project the eigenvalues
|
|
682
|
+
if self.project_every:
|
|
683
|
+
if self.forward_count % self.project_every == 0:
|
|
684
|
+
self.adjust_eigs()
|
|
685
|
+
|
|
686
|
+
# Handle inputs
|
|
687
|
+
du, u = self.handle_inputs(du, u)
|
|
688
|
+
|
|
689
|
+
# Block-wise rotation with damping
|
|
690
|
+
h1 = (self.ones - self.delta * self.alpha) * h_pair[..., 0] + self.delta * self.omega * h_pair[..., 1]
|
|
691
|
+
h2 = -self.delta * self.omega * h_pair[..., 0] + (self.ones - self.delta * self.alpha) * h_pair[..., 1]
|
|
692
|
+
|
|
693
|
+
# Recurrent and input components
|
|
694
|
+
rec = torch.stack([h1, h2], dim=-1).flatten(start_dim=1)
|
|
695
|
+
inp = self.delta * self.B(torch.cat([du, u], dim=1))
|
|
696
|
+
|
|
697
|
+
# Handle locality
|
|
698
|
+
h_new = rec + inp # Updated hidden state
|
|
699
|
+
if self.local:
|
|
700
|
+
|
|
701
|
+
# In the local version we keep track in self.h of the old value of the state
|
|
702
|
+
self.h = h
|
|
703
|
+
self.dh = (h_new - self.h) / self.delta # (h_new - h_old) / delta
|
|
704
|
+
else:
|
|
705
|
+
|
|
706
|
+
# In the non-local version we keep track in self.h of the new value of the state
|
|
707
|
+
self.h = h_new
|
|
708
|
+
self.dh = (self.h - h) / self.delta # (h_new - h_old) / delta
|
|
709
|
+
|
|
710
|
+
# Compute output using a nonlinear activation function
|
|
711
|
+
y = self.C(self.sigma(self.h))
|
|
712
|
+
|
|
713
|
+
# Store the new state for the next iteration
|
|
714
|
+
self.h_next.data = h_new.detach()
|
|
715
|
+
self.forward_count += 1
|
|
716
|
+
|
|
717
|
+
return y
|
|
718
|
+
|
|
719
|
+
|
|
720
|
+
class CTBE(ModuleWrapper):
|
|
721
|
+
"""Antisymmetric Generator with Exact Matrix Exponential Blocks.
|
|
722
|
+
|
|
723
|
+
Implements precise rotational dynamics using trigonometric parameterization.
|
|
724
|
+
|
|
725
|
+
Args:
|
|
726
|
+
u_shape: Input shape (tuple of integers)
|
|
727
|
+
d_dim: Input descriptor dimension
|
|
728
|
+
y_dim: Output dimension
|
|
729
|
+
h_dim: Hidden state dimension
|
|
730
|
+
delta: Time step for discrete approximation
|
|
731
|
+
"""
|
|
732
|
+
|
|
733
|
+
def __init__(self, u_shape: tuple[int], d_dim: int, y_dim: int, h_dim: int, delta: float,
|
|
734
|
+
sigma: Callable = lambda x: x, project_every: int = 0, local: bool = False,
|
|
735
|
+
cnu_memories: int = 0, batch_size: int = 1, seed: int = -1):
|
|
736
|
+
super(CTBE, self).__init__(seed=seed)
|
|
737
|
+
device = self.device
|
|
738
|
+
u_shape = torch.Size(u_shape)
|
|
739
|
+
u_dim = u_shape.numel()
|
|
740
|
+
du_dim = d_dim
|
|
741
|
+
self.batch_size = batch_size
|
|
742
|
+
self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_rnn(u_shape, du_dim, y_dim)
|
|
743
|
+
|
|
744
|
+
assert h_dim % 2 == 0, "Hidden dimension must be even for 2x2 blocks"
|
|
745
|
+
self.order = h_dim // 2
|
|
746
|
+
|
|
747
|
+
# Learnable rotational frequencies
|
|
748
|
+
self.omega = torch.nn.Parameter(torch.empty(self.order, device=device))
|
|
749
|
+
self.B = torch.nn.Linear(u_dim + du_dim, h_dim, bias=False, device=device)
|
|
750
|
+
if cnu_memories <= 0:
|
|
751
|
+
self.C = torch.nn.Linear(h_dim, y_dim, bias=False, device=device)
|
|
752
|
+
else:
|
|
753
|
+
self.C = LinearCNU(h_dim, y_dim, bias=False, device=device, key_size=u_dim + du_dim,
|
|
754
|
+
delta=1, beta_k=delta, scramble=False, key_mem_units=cnu_memories, shared_keys=True)
|
|
755
|
+
|
|
756
|
+
# Hidden state initialization
|
|
757
|
+
self.register_buffer('h_init', torch.randn((batch_size, h_dim), device=device))
|
|
758
|
+
self.register_buffer('h_next', torch.randn((batch_size, h_dim), device=device))
|
|
759
|
+
self.h = None
|
|
760
|
+
self.dh = None
|
|
761
|
+
self.sigma = sigma # The non-linear activation function
|
|
762
|
+
|
|
763
|
+
# System parameters
|
|
764
|
+
self.u_dim = u_dim
|
|
765
|
+
self.du_dim = du_dim
|
|
766
|
+
self.delta = delta
|
|
767
|
+
self.local = local # If True the state update is computed locally in time (i.e., kept out from the graph)
|
|
768
|
+
self.reset_parameters()
|
|
769
|
+
self.forward_count = 0
|
|
770
|
+
self.project_every = project_every
|
|
771
|
+
|
|
772
|
+
def reset_parameters(self) -> None:
|
|
773
|
+
"""Initialize rotational frequencies"""
|
|
774
|
+
if not isinstance(self.omega, CNUs):
|
|
775
|
+
torch.nn.init.uniform_(self.omega)
|
|
776
|
+
else:
|
|
777
|
+
torch.nn.init.uniform_(self.omega.M)
|
|
778
|
+
|
|
779
|
+
@torch.no_grad()
|
|
780
|
+
def adjust_eigs(self):
|
|
781
|
+
"""Placeholder for eigenvalue adjustment"""
|
|
782
|
+
pass
|
|
783
|
+
|
|
784
|
+
def init_h(self, udu: torch.Tensor) -> torch.Tensor:
|
|
785
|
+
return self.h_init.data
|
|
786
|
+
|
|
787
|
+
@staticmethod
|
|
788
|
+
def handle_inputs(du, u):
|
|
789
|
+
return du, u
|
|
790
|
+
|
|
791
|
+
def forward(self, u: torch.Tensor, du: torch.Tensor, first: bool = True, last: bool = False) -> torch.Tensor:
|
|
792
|
+
"""Exact matrix exponential forward pass"""
|
|
793
|
+
|
|
794
|
+
# Handle missing inputs
|
|
795
|
+
u = u.flatten(1).to(self.device) if u is not None \
|
|
796
|
+
else torch.zeros((self.batch_size, self.u_dim), device=self.device)
|
|
797
|
+
du = du.to(self.device) if du is not None \
|
|
798
|
+
else torch.zeros((self.batch_size, self.du_dim), device=self.device)
|
|
799
|
+
|
|
800
|
+
# Reset hidden state if first step
|
|
801
|
+
if first:
|
|
802
|
+
h = self.init_h(torch.cat([du, u], dim=1))
|
|
803
|
+
self.forward_count = 0
|
|
804
|
+
else:
|
|
805
|
+
h = self.h_next.data
|
|
806
|
+
|
|
807
|
+
# Track the gradients on h from here on
|
|
808
|
+
h.requires_grad_()
|
|
809
|
+
h_pair = h.view(-1, self.order, 2)
|
|
810
|
+
|
|
811
|
+
# Check if it's time to project the eigenvalues
|
|
812
|
+
if self.project_every:
|
|
813
|
+
if self.forward_count % self.project_every == 0:
|
|
814
|
+
self.adjust_eigs()
|
|
815
|
+
|
|
816
|
+
if not isinstance(self.C, LinearCNU):
|
|
817
|
+
C = self.C
|
|
818
|
+
else:
|
|
819
|
+
udu = torch.cat([du, u], dim=1)
|
|
820
|
+
weight_C = self.C.compute_weights(udu).view(self.C.out_features, self.C.in_features)
|
|
821
|
+
|
|
822
|
+
def C(x):
|
|
823
|
+
return torch.nn.functional.linear(x, weight_C)
|
|
824
|
+
|
|
825
|
+
# Handle inputs
|
|
826
|
+
du, u = self.handle_inputs(du, u)
|
|
827
|
+
udu = torch.cat([du, u], dim=1)
|
|
828
|
+
|
|
829
|
+
# Trigonometric terms for exact rotation
|
|
830
|
+
cos_t = torch.cos(self.omega * self.delta)
|
|
831
|
+
sin_t = torch.sin(self.omega * self.delta)
|
|
832
|
+
|
|
833
|
+
# Rotational update
|
|
834
|
+
h1 = cos_t * h_pair[..., 0] + sin_t * h_pair[..., 1]
|
|
835
|
+
h2 = -sin_t * h_pair[..., 0] + cos_t * h_pair[..., 1]
|
|
836
|
+
rec = torch.stack([h1, h2], dim=-1).flatten(start_dim=1)
|
|
837
|
+
|
|
838
|
+
# Input processing
|
|
839
|
+
u_hat = self.B(udu).view(-1, self.order, 2)
|
|
840
|
+
inp1 = (sin_t * u_hat[..., 0] - (cos_t - 1) * u_hat[..., 1]) / self.omega
|
|
841
|
+
inp2 = ((cos_t - 1) * u_hat[..., 0] + sin_t * u_hat[..., 1]) / self.omega
|
|
842
|
+
inp = torch.stack([inp1, inp2], dim=-1).flatten(start_dim=1)
|
|
843
|
+
|
|
844
|
+
# Handle locality
|
|
845
|
+
h_new = rec + inp # Updated hidden state
|
|
846
|
+
if self.local:
|
|
847
|
+
|
|
848
|
+
# In the local version we keep track in self.h of the old value of the state
|
|
849
|
+
self.h = h
|
|
850
|
+
self.dh = (h_new - self.h) / self.delta # (h_new - h_old) / delta
|
|
851
|
+
else:
|
|
852
|
+
|
|
853
|
+
# In the non-local version we keep track in self.h of the new value of the state
|
|
854
|
+
self.h = h_new
|
|
855
|
+
self.dh = (self.h - h) / self.delta # (h_new - h_old) / delta
|
|
856
|
+
|
|
857
|
+
# Compute output using a nonlinear activation function
|
|
858
|
+
y = C(self.sigma(self.h))
|
|
859
|
+
|
|
860
|
+
# Store the new state for the next iteration
|
|
861
|
+
self.h_next.data = h_new.detach()
|
|
862
|
+
self.forward_count += 1
|
|
863
|
+
|
|
864
|
+
return y
|
|
865
|
+
|
|
866
|
+
|
|
867
|
+
class CTBEInitStateBZeroInput(CTBE):
|
|
868
|
+
def __init__(self, u_shape, d_dim, y_dim, h_dim, delta, local, cnu_memories: int = 0,
|
|
869
|
+
batch_size: int = 1, seed: int = -1):
|
|
870
|
+
super().__init__(u_shape=u_shape, d_dim=d_dim, y_dim=y_dim, h_dim=h_dim, delta=delta, local=local,
|
|
871
|
+
cnu_memories=cnu_memories, batch_size=batch_size, seed=seed)
|
|
872
|
+
|
|
873
|
+
@torch.no_grad()
|
|
874
|
+
def init_h(self, udu: torch.Tensor) -> torch.Tensor:
|
|
875
|
+
return self.B(udu).detach() / torch.sum(udu)
|
|
876
|
+
|
|
877
|
+
@staticmethod
|
|
878
|
+
def handle_inputs(du, u):
|
|
879
|
+
return torch.zeros_like(du), torch.zeros_like(u)
|
|
880
|
+
|
|
881
|
+
|
|
882
|
+
class CNN(ModuleWrapper):
|
|
883
|
+
|
|
884
|
+
def __init__(self, d_dim: int, in_channels: int = 3, in_res: int = 32, return_input: bool = False,
|
|
885
|
+
seed: int = -1):
|
|
886
|
+
super(CNN, self).__init__(seed=seed)
|
|
887
|
+
self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_image_classification(d_dim)
|
|
888
|
+
self.return_input = return_input
|
|
889
|
+
if self.return_input:
|
|
890
|
+
self.proc_outputs.insert(0, Data4Proc(data_type="img", pubsub=False, private_only=True))
|
|
891
|
+
self.transforms = transforms_factory("rgb" + str(in_res) if in_channels == 3 else "gray" + str(in_res))
|
|
892
|
+
|
|
893
|
+
self.module = torch.nn.Sequential(
|
|
894
|
+
torch.nn.Conv2d(in_channels, 64, kernel_size=5, padding=2),
|
|
895
|
+
torch.nn.ReLU(inplace=True),
|
|
896
|
+
torch.nn.AvgPool2d(kernel_size=3, stride=2),
|
|
897
|
+
torch.nn.Conv2d(64, 128, kernel_size=5, padding=2),
|
|
898
|
+
torch.nn.ReLU(inplace=True),
|
|
899
|
+
torch.nn.AvgPool2d(kernel_size=3, stride=2),
|
|
900
|
+
torch.nn.Conv2d(128, 256, kernel_size=3, padding=1),
|
|
901
|
+
torch.nn.ReLU(inplace=True),
|
|
902
|
+
torch.nn.AvgPool2d(kernel_size=3, stride=2),
|
|
903
|
+
torch.nn.Flatten(),
|
|
904
|
+
torch.nn.LazyLinear(2048),
|
|
905
|
+
torch.nn.ReLU(inplace=True),
|
|
906
|
+
torch.nn.Linear(2048, d_dim),
|
|
907
|
+
torch.nn.Sigmoid()
|
|
908
|
+
).to(self.device)
|
|
909
|
+
|
|
910
|
+
def forward(self, y: Image.Image, first: bool = True, last: bool = False):
|
|
911
|
+
o = self.module(self.transforms(y).to(self.device))
|
|
912
|
+
if not self.return_input:
|
|
913
|
+
return o
|
|
914
|
+
else:
|
|
915
|
+
return y, o
|
|
916
|
+
|
|
917
|
+
|
|
918
|
+
class CNNCNU(ModuleWrapper):
|
|
919
|
+
|
|
920
|
+
def __init__(self, d_dim: int, cnu_memories: int, in_channels: int = 3, in_res: int = 32,
|
|
921
|
+
delta: int = 1, scramble: bool = False, return_input: bool = False, seed: int = -1):
|
|
922
|
+
super(CNNCNU, self).__init__(seed=seed)
|
|
923
|
+
self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_image_classification(d_dim)
|
|
924
|
+
self.return_input = return_input
|
|
925
|
+
if self.return_input:
|
|
926
|
+
self.proc_outputs.insert(0, Data4Proc(data_type="img", pubsub=False, private_only=True))
|
|
927
|
+
self.transforms = transforms_factory("rgb" + str(in_res) if in_channels == 3 else "gray" + str(in_res))
|
|
928
|
+
|
|
929
|
+
self.module = torch.nn.Sequential(
|
|
930
|
+
torch.nn.Conv2d(in_channels, 64, kernel_size=5, padding=2),
|
|
931
|
+
torch.nn.ReLU(inplace=True),
|
|
932
|
+
torch.nn.AvgPool2d(kernel_size=3, stride=2),
|
|
933
|
+
torch.nn.Conv2d(64, 128, kernel_size=5, padding=2),
|
|
934
|
+
torch.nn.ReLU(inplace=True),
|
|
935
|
+
torch.nn.AvgPool2d(kernel_size=3, stride=2),
|
|
936
|
+
torch.nn.Conv2d(128, 256, kernel_size=3, padding=1),
|
|
937
|
+
torch.nn.ReLU(inplace=True),
|
|
938
|
+
torch.nn.AvgPool2d(kernel_size=3, stride=2),
|
|
939
|
+
torch.nn.Flatten(),
|
|
940
|
+
torch.nn.LazyLinear(2048),
|
|
941
|
+
torch.nn.ReLU(inplace=True),
|
|
942
|
+
LinearCNU(2048, d_dim, key_mem_units=cnu_memories, delta=delta, scramble=scramble),
|
|
943
|
+
torch.nn.Sigmoid()
|
|
944
|
+
).to(self.device)
|
|
945
|
+
|
|
946
|
+
def forward(self, y: Image.Image, first: bool = True, last: bool = False):
|
|
947
|
+
o = self.module(self.transforms(y).to(self.device))
|
|
948
|
+
if not self.return_input:
|
|
949
|
+
return o
|
|
950
|
+
else:
|
|
951
|
+
return y, o
|
|
952
|
+
|
|
953
|
+
|
|
954
|
+
class SingleLayerCNU(ModuleWrapper):
|
|
955
|
+
|
|
956
|
+
def __init__(self, d_dim: int, cnu_memories: int, in_channels: int = 3, in_res: int = 32,
|
|
957
|
+
delta: int = 1, scramble: bool = False, return_input: bool = False, seed: int = -1):
|
|
958
|
+
super(SingleLayerCNU, self).__init__(seed=seed)
|
|
959
|
+
self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_image_classification(d_dim)
|
|
960
|
+
self.return_input = return_input
|
|
961
|
+
if self.return_input:
|
|
962
|
+
self.proc_outputs.insert(0, Data4Proc(data_type="img", pubsub=False, private_only=True))
|
|
963
|
+
self.transforms = transforms_factory("rgb" + str(in_res) if in_channels == 3 else "gray" + str(in_res))
|
|
964
|
+
|
|
965
|
+
self.module = torch.nn.Sequential(
|
|
966
|
+
torch.nn.Flatten(),
|
|
967
|
+
LinearCNU(in_res * in_res * in_channels, d_dim, key_mem_units=cnu_memories, delta=delta, scramble=scramble),
|
|
968
|
+
torch.nn.Sigmoid()
|
|
969
|
+
).to(self.device)
|
|
970
|
+
|
|
971
|
+
def forward(self, y: Image.Image, first: bool = True, last: bool = False):
|
|
972
|
+
o = self.module(self.transforms(y).to(self.device))
|
|
973
|
+
if not self.return_input:
|
|
974
|
+
return o
|
|
975
|
+
else:
|
|
976
|
+
return y, o
|
|
977
|
+
|
|
978
|
+
|
|
979
|
+
class CNNMNIST(CNN):
|
|
980
|
+
|
|
981
|
+
def __init__(self, *args, **kwargs):
|
|
982
|
+
kwargs['in_channels'] = 1
|
|
983
|
+
kwargs['in_res'] = 28
|
|
984
|
+
super(CNNMNIST, self).__init__(*args, **kwargs)
|
|
985
|
+
self.transforms = transforms_factory("gray_mnist")
|
|
986
|
+
|
|
987
|
+
|
|
988
|
+
class CNNCNUMNIST(CNNCNU):
|
|
989
|
+
|
|
990
|
+
def __init__(self, *args, **kwargs):
|
|
991
|
+
kwargs['in_channels'] = 1
|
|
992
|
+
kwargs['in_res'] = 28
|
|
993
|
+
super(CNNCNUMNIST, self).__init__(*args, **kwargs)
|
|
994
|
+
self.transforms = transforms_factory("gray_mnist")
|
|
995
|
+
|
|
996
|
+
|
|
997
|
+
class SingleLayerCNUMNIST(SingleLayerCNU):
|
|
998
|
+
|
|
999
|
+
def __init__(self, *args, **kwargs):
|
|
1000
|
+
kwargs['in_channels'] = 1
|
|
1001
|
+
kwargs['in_res'] = 28
|
|
1002
|
+
super(SingleLayerCNUMNIST, self).__init__(*args, **kwargs)
|
|
1003
|
+
self.transforms = transforms_factory("gray_mnist")
|
|
1004
|
+
|
|
1005
|
+
|
|
1006
|
+
class ResNet(ModuleWrapper):
|
|
1007
|
+
def __init__(self, d_dim: int = -1, return_input: bool = False, seed: int = -1, freeze_backbone: bool = True):
|
|
1008
|
+
super(ResNet, self).__init__(seed=seed)
|
|
1009
|
+
self.return_input = return_input
|
|
1010
|
+
if self.return_input:
|
|
1011
|
+
self.proc_outputs.insert(0, Data4Proc(data_type="img", pubsub=False, private_only=True))
|
|
1012
|
+
self.transforms = transforms_factory("rgb224")
|
|
1013
|
+
self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_image_classification(d_dim)
|
|
1014
|
+
resnet = torchvision.models.resnet50(weights="IMAGENET1K_V1")
|
|
1015
|
+
|
|
1016
|
+
if freeze_backbone:
|
|
1017
|
+
for layer in resnet.parameters():
|
|
1018
|
+
if layer != resnet.fc:
|
|
1019
|
+
layer.requires_grad = False
|
|
1020
|
+
|
|
1021
|
+
if d_dim > 0:
|
|
1022
|
+
resnet.fc = torch.nn.Sequential(
|
|
1023
|
+
torch.nn.Linear(resnet.fc.in_features, d_dim),
|
|
1024
|
+
torch.nn.Sigmoid()
|
|
1025
|
+
)
|
|
1026
|
+
|
|
1027
|
+
self.module = resnet.to(self.device)
|
|
1028
|
+
|
|
1029
|
+
def forward(self, y: Image.Image, first: bool = True, last: bool = False):
|
|
1030
|
+
o = self.module(self.transforms(y).to(self.device))
|
|
1031
|
+
if not self.return_input:
|
|
1032
|
+
return o
|
|
1033
|
+
else:
|
|
1034
|
+
return y, o
|
|
1035
|
+
|
|
1036
|
+
|
|
1037
|
+
class ResNetCNU(ModuleWrapper):
|
|
1038
|
+
def __init__(self, d_dim: int, cnu_memories: int,
|
|
1039
|
+
delta: int = 1, scramble: bool = False, return_input: bool = False, seed: int = -1):
|
|
1040
|
+
super(ResNetCNU, self).__init__(seed=seed)
|
|
1041
|
+
self.return_input = return_input
|
|
1042
|
+
if self.return_input:
|
|
1043
|
+
self.proc_outputs.insert(0, Data4Proc(data_type="img", pubsub=False, private_only=True))
|
|
1044
|
+
self.transforms = transforms_factory("rgb224")
|
|
1045
|
+
self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_image_classification(d_dim)
|
|
1046
|
+
resnet = torchvision.models.resnet50(weights="IMAGENET1K_V1")
|
|
1047
|
+
|
|
1048
|
+
resnet.fc = torch.nn.Sequential(
|
|
1049
|
+
LinearCNU(resnet.fc.in_features, d_dim, key_mem_units=cnu_memories, delta=delta, scramble=scramble),
|
|
1050
|
+
torch.nn.Sigmoid()
|
|
1051
|
+
)
|
|
1052
|
+
|
|
1053
|
+
self.module = resnet.to(self.device)
|
|
1054
|
+
|
|
1055
|
+
def forward(self, y: Image.Image, first: bool = True, last: bool = False):
|
|
1056
|
+
o = self.module(self.transforms(y).to(self.device))
|
|
1057
|
+
if not self.return_input:
|
|
1058
|
+
return o
|
|
1059
|
+
else:
|
|
1060
|
+
return y, o
|
|
1061
|
+
|
|
1062
|
+
|
|
1063
|
+
class ViT(ModuleWrapper):
|
|
1064
|
+
def __init__(self, d_dim: int = -1, return_input: bool = False, seed: int = -1):
|
|
1065
|
+
super(ViT, self).__init__(seed=seed)
|
|
1066
|
+
self.return_input = return_input
|
|
1067
|
+
if self.return_input:
|
|
1068
|
+
self.proc_outputs.insert(0, Data4Proc(data_type="img", pubsub=False, private_only=True))
|
|
1069
|
+
weights = torchvision.models.ViT_B_16_Weights.IMAGENET1K_V1
|
|
1070
|
+
self.transforms = torchvision.transforms.Compose([
|
|
1071
|
+
weights.transforms(),
|
|
1072
|
+
torchvision.transforms.Lambda(lambda x: x.unsqueeze(0)) # Add batch dimension
|
|
1073
|
+
])
|
|
1074
|
+
self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_image_classification(d_dim)
|
|
1075
|
+
vit = torchvision.models.vit_b_16(weights=weights)
|
|
1076
|
+
|
|
1077
|
+
if d_dim > 0:
|
|
1078
|
+
vit.heads = torch.nn.Sequential(
|
|
1079
|
+
torch.nn.Linear(vit.heads.head.in_features, 2048),
|
|
1080
|
+
torch.nn.ReLU(inplace=True),
|
|
1081
|
+
torch.nn.Linear(2048, d_dim),
|
|
1082
|
+
torch.nn.Sigmoid()
|
|
1083
|
+
)
|
|
1084
|
+
self.labels = ["unk"] * d_dim
|
|
1085
|
+
else:
|
|
1086
|
+
url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
|
|
1087
|
+
self.labels = []
|
|
1088
|
+
with urllib.request.urlopen(url) as f:
|
|
1089
|
+
self.labels = [line.strip().decode('utf-8') for line in f.readlines()]
|
|
1090
|
+
|
|
1091
|
+
self.module = vit.to(self.device)
|
|
1092
|
+
|
|
1093
|
+
def forward(self, y: Image.Image, first: bool = True, last: bool = False):
|
|
1094
|
+
o = self.module(self.transforms(y).to(self.device))
|
|
1095
|
+
if not self.return_input:
|
|
1096
|
+
return o
|
|
1097
|
+
else:
|
|
1098
|
+
return y, o
|
|
1099
|
+
|
|
1100
|
+
|
|
1101
|
+
class DenseNet(ModuleWrapper):
|
|
1102
|
+
def __init__(self, d_dim: int = -1, return_input: bool = False, seed: int = -1):
|
|
1103
|
+
super(DenseNet, self).__init__(seed=seed)
|
|
1104
|
+
self.return_input = return_input
|
|
1105
|
+
if self.return_input:
|
|
1106
|
+
self.proc_outputs.insert(0, Data4Proc(data_type="img", pubsub=False, private_only=True))
|
|
1107
|
+
self.transforms = transforms_factory("rgb224")
|
|
1108
|
+
self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_image_classification(d_dim)
|
|
1109
|
+
densenet = torchvision.models.densenet121(weights=None)
|
|
1110
|
+
|
|
1111
|
+
if d_dim > 0:
|
|
1112
|
+
densenet.classifier = torch.nn.Sequential(
|
|
1113
|
+
torch.nn.Linear(densenet.classifier.in_features, 2048),
|
|
1114
|
+
torch.nn.ReLU(inplace=True),
|
|
1115
|
+
torch.nn.Linear(2048, d_dim),
|
|
1116
|
+
torch.nn.Sigmoid()
|
|
1117
|
+
)
|
|
1118
|
+
|
|
1119
|
+
self.module = densenet.to(self.device)
|
|
1120
|
+
|
|
1121
|
+
def forward(self, y: Image.Image, first: bool = True, last: bool = False):
|
|
1122
|
+
o = self.module(self.transforms(y).to(self.device))
|
|
1123
|
+
if not self.return_input:
|
|
1124
|
+
return o
|
|
1125
|
+
else:
|
|
1126
|
+
return y, o
|
|
1127
|
+
|
|
1128
|
+
|
|
1129
|
+
class EfficientNet(ModuleWrapper):
|
|
1130
|
+
def __init__(self, d_dim: int = -1, return_input: bool = False, seed: int = -1):
|
|
1131
|
+
super(EfficientNet, self).__init__(seed=seed)
|
|
1132
|
+
self.return_input = return_input
|
|
1133
|
+
if self.return_input:
|
|
1134
|
+
self.proc_outputs.insert(0, Data4Proc(data_type="img", pubsub=False, private_only=True))
|
|
1135
|
+
weights = torchvision.models.EfficientNet_B0_Weights.IMAGENET1K_V1
|
|
1136
|
+
self.transforms = weights.transforms
|
|
1137
|
+
self.proc_inputs, self.proc_outputs = get_proc_inputs_and_proc_outputs_for_image_classification(d_dim)
|
|
1138
|
+
effnet = torchvision.models.efficientnet_b0(weights=weights)
|
|
1139
|
+
|
|
1140
|
+
if d_dim > 0:
|
|
1141
|
+
effnet.classifier = torch.nn.Sequential(
|
|
1142
|
+
torch.nn.Linear(effnet.classifier[1].in_features, 2048),
|
|
1143
|
+
torch.nn.ReLU(inplace=True),
|
|
1144
|
+
torch.nn.Linear(2048, d_dim),
|
|
1145
|
+
torch.nn.Sigmoid()
|
|
1146
|
+
)
|
|
1147
|
+
|
|
1148
|
+
self.module = effnet.to(self.device)
|
|
1149
|
+
|
|
1150
|
+
def forward(self, y: Image.Image, first: bool = True, last: bool = False):
|
|
1151
|
+
o = self.module(self.transforms(y).to(self.device))
|
|
1152
|
+
if o.dim() == 1:
|
|
1153
|
+
o = o.unsqueeze(0)
|
|
1154
|
+
if not self.return_input:
|
|
1155
|
+
return o
|
|
1156
|
+
else:
|
|
1157
|
+
return y, o
|
|
1158
|
+
|
|
1159
|
+
|
|
1160
|
+
class FasterRCNN(ModuleWrapper):
|
|
1161
|
+
def __init__(self, seed: int = -1):
|
|
1162
|
+
self.labels = ['__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
|
1163
|
+
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
|
|
1164
|
+
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
|
|
1165
|
+
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
|
|
1166
|
+
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
|
1167
|
+
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
|
|
1168
|
+
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
|
|
1169
|
+
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
|
|
1170
|
+
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
|
|
1171
|
+
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
|
|
1172
|
+
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
|
|
1173
|
+
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
|
|
1174
|
+
]
|
|
1175
|
+
|
|
1176
|
+
weights = torchvision.models.detection.FasterRCNN_ResNet50_FPN_Weights.DEFAULT
|
|
1177
|
+
faster_rcnn = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=weights)
|
|
1178
|
+
faster_rcnn.eval()
|
|
1179
|
+
self.transforms = torchvision.transforms.Compose([transforms_factory("rgb-no_norm"),
|
|
1180
|
+
torchvision.transforms.Lambda(lambda x: x.squeeze(0)),
|
|
1181
|
+
weights.transforms()])
|
|
1182
|
+
|
|
1183
|
+
super(FasterRCNN, self).__init__(
|
|
1184
|
+
module=faster_rcnn,
|
|
1185
|
+
proc_inputs=[Data4Proc(data_type="img", pubsub=False, private_only=True)],
|
|
1186
|
+
proc_outputs=[Data4Proc(data_type="tensor", tensor_dtype=torch.long, tensor_shape=(None,),
|
|
1187
|
+
pubsub=False, private_only=True),
|
|
1188
|
+
Data4Proc(data_type="tensor", tensor_dtype=torch.float32, tensor_shape=(None,),
|
|
1189
|
+
pubsub=False, private_only=True),
|
|
1190
|
+
Data4Proc(data_type="tensor", tensor_dtype=torch.float32, tensor_shape=(None, 4),
|
|
1191
|
+
pubsub=False, private_only=True),
|
|
1192
|
+
Data4Proc(data_type="text",
|
|
1193
|
+
pubsub=False, private_only=True)],
|
|
1194
|
+
seed=seed)
|
|
1195
|
+
|
|
1196
|
+
def forward(self, y: Image.Image, first: bool = True, last: bool = False):
|
|
1197
|
+
o = self.module([self.transforms(y).to(self.device)]) # List with 1 image per element (no batch dim)
|
|
1198
|
+
|
|
1199
|
+
found_class_indices = o[0]['labels']
|
|
1200
|
+
found_class_scores = o[0]['scores']
|
|
1201
|
+
found_class_boxes = o[0]['boxes']
|
|
1202
|
+
valid = found_class_scores > 0.8
|
|
1203
|
+
|
|
1204
|
+
found_class_indices = found_class_indices[valid]
|
|
1205
|
+
found_class_scores = found_class_scores[valid]
|
|
1206
|
+
found_class_boxes = found_class_boxes[valid]
|
|
1207
|
+
found_class_names = [self.labels[i.item()] for i in found_class_indices]
|
|
1208
|
+
|
|
1209
|
+
return found_class_indices, found_class_scores, found_class_boxes, ", ".join(found_class_names)
|
|
1210
|
+
|
|
1211
|
+
|
|
1212
|
+
class TinyLLama(ModuleWrapper):
|
|
1213
|
+
def __init__(self):
|
|
1214
|
+
super(TinyLLama, self).__init__(
|
|
1215
|
+
proc_inputs=[Data4Proc(data_type="text", pubsub=False, private_only=True)],
|
|
1216
|
+
proc_outputs=[Data4Proc(data_type="text", pubsub=False, private_only=True)]
|
|
1217
|
+
)
|
|
1218
|
+
self.module = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
|
1219
|
+
torch_dtype=torch.bfloat16, device=self.device)
|
|
1220
|
+
|
|
1221
|
+
def forward(self, msg: str, first: bool = False, last: bool = False):
|
|
1222
|
+
msg_struct = [{"role": "system", "content": "You are a helpful assistant"},
|
|
1223
|
+
{"role": "user", "content": msg}]
|
|
1224
|
+
prompt = self.module.tokenizer.apply_chat_template(msg_struct, tokenize=False, add_generation_prompt=True)
|
|
1225
|
+
|
|
1226
|
+
out = self.module(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
|
|
1227
|
+
out = out[0]["generated_text"] if (out is not None and len(out) > 0 and "generated_text" in out[0])\
|
|
1228
|
+
else "Error!"
|
|
1229
|
+
if "<|assistant|>\n" in out:
|
|
1230
|
+
out = out.split("<|assistant|>\n")[1]
|
|
1231
|
+
return out.strip()
|
|
1232
|
+
|
|
1233
|
+
|
|
1234
|
+
class LLama(ModuleWrapper):
|
|
1235
|
+
def __init__(self):
|
|
1236
|
+
super(LLama, self).__init__(
|
|
1237
|
+
proc_inputs=[Data4Proc(data_type="text", pubsub=False, private_only=True)],
|
|
1238
|
+
proc_outputs=[Data4Proc(data_type="text", pubsub=False, private_only=True)]
|
|
1239
|
+
)
|
|
1240
|
+
self.module = pipeline("text-generation", model="meta-llama/Llama-3.2-3B",
|
|
1241
|
+
torch_dtype=torch.bfloat16, device=self.device)
|
|
1242
|
+
|
|
1243
|
+
def forward(self, msg: str, first: bool = False, last: bool = False):
|
|
1244
|
+
msg_struct = [{"role": "system", "content": "You are a helpful assistant"},
|
|
1245
|
+
{"role": "user", "content": msg}]
|
|
1246
|
+
prompt = self.module.tokenizer.apply_chat_template(msg_struct, tokenize=False, add_generation_prompt=True)
|
|
1247
|
+
|
|
1248
|
+
out = self.module(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
|
|
1249
|
+
out = out[0]["generated_text"] if (out is not None and len(out) > 0 and "generated_text" in out[0])\
|
|
1250
|
+
else "Error!"
|
|
1251
|
+
if "<|assistant|>\n" in out:
|
|
1252
|
+
out = out.split("<|assistant|>\n")[1]
|
|
1253
|
+
return out.strip()
|
|
1254
|
+
|
|
1255
|
+
|
|
1256
|
+
class Phi(ModuleWrapper):
|
|
1257
|
+
def __init__(self):
|
|
1258
|
+
super(Phi, self).__init__(
|
|
1259
|
+
proc_inputs=[Data4Proc(data_type="text", pubsub=False, private_only=True)],
|
|
1260
|
+
proc_outputs=[Data4Proc(data_type="text", pubsub=False, private_only=True)]
|
|
1261
|
+
)
|
|
1262
|
+
self.module = pipeline("text-generation", model="microsoft/Phi-3.5-mini-instruct",
|
|
1263
|
+
torch_dtype="auto", device=self.device)
|
|
1264
|
+
|
|
1265
|
+
def forward(self, msg: str, first: bool = False, last: bool = False):
|
|
1266
|
+
msg_struct = [{"role": "system", "content": "You are a helpful assistant"},
|
|
1267
|
+
{"role": "user", "content": msg}]
|
|
1268
|
+
prompt = self.module.tokenizer.apply_chat_template(msg_struct, tokenize=False, add_generation_prompt=True)
|
|
1269
|
+
|
|
1270
|
+
out = self.module(prompt, max_new_tokens=256, do_sample=True, return_full_text=False)
|
|
1271
|
+
out = out[0]["generated_text"] if (out is not None and len(out) > 0 and "generated_text" in out[0])\
|
|
1272
|
+
else "Error!"
|
|
1273
|
+
if "<|assistant|>\n" in out:
|
|
1274
|
+
out = out.split("<|assistant|>\n")[1]
|
|
1275
|
+
return out.strip()
|
|
1276
|
+
|
|
1277
|
+
|
|
1278
|
+
class LangSegmentAnything(ModuleWrapper):
|
|
1279
|
+
def __init__(self):
|
|
1280
|
+
super(LangSegmentAnything, self).__init__(
|
|
1281
|
+
proc_inputs=[Data4Proc(data_type="img", pubsub=False, private_only=True),
|
|
1282
|
+
Data4Proc(data_type="text", pubsub=False, private_only=True)],
|
|
1283
|
+
proc_outputs=[Data4Proc(data_type="img", pubsub=False, private_only=True)]
|
|
1284
|
+
)
|
|
1285
|
+
from lang_sam import LangSAM
|
|
1286
|
+
self.module = LangSAM(device=self.device)
|
|
1287
|
+
|
|
1288
|
+
# Generate a 64x64 error image (with text "Error" on it)
|
|
1289
|
+
from PIL import ImageDraw, ImageFont
|
|
1290
|
+
self.error_img = Image.new("RGB", (64, 64), color="white")
|
|
1291
|
+
draw = ImageDraw.Draw(self.error_img)
|
|
1292
|
+
font = ImageFont.load_default()
|
|
1293
|
+
text = "Error"
|
|
1294
|
+
bbox = draw.textbbox((0, 0), text, font=font)
|
|
1295
|
+
text_width = bbox[2] - bbox[0]
|
|
1296
|
+
text_height = bbox[3] - bbox[1]
|
|
1297
|
+
position = ((64 - text_width) // 2, (64 - text_height) // 2)
|
|
1298
|
+
draw.text(position, text, fill="black", font=font)
|
|
1299
|
+
|
|
1300
|
+
def forward(self, image_pil: Image, msg: str, first: bool = False, last: bool = False):
|
|
1301
|
+
try:
|
|
1302
|
+
image_pil = image_pil.convert("RGB") if image_pil.mode != "RGB" else image_pil # Forcing RGB
|
|
1303
|
+
out = self.module.predict([image_pil], [msg])
|
|
1304
|
+
|
|
1305
|
+
if (out is None or not isinstance(out, list) or len(out) < 1 or
|
|
1306
|
+
not isinstance(out[0], dict) or 'masks' not in out[0]) or out[0]['masks'].ndim != 3:
|
|
1307
|
+
return image_pil
|
|
1308
|
+
else:
|
|
1309
|
+
return LangSegmentAnything.__highlight_masks_on_image(image_pil, out[0]['masks'])
|
|
1310
|
+
except Exception as e:
|
|
1311
|
+
return self.error_img
|
|
1312
|
+
|
|
1313
|
+
@staticmethod
|
|
1314
|
+
def __highlight_masks_on_image(image_pil: Image.Image, masks: np.ndarray, alpha: float = 0.75):
|
|
1315
|
+
img_np = np.array(image_pil, dtype=np.float32) / 255.0
|
|
1316
|
+
height, width, _ = img_np.shape
|
|
1317
|
+
num_masks = masks.shape[0]
|
|
1318
|
+
|
|
1319
|
+
overlay_np = np.zeros((height, width, 3), dtype=np.float32)
|
|
1320
|
+
alpha_mask_combined = np.zeros((height, width, 1), dtype=np.float32)
|
|
1321
|
+
|
|
1322
|
+
color_palette = [
|
|
1323
|
+
(255, 102, 102), # Light Red
|
|
1324
|
+
(102, 255, 102), # Light Green
|
|
1325
|
+
(102, 102, 255), # Light Blue
|
|
1326
|
+
(255, 255, 102), # Light Yellow
|
|
1327
|
+
(255, 102, 255), # Light Magenta
|
|
1328
|
+
(102, 255, 255), # Light Cyan
|
|
1329
|
+
(255, 178, 102), # Orange
|
|
1330
|
+
(178, 102, 255), # Purple
|
|
1331
|
+
(102, 178, 255), # Sky Blue
|
|
1332
|
+
]
|
|
1333
|
+
|
|
1334
|
+
for i in range(num_masks):
|
|
1335
|
+
mask = masks[i, :, :].astype(np.bool)
|
|
1336
|
+
|
|
1337
|
+
color_rgb_int = color_palette[i % len(color_palette)]
|
|
1338
|
+
color = np.array(color_rgb_int, dtype=np.float32) / 255.0
|
|
1339
|
+
overlay_np[mask] = (1 - alpha) * overlay_np[mask] + alpha * color
|
|
1340
|
+
alpha_mask_combined[mask] = np.maximum(alpha_mask_combined[mask], alpha)
|
|
1341
|
+
|
|
1342
|
+
# Final blending and conversion ...
|
|
1343
|
+
final_np = (1 - alpha_mask_combined) * img_np + alpha_mask_combined * overlay_np
|
|
1344
|
+
final_np = (final_np * 255).astype(np.uint8)
|
|
1345
|
+
final_image = Image.fromarray(final_np)
|
|
1346
|
+
return final_image
|
|
1347
|
+
|
|
1348
|
+
|
|
1349
|
+
class SmolVLM(ModuleWrapper):
|
|
1350
|
+
def __init__(self):
|
|
1351
|
+
super(SmolVLM, self).__init__(
|
|
1352
|
+
proc_inputs=[Data4Proc(data_type="img", pubsub=False, private_only=True),
|
|
1353
|
+
Data4Proc(data_type="text", pubsub=False, private_only=True)],
|
|
1354
|
+
proc_outputs=[Data4Proc(data_type="text", pubsub=False, private_only=True)]
|
|
1355
|
+
)
|
|
1356
|
+
model_id = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct"
|
|
1357
|
+
self.pre_post_processor = AutoProcessor.from_pretrained(model_id, device_map=self.device)
|
|
1358
|
+
|
|
1359
|
+
from transformers import AutoModelForImageTextToText
|
|
1360
|
+
self.module = AutoModelForImageTextToText.from_pretrained(model_id, torch_dtype=torch.bfloat16,
|
|
1361
|
+
device_map=self.device)
|
|
1362
|
+
self.module = self.module.to(self.device)
|
|
1363
|
+
|
|
1364
|
+
def forward(self, image_pil: Image, msg: str = "what is this?", first: bool = False, last: bool = False):
|
|
1365
|
+
image_pil = image_pil.convert("RGB") if image_pil.mode != "RGB" else image_pil # Forcing RGB
|
|
1366
|
+
|
|
1367
|
+
msg_struct = [{"role": "user", "content": [{"type": "text", "text": f"{msg}"},
|
|
1368
|
+
{"type": "image", "image": image_pil}]}]
|
|
1369
|
+
|
|
1370
|
+
prompt = self.pre_post_processor.apply_chat_template(msg_struct,
|
|
1371
|
+
tokenize=True,
|
|
1372
|
+
add_generation_prompt=True,
|
|
1373
|
+
return_dict=True,
|
|
1374
|
+
return_tensors="pt").to(self.device, dtype=torch.bfloat16)
|
|
1375
|
+
|
|
1376
|
+
out = self.module.generate(**prompt, do_sample=False, max_new_tokens=128)
|
|
1377
|
+
out = self.pre_post_processor.batch_decode(out, skip_special_tokens=True)[0] if out is not None else "Error!"
|
|
1378
|
+
if "Assistant:" in out:
|
|
1379
|
+
out = out.split("Assistant:")[1]
|
|
1380
|
+
return out.strip()
|
|
1381
|
+
|
|
1382
|
+
|
|
1383
|
+
class SiteRAG(ModuleWrapper):
|
|
1384
|
+
|
|
1385
|
+
def __init__(self,
|
|
1386
|
+
site_url: str,
|
|
1387
|
+
site_folder: str = os.path.join("rag", "downloaded_site"),
|
|
1388
|
+
db_folder: str = os.path.join("rag", "chroma_db")):
|
|
1389
|
+
super(SiteRAG, self).__init__(
|
|
1390
|
+
proc_inputs=[Data4Proc(data_type="text", pubsub=False, private_only=True)],
|
|
1391
|
+
proc_outputs=[Data4Proc(data_type="text", pubsub=False, private_only=True)],
|
|
1392
|
+
)
|
|
1393
|
+
|
|
1394
|
+
# Saving options
|
|
1395
|
+
self.site_url = site_url
|
|
1396
|
+
self.site_folder = site_folder
|
|
1397
|
+
self.db_folder = db_folder
|
|
1398
|
+
|
|
1399
|
+
# Loading neural model
|
|
1400
|
+
model_id = "TheBloke/vicuna-7b-1.1-HF"
|
|
1401
|
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
1402
|
+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16,
|
|
1403
|
+
device_map=self.device, offload_folder="offload")
|
|
1404
|
+
self.module = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=200)
|
|
1405
|
+
|
|
1406
|
+
# Embedder
|
|
1407
|
+
from langchain.embeddings import SentenceTransformerEmbeddings
|
|
1408
|
+
self.embedder = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2",
|
|
1409
|
+
model_kwargs={"device": self.device.type})
|
|
1410
|
+
|
|
1411
|
+
# Crawling site
|
|
1412
|
+
self.crawl_website()
|
|
1413
|
+
self.crawled_site_to_rag_knowledge_base()
|
|
1414
|
+
|
|
1415
|
+
# Setting up RAG stuff
|
|
1416
|
+
from langchain.vectorstores import Chroma
|
|
1417
|
+
db = Chroma(persist_directory=db_folder, embedding_function=self.embedder)
|
|
1418
|
+
self.retriever = db.as_retriever(search_kwargs={"k": 3})
|
|
1419
|
+
|
|
1420
|
+
def forward(self, msg: str, first: bool = False, last: bool = False):
|
|
1421
|
+
|
|
1422
|
+
# Build context
|
|
1423
|
+
docs = self.retriever.get_relevant_documents(msg)
|
|
1424
|
+
context = "\n\n".join(doc.page_content for doc in docs)
|
|
1425
|
+
prompt = f"Answer the question based on the following context:\n\n{context}\n\nQuestion: {msg}\nAnswer:"
|
|
1426
|
+
|
|
1427
|
+
# Generate answer
|
|
1428
|
+
out = self.module(prompt, max_new_tokens=256, do_sample=True, temperature=0.7)
|
|
1429
|
+
out = out[0]['generated_text'][len(prompt):].strip() if (out is not None and len(out) > 0 and
|
|
1430
|
+
"generated_text" in out[0]) else "Error!"
|
|
1431
|
+
|
|
1432
|
+
# Append source URLs
|
|
1433
|
+
best_doc_with_score = self.retriever.vectorstore.similarity_search_with_score(msg, k=1)
|
|
1434
|
+
best_doc, _ = best_doc_with_score[0]
|
|
1435
|
+
docs = [best_doc]
|
|
1436
|
+
sources = set("<a href='" +
|
|
1437
|
+
doc.metadata['source'] +
|
|
1438
|
+
"' onclick='window.open(this.href); return false;' style='color: blue;'>" +
|
|
1439
|
+
doc.metadata['source'] + "</a>" for doc in docs)
|
|
1440
|
+
sources_text = "<br/><br/>\nURLs:\n" + "\n".join(sources)
|
|
1441
|
+
|
|
1442
|
+
return out.strip() + sources_text
|
|
1443
|
+
|
|
1444
|
+
def crawl_website(self, max_pages=300):
|
|
1445
|
+
import requests
|
|
1446
|
+
from bs4 import BeautifulSoup
|
|
1447
|
+
from urllib.parse import urljoin, urlparse
|
|
1448
|
+
|
|
1449
|
+
if os.path.exists(self.site_folder):
|
|
1450
|
+
shutil.rmtree(self.site_folder)
|
|
1451
|
+
os.makedirs(self.site_folder)
|
|
1452
|
+
visited = set()
|
|
1453
|
+
to_visit = [self.site_url]
|
|
1454
|
+
|
|
1455
|
+
while to_visit and len(visited) < max_pages:
|
|
1456
|
+
url = to_visit.pop(0)
|
|
1457
|
+
if url in visited:
|
|
1458
|
+
continue
|
|
1459
|
+
visited.add(url)
|
|
1460
|
+
|
|
1461
|
+
try:
|
|
1462
|
+
r = requests.get(url, timeout=10)
|
|
1463
|
+
if "text/html" not in r.headers.get("Content-Type", ""):
|
|
1464
|
+
continue
|
|
1465
|
+
|
|
1466
|
+
parsed = urlparse(url)
|
|
1467
|
+
filename = parsed.path.strip("/") or "index.html"
|
|
1468
|
+
filename += ".crawled"
|
|
1469
|
+
file_path = os.path.join(self.site_folder, filename.replace("/", "__"))
|
|
1470
|
+
with open(file_path, "w", encoding="utf-8") as f:
|
|
1471
|
+
f.write(r.text)
|
|
1472
|
+
|
|
1473
|
+
soup = BeautifulSoup(r.text, "html.parser")
|
|
1474
|
+
for link in soup.find_all("a", href=True):
|
|
1475
|
+
full_url = urljoin(url, link["href"])
|
|
1476
|
+
if full_url.startswith(self.site_url) and full_url not in visited:
|
|
1477
|
+
to_visit.append(full_url)
|
|
1478
|
+
except Exception as e:
|
|
1479
|
+
print(f"Error fetching {url}: {e}")
|
|
1480
|
+
|
|
1481
|
+
print(f"Crawled {len(visited)} pages.")
|
|
1482
|
+
|
|
1483
|
+
def crawled_site_to_rag_knowledge_base(self):
|
|
1484
|
+
from bs4 import BeautifulSoup
|
|
1485
|
+
from urllib.parse import urljoin
|
|
1486
|
+
from langchain.vectorstores import Chroma
|
|
1487
|
+
from langchain.docstore.document import Document
|
|
1488
|
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
1489
|
+
|
|
1490
|
+
docs = []
|
|
1491
|
+
for filename in os.listdir(self.site_folder):
|
|
1492
|
+
if filename.endswith(".crawled"):
|
|
1493
|
+
file_path = os.path.join(self.site_folder, filename)
|
|
1494
|
+
with open(file_path, encoding="utf-8") as f:
|
|
1495
|
+
html = f.read()
|
|
1496
|
+
|
|
1497
|
+
soup: BeautifulSoup = BeautifulSoup(html, "html.parser")
|
|
1498
|
+
text = soup.get_text(separator=" ", strip=True) # Type: ignore
|
|
1499
|
+
|
|
1500
|
+
page_path = filename.replace("__", "/").replace(".crawled", "")
|
|
1501
|
+
url = urljoin(self.site_url, page_path)
|
|
1502
|
+
|
|
1503
|
+
docs.append(Document(page_content=text, metadata={"source": url}))
|
|
1504
|
+
|
|
1505
|
+
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
|
|
1506
|
+
split_docs = splitter.split_documents(docs)
|
|
1507
|
+
|
|
1508
|
+
chroma_db = Chroma.from_documents(split_docs, self.embedder, persist_directory=self.db_folder)
|
|
1509
|
+
chroma_db.persist()
|