homa 0.3.11__py3-none-any.whl → 0.3.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- homa/rl/buffers/Buffer.py +3 -1
- homa/rl/buffers/SoftActorCriticBuffer.py +1 -1
- homa/rl/sac/SoftActor.py +3 -2
- homa/rl/sac/SoftCritic.py +13 -14
- homa/rl/sac/modules/SoftActorModule.py +1 -1
- homa/rl/utils.py +7 -0
- {homa-0.3.11.dist-info → homa-0.3.15.dist-info}/METADATA +1 -1
- {homa-0.3.11.dist-info → homa-0.3.15.dist-info}/RECORD +11 -10
- {homa-0.3.11.dist-info → homa-0.3.15.dist-info}/WHEEL +0 -0
- {homa-0.3.11.dist-info → homa-0.3.15.dist-info}/entry_points.txt +0 -0
- {homa-0.3.11.dist-info → homa-0.3.15.dist-info}/top_level.txt +0 -0
homa/rl/buffers/Buffer.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from typing import Type
|
|
1
3
|
from .concerns import ResetsCollection, HasRecordAlternatives
|
|
2
4
|
|
|
3
5
|
|
|
4
6
|
class Buffer(ResetsCollection, HasRecordAlternatives):
|
|
5
7
|
def __init__(self, capacity: int):
|
|
6
8
|
self.capacity: int = capacity
|
|
7
|
-
self.
|
|
9
|
+
self.collection: Type[deque] = deque(maxlen=self.capacity)
|
|
8
10
|
|
|
9
11
|
@property
|
|
10
12
|
def size(self):
|
|
@@ -35,7 +35,7 @@ class SoftActorCriticBuffer(Buffer):
|
|
|
35
35
|
|
|
36
36
|
if as_tensor:
|
|
37
37
|
states = torch.from_numpy(states).float()
|
|
38
|
-
actions = torch.from_numpy(actions).
|
|
38
|
+
actions = torch.from_numpy(actions).float()
|
|
39
39
|
rewards = torch.from_numpy(rewards).float()
|
|
40
40
|
next_states = torch.from_numpy(next_states).float()
|
|
41
41
|
terminations = torch.from_numpy(terminations).float()
|
homa/rl/sac/SoftActor.py
CHANGED
|
@@ -29,6 +29,7 @@ class SoftActor:
|
|
|
29
29
|
)
|
|
30
30
|
|
|
31
31
|
def train(self, states: torch.Tensor, critic_network: torch.nn.Module):
|
|
32
|
+
self.network.train()
|
|
32
33
|
self.optimizer.zero_grad()
|
|
33
34
|
loss = self.loss(states=states, critic_network=critic_network)
|
|
34
35
|
loss.backward()
|
|
@@ -64,6 +65,6 @@ class SoftActor:
|
|
|
64
65
|
action = torch.tanh(pre_tanh)
|
|
65
66
|
|
|
66
67
|
probabilities = distribution.log_prob(pre_tanh).sum(dim=1, keepdim=True)
|
|
67
|
-
|
|
68
|
+
probabilities -= torch.log(1 - action.pow(2) + 1e-6).sum(dim=1, keepdim=True)
|
|
68
69
|
|
|
69
|
-
return action, probabilities
|
|
70
|
+
return action, probabilities
|
homa/rl/sac/SoftCritic.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from torch.nn.functional import mse_loss as mse
|
|
3
|
-
from typing import Type
|
|
4
3
|
from .modules import DualSoftCriticModule
|
|
5
4
|
from .SoftActor import SoftActor
|
|
5
|
+
from ..utils import soft_update
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class SoftCritic:
|
|
@@ -31,6 +31,10 @@ class SoftCritic:
|
|
|
31
31
|
hidden_dimension=hidden_dimension,
|
|
32
32
|
action_dimension=action_dimension,
|
|
33
33
|
)
|
|
34
|
+
|
|
35
|
+
# copy source to target when initiated
|
|
36
|
+
self.target.load_state_dict(self.network.state_dict())
|
|
37
|
+
|
|
34
38
|
self.optimizer = torch.optim.AdamW(
|
|
35
39
|
self.network.parameters(), lr=lr, weight_decay=weight_decay
|
|
36
40
|
)
|
|
@@ -42,8 +46,9 @@ class SoftCritic:
|
|
|
42
46
|
rewards: torch.Tensor,
|
|
43
47
|
terminations: torch.Tensor,
|
|
44
48
|
next_states: torch.Tensor,
|
|
45
|
-
actor:
|
|
49
|
+
actor: SoftActor,
|
|
46
50
|
):
|
|
51
|
+
self.network.train()
|
|
47
52
|
self.optimizer.zero_grad()
|
|
48
53
|
loss = self.loss(
|
|
49
54
|
states=states,
|
|
@@ -65,7 +70,7 @@ class SoftCritic:
|
|
|
65
70
|
next_states: torch.Tensor,
|
|
66
71
|
actor: torch.nn.Module,
|
|
67
72
|
):
|
|
68
|
-
q_alpha, q_beta = self.
|
|
73
|
+
q_alpha, q_beta = self.network(states, actions)
|
|
69
74
|
target = self.calculate_target(
|
|
70
75
|
rewards=rewards,
|
|
71
76
|
terminations=terminations,
|
|
@@ -82,19 +87,13 @@ class SoftCritic:
|
|
|
82
87
|
next_states: torch.Tensor,
|
|
83
88
|
actor: SoftActor,
|
|
84
89
|
):
|
|
90
|
+
termination_mask = 1 - terminations
|
|
85
91
|
next_actions, next_probabilities = actor.sample(next_states)
|
|
86
92
|
q_alpha, q_beta = self.target(next_states, next_actions)
|
|
87
93
|
q = torch.min(q_alpha, q_beta)
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
return rewards + self.gamma * entropy_q
|
|
91
|
-
|
|
92
|
-
def soft_update(
|
|
93
|
-
self, network: Type[torch.nn.Module], target: Type[torch.nn.Module]
|
|
94
|
-
):
|
|
95
|
-
for s, t in zip(network.parameters(), target.parameters()):
|
|
96
|
-
t.data.copy_(self.tau * s.data + (1 - self.tau) * t.data)
|
|
94
|
+
entropy_q = q - self.alpha * next_probabilities
|
|
95
|
+
return rewards + self.gamma * termination_mask * entropy_q
|
|
97
96
|
|
|
98
97
|
def update(self):
|
|
99
|
-
|
|
100
|
-
|
|
98
|
+
soft_update(network=self.network.alpha, target=self.target.alpha, tau=self.tau)
|
|
99
|
+
soft_update(network=self.network.beta, target=self.target.beta, tau=self.tau)
|
homa/rl/utils.py
ADDED
|
@@ -101,10 +101,11 @@ homa/rl/DRQN.py,sha256=zooojji9aeeubOP7cRPSHg31u2Assxk-qjXyGUWIO3A,49
|
|
|
101
101
|
homa/rl/DiversityIsAllYouNeed.py,sha256=8yKzlVdLisForGyXqxaXUAWG_dozq7dNY8MBasCvniE,3322
|
|
102
102
|
homa/rl/SoftActorCritic.py,sha256=N8EsiYbsLH-dpT2EmqdYFG9KvHNfO3JX8SG2LPTy94s,1962
|
|
103
103
|
homa/rl/__init__.py,sha256=EaNDkIzLH1Oy0Wc0aAyyVs4HVMcZS1tdHDh631LKSXs,146
|
|
104
|
-
homa/rl/
|
|
104
|
+
homa/rl/utils.py,sha256=IqbN5aDLwovocpPbxgywuetjz7GQwh9aJ4WFIOtLP3g,232
|
|
105
|
+
homa/rl/buffers/Buffer.py,sha256=YCESh9tFxgWOLzGQj_IA0zLJoZWDmz6gCNu1iYsGp1s,388
|
|
105
106
|
homa/rl/buffers/DiversityIsAllYouNeedBuffer.py,sha256=Nwcqs3Q10x6OKZ-zWug4IcBc6RR1TwEIybuFQOtmftA,1612
|
|
106
107
|
homa/rl/buffers/ImageBuffer.py,sha256=HSmMt82hmkL3ooBYo7c6YUtTsMz9TAA8CvPh3y8z3yg,65
|
|
107
|
-
homa/rl/buffers/SoftActorCriticBuffer.py,sha256=
|
|
108
|
+
homa/rl/buffers/SoftActorCriticBuffer.py,sha256=JQ9Y6KeeQS5naO_JPONiks-HYXw7hiZZAbqpoWDZlNI,1797
|
|
108
109
|
homa/rl/buffers/__init__.py,sha256=h1AkCHs6isXbNtxpaZfLp6YudHj1KlnOvURE64vhRa4,190
|
|
109
110
|
homa/rl/buffers/concerns/HasRecordAlternatives.py,sha256=D5aVlPZlnGm0GyGtikKb4wZqyO6zpyqR1IOETmAgLx4,362
|
|
110
111
|
homa/rl/buffers/concerns/ResetsCollection.py,sha256=bZ8q4czYXo1jMtVCnnlG69OgiJ0AqSGY6CiKzJC6xtQ,215
|
|
@@ -117,11 +118,11 @@ homa/rl/diayn/modules/ContinuousActorModule.py,sha256=yeC117I5gkXZSidQhjwakjiY7G
|
|
|
117
118
|
homa/rl/diayn/modules/CriticModule.py,sha256=OUenwCG0dG4PnK7Iq-jy7oCTv_Cn9s7bXRpro6Pvb40,956
|
|
118
119
|
homa/rl/diayn/modules/DiscriminatorModule.py,sha256=D58dKBv4f6gtrpqMKLK8XAZpiMqKfS4sG6s3QcF8iGE,891
|
|
119
120
|
homa/rl/diayn/modules/__init__.py,sha256=1Pgjr4FT5WG-AMh26NPEfbf5pK6I02B1x8HYsgyUCJ4,149
|
|
120
|
-
homa/rl/sac/SoftActor.py,sha256=
|
|
121
|
-
homa/rl/sac/SoftCritic.py,sha256=
|
|
121
|
+
homa/rl/sac/SoftActor.py,sha256=NSTqnv_BZzTqfgEEOIEtOgYV2_VycicIF0alD1O5Nk8,2162
|
|
122
|
+
homa/rl/sac/SoftCritic.py,sha256=rOgPR8zRUtjEwF9W4q5nZQaGXFmf_9tmXqaRWzUkAm8,2980
|
|
122
123
|
homa/rl/sac/__init__.py,sha256=8EIkOcVvxN94gGzcZoX2XTnvTsHqW6yBaZ2RdFwIveM,68
|
|
123
124
|
homa/rl/sac/modules/DualSoftCriticModule.py,sha256=Ax28i7U-KnP4QJig-AeeCfpPYNvTT3DfvRMJI-f-TGY,749
|
|
124
|
-
homa/rl/sac/modules/SoftActorModule.py,sha256=
|
|
125
|
+
homa/rl/sac/modules/SoftActorModule.py,sha256=LQ4z7s8mE3wwb1JgxPs0QvnriZULK3_ULdhkt60Ffpw,1152
|
|
125
126
|
homa/rl/sac/modules/SoftCriticModule.py,sha256=aOfhDZTB5og-BLTsmdBdIcRufygCJUas7P-ikBvWQ34,928
|
|
126
127
|
homa/rl/sac/modules/__init__.py,sha256=h-22B5CAK1xhn75tolI5J5sQMxl--kOXbQ6r_JfHIOA,147
|
|
127
128
|
homa/vision/Classifier.py,sha256=bAypqREQVuPamnc8hpbLCwmW9Uly3T1rvrlbMxXp1eA,61
|
|
@@ -142,8 +143,8 @@ homa/vision/concerns/__init__.py,sha256=mrw1YvN-GpQPvMwDF00KxnFkksPKo23RWM4KRioU
|
|
|
142
143
|
homa/vision/modules/ResnetModule.py,sha256=eFudBnILD6OmgQtcW_CQQ8aZ62NEa4HyZ15-lobTtt0,712
|
|
143
144
|
homa/vision/modules/SwinModule.py,sha256=3ZtUcfyJt0NMGmIlGpN35MIJG9QsgcLdFniZH7NxZQo,1227
|
|
144
145
|
homa/vision/modules/__init__.py,sha256=zVMYB9IAO_xZylC1-N3p8ymHgEkAE2sBbuVz8K5Y1kk,74
|
|
145
|
-
homa-0.3.
|
|
146
|
-
homa-0.3.
|
|
147
|
-
homa-0.3.
|
|
148
|
-
homa-0.3.
|
|
149
|
-
homa-0.3.
|
|
146
|
+
homa-0.3.15.dist-info/METADATA,sha256=jo-jsI9A6KvK95nGmxrh_IExb1WeNp_3DWTZgMDSxcI,1760
|
|
147
|
+
homa-0.3.15.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
148
|
+
homa-0.3.15.dist-info/entry_points.txt,sha256=tJZzjs-f2QvFe3ES8Qta8IE5sAbeE8-cyZ_UtbgqG4s,51
|
|
149
|
+
homa-0.3.15.dist-info/top_level.txt,sha256=tmOfy2tuaAwc3W5-i6j61_vYJsXgR4ivBWkhJ3ZtJDc,5
|
|
150
|
+
homa-0.3.15.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|