lt-tensor 0.0.1a11__py3-none-any.whl → 0.0.1a13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lt_tensor/model_base.py CHANGED
@@ -1,12 +1,14 @@
1
- __all__ = ["Model", "_ModelExtended", "LossTracker"]
1
+ __all__ = ["Model", "LossTracker"]
2
2
 
3
3
  import gc
4
4
  import json
5
5
  import math
6
6
  import warnings
7
- from .torch_commons import *
7
+ from lt_tensor.torch_commons import *
8
8
  from lt_utils.common import *
9
+ from lt_tensor.misc_utils import plot_view, get_weights, get_config, updateDict
9
10
  from lt_utils.misc_utils import log_traceback, get_current_time
11
+ from lt_utils.file_ops import load_json, save_json
10
12
 
11
13
  T = TypeVar("T")
12
14
 
@@ -29,79 +31,47 @@ class LossTracker:
29
31
  "eval": [],
30
32
  }
31
33
 
32
- def append(self, loss: float, mode: Literal["train", "eval"] = "train"):
33
- assert mode in self.history, f"Invalid mode '{mode}'. Use 'train' or 'eval'."
34
+ def append(
35
+ self, loss: float, mode: Union[Literal["train", "eval"], "str"] = "train"
36
+ ):
34
37
  self.history[mode].append(float(loss))
35
38
  if len(self.history[mode]) > self.max_len:
36
39
  self.history[mode] = self.history[mode][-self.max_len :]
37
40
 
38
- def get(self, mode: Literal["train", "eval"] = "train"):
41
+ def get(self, mode: Union[Literal["train", "eval"], "str"] = "train"):
39
42
  return self.history.get(mode, [])
40
43
 
41
44
  def save(self, path: Optional[PathLike] = None):
42
45
  if path is None:
43
46
  path = f"logs/history_{get_current_time()}.json"
44
-
45
- Path(path).parent.mkdir(exist_ok=True, parents=True)
46
- with open(path, "w") as f:
47
- json.dump(self.history, f, indent=2)
48
-
47
+ save_json(path, self.history, indent=2)
49
48
  self.last_file = path
50
49
 
51
50
  def load(self, path: Optional[PathLike] = None):
52
51
  if path is None:
53
- _path = self.last_file
54
- else:
55
- _path = path
56
- with open(_path) as f:
57
- self.history = json.load(f)
58
- if path is not None:
59
- self.last_file = path
60
-
61
- def plot(self, backend: Literal["matplotlib", "plotly"] = "plotly"):
62
- if backend == "plotly":
63
- try:
64
- import plotly.graph_objs as go
65
- except ModuleNotFoundError:
66
- warnings.warn(
67
- "No installation of plotly was found. To use it use 'pip install plotly' and restart this application!"
68
- )
69
- return
70
- fig = go.Figure()
71
- for mode, losses in self.history.items():
72
- if losses:
73
- fig.add_trace(go.Scatter(y=losses, name=mode.capitalize()))
74
- fig.update_layout(
75
- title="Training vs Evaluation Loss",
76
- xaxis_title="Step",
77
- yaxis_title="Loss",
78
- template="plotly_dark",
79
- )
80
- fig.show()
81
-
82
- elif backend == "matplotlib":
83
- import matplotlib.pyplot as plt
84
-
85
- for mode, losses in self.history.items():
86
- if losses:
87
- plt.plot(losses, label=f"{mode.capitalize()} Loss")
88
- plt.title("Loss over Time")
89
- plt.xlabel("Step")
90
- plt.ylabel("Loss")
91
- plt.legend()
92
- plt.grid(True)
93
- plt.show()
52
+ path = self.last_file
53
+ self.history = load_json(path, [])
54
+ self.last_file = path
94
55
 
56
+ def plot(
57
+ self,
58
+ history_keys: Union[str, List[str]],
59
+ max_amount: int = 0,
60
+ title: str = "Loss",
61
+ ):
62
+ if isinstance(history_keys, str):
63
+ history_keys = [history_keys]
64
+ return plot_view(
65
+ {k: v for k, v in self.history.items() if k in history_keys},
66
+ title,
67
+ max_amount,
68
+ )
95
69
 
96
- class Model(nn.Module, ABC):
97
- """
98
- This makes it easier to assign a device and retrieves it later
99
- """
100
70
 
