lucid-dl 2.7.4__py3-none-any.whl → 2.7.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lucid/nn/functional/__init__.py +7 -5
- lucid/nn/functional/_loss.py +46 -11
- lucid/nn/module.py +4 -6
- lucid/optim/_base.py +109 -24
- lucid/optim/lr_scheduler/_base.py +45 -23
- lucid/port.py +13 -12
- {lucid_dl-2.7.4.dist-info → lucid_dl-2.7.6.dist-info}/METADATA +1 -1
- {lucid_dl-2.7.4.dist-info → lucid_dl-2.7.6.dist-info}/RECORD +11 -11
- {lucid_dl-2.7.4.dist-info → lucid_dl-2.7.6.dist-info}/WHEEL +0 -0
- {lucid_dl-2.7.4.dist-info → lucid_dl-2.7.6.dist-info}/licenses/LICENSE +0 -0
- {lucid_dl-2.7.4.dist-info → lucid_dl-2.7.6.dist-info}/top_level.txt +0 -0
lucid/nn/functional/__init__.py
CHANGED
|
@@ -395,12 +395,12 @@ def binary_cross_entropy(
|
|
|
395
395
|
def binary_cross_entropy_with_logits(
|
|
396
396
|
input_: Tensor,
|
|
397
397
|
target: Tensor,
|
|
398
|
-
|
|
398
|
+
weight: Tensor | None = None,
|
|
399
|
+
pos_weight: Tensor | None = None,
|
|
399
400
|
reduction: _ReductionType | None = "mean",
|
|
400
|
-
eps: float = 1e-7,
|
|
401
401
|
) -> Tensor:
|
|
402
402
|
return _loss.binary_cross_entropy_with_logits(
|
|
403
|
-
input_, target,
|
|
403
|
+
input_, target, weight, pos_weight, reduction
|
|
404
404
|
)
|
|
405
405
|
|
|
406
406
|
|
|
@@ -410,8 +410,9 @@ def cross_entropy(
|
|
|
410
410
|
weight: Tensor | None = None,
|
|
411
411
|
reduction: _ReductionType | None = "mean",
|
|
412
412
|
eps: float = 1e-7,
|
|
413
|
+
ignore_index: int | None = None,
|
|
413
414
|
) -> Tensor:
|
|
414
|
-
return _loss.cross_entropy(input_, target, weight, reduction, eps)
|
|
415
|
+
return _loss.cross_entropy(input_, target, weight, reduction, eps, ignore_index)
|
|
415
416
|
|
|
416
417
|
|
|
417
418
|
def nll_loss(
|
|
@@ -419,8 +420,9 @@ def nll_loss(
|
|
|
419
420
|
target: Tensor,
|
|
420
421
|
weight: Tensor | None = None,
|
|
421
422
|
reduction: _ReductionType | None = "mean",
|
|
423
|
+
ignore_index: int | None = None,
|
|
422
424
|
) -> Tensor:
|
|
423
|
-
return _loss.nll_loss(input_, target, weight, reduction)
|
|
425
|
+
return _loss.nll_loss(input_, target, weight, reduction, ignore_index)
|
|
424
426
|
|
|
425
427
|
|
|
426
428
|
def huber_loss(
|
lucid/nn/functional/_loss.py
CHANGED
|
@@ -20,6 +20,27 @@ def _loss_reduction(loss: Tensor, reduction: _ReductionType | None) -> Tensor:
|
|
|
20
20
|
)
|
|
21
21
|
|
|
22
22
|
|
|
23
|
+
def _ignore_index_loss(
|
|
24
|
+
loss: Tensor,
|
|
25
|
+
target_int: Tensor,
|
|
26
|
+
ignore_index: int,
|
|
27
|
+
reduction: _ReductionType | None,
|
|
28
|
+
) -> Tensor:
|
|
29
|
+
mask = (target_int != ignore_index).astype(lucid.Float32)
|
|
30
|
+
if reduction is None:
|
|
31
|
+
return loss * mask
|
|
32
|
+
|
|
33
|
+
loss_sum = (loss * mask).sum()
|
|
34
|
+
if reduction == "sum":
|
|
35
|
+
return loss_sum
|
|
36
|
+
|
|
37
|
+
valid_count = mask.sum()
|
|
38
|
+
if valid_count.item() == 0:
|
|
39
|
+
return lucid.zeros_like(valid_count)
|
|
40
|
+
|
|
41
|
+
return loss_sum / valid_count
|
|
42
|
+
|
|
43
|
+
|
|
23
44
|
def mse_loss(
|
|
24
45
|
input_: Tensor, target: Tensor, reduction: _ReductionType | None = "mean"
|
|
25
46
|
) -> Tensor:
|
|
@@ -46,18 +67,21 @@ def binary_cross_entropy(
|
|
|
46
67
|
def binary_cross_entropy_with_logits(
|
|
47
68
|
input_: Tensor,
|
|
48
69
|
target: Tensor,
|
|
49
|
-
|
|
70
|
+
weight: Tensor | None = None,
|
|
71
|
+
pos_weight: Tensor | None = None,
|
|
50
72
|
reduction: _ReductionType | None = "mean",
|
|
51
|
-
eps: float = 1e-7,
|
|
52
73
|
) -> Tensor:
|
|
53
|
-
max_val = lucid.
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
loss
|
|
74
|
+
max_val = lucid.maximum(-input_, 0)
|
|
75
|
+
sp = max_val + lucid.log(lucid.exp(-max_val) + lucid.exp(-input_ - max_val))
|
|
76
|
+
|
|
77
|
+
if pos_weight is not None:
|
|
78
|
+
coeff = 1 + (pos_weight - 1) * target
|
|
79
|
+
loss = (1 - target) * input_ + coeff * sp
|
|
80
|
+
else:
|
|
81
|
+
loss = lucid.maximum(input_, 0) - input_ * target + sp
|
|
82
|
+
|
|
83
|
+
if weight is not None:
|
|
84
|
+
loss *= weight
|
|
61
85
|
|
|
62
86
|
return _loss_reduction(loss, reduction)
|
|
63
87
|
|
|
@@ -68,6 +92,7 @@ def cross_entropy(
|
|
|
68
92
|
weight: Tensor | None = None,
|
|
69
93
|
reduction: _ReductionType | None = "mean",
|
|
70
94
|
eps: float = 1e-7,
|
|
95
|
+
ignore_index: int | None = None,
|
|
71
96
|
) -> Tensor:
|
|
72
97
|
exp_logits = lucid.exp(input_ - lucid.max(input_, axis=1, keepdims=True))
|
|
73
98
|
prob = exp_logits / lucid.sum(exp_logits, axis=1, keepdims=True)
|
|
@@ -79,6 +104,9 @@ def cross_entropy(
|
|
|
79
104
|
if weight is not None:
|
|
80
105
|
loss *= weight[target_int]
|
|
81
106
|
|
|
107
|
+
if ignore_index is not None:
|
|
108
|
+
return _ignore_index_loss(loss, target_int, ignore_index, reduction)
|
|
109
|
+
|
|
82
110
|
return _loss_reduction(loss, reduction)
|
|
83
111
|
|
|
84
112
|
|
|
@@ -87,12 +115,19 @@ def nll_loss(
|
|
|
87
115
|
target: Tensor,
|
|
88
116
|
weight: Tensor | None = None,
|
|
89
117
|
reduction: _ReductionType | None = "mean",
|
|
118
|
+
ignore_index: int | None = None,
|
|
90
119
|
) -> Tensor:
|
|
91
120
|
target_int = target.astype(lucid.Int)
|
|
92
|
-
|
|
121
|
+
n = input_.shape[0]
|
|
122
|
+
idx = lucid.arange(n, device=input_.device).astype(lucid.Int)
|
|
123
|
+
|
|
124
|
+
loss = -input_[idx, target_int]
|
|
93
125
|
if weight is not None:
|
|
94
126
|
loss *= weight[target_int]
|
|
95
127
|
|
|
128
|
+
if ignore_index is not None:
|
|
129
|
+
return _ignore_index_loss(loss, target_int, ignore_index, reduction)
|
|
130
|
+
|
|
96
131
|
return _loss_reduction(loss, reduction)
|
|
97
132
|
|
|
98
133
|
|
lucid/nn/module.py
CHANGED
|
@@ -118,7 +118,7 @@ class Module:
|
|
|
118
118
|
for param in self.parameters():
|
|
119
119
|
param.zero()
|
|
120
120
|
|
|
121
|
-
def forward(self
|
|
121
|
+
def forward(self) -> Tensor | tuple[Tensor, ...]:
|
|
122
122
|
raise NotImplementedError(
|
|
123
123
|
"The forward method must be implemented by the subclass."
|
|
124
124
|
)
|
|
@@ -204,7 +204,7 @@ class Module:
|
|
|
204
204
|
destination=destination, prefix=prefix + name + ".", keep_vars=keep_vars
|
|
205
205
|
)
|
|
206
206
|
|
|
207
|
-
for key in destination.keys():
|
|
207
|
+
for key in list(destination.keys()):
|
|
208
208
|
if key in self._state_dict_pass_attr:
|
|
209
209
|
del destination[key]
|
|
210
210
|
|
|
@@ -229,10 +229,8 @@ class Module:
|
|
|
229
229
|
if key in own_state:
|
|
230
230
|
attr = own_state[key]
|
|
231
231
|
if isinstance(attr, (nn.Parameter, nn.Buffer)):
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
else:
|
|
235
|
-
attr.data = value
|
|
232
|
+
value_t = Tensor(value, device=self.device)
|
|
233
|
+
attr.data = value_t.data
|
|
236
234
|
else:
|
|
237
235
|
setattr(self, key, value)
|
|
238
236
|
elif strict:
|
lucid/optim/_base.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from collections import defaultdict
|
|
2
|
-
from typing import Any, Iterable
|
|
2
|
+
from typing import Any, Iterable
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
4
|
import copy
|
|
5
5
|
|
|
@@ -12,19 +12,19 @@ class Optimizer(ABC):
|
|
|
12
12
|
def __init__(
|
|
13
13
|
self, params: Iterable[nn.Parameter], defaults: dict[str, Any]
|
|
14
14
|
) -> None:
|
|
15
|
-
super().__init__()
|
|
16
15
|
if not isinstance(params, Iterable):
|
|
17
16
|
raise TypeError("params should be an iterable of Parameters.")
|
|
18
17
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
18
|
+
param_list = list(params)
|
|
19
|
+
for p in param_list:
|
|
20
|
+
if not isinstance(p, nn.Parameter):
|
|
21
|
+
raise TypeError(f"Expected nn.Parameter, got {type(p).__name__}.")
|
|
23
22
|
|
|
24
|
-
|
|
25
|
-
self
|
|
26
|
-
|
|
27
|
-
|
|
23
|
+
self.defaults: dict[str, Any] = dict(defaults)
|
|
24
|
+
self.param_groups: list[dict[str, Any]] = self.param_groups_setup(
|
|
25
|
+
param_list, self.defaults
|
|
26
|
+
)
|
|
27
|
+
self.state: dict[nn.Parameter, dict[str, Any]] = defaultdict(dict)
|
|
28
28
|
|
|
29
29
|
@abstractmethod
|
|
30
30
|
def step(self, closure: _OptimClosure | None = None) -> Any | None:
|
|
@@ -33,25 +33,110 @@ class Optimizer(ABC):
|
|
|
33
33
|
def zero_grad(self) -> None:
|
|
34
34
|
for group in self.param_groups:
|
|
35
35
|
for param in group["params"]:
|
|
36
|
-
param.
|
|
36
|
+
if isinstance(param, nn.Parameter):
|
|
37
|
+
param.zero_grad()
|
|
38
|
+
|
|
39
|
+
def param_groups_setup(
|
|
40
|
+
self, params: list[nn.Parameter], defaults: dict[str, Any]
|
|
41
|
+
) -> list[dict[str, Any]]:
|
|
42
|
+
return [{"params": list(params), **defaults}]
|
|
37
43
|
|
|
38
44
|
def add_param_group(self, param_group: dict[str, Any]) -> None:
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
45
|
+
if "params" not in param_group:
|
|
46
|
+
raise ValueError("param_group must have a 'params' key.")
|
|
47
|
+
|
|
48
|
+
params = list(param_group["params"])
|
|
49
|
+
if len(params) == 0:
|
|
50
|
+
raise ValueError("param_group['params'] must be non-empty.")
|
|
51
|
+
|
|
52
|
+
for p in params:
|
|
53
|
+
if not isinstance(p, nn.Parameter):
|
|
54
|
+
raise TypeError(
|
|
55
|
+
f"Expected nn.Parameter in param_group, got {type(p).__name__}."
|
|
43
56
|
)
|
|
44
|
-
self.param_groups.append(param_group)
|
|
45
57
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
"
|
|
49
|
-
|
|
58
|
+
existing = set()
|
|
59
|
+
for g in self.param_groups:
|
|
60
|
+
existing.update(g["params"])
|
|
61
|
+
|
|
62
|
+
if any(p in existing for p in params):
|
|
63
|
+
raise ValueError("Some parameters appear in more than one parameter group.")
|
|
64
|
+
|
|
65
|
+
filled = {
|
|
66
|
+
**self.defaults,
|
|
67
|
+
**{k: v for k, v in param_group.items() if k != "params"},
|
|
50
68
|
}
|
|
69
|
+
filled["params"] = params
|
|
70
|
+
self.param_groups.append(filled)
|
|
71
|
+
|
|
72
|
+
def _flat_params(self) -> list[nn.Parameter]:
|
|
73
|
+
flat: list[nn.Parameter] = []
|
|
74
|
+
for g in self.param_groups:
|
|
75
|
+
flat.extend(g["params"])
|
|
76
|
+
|
|
77
|
+
return flat
|
|
78
|
+
|
|
79
|
+
def state_dict(self) -> dict:
|
|
80
|
+
param_to_idx: dict[nn.Parameter, int] = {}
|
|
81
|
+
for idx, p in enumerate(self._flat_params()):
|
|
82
|
+
if p not in param_to_idx:
|
|
83
|
+
param_to_idx[p] = idx
|
|
84
|
+
|
|
85
|
+
packed_state: dict[int, dict[str, Any]] = {}
|
|
86
|
+
for p, st in self.state.items():
|
|
87
|
+
if p in param_to_idx:
|
|
88
|
+
packed_state[param_to_idx[p]] = copy.deepcopy(st)
|
|
89
|
+
|
|
90
|
+
packed_groups: list[dict[str, Any]] = []
|
|
91
|
+
for g in self.param_groups:
|
|
92
|
+
new_g: dict[str, Any] = {}
|
|
93
|
+
|
|
94
|
+
for k, v in g.items():
|
|
95
|
+
if k == "params":
|
|
96
|
+
new_g[k] = [param_to_idx[p] for p in v]
|
|
97
|
+
else:
|
|
98
|
+
new_g[k] = copy.deepcopy(v)
|
|
99
|
+
|
|
100
|
+
packed_groups.append(new_g)
|
|
101
|
+
|
|
102
|
+
return {"state": packed_state, "param_groups": packed_groups}
|
|
103
|
+
|
|
104
|
+
def load_state_dict(self, state_dict: dict) -> None:
|
|
105
|
+
if (
|
|
106
|
+
not isinstance(state_dict, dict)
|
|
107
|
+
or "state" not in state_dict
|
|
108
|
+
or "param_groups" not in state_dict
|
|
109
|
+
):
|
|
110
|
+
raise TypeError("Invalid state_dict format for Optimizer.")
|
|
111
|
+
|
|
112
|
+
saved_groups = state_dict["param_groups"]
|
|
113
|
+
saved_state = state_dict["state"]
|
|
114
|
+
|
|
115
|
+
current_params = self._flat_params()
|
|
116
|
+
n_current = len(current_params)
|
|
117
|
+
|
|
118
|
+
new_groups: list[dict[str, Any]] = []
|
|
119
|
+
for sg in saved_groups:
|
|
120
|
+
if "params" not in sg:
|
|
121
|
+
raise KeyError("Saved param_group missing 'params'.")
|
|
122
|
+
indices: list[int] = list(sg["params"])
|
|
123
|
+
|
|
124
|
+
if any(i < 0 or i >= n_current for i in indices):
|
|
125
|
+
raise IndexError("Saved state refers to parameter index out of range.")
|
|
126
|
+
|
|
127
|
+
params = [current_params[i] for i in indices]
|
|
128
|
+
ng = {
|
|
129
|
+
k: (params if k == "params" else copy.deepcopy(v))
|
|
130
|
+
for k, v in sg.items()
|
|
131
|
+
}
|
|
132
|
+
new_groups.append(ng)
|
|
133
|
+
|
|
134
|
+
self.param_groups = new_groups
|
|
51
135
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
136
|
+
self.state = defaultdict(dict)
|
|
137
|
+
for i, p in enumerate(self._flat_params()):
|
|
138
|
+
if i in saved_state:
|
|
139
|
+
self.state[p] = copy.deepcopy(saved_state[i])
|
|
55
140
|
|
|
56
141
|
def __repr__(self) -> str:
|
|
57
|
-
return f"{type(self).__name__}({self.
|
|
142
|
+
return f"{type(self).__name__}({self.defaults})"
|
|
@@ -8,56 +8,78 @@ class LRScheduler(ABC):
|
|
|
8
8
|
def __init__(
|
|
9
9
|
self, optimizer: Optimizer, last_epoch: int = -1, verbose: bool = False
|
|
10
10
|
) -> None:
|
|
11
|
-
super().__init__()
|
|
12
11
|
if not hasattr(optimizer, "param_groups"):
|
|
13
12
|
raise TypeError(f"{type(optimizer).__name__} is not a valid optimizer.")
|
|
14
13
|
|
|
15
14
|
self.optimizer = optimizer
|
|
16
15
|
self.last_epoch = last_epoch
|
|
17
16
|
self.verbose = verbose
|
|
18
|
-
self.base_lrs = [
|
|
17
|
+
self.base_lrs: list[float] = [float(g["lr"]) for g in optimizer.param_groups]
|
|
19
18
|
|
|
20
19
|
self._step_count = 0
|
|
21
|
-
self._last_lr = [
|
|
20
|
+
self._last_lr: list[float] = [float(g["lr"]) for g in optimizer.param_groups]
|
|
22
21
|
|
|
23
22
|
@abstractmethod
|
|
24
23
|
def get_lr(self) -> list[float]:
|
|
25
|
-
raise NotImplementedError
|
|
24
|
+
raise NotImplementedError
|
|
26
25
|
|
|
27
26
|
def step(self, epoch: int | None = None) -> None:
|
|
28
|
-
if epoch is
|
|
29
|
-
self.last_epoch
|
|
27
|
+
if epoch is None:
|
|
28
|
+
self.last_epoch += 1
|
|
30
29
|
else:
|
|
31
|
-
self.
|
|
32
|
-
|
|
30
|
+
self.last_epoch = int(epoch)
|
|
31
|
+
self._step_count += 1
|
|
33
32
|
|
|
34
33
|
new_lrs = self.get_lr()
|
|
35
|
-
|
|
36
|
-
|
|
34
|
+
if len(new_lrs) != len(self.optimizer.param_groups):
|
|
35
|
+
raise ValueError(
|
|
36
|
+
f"get_lr returned {len(new_lrs)} values, "
|
|
37
|
+
f"but optimizer has {len(self.optimizer.param_groups)} param groups."
|
|
38
|
+
)
|
|
37
39
|
|
|
38
|
-
self.
|
|
40
|
+
for group, lr in zip(self.optimizer.param_groups, new_lrs):
|
|
41
|
+
group["lr"] = float(lr)
|
|
42
|
+
|
|
43
|
+
self._last_lr = [float(g["lr"]) for g in self.optimizer.param_groups]
|
|
39
44
|
|
|
40
45
|
if self.verbose:
|
|
41
|
-
print(
|
|
46
|
+
print(
|
|
47
|
+
f"Epoch {self.last_epoch}: setting learning rates to {self._last_lr}."
|
|
48
|
+
)
|
|
42
49
|
|
|
43
50
|
def state_dict(self) -> dict[str, Any]:
|
|
44
51
|
return {
|
|
45
|
-
"last_epoch": self.last_epoch,
|
|
46
|
-
"base_lrs": self.base_lrs,
|
|
47
|
-
"_step_count": self._step_count,
|
|
48
|
-
"_last_lr": self._last_lr,
|
|
52
|
+
"last_epoch": int(self.last_epoch),
|
|
53
|
+
"base_lrs": [float(x) for x in self.base_lrs],
|
|
54
|
+
"_step_count": int(self._step_count),
|
|
55
|
+
"_last_lr": [float(x) for x in self._last_lr],
|
|
56
|
+
"_group_count": len(self.optimizer.param_groups),
|
|
49
57
|
}
|
|
50
58
|
|
|
51
59
|
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
52
|
-
|
|
53
|
-
|
|
60
|
+
required = {"last_epoch", "base_lrs", "_step_count", "_last_lr"}
|
|
61
|
+
missing = required - set(state_dict)
|
|
62
|
+
if missing:
|
|
63
|
+
raise KeyError(f"Missing keys in scheduler state_dict: {missing}")
|
|
64
|
+
|
|
65
|
+
saved_group_count = int(
|
|
66
|
+
state_dict.get("_group_count", len(state_dict["_last_lr"]))
|
|
67
|
+
)
|
|
68
|
+
current_group_count = len(self.optimizer.param_groups)
|
|
69
|
+
if saved_group_count != current_group_count:
|
|
70
|
+
raise ValueError(
|
|
71
|
+
"Cannot load scheduler state: param group count mismatch "
|
|
72
|
+
f"(saved={saved_group_count}, current={current_group_count})."
|
|
73
|
+
)
|
|
54
74
|
|
|
55
|
-
self.
|
|
56
|
-
self.
|
|
75
|
+
self.last_epoch = int(state_dict["last_epoch"])
|
|
76
|
+
self.base_lrs = [float(x) for x in state_dict["base_lrs"]]
|
|
77
|
+
self._step_count = int(state_dict["_step_count"])
|
|
78
|
+
self._last_lr = [float(x) for x in state_dict["_last_lr"]]
|
|
57
79
|
|
|
58
|
-
for
|
|
59
|
-
|
|
80
|
+
for group, lr in zip(self.optimizer.param_groups, self._last_lr):
|
|
81
|
+
group["lr"] = float(lr)
|
|
60
82
|
|
|
61
83
|
@property
|
|
62
84
|
def last_lr(self) -> list[float]:
|
|
63
|
-
return self._last_lr
|
|
85
|
+
return list(self._last_lr)
|
lucid/port.py
CHANGED
|
@@ -5,14 +5,13 @@ from typing import Literal
|
|
|
5
5
|
|
|
6
6
|
from lucid._tensor import Tensor
|
|
7
7
|
from lucid.nn import Module
|
|
8
|
-
from lucid.types import _NumPyArray
|
|
9
8
|
|
|
10
9
|
|
|
11
10
|
__all__ = ["save", "load"]
|
|
12
11
|
|
|
13
|
-
_LucidPortable = Tensor | Module | OrderedDict
|
|
12
|
+
_LucidPortable = Tensor | Module | OrderedDict | dict
|
|
14
13
|
|
|
15
|
-
FORMAT_VERSION: float = 1.
|
|
14
|
+
FORMAT_VERSION: float = 1.1
|
|
16
15
|
|
|
17
16
|
EXTENSIONS = Literal[".lct", ".lcd", ".safetensors"]
|
|
18
17
|
|
|
@@ -30,7 +29,7 @@ def save(obj: _LucidPortable, path: Path | str, safetensors: bool = False) -> Pa
|
|
|
30
29
|
if path.suffix == "":
|
|
31
30
|
if isinstance(obj, Tensor):
|
|
32
31
|
path = path.with_suffix(".lct")
|
|
33
|
-
elif isinstance(obj, (Module, OrderedDict)):
|
|
32
|
+
elif isinstance(obj, (Module, OrderedDict, dict)):
|
|
34
33
|
path = (
|
|
35
34
|
path.with_suffix(".safetensors")
|
|
36
35
|
if safetensors
|
|
@@ -56,10 +55,14 @@ def save(obj: _LucidPortable, path: Path | str, safetensors: bool = False) -> Pa
|
|
|
56
55
|
elif suffix == ".lcd":
|
|
57
56
|
if isinstance(obj, Module):
|
|
58
57
|
obj = obj.state_dict()
|
|
59
|
-
if not isinstance(obj, OrderedDict):
|
|
60
|
-
raise TypeError("Expected a state_dict
|
|
58
|
+
if not isinstance(obj, (OrderedDict, dict)):
|
|
59
|
+
raise TypeError("Expected a state_dict for .lcd file.")
|
|
61
60
|
|
|
62
|
-
data = {
|
|
61
|
+
data = {
|
|
62
|
+
"type": type(obj).__name__,
|
|
63
|
+
"format_version": FORMAT_VERSION,
|
|
64
|
+
"content": obj,
|
|
65
|
+
}
|
|
63
66
|
|
|
64
67
|
elif suffix == ".safetensors":
|
|
65
68
|
try:
|
|
@@ -72,10 +75,8 @@ def save(obj: _LucidPortable, path: Path | str, safetensors: bool = False) -> Pa
|
|
|
72
75
|
|
|
73
76
|
if isinstance(obj, Module):
|
|
74
77
|
obj = obj.state_dict()
|
|
75
|
-
if not isinstance(obj, OrderedDict):
|
|
76
|
-
raise TypeError(
|
|
77
|
-
"Expected a state_dict (OrderedDict) for .safetensors file."
|
|
78
|
-
)
|
|
78
|
+
if not isinstance(obj, (OrderedDict, dict)):
|
|
79
|
+
raise TypeError("Expected a state_dict for .safetensors file.")
|
|
79
80
|
|
|
80
81
|
save_file(obj, str(path))
|
|
81
82
|
return path.resolve()
|
|
@@ -122,7 +123,7 @@ def load(path: Path | str) -> _LucidPortable:
|
|
|
122
123
|
array = data["content"]
|
|
123
124
|
return Tensor(array)
|
|
124
125
|
|
|
125
|
-
elif file_type
|
|
126
|
+
elif file_type in {"OrderedDict", "dict"}:
|
|
126
127
|
return data["content"]
|
|
127
128
|
|
|
128
129
|
else:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
lucid/__init__.py,sha256=NzXUnsIgDEq0Yx146zxhqID44X_0yNirkxt68NuL6Kg,7900
|
|
2
2
|
lucid/error.py,sha256=qnTiVuZm3c5-DIt-OOyobZ7RUm7E1K4NR0j998LG1ug,709
|
|
3
|
-
lucid/port.py,sha256=
|
|
3
|
+
lucid/port.py,sha256=Kt1YaSWef_eKF4KRj-UFhirvFC5urEESfYQ_BSlBZGE,3811
|
|
4
4
|
lucid/types.py,sha256=3yj99eZNv8N3bupP4htHU_SEvrqNyRuq2k7WGLVD-X0,3605
|
|
5
5
|
lucid/_backend/__init__.py,sha256=n1bnYdeb_bNDBKASWGywTRa0Ne9hMAkal3AuVZJgovI,5
|
|
6
6
|
lucid/_backend/core.py,sha256=xMyhNoEswIxCr69wTNo3QXtNVUfX7xsmDP4t0JRlRzQ,5798
|
|
@@ -73,15 +73,15 @@ lucid/models/seq2seq/__init__.py,sha256=wjsrhj4H_AcqwwbebAN8b68QBA8L6p1_12dkG299
|
|
|
73
73
|
lucid/models/seq2seq/transformer.py,sha256=y5rerCs1s6jXTsVvbgscWScKpQKuSu1fezsBe7PNTRA,3513
|
|
74
74
|
lucid/nn/__init__.py,sha256=Kc6_wlpWo0_AtywX8aEWtzjKb0ju2c2cKGNsEY9ho4E,153
|
|
75
75
|
lucid/nn/fused.py,sha256=ZGOQmDThaGNQLC59y3M7s993K_K09ce6IZP8cFX8FUE,5498
|
|
76
|
-
lucid/nn/module.py,sha256=
|
|
76
|
+
lucid/nn/module.py,sha256=XvFWJ8NqXeZpr3RmKBQBz5eqT535Oi_7DaPN1Zi9gJc,21971
|
|
77
77
|
lucid/nn/parameter.py,sha256=jDaWukWecCcH9ri65SefNls66MmyTyucFolWbzSjapc,856
|
|
78
|
-
lucid/nn/functional/__init__.py,sha256=
|
|
78
|
+
lucid/nn/functional/__init__.py,sha256=90Zi7jClPOiiSYx-Qkg0QTideKD6GigbWON9eFCoxzg,13869
|
|
79
79
|
lucid/nn/functional/_activation.py,sha256=nQVwArvPuwkUpLMLCNABTw96Zgw9VsPB8SyXCL6t2LM,1331
|
|
80
80
|
lucid/nn/functional/_attention.py,sha256=nrZF3-2AR03kNo1PGNszujhWlAVcab_FNQwOCWZT47I,946
|
|
81
81
|
lucid/nn/functional/_conv.py,sha256=E8PF5UBivTb6zIYj6DVCbdzJSt10rW1-vepaZB9_QPc,7125
|
|
82
82
|
lucid/nn/functional/_drop.py,sha256=99zcj-06BGHaAPEFHeXJWlfxa-jinZrKc5pteXYtxL0,2351
|
|
83
83
|
lucid/nn/functional/_linear.py,sha256=0KPs14tdjTqJDhLaHTQp_OAOpDP0i-Pd4kJes2BTR9I,663
|
|
84
|
-
lucid/nn/functional/_loss.py,sha256=
|
|
84
|
+
lucid/nn/functional/_loss.py,sha256=H_IB08O-7XTBXVZwLgsgNA32zmUspOS0Hl5TqOkjlRk,3871
|
|
85
85
|
lucid/nn/functional/_norm.py,sha256=hmMkHjtUOnisi7BBD4I6CrU4-O1UGo213TtceP4m3v8,5862
|
|
86
86
|
lucid/nn/functional/_pool.py,sha256=FgjqfWvH2rNB39U9Ktpnz63X26xBH-kLXcppfc6kKis,7811
|
|
87
87
|
lucid/nn/functional/_spatial.py,sha256=lazoSvVMFcauBWRbMOqmkgixA5bDes6scGHVWCgVmHE,3911
|
|
@@ -102,13 +102,13 @@ lucid/nn/modules/sparse.py,sha256=EpjiviED2nI55wUjh1twFwa4Lvlrzw0TR6lpCDGeSbo,11
|
|
|
102
102
|
lucid/nn/modules/transformer.py,sha256=z56emF_eX18pxRELjfmmsY-7Bn9h2yjIdxCaxs6YDwA,11246
|
|
103
103
|
lucid/nn/modules/vision.py,sha256=8xYasT7TNj4NXwMwwJIw1nbV1paeWEFg_ZohXn9kZBg,1579
|
|
104
104
|
lucid/optim/__init__.py,sha256=21EcCCPwrhPGP9TXvDje075_S2hPr0pHToygCaq8keI,201
|
|
105
|
-
lucid/optim/_base.py,sha256=
|
|
105
|
+
lucid/optim/_base.py,sha256=KxM5h5ONeO8hCpAzD2_vverFRKeymu2XC6AHN_L_v3g,4859
|
|
106
106
|
lucid/optim/ada.py,sha256=POIl7dbv3qqwKxGGaceSrs-lZF1tD-vyvDxjtZdx--E,5807
|
|
107
107
|
lucid/optim/adam.py,sha256=pVlZIcXD1s-IYK-WAfFognId8RhxzmlS5227-i0Vhq4,10347
|
|
108
108
|
lucid/optim/prop.py,sha256=CbsWmoBb_g_8z16M3T6dMoSR9c72hm8M375IT1UHjpw,4740
|
|
109
109
|
lucid/optim/sgd.py,sha256=DBZ1ZXQ9TfKZCRECfNRMDH9mvqUWCOPdY5TobnVxpz8,4477
|
|
110
110
|
lucid/optim/lr_scheduler/__init__.py,sha256=kUoyN2g9nwTtEAqEVij832WSRvzEpKZywSJdfD7MQvY,58
|
|
111
|
-
lucid/optim/lr_scheduler/_base.py,sha256=
|
|
111
|
+
lucid/optim/lr_scheduler/_base.py,sha256=NNJnjwmJpsRXathrbLtH4tjfBHtwOiJ5HwF1_S6Ym5c,3092
|
|
112
112
|
lucid/optim/lr_scheduler/_schedulers.py,sha256=OIzduTXV6Ut4qcvw6neMPr3jlv6BgTSsys0-6KoHxK4,8140
|
|
113
113
|
lucid/random/__init__.py,sha256=s8EAaKhEiTKT_vYjP4IFHx0xQVa1jqc_qIyvMauUu7M,2727
|
|
114
114
|
lucid/random/_func.py,sha256=1Lu4m-ciEK037chNDGqv_j00RgGGzQ7UfslSfYActUk,2232
|
|
@@ -119,8 +119,8 @@ lucid/visual/__init__.py,sha256=6TuFDfmXTwpLyHl7_KqBfdzW6zqHjGzIFvymjFPlvjI,21
|
|
|
119
119
|
lucid/visual/graph.py,sha256=YjpIDM_lloZARw3sCBiXPl_hT5A2gTk2fEHvwvJWXTk,4599
|
|
120
120
|
lucid/weights/__init__.py,sha256=z1AikA3rOEeckWGkYWlcZkxNlJo9Xwa39PL6ly3hWnc,8801
|
|
121
121
|
lucid/weights/__init__.pyi,sha256=lFonYC3cUx2Idolf3AEPnjFcyqcn3UDU84oJlZafqLY,3013
|
|
122
|
-
lucid_dl-2.7.
|
|
123
|
-
lucid_dl-2.7.
|
|
124
|
-
lucid_dl-2.7.
|
|
125
|
-
lucid_dl-2.7.
|
|
126
|
-
lucid_dl-2.7.
|
|
122
|
+
lucid_dl-2.7.6.dist-info/licenses/LICENSE,sha256=vxRFYnVD1IeYtsvw-KmoElfqrjxKHv1h9YTvsG54loQ,1065
|
|
123
|
+
lucid_dl-2.7.6.dist-info/METADATA,sha256=GgdsxwpPv_EGIOIARff6VweVKfgDi_A8vs9aHvFQcCI,11260
|
|
124
|
+
lucid_dl-2.7.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
125
|
+
lucid_dl-2.7.6.dist-info/top_level.txt,sha256=uzP_qBx9iNWIHKJRlElYcBLYVqMpdm9Q1Ma63QPYbFc,6
|
|
126
|
+
lucid_dl-2.7.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|