homa 0.3.2__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- homa/core/concerns/__init__.py +0 -1
- homa/device.py +2 -3
- homa/ensemble/concerns/CalculatesMetricNecessities.py +2 -2
- homa/ensemble/concerns/PredictsProbabilities.py +2 -2
- homa/ensemble/concerns/ReportsClassificationMetrics.py +1 -2
- homa/ensemble/concerns/ReportsEnsembleAccuracy.py +2 -2
- homa/ensemble/concerns/ReportsEnsembleF1.py +2 -2
- homa/ensemble/concerns/ReportsEnsembleKappa.py +2 -2
- homa/ensemble/concerns/ReportsEnsembleSize.py +2 -2
- homa/ensemble/concerns/ReportsLogits.py +4 -23
- homa/ensemble/concerns/SavesEnsembleModels.py +2 -2
- homa/ensemble/concerns/StoresModels.py +2 -2
- homa/rl/SoftActorCritic.py +6 -9
- homa/rl/buffers/Buffer.py +1 -3
- homa/rl/buffers/SoftActorCriticBuffer.py +2 -10
- homa/rl/sac/SoftActor.py +10 -11
- homa/rl/sac/SoftCritic.py +18 -16
- homa/rl/sac/modules/SoftActorModule.py +1 -1
- {homa-0.3.2.dist-info → homa-0.3.11.dist-info}/METADATA +1 -1
- {homa-0.3.2.dist-info → homa-0.3.11.dist-info}/RECORD +23 -25
- homa/core/concerns/TracksTime.py +0 -7
- homa/rl/utils.py +0 -7
- {homa-0.3.2.dist-info → homa-0.3.11.dist-info}/WHEEL +0 -0
- {homa-0.3.2.dist-info → homa-0.3.11.dist-info}/entry_points.txt +0 -0
- {homa-0.3.2.dist-info → homa-0.3.11.dist-info}/top_level.txt +0 -0
homa/core/concerns/__init__.py
CHANGED
homa/device.py
CHANGED
|
@@ -3,8 +3,8 @@ from .ReportsLogits import ReportsLogits
|
|
|
3
3
|
|
|
4
4
|
|
|
5
5
|
class PredictsProbabilities(ReportsLogits):
|
|
6
|
-
def __init__(self):
|
|
7
|
-
super().__init__()
|
|
6
|
+
def __init__(self, *args, **kwargs):
|
|
7
|
+
super().__init__(*args, **kwargs)
|
|
8
8
|
|
|
9
9
|
def predict(self, x: torch.Tensor) -> torch.Tensor:
|
|
10
10
|
logits = self.logits(x)
|
|
@@ -3,8 +3,8 @@ from torch.utils.data import DataLoader
|
|
|
3
3
|
|
|
4
4
|
|
|
5
5
|
class ReportsEnsembleAccuracy:
|
|
6
|
-
def __init__(self):
|
|
7
|
-
super().__init__()
|
|
6
|
+
def __init__(self, *args, **kwargs):
|
|
7
|
+
super().__init__(*args, **kwargs)
|
|
8
8
|
|
|
9
9
|
def accuracy(self, dataloader: DataLoader) -> float:
|
|
10
10
|
predictions, labels = self.metric_necessities(dataloader)
|
|
@@ -2,8 +2,8 @@ from sklearn.metrics import f1_score as f1
|
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
class ReportsEnsembleF1:
|
|
5
|
-
def __init__(self):
|
|
6
|
-
super().__init__()
|
|
5
|
+
def __init__(self, *args, **kwargs):
|
|
6
|
+
super().__init__(*args, **kwargs)
|
|
7
7
|
|
|
8
8
|
def f1(self) -> float:
|
|
9
9
|
predictions, labels = self.metric_necessities()
|
|
@@ -2,8 +2,8 @@ from sklearn.metrics import cohen_kappa_score as kappa
|
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
class ReportsEnsembleKappa:
|
|
5
|
-
def __init__(self):
|
|
6
|
-
super().__init__()
|
|
5
|
+
def __init__(self, *args, **kwargs):
|
|
6
|
+
super().__init__(*args, **kwargs)
|
|
7
7
|
|
|
8
8
|
def accuracy(self) -> float:
|
|
9
9
|
predictions, labels = self.metric_necessities()
|
|
@@ -2,14 +2,11 @@ import torch
|
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
class ReportsLogits:
|
|
5
|
-
def __init__(self):
|
|
6
|
-
super().__init__()
|
|
5
|
+
def __init__(self, *args, **kwargs):
|
|
6
|
+
super().__init__(*args, **kwargs)
|
|
7
7
|
|
|
8
|
-
def
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def logits_sum(self, x: torch.Tensor) -> torch.Tensor:
|
|
12
|
-
batch_size = x.size(0)
|
|
8
|
+
def logits(self, x: torch.Tensor) -> torch.Tensor:
|
|
9
|
+
batch_size = x.shape[0]
|
|
13
10
|
logits = torch.zeros((batch_size, self.num_classes))
|
|
14
11
|
for factory, weight in zip(self.factories, self.weights):
|
|
15
12
|
model = factory(num_classes=self.num_classes)
|
|
@@ -17,22 +14,6 @@ class ReportsLogits:
|
|
|
17
14
|
logits += model(x)
|
|
18
15
|
return logits
|
|
19
16
|
|
|
20
|
-
def check_aggregation_strategy(self, aggregation: str):
|
|
21
|
-
if aggregation not in ["mean", "average", "sum"]:
|
|
22
|
-
raise ValueError(
|
|
23
|
-
f"Ensemble aggregation strategy must be in [mean, average, sum], but found {aggregation}."
|
|
24
|
-
)
|
|
25
|
-
|
|
26
|
-
def logits(self, x: torch.Tensor, aggregation: str = "mean") -> torch.Tensor:
|
|
27
|
-
self.check_aggregation_strategy(aggregation=aggregation)
|
|
28
|
-
logits_handlers = {
|
|
29
|
-
"mean": self.logits_average,
|
|
30
|
-
"average": self.logits_average,
|
|
31
|
-
"sum": self.logits_sum,
|
|
32
|
-
}
|
|
33
|
-
handler = logits_handlers.get(aggregation)
|
|
34
|
-
return handler(x)
|
|
35
|
-
|
|
36
17
|
@torch.no_grad()
|
|
37
18
|
def logits_(self, *args, **kwargs):
|
|
38
19
|
return self.logits(*args, **kwargs)
|
|
@@ -5,8 +5,8 @@ from ...vision import Model
|
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class StoresModels:
|
|
8
|
-
def __init__(self):
|
|
9
|
-
super().__init__()
|
|
8
|
+
def __init__(self, *args, **kwargs):
|
|
9
|
+
super().__init__(*args, **kwargs)
|
|
10
10
|
self.factories: List[Type[torch.nn.Module]] = []
|
|
11
11
|
self.weights: List[OrderedDict] = []
|
|
12
12
|
|
homa/rl/SoftActorCritic.py
CHANGED
|
@@ -1,15 +1,14 @@
|
|
|
1
1
|
from .sac import SoftActor, SoftCritic
|
|
2
2
|
from .buffers import SoftActorCriticBuffer
|
|
3
|
-
from ..core.concerns import TracksTime
|
|
4
3
|
|
|
5
4
|
|
|
6
|
-
class SoftActorCritic
|
|
5
|
+
class SoftActorCritic:
|
|
7
6
|
def __init__(
|
|
8
7
|
self,
|
|
9
8
|
state_dimension: int,
|
|
10
9
|
action_dimension: int,
|
|
11
10
|
hidden_dimension: int = 256,
|
|
12
|
-
buffer_capacity: int =
|
|
11
|
+
buffer_capacity: int = 1_000_000,
|
|
13
12
|
batch_size: int = 256,
|
|
14
13
|
actor_lr: float = 0.0002,
|
|
15
14
|
critic_lr: float = 0.0003,
|
|
@@ -20,13 +19,10 @@ class SoftActorCritic(TracksTime):
|
|
|
20
19
|
gamma: float = 0.99,
|
|
21
20
|
min_std: float = -20,
|
|
22
21
|
max_std: float = 2,
|
|
23
|
-
warmup: int =
|
|
22
|
+
warmup: int = 10_000,
|
|
24
23
|
):
|
|
25
|
-
super().__init__()
|
|
26
|
-
|
|
27
24
|
self.batch_size: int = batch_size
|
|
28
25
|
self.warmup: int = warmup
|
|
29
|
-
self.tau: float = tau
|
|
30
26
|
|
|
31
27
|
self.actor = SoftActor(
|
|
32
28
|
state_dimension=state_dimension,
|
|
@@ -44,6 +40,7 @@ class SoftActorCritic(TracksTime):
|
|
|
44
40
|
hidden_dimension=hidden_dimension,
|
|
45
41
|
lr=critic_lr,
|
|
46
42
|
weight_decay=critic_decay,
|
|
43
|
+
tau=tau,
|
|
47
44
|
gamma=gamma,
|
|
48
45
|
alpha=alpha,
|
|
49
46
|
)
|
|
@@ -63,5 +60,5 @@ class SoftActorCritic(TracksTime):
|
|
|
63
60
|
next_states=data.next_states,
|
|
64
61
|
actor=self.actor,
|
|
65
62
|
)
|
|
66
|
-
self.actor.train(states=data.states,
|
|
67
|
-
self.critic.update(
|
|
63
|
+
self.actor.train(states=data.states, critic_network=self.critic.network)
|
|
64
|
+
self.critic.update()
|
homa/rl/buffers/Buffer.py
CHANGED
|
@@ -1,12 +1,10 @@
|
|
|
1
|
-
from collections import deque
|
|
2
|
-
from typing import Type
|
|
3
1
|
from .concerns import ResetsCollection, HasRecordAlternatives
|
|
4
2
|
|
|
5
3
|
|
|
6
4
|
class Buffer(ResetsCollection, HasRecordAlternatives):
|
|
7
5
|
def __init__(self, capacity: int):
|
|
8
6
|
self.capacity: int = capacity
|
|
9
|
-
self.
|
|
7
|
+
self.reset()
|
|
10
8
|
|
|
11
9
|
@property
|
|
12
10
|
def size(self):
|
|
@@ -3,7 +3,6 @@ import random
|
|
|
3
3
|
import torch
|
|
4
4
|
from types import SimpleNamespace
|
|
5
5
|
from .Buffer import Buffer
|
|
6
|
-
from ...device import move
|
|
7
6
|
|
|
8
7
|
|
|
9
8
|
class SoftActorCriticBuffer(Buffer):
|
|
@@ -23,7 +22,7 @@ class SoftActorCriticBuffer(Buffer):
|
|
|
23
22
|
(state, action, reward, next_state, termination, probability)
|
|
24
23
|
)
|
|
25
24
|
|
|
26
|
-
def sample(self, k: int, as_tensor: bool = False
|
|
25
|
+
def sample(self, k: int, as_tensor: bool = False):
|
|
27
26
|
batch = random.sample(self.collection, k)
|
|
28
27
|
states, actions, rewards, next_states, terminations, probabilities = zip(*batch)
|
|
29
28
|
|
|
@@ -34,21 +33,14 @@ class SoftActorCriticBuffer(Buffer):
|
|
|
34
33
|
terminations = numpy.array(terminations)
|
|
35
34
|
probabilities = numpy.array(probabilities)
|
|
36
35
|
|
|
37
|
-
# add one dimension to both rewards and terminations
|
|
38
|
-
rewards = numpy.expand_dims(rewards, axis=-1)
|
|
39
|
-
terminations = numpy.expand_dims(terminations, axis=-1)
|
|
40
|
-
|
|
41
36
|
if as_tensor:
|
|
42
37
|
states = torch.from_numpy(states).float()
|
|
43
|
-
actions = torch.from_numpy(actions).
|
|
38
|
+
actions = torch.from_numpy(actions).long()
|
|
44
39
|
rewards = torch.from_numpy(rewards).float()
|
|
45
40
|
next_states = torch.from_numpy(next_states).float()
|
|
46
41
|
terminations = torch.from_numpy(terminations).float()
|
|
47
42
|
probabilities = torch.from_numpy(probabilities).float()
|
|
48
43
|
|
|
49
|
-
if move_to_device:
|
|
50
|
-
move(states, actions, rewards, next_states, terminations, probabilities)
|
|
51
|
-
|
|
52
44
|
return SimpleNamespace(
|
|
53
45
|
**{
|
|
54
46
|
"states": states,
|
homa/rl/sac/SoftActor.py
CHANGED
|
@@ -1,11 +1,9 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import numpy
|
|
3
|
-
from .SoftCritic import SoftCritic
|
|
4
3
|
from .modules import SoftActorModule
|
|
5
|
-
from ...core.concerns import MovesNetworkToDevice
|
|
6
4
|
|
|
7
5
|
|
|
8
|
-
class SoftActor
|
|
6
|
+
class SoftActor:
|
|
9
7
|
def __init__(
|
|
10
8
|
self,
|
|
11
9
|
state_dimension: int,
|
|
@@ -30,22 +28,23 @@ class SoftActor(MovesNetworkToDevice):
|
|
|
30
28
|
self.network.parameters(), lr=lr, weight_decay=weight_decay
|
|
31
29
|
)
|
|
32
30
|
|
|
33
|
-
def train(self, states: torch.Tensor,
|
|
34
|
-
self.network.train()
|
|
31
|
+
def train(self, states: torch.Tensor, critic_network: torch.nn.Module):
|
|
35
32
|
self.optimizer.zero_grad()
|
|
36
|
-
loss = self.loss(states=states,
|
|
33
|
+
loss = self.loss(states=states, critic_network=critic_network)
|
|
37
34
|
loss.backward()
|
|
38
35
|
self.optimizer.step()
|
|
39
36
|
|
|
40
|
-
def loss(
|
|
37
|
+
def loss(
|
|
38
|
+
self, states: torch.Tensor, critic_network: torch.nn.Module
|
|
39
|
+
) -> torch.Tensor:
|
|
41
40
|
actions, probabilities = self.sample(states)
|
|
42
|
-
q_alpha, q_beta =
|
|
41
|
+
q_alpha, q_beta = critic_network(states, actions)
|
|
43
42
|
q = torch.min(q_alpha, q_beta)
|
|
44
43
|
return (self.alpha * probabilities - q).mean()
|
|
45
44
|
|
|
46
45
|
def process_state(self, state: numpy.ndarray | torch.Tensor) -> torch.Tensor:
|
|
47
46
|
if isinstance(state, numpy.ndarray):
|
|
48
|
-
state = torch.from_numpy(state)
|
|
47
|
+
state = torch.from_numpy(state)
|
|
49
48
|
|
|
50
49
|
if state.ndim < 2:
|
|
51
50
|
state = state.unsqueeze(0)
|
|
@@ -65,6 +64,6 @@ class SoftActor(MovesNetworkToDevice):
|
|
|
65
64
|
action = torch.tanh(pre_tanh)
|
|
66
65
|
|
|
67
66
|
probabilities = distribution.log_prob(pre_tanh).sum(dim=1, keepdim=True)
|
|
68
|
-
|
|
67
|
+
correction = torch.log(1 - action.pow(2) + 1e-6).sum(dim=1, keepdim=True)
|
|
69
68
|
|
|
70
|
-
return action, probabilities
|
|
69
|
+
return action, probabilities - correction
|
homa/rl/sac/SoftCritic.py
CHANGED
|
@@ -1,12 +1,11 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from torch.nn.functional import mse_loss as mse
|
|
3
|
+
from typing import Type
|
|
3
4
|
from .modules import DualSoftCriticModule
|
|
4
5
|
from .SoftActor import SoftActor
|
|
5
|
-
from ..utils import soft_update
|
|
6
|
-
from ...core.concerns import MovesNetworkToDevice
|
|
7
6
|
|
|
8
7
|
|
|
9
|
-
class SoftCritic
|
|
8
|
+
class SoftCritic:
|
|
10
9
|
def __init__(
|
|
11
10
|
self,
|
|
12
11
|
state_dimension: int,
|
|
@@ -14,9 +13,11 @@ class SoftCritic(MovesNetworkToDevice):
|
|
|
14
13
|
action_dimension: int,
|
|
15
14
|
lr: float,
|
|
16
15
|
weight_decay: float,
|
|
16
|
+
tau: float,
|
|
17
17
|
gamma: float,
|
|
18
18
|
alpha: float,
|
|
19
19
|
):
|
|
20
|
+
self.tau: float = tau
|
|
20
21
|
self.gamma: float = gamma
|
|
21
22
|
self.alpha: float = alpha
|
|
22
23
|
|
|
@@ -30,10 +31,6 @@ class SoftCritic(MovesNetworkToDevice):
|
|
|
30
31
|
hidden_dimension=hidden_dimension,
|
|
31
32
|
action_dimension=action_dimension,
|
|
32
33
|
)
|
|
33
|
-
|
|
34
|
-
# copy source to target when initiated
|
|
35
|
-
self.target.load_state_dict(self.network.state_dict())
|
|
36
|
-
|
|
37
34
|
self.optimizer = torch.optim.AdamW(
|
|
38
35
|
self.network.parameters(), lr=lr, weight_decay=weight_decay
|
|
39
36
|
)
|
|
@@ -45,9 +42,8 @@ class SoftCritic(MovesNetworkToDevice):
|
|
|
45
42
|
rewards: torch.Tensor,
|
|
46
43
|
terminations: torch.Tensor,
|
|
47
44
|
next_states: torch.Tensor,
|
|
48
|
-
actor:
|
|
45
|
+
actor: torch.nn.Module,
|
|
49
46
|
):
|
|
50
|
-
self.network.train()
|
|
51
47
|
self.optimizer.zero_grad()
|
|
52
48
|
loss = self.loss(
|
|
53
49
|
states=states,
|
|
@@ -69,7 +65,7 @@ class SoftCritic(MovesNetworkToDevice):
|
|
|
69
65
|
next_states: torch.Tensor,
|
|
70
66
|
actor: torch.nn.Module,
|
|
71
67
|
):
|
|
72
|
-
q_alpha, q_beta = self.
|
|
68
|
+
q_alpha, q_beta = self.target(states, actions)
|
|
73
69
|
target = self.calculate_target(
|
|
74
70
|
rewards=rewards,
|
|
75
71
|
terminations=terminations,
|
|
@@ -86,13 +82,19 @@ class SoftCritic(MovesNetworkToDevice):
|
|
|
86
82
|
next_states: torch.Tensor,
|
|
87
83
|
actor: SoftActor,
|
|
88
84
|
):
|
|
89
|
-
termination_mask = 1 - terminations
|
|
90
85
|
next_actions, next_probabilities = actor.sample(next_states)
|
|
91
86
|
q_alpha, q_beta = self.target(next_states, next_actions)
|
|
92
87
|
q = torch.min(q_alpha, q_beta)
|
|
93
|
-
|
|
94
|
-
|
|
88
|
+
termination_mask = 1 - terminations
|
|
89
|
+
entropy_q = q - self.alpha * next_probabilities * termination_mask
|
|
90
|
+
return rewards + self.gamma * entropy_q
|
|
91
|
+
|
|
92
|
+
def soft_update(
|
|
93
|
+
self, network: Type[torch.nn.Module], target: Type[torch.nn.Module]
|
|
94
|
+
):
|
|
95
|
+
for s, t in zip(network.parameters(), target.parameters()):
|
|
96
|
+
t.data.copy_(self.tau * s.data + (1 - self.tau) * t.data)
|
|
95
97
|
|
|
96
|
-
def update(self
|
|
97
|
-
soft_update(
|
|
98
|
-
soft_update(
|
|
98
|
+
def update(self):
|
|
99
|
+
self.soft_update(self.network.alpha, self.target.alpha)
|
|
100
|
+
self.soft_update(self.network.beta, self.target.beta)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
homa/__init__.py,sha256=NBYFKizG8UASiz5HLsEBqzXNGlWr78xm4sLr5hxKvjU,46
|
|
2
|
-
homa/device.py,sha256=
|
|
2
|
+
homa/device.py,sha256=dpKI-ah_kPgNfFH_ism8YXHndEndGngBrTVnuZZ2J2I,408
|
|
3
3
|
homa/settings.py,sha256=CPZDPvs1380O7SY7FcSKol8kBVFVVYFgSJl3YEyJuZ0,263
|
|
4
4
|
homa/utils.py,sha256=dPp6TItJwWxBqxmkMzUuCtX_BzdPT-kMOZyXRGVMCbQ,70
|
|
5
5
|
homa/activations/APLU.py,sha256=cUf6LUjY8TewXe_V1avO_7IcOtY66Hd6Dyk_1K4R3Ms,1555
|
|
@@ -73,21 +73,20 @@ homa/cli/namespaces/MakeNamespace.py,sha256=5G6LHk3lDkXROz7uq4jYE0DyO_V7JvnhJ33I
|
|
|
73
73
|
homa/cli/namespaces/__init__.py,sha256=zAKUGPH4wcacxfH5Qvidp-uOuHdfzhan6kvVI6eMKA8,84
|
|
74
74
|
homa/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
75
75
|
homa/core/concerns/MovesNetworkToDevice.py,sha256=OPMvO7scsM6NNy_fM0cJdkRdoVc-b2j6l4bz88cBif0,348
|
|
76
|
-
homa/core/concerns/
|
|
77
|
-
homa/core/concerns/__init__.py,sha256=6jL3_kiqmmMs8BV789ZBbwEYNQNAhq1otVOrDJJrSXo,90
|
|
76
|
+
homa/core/concerns/__init__.py,sha256=O9OXMIMYrkIgp11lAyEv-OgT3Wq0IvNdDVZr2bOmpQU,55
|
|
78
77
|
homa/ensemble/Ensemble.py,sha256=mrqwbEm8OtiBmEgKuO6RzO1V8v80vrQFIJ4WHl8Yqgk,356
|
|
79
78
|
homa/ensemble/__init__.py,sha256=1pk2W-NbgfDFh9WLKZVLUk2E3PTjVZ5Bap9dQEnrs9o,31
|
|
80
79
|
homa/ensemble/utils.py,sha256=nn6eAgGW7ZafjjOVJWzGUWE0XYeyJAOMNEHm-lHxd6A,200
|
|
81
|
-
homa/ensemble/concerns/CalculatesMetricNecessities.py,sha256=
|
|
82
|
-
homa/ensemble/concerns/PredictsProbabilities.py,sha256=
|
|
83
|
-
homa/ensemble/concerns/ReportsClassificationMetrics.py,sha256=
|
|
84
|
-
homa/ensemble/concerns/ReportsEnsembleAccuracy.py,sha256=
|
|
85
|
-
homa/ensemble/concerns/ReportsEnsembleF1.py,sha256=
|
|
86
|
-
homa/ensemble/concerns/ReportsEnsembleKappa.py,sha256=
|
|
87
|
-
homa/ensemble/concerns/ReportsEnsembleSize.py,sha256=
|
|
88
|
-
homa/ensemble/concerns/ReportsLogits.py,sha256=
|
|
89
|
-
homa/ensemble/concerns/SavesEnsembleModels.py,sha256=
|
|
90
|
-
homa/ensemble/concerns/StoresModels.py,sha256=
|
|
80
|
+
homa/ensemble/concerns/CalculatesMetricNecessities.py,sha256=QccROg_FOp_X2T_lZDg8p1DMZhPYdO-7aEdnebRXMsY,825
|
|
81
|
+
homa/ensemble/concerns/PredictsProbabilities.py,sha256=7rmI66DzE7-QGoJgZEk-9fu5YQvJW-4ZnMn_dWEEhqU,440
|
|
82
|
+
homa/ensemble/concerns/ReportsClassificationMetrics.py,sha256=bg__cdCKp2U1H9qN1aOJH4BoX98oIvt8XaPDGApJhSM,395
|
|
83
|
+
homa/ensemble/concerns/ReportsEnsembleAccuracy.py,sha256=AX5X3VGOm7DfdonW0N7FFgUwEr7wnsojRSVEULEii7c,380
|
|
84
|
+
homa/ensemble/concerns/ReportsEnsembleF1.py,sha256=hdtdCQrWaFJNUn1KP9cAmi_q_EA4FYnpkBMlYLjzRZg,296
|
|
85
|
+
homa/ensemble/concerns/ReportsEnsembleKappa.py,sha256=ZRbtrFCTD84EDql6ZL1xeWtTLFxpO5Y5tQaUlR6_0jw,300
|
|
86
|
+
homa/ensemble/concerns/ReportsEnsembleSize.py,sha256=eIweQHpLcfGnNLwiMuTho-9rDgxV0xXGHPTOaEOABzw,240
|
|
87
|
+
homa/ensemble/concerns/ReportsLogits.py,sha256=sJZGJwTISZo2DFmJbI5zqhrt7CblNi09iGn1zaEk-ro,593
|
|
88
|
+
homa/ensemble/concerns/SavesEnsembleModels.py,sha256=VIXT9wJ8FiCspIvI2-F4WPa6mBBe9SWvMLFyad3TgRg,275
|
|
89
|
+
homa/ensemble/concerns/StoresModels.py,sha256=uAYbdUtadZsAJ9-Fj4jJFLWC23qfiXKo1mBm6-PZkN4,963
|
|
91
90
|
homa/ensemble/concerns/__init__.py,sha256=IF5mHIgzCuCpA2EmpkctbjAr0kYW4P96v7RffK2V_iQ,548
|
|
92
91
|
homa/graph/GraphAttention.py,sha256=oPXuc1s-3BXwGkHuomEIxnOcZSRBbL8b8fO0432RdDo,478
|
|
93
92
|
homa/graph/__init__.py,sha256=NCtMUB-awe9UvkwDYqWXxTAZ1RW-AwSW1DD9X_kFkD0,43
|
|
@@ -100,13 +99,12 @@ homa/loss/__init__.py,sha256=4mPVzme2_-M64bgBu1cANIfBFAL0voa5I71-ceMr_qk,64
|
|
|
100
99
|
homa/rl/DQN.py,sha256=PaNq9Z1K87IQ7Y7mhiJ1CE4TofgV7c7m1py8qT09vE4,20
|
|
101
100
|
homa/rl/DRQN.py,sha256=zooojji9aeeubOP7cRPSHg31u2Assxk-qjXyGUWIO3A,49
|
|
102
101
|
homa/rl/DiversityIsAllYouNeed.py,sha256=8yKzlVdLisForGyXqxaXUAWG_dozq7dNY8MBasCvniE,3322
|
|
103
|
-
homa/rl/SoftActorCritic.py,sha256=
|
|
102
|
+
homa/rl/SoftActorCritic.py,sha256=N8EsiYbsLH-dpT2EmqdYFG9KvHNfO3JX8SG2LPTy94s,1962
|
|
104
103
|
homa/rl/__init__.py,sha256=EaNDkIzLH1Oy0Wc0aAyyVs4HVMcZS1tdHDh631LKSXs,146
|
|
105
|
-
homa/rl/
|
|
106
|
-
homa/rl/buffers/Buffer.py,sha256=YCESh9tFxgWOLzGQj_IA0zLJoZWDmz6gCNu1iYsGp1s,388
|
|
104
|
+
homa/rl/buffers/Buffer.py,sha256=wOk8MH0Wf0cpvavpHIK2O7PrbGP6MwHTH5YFkq2Ints,288
|
|
107
105
|
homa/rl/buffers/DiversityIsAllYouNeedBuffer.py,sha256=Nwcqs3Q10x6OKZ-zWug4IcBc6RR1TwEIybuFQOtmftA,1612
|
|
108
106
|
homa/rl/buffers/ImageBuffer.py,sha256=HSmMt82hmkL3ooBYo7c6YUtTsMz9TAA8CvPh3y8z3yg,65
|
|
109
|
-
homa/rl/buffers/SoftActorCriticBuffer.py,sha256=
|
|
107
|
+
homa/rl/buffers/SoftActorCriticBuffer.py,sha256=iDC2C5XFvONT3f7YX_gYXQJGU9wz2usvPOVGbQUd22M,1796
|
|
110
108
|
homa/rl/buffers/__init__.py,sha256=h1AkCHs6isXbNtxpaZfLp6YudHj1KlnOvURE64vhRa4,190
|
|
111
109
|
homa/rl/buffers/concerns/HasRecordAlternatives.py,sha256=D5aVlPZlnGm0GyGtikKb4wZqyO6zpyqR1IOETmAgLx4,362
|
|
112
110
|
homa/rl/buffers/concerns/ResetsCollection.py,sha256=bZ8q4czYXo1jMtVCnnlG69OgiJ0AqSGY6CiKzJC6xtQ,215
|
|
@@ -119,11 +117,11 @@ homa/rl/diayn/modules/ContinuousActorModule.py,sha256=yeC117I5gkXZSidQhjwakjiY7G
|
|
|
119
117
|
homa/rl/diayn/modules/CriticModule.py,sha256=OUenwCG0dG4PnK7Iq-jy7oCTv_Cn9s7bXRpro6Pvb40,956
|
|
120
118
|
homa/rl/diayn/modules/DiscriminatorModule.py,sha256=D58dKBv4f6gtrpqMKLK8XAZpiMqKfS4sG6s3QcF8iGE,891
|
|
121
119
|
homa/rl/diayn/modules/__init__.py,sha256=1Pgjr4FT5WG-AMh26NPEfbf5pK6I02B1x8HYsgyUCJ4,149
|
|
122
|
-
homa/rl/sac/SoftActor.py,sha256=
|
|
123
|
-
homa/rl/sac/SoftCritic.py,sha256=
|
|
120
|
+
homa/rl/sac/SoftActor.py,sha256=CxR58IFrZ6xlmBj_gq_abZfgdzlVD71c6wA6wQiVL2c,2142
|
|
121
|
+
homa/rl/sac/SoftCritic.py,sha256=wFIunTgKGBy64Igu7zuvE2BvGz2e-DTplviLyq4tQ7M,3031
|
|
124
122
|
homa/rl/sac/__init__.py,sha256=8EIkOcVvxN94gGzcZoX2XTnvTsHqW6yBaZ2RdFwIveM,68
|
|
125
123
|
homa/rl/sac/modules/DualSoftCriticModule.py,sha256=Ax28i7U-KnP4QJig-AeeCfpPYNvTT3DfvRMJI-f-TGY,749
|
|
126
|
-
homa/rl/sac/modules/SoftActorModule.py,sha256=
|
|
124
|
+
homa/rl/sac/modules/SoftActorModule.py,sha256=AiWnsWkmQONjOAWAp06eO-lLWEYNJDmx8FSjPKTcjI0,1152
|
|
127
125
|
homa/rl/sac/modules/SoftCriticModule.py,sha256=aOfhDZTB5og-BLTsmdBdIcRufygCJUas7P-ikBvWQ34,928
|
|
128
126
|
homa/rl/sac/modules/__init__.py,sha256=h-22B5CAK1xhn75tolI5J5sQMxl--kOXbQ6r_JfHIOA,147
|
|
129
127
|
homa/vision/Classifier.py,sha256=bAypqREQVuPamnc8hpbLCwmW9Uly3T1rvrlbMxXp1eA,61
|
|
@@ -144,8 +142,8 @@ homa/vision/concerns/__init__.py,sha256=mrw1YvN-GpQPvMwDF00KxnFkksPKo23RWM4KRioU
|
|
|
144
142
|
homa/vision/modules/ResnetModule.py,sha256=eFudBnILD6OmgQtcW_CQQ8aZ62NEa4HyZ15-lobTtt0,712
|
|
145
143
|
homa/vision/modules/SwinModule.py,sha256=3ZtUcfyJt0NMGmIlGpN35MIJG9QsgcLdFniZH7NxZQo,1227
|
|
146
144
|
homa/vision/modules/__init__.py,sha256=zVMYB9IAO_xZylC1-N3p8ymHgEkAE2sBbuVz8K5Y1kk,74
|
|
147
|
-
homa-0.3.
|
|
148
|
-
homa-0.3.
|
|
149
|
-
homa-0.3.
|
|
150
|
-
homa-0.3.
|
|
151
|
-
homa-0.3.
|
|
145
|
+
homa-0.3.11.dist-info/METADATA,sha256=SvSxNXB1IsX3N5IfhOsnWYtvhjpfzauJPanVH7i5cRs,1760
|
|
146
|
+
homa-0.3.11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
147
|
+
homa-0.3.11.dist-info/entry_points.txt,sha256=tJZzjs-f2QvFe3ES8Qta8IE5sAbeE8-cyZ_UtbgqG4s,51
|
|
148
|
+
homa-0.3.11.dist-info/top_level.txt,sha256=tmOfy2tuaAwc3W5-i6j61_vYJsXgR4ivBWkhJ3ZtJDc,5
|
|
149
|
+
homa-0.3.11.dist-info/RECORD,,
|
homa/core/concerns/TracksTime.py
DELETED
homa/rl/utils.py
DELETED
|
File without changes
|
|
File without changes
|
|
File without changes
|