71
+ class _Devices_Base(nn.Module):
101
72
  _device: torch.device = ROOT_DEVICE
102
73
  _autocast: bool = False
103
74
  _loss_history: LossTracker = LossTracker(100_000)
104
- _is_unfrozen: bool = False
105
75
 
106
76
  @property
107
77
  def autocast(self):
@@ -119,7 +89,96 @@ class Model(nn.Module, ABC):
119
89
  def device(self, device: Union[torch.device, str]):
120
90
  assert isinstance(device, (str, torch.device))
121
91
  self._device = torch.device(device) if isinstance(device, str) else device
122
- self._apply_device_to()
92
+
93
+ def to(self, *args, **kwargs):
94
+ device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
95
+ *args, **kwargs
96
+ )
97
+
98
+ if dtype is not None:
99
+ if not (dtype.is_floating_point or dtype.is_complex):
100
+ raise TypeError(
101
+ "nn.Module.to only accepts floating point or complex "
102
+ f"dtypes, but got desired dtype={dtype}"
103
+ )
104
+ if dtype.is_complex:
105
+ warnings.warn(
106
+ "Complex modules are a new feature under active development whose design may change, "
107
+ "and some modules might not work as expected when using complex tensors as parameters or buffers. "
108
+ "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "
109
+ "if a complex module does not work as expected."
110
+ )
111
+
112
+ def convert(t: Tensor):
113
+ try:
114
+ if convert_to_format is not None and t.dim() in (4, 5):
115
+ return t.to(
116
+ device,
117
+ dtype if t.is_floating_point() or t.is_complex() else None,
118
+ non_blocking,
119
+ memory_format=convert_to_format,
120
+ )
121
+ return t.to(
122
+ device,
123
+ dtype if t.is_floating_point() or t.is_complex() else None,
124
+ non_blocking,
125
+ )
126
+ except NotImplementedError as e:
127
+ if str(e) == "Cannot copy out of meta tensor; no data!":
128
+ raise NotImplementedError(
129
+ f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() "
130
+ f"when moving module from meta to a different device."
131
+ ) from None
132
+ else:
133
+ raise
134
+
135
+ self._apply(convert)
136
+ self.device = device
137
+ return self
138
+
139
+ def _to_dvc(
140
+ self, device_name: str, device_id: Optional[Union[int, torch.device]] = None
141
+ ):
142
+ device = device_name
143
+ if device_id is not None:
144
+ if isinstance(device_id, Number):
145
+ device += ":" + str(int(device_id))
146
+ elif hasattr(device_id, "index"):
147
+ device += ":" + str(device_id.index)
148
+ self.device = device
149
+
150
+ def ipu(self, device: Optional[Union[int, torch.device]] = None) -> T:
151
+ super().ipu(device)
152
+ self._to_dvc("ipu", device)
153
+ return self
154
+
155
+ def xpu(self, device: Optional[Union[int, torch.device]] = None) -> T:
156
+ super().xpu(device)
157
+ self._to_dvc("xpu", device)
158
+ return self
159
+
160
+ def cuda(self, device: Optional[Union[int, torch.device]] = None) -> T:
161
+ super().cuda(device)
162
+ self._to_dvc("cuda", device)
163
+ return self
164
+
165
+ def mtia(self, device: Optional[Union[int, torch.device]] = None) -> T:
166
+ super().mtia(device)
167
+ self._to_dvc("mtia", device)
168
+ return self
169
+
170
+ def cpu(self) -> T:
171
+ super().cpu()
172
+ self._to_dvc("cpu", None)
173
+ return self
174
+
175
+
176
+ class Model(_Devices_Base, ABC):
177
+ """
178
+ This makes it easier to assign a device and retrieves it later
179
+ """
180
+
181
+ _is_unfrozen: bool = False
123
182
 
124
183
  def _apply_device_to(self):
125
184
  """Add here components that are needed to have device applied to them,
@@ -239,103 +298,6 @@ class Model(nn.Module, ABC):
239
298
  not_unfrozen.append((name, "excluded"))
240
299
  return dict(unfrozen=unfrozen, not_unfrozen=not_unfrozen)
241
300
 
242
- def to(self, *args, **kwargs):
243
- device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
244
- *args, **kwargs
245
- )
246
-
247
- if dtype is not None:
248
- if not (dtype.is_floating_point or dtype.is_complex):
249
- raise TypeError(
250
- "nn.Module.to only accepts floating point or complex "
251
- f"dtypes, but got desired dtype={dtype}"
252
- )
253
- if dtype.is_complex:
254
- warnings.warn(
255
- "Complex modules are a new feature under active development whose design may change, "
256
- "and some modules might not work as expected when using complex tensors as parameters or buffers. "
257
- "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "
258
- "if a complex module does not work as expected."
259
- )
260
-
261
- def convert(t: Tensor):
262
- try:
263
- if convert_to_format is not None and t.dim() in (4, 5):
264
- return t.to(
265
- device,
266
- dtype if t.is_floating_point() or t.is_complex() else None,
267
- non_blocking,
268
- memory_format=convert_to_format,
269
- )
270
- return t.to(
271
- device,
272
- dtype if t.is_floating_point() or t.is_complex() else None,
273
- non_blocking,
274
- )
275
- except NotImplementedError as e:
276
- if str(e) == "Cannot copy out of meta tensor; no data!":
277
- raise NotImplementedError(
278
- f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() "
279
- f"when moving module from meta to a different device."
280
- ) from None
281
- else:
282
- raise
283
-
284
- self._apply(convert)
285
- self.device = device
286
- self._apply_device_to()
287
- return self
288
-
289
- def ipu(self, device: Optional[Union[int, torch.device]] = None) -> T:
290
- super().ipu(device)
291
- dvc = "ipu"
292
- if device is not None:
293
- dvc += (
294
- ":" + str(device) if isinstance(device, (int, float)) else device.index
295
- )
296
- self.device = dvc
297
- self._apply_device_to()
298
- return self
299
-
300
- def xpu(self, device: Optional[Union[int, torch.device]] = None) -> T:
301
- super().xpu(device)
302
- dvc = "xpu"
303
- if device is not None:
304
- dvc += (
305
- ":" + str(device) if isinstance(device, (int, float)) else device.index
306
- )
307
- self.device = dvc
308
- self._apply_device_to()
309
- return self
310
-
311
- def cuda(self, device: Optional[Union[int, torch.device]] = None) -> T:
312
- super().cuda(device)
313
- dvc = "cuda"
314
- if device is not None:
315
- dvc += (
316
- ":" + str(device) if isinstance(device, (int, float)) else device.index
317
- )
318
- self.device = dvc
319
- self._apply_device_to()
320
- return self
321
-
322
- def mtia(self, device: Optional[Union[int, torch.device]] = None) -> T:
323
- super().mtia(device)
324
- dvc = "mtia"
325
- if device is not None:
326
- dvc += (
327
- ":" + str(device) if isinstance(device, (int, float)) else device.index
328
- )
329
- self.device = dvc
330
- self._apply_device_to()
331
- return self
332
-
333
- def cpu(self) -> T:
334
- super().cpu()
335
- self.device = "cpu"
336
- self._apply_device_to()
337
- return self
338
-
339
301
  def count_trainable_parameters(self, module_name: Optional[str] = None):
340
302
  """Gets the number of trainable parameters from either the entire model or from a specific module."""
341
303
  if module_name is not None:
@@ -411,6 +373,19 @@ class Model(nn.Module, ABC):
411
373
  else:
412
374
  print(f"Non-Trainable Parameters: {params}")
413
375
 
376
+ @classmethod
377
+ def from_pretrained(
378
+ cls,
379
+ model_path: Union[str, PathLike],
380
+ model_config: Optional[Union[str, PathLike]] = None,
381
+ model_config_dict: Optional[Dict[str, Any]] = None,
382
+ ):
383
+ """TODO: Finish this implementation or leave it as an optional function."""
384
+ assert isinstance(model_path, (str, PathLike, bytes))
385
+ model_path = get_weights(model_path)
386
+ model_config = get_config(model_config, model_config_dict)
387
+ return model_path, model_config
388
+
414
389
  def save_weights(
415
390
  self,
416
391
  path: Union[Path, str],
@@ -494,7 +469,9 @@ class Model(nn.Module, ABC):
494
469
  pass
495
470
 
496
471
  def add_loss(
497
- self, loss: Union[float, list[float]], mode: Literal["train", "eval"] = "train"
472
+ self,
473
+ loss: Union[float, list[float]],
474
+ mode: Union[Literal["train", "eval"]] = "train",
498
475
  ):
499
476
  if isinstance(loss, Number) and loss:
500
477
  self._loss_history.append(loss, mode)
@@ -513,8 +490,12 @@ class Model(nn.Module, ABC):
513
490
  def load_loss_history(self, path: Optional[PathLike] = None):
514
491
  self._loss_history.load(path)
515
492
 
516
- def get_loss_avg(self, mode: Literal["train", "eval"], quantity: int = 0):
517
- t_list = self._loss_history.get("train")
493
+ def get_loss_avg(
494
+ self,
495
+ mode: Union[Literal["train", "eval"], str],
496
+ quantity: int = 0,
497
+ ):
498
+ t_list = self._loss_history.get(mode)
518
499
  if not t_list:
519
500
  return float("nan")
520
501
  if quantity > 0:
@@ -524,9 +505,10 @@ class Model(nn.Module, ABC):
524
505
  def freeze_unfreeze_loss(
525
506
  self,
526
507
  losses: Optional[Union[float, List[float]]] = None,
527
- trigger_loss: float = 0.1,
508
+ trigger_loss: Union[float, bool] = 0.1,
528
509
  excluded_modules: Optional[List[str]] = None,
529
- eval_last: int = 1000,
510
+ max_items: int = 1000,
511
+ loss_name: str = "train",
530
512
  ):
531
513
  """If a certain threshold is reached the weights will freeze or unfreeze the modules.
532
514
  the biggest use-case for this function is when training GANs where the balance
@@ -534,18 +516,29 @@ class Model(nn.Module, ABC):
534
516
 
535
517
  Args:
536
518
  losses (Union[float, List[float]], Optional): The loss value or a list of losses that will be used to determine if it has reached or not the threshold. Defaults to None.
537
- trigger_loss (float, optional): The value where the weights will be either freeze or unfreeze. Defaults to 0.1.
519
+ trigger_loss (float, bool, optional): The value where the weights will be either freeze or unfreeze. If set to a boolean it will freeze or unfreeze immediately according to the value (True = Freeze, False = Unfreeze). Defaults to 0.1.
538
520
  excluded_modules (list[str], optional): The list of modules (names) that is not to be changed by either freezing nor unfreezing. Defaults to None.
539
- eval_last (float, optional): The number of previous losses to be locked behind to calculate the current averange. Default to 1000.
540
-
521
+ max_items (float, optional): The number of previous losses to be locked behind to calculate the current average. Default to 1000.
522
+ loss_name (str, optional): Responsible to define with key to recover the loss.
541
523
  returns:
542
524
  bool: True when its frozen and false when its trainable.
543
525
  """
544
526
  if losses is not None:
545
- calculated = None
546
527
  self.add_loss(losses)
547
528
 
548
- value = self.get_loss_avg("train", eval_last)
529
+ if isinstance(trigger_loss, bool):
530
+ if trigger_loss:
531
+ if self._is_unfrozen:
532
+ self.freeze_all(excluded_modules)
533
+ self._is_unfrozen = False
534
+ return True
535
+ # else
536
+ if not self._is_unfrozen:
537
+ self.unfreeze_all(excluded_modules)
538
+ self._is_unfrozen = True
539
+ return False
540
+
541
+ value = self.get_loss_avg(loss_name, max_items)
549
542
 
550
543
  if value <= trigger_loss:
551
544
  if self._is_unfrozen:
@@ -557,42 +550,3 @@ class Model(nn.Module, ABC):
557
550
  self.unfreeze_all(excluded_modules)
558
551
  self._is_unfrozen = True
559
552
  return False
560
-
561
-
562
- class _ModelExtended(Model):
563
- """Planed, but not ready, maybe in the near future?"""
564
- criterion: Optional[Callable[[Tensor, Tensor], Tensor]] = None
565
- optimizer: Optional[optim.Optimizer] = None
566
-
567
- def train_step(
568
- self,
569
- *inputs,
570
- loss_label: Optional[Tensor] = None,
571
- **kwargs,
572
- ):
573
- if not self.training:
574
- self.train()
575
- if self.optimizer is not None:
576
- self.optimizer.zero_grad()
577
- if self.autocast:
578
- if self.criterion is None:
579
- raise RuntimeError(
580
- "To use autocast during training, you must assign a criterion first!"
581
- )
582
- with torch.autocast(device_type=self.device.type):
583
- out = self.forward(*loss_label, **kwargs)
584
- loss = self.criterion(out, loss_label)
585
-
586
- if self.optimizer is not None:
587
- loss.backward()
588
- self.optimizer.step()
589
- return loss
590
- elif self.criterion is not None:
591
- out = self.forward(*loss_label, **kwargs)
592
- loss = self.criterion(out, loss_label)
593
- if self.optimizer is not None:
594
- loss.backward()
595
- self.optimizer.step()
596
- return loss
597
- else:
598
- return self(*inputs, **kwargs)
@@ -1,11 +1,20 @@
1
1
  __all__ = [
2
- "bsc", # basic
3
- "rsd", # residual
4
- "tfrms", # transformer
5
- "pos", # positional encoders
6
- "fsn", # fusion
7
- "gns", # generators
8
- "disc", # discriminators
9
- "istft" # self-explanatory
2
+ "basic", # basic
3
+ "residual", # residual
4
+ "transformer", # transformer
5
+ "pos_encoder",
6
+ "fusion",
7
+ "features",
8
+ "discriminator",
9
+ "istft",
10
10
  ]
11
- from . import bsc, fsn, gns, istft, pos, rsd, tfrms, disc
11
+ from . import (
12
+ basic,
13
+ discriminator,
14
+ features,
15
+ fusion,
16
+ istft,
17
+ pos_encoder,
18
+ residual,
19
+ transformer,
20
+ )
@@ -8,11 +8,14 @@ __all__ = [
8
8
  "StyleEncoder",
9
9
  "PatchEmbed1D",
10
10
  "MultiScaleEncoder1D",
11
+ "UpSampleConv1D",
11
12
  ]
12
-
13
- from ..torch_commons import *
14
- from ..model_base import Model
15
- from ..transform import get_sinusoidal_embedding
13
+ import torch.nn.functional as F
14
+ from lt_tensor.torch_commons import *
15
+ from lt_tensor.model_base import Model
16
+ from lt_tensor.transform import get_sinusoidal_embedding
17
+ from lt_utils.common import *
18
+ import math
16
19
 
17
20
 
18
21
  class FeedForward(Model):
@@ -30,6 +33,9 @@ class FeedForward(Model):
30
33
  nn.Linear(d_model, ff_dim),
31
34
  activation,
32
35
  nn.Dropout(dropout),
36
+ nn.Linear(ff_dim, ff_dim),
37
+ activation,
38
+ nn.Dropout(dropout),
33
39
  nn.Linear(ff_dim, d_model),
34
40
  normalizer,
35
41
  )
@@ -38,6 +44,24 @@ class FeedForward(Model):
38
44
  return self.net(x)
39
45
 
40
46
 
47
+ class UpSampleConv1D(nn.Module):
48
+ def __init__(self, upsample: bool = False, dim_in: int = 0, dim_out: int = 0):
49
+ super().__init__()
50
+ if upsample:
51
+ self.upsample = lambda x: F.interpolate(x, scale_factor=2, mode="nearest")
52
+ else:
53
+ self.upsample = nn.Identity()
54
+
55
+ if dim_in == dim_out:
56
+ self.learned = nn.Identity()
57
+ else:
58
+ self.learned = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
59
+
60
+ def forward(self, x):
61
+ x = self.upsample(x)
62
+ return self.learned(x)
63
+
64
+
41
65
  class MLP(Model):
42
66
  def __init__(
43
67
  self,
@@ -161,7 +185,6 @@ class StyleEncoder(Model):
161
185
  self.linear = nn.Linear(hidden, out_dim)
162
186
 
163
187
  def forward(self, x: torch.Tensor) -> torch.Tensor:
164
- # x: [B, Mels, T]
165
188
  x = self.net(x).squeeze(-1) # [B, hidden]
166
189
  return self.linear(x) # [B, out_dim]
167
190
 
@@ -208,10 +231,10 @@ class MultiScaleEncoder1D(Model):
208
231
  def forward(self, x: torch.Tensor) -> torch.Tensor:
209
232
  # x: [B, C, T]
210
233
  return self.net(x) # [B, hidden, T]
211
-
212
-
234
+
235
+
213
236
  class AudioClassifier(Model):
214
- def __init__(self, n_mels:int=80, num_classes=5):
237
+ def __init__(self, n_mels: int = 80, num_classes=5):
215
238
  super().__init__()
216
239
  self.model = nn.Sequential(
217
240
  nn.Conv1d(n_mels, 256, kernel_size=3, padding=1),
@@ -230,3 +253,96 @@ class AudioClassifier(Model):
230
253
 
231
254
  def forward(self, x):
232
255
  return self.model(x)
256
+
257
+
258
+ class LoRALinearLayer(nn.Module):
259
+
260
+ def __init__(
261
+ self,
262
+ in_features: int,
263
+ out_features: int,
264
+ rank: int = 4,
265
+ alpha: float = 1.0,
266
+ ):
267
+ super().__init__()
268
+ self.down = nn.Linear(in_features, rank, bias=False)
269
+ self.up = nn.Linear(rank, out_features, bias=False)
270
+ self.alpha = alpha
271
+ self.rank = rank
272
+ self.out_features = out_features
273
+ self.in_features = in_features
274
+ self.ah = self.alpha / self.rank
275
+ self._down_dt = self.down.weight.dtype
276
+
277
+ nn.init.normal_(self.down.weight, std=1 / rank)
278
+ nn.init.zeros_(self.up.weight)
279
+
280
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
281
+ orig_dtype = hidden_states.dtype
282
+ down_hidden_states = self.down(hidden_states.to(self._down_dt))
283
+ up_hidden_states = self.up(down_hidden_states) * self.ah
284
+ return up_hidden_states.to(orig_dtype)
285
+
286
+
287
+ class LoRAConv1DLayer(nn.Module):
288
+ def __init__(
289
+ self,
290
+ in_features: int,
291
+ out_features: int,
292
+ kernel_size: Union[int, Tuple[int, ...]] = 1,
293
+ rank: int = 4,
294
+ alpha: float = 1.0,
295
+ ):
296
+ super().__init__()
297
+ self.down = nn.Conv1d(
298
+ in_features, rank, kernel_size, padding=kernel_size // 2, bias=False
299
+ )
300
+ self.up = nn.Conv1d(
301
+ rank, out_features, kernel_size, padding=kernel_size // 2, bias=False
302
+ )
303
+ self.ah = alpha / rank
304
+ self._down_dt = self.down.weight.dtype
305
+ nn.init.kaiming_uniform_(self.down.weight, a=math.sqrt(5))
306
+ nn.init.zeros_(self.up.weight)
307
+
308
+ def forward(self, inputs: Tensor) -> Tensor:
309
+ orig_dtype = inputs.dtype
310
+ down_hidden_states = self.down(inputs.to(self._down_dt))
311
+ up_hidden_states = self.up(down_hidden_states) * self.ah
312
+ return up_hidden_states.to(orig_dtype)
313
+
314
+
315
+ class LoRAConv2DLayer(nn.Module):
316
+ def __init__(
317
+ self,
318
+ in_features: int,
319
+ out_features: int,
320
+ kernel_size: Union[int, Tuple[int, ...]] = (1, 1),
321
+ rank: int = 4,
322
+ alpha: float = 1.0,
323
+ ):
324
+ super().__init__()
325
+ self.down = nn.Conv2d(
326
+ in_features,
327
+ rank,
328
+ kernel_size,
329
+ padding="same",
330
+ bias=False,
331
+ )
332
+ self.up = nn.Conv2d(
333
+ rank,
334
+ out_features,
335
+ kernel_size,
336
+ padding="same",
337
+ bias=False,
338
+ )
339
+ self.ah = alpha / rank
340
+
341
+ nn.init.kaiming_normal_(self.down.weight, a=0.2)
342
+ nn.init.zeros_(self.up.weight)
343
+
344
+ def forward(self, inputs: Tensor) -> Tensor:
345
+ orig_dtype = inputs.dtype
346
+ down_hidden_states = self.down(inputs.to(self._down_dt))
347
+ up_hidden_states = self.up(down_hidden_states) * self.ah
348
+ return up_hidden_states.to(orig_dtype)
@@ -1,4 +1,4 @@
1
- from ..torch_commons import *
1
+ from lt_tensor.torch_commons import *
2
2
  import torch.nn.functional as F
3
3
  from lt_tensor.model_base import Model
4
4
  from lt_utils.common import *