homa 0.3.2__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,2 +1 @@
1
1
  from .MovesNetworkToDevice import MovesNetworkToDevice
2
- from .TracksTime import TracksTime
homa/device.py CHANGED
@@ -25,6 +25,5 @@ def device():
25
25
  return get_device()
26
26
 
27
27
 
28
- def move(*modules):
29
- for module in modules:
30
- module.to(get_device())
28
+ def move(module: torch.nn.Module):
29
+ module.to(get_device())
@@ -3,8 +3,8 @@ from ...device import get_device
3
3
 
4
4
 
5
5
  class CalculatesMetricNecessities:
6
- def __init__(self):
7
- super().__init__()
6
+ def __init__(self, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
8
 
9
9
  @torch.no_grad()
10
10
  def metric_necessities(self, dataloader):
@@ -3,8 +3,8 @@ from .ReportsLogits import ReportsLogits
3
3
 
4
4
 
5
5
  class PredictsProbabilities(ReportsLogits):
6
- def __init__(self):
7
- super().__init__()
6
+ def __init__(self, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
8
 
9
9
  def predict(self, x: torch.Tensor) -> torch.Tensor:
10
10
  logits = self.logits(x)
@@ -10,5 +10,4 @@ class ReportsClassificationMetrics(
10
10
  ReportsEnsembleF1,
11
11
  ReportsEnsembleKappa,
12
12
  ):
13
- def __init__(self):
14
- super().__init__()
13
+ pass
@@ -3,8 +3,8 @@ from torch.utils.data import DataLoader
3
3
 
4
4
 
5
5
  class ReportsEnsembleAccuracy:
6
- def __init__(self):
7
- super().__init__()
6
+ def __init__(self, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
8
 
9
9
  def accuracy(self, dataloader: DataLoader) -> float:
10
10
  predictions, labels = self.metric_necessities(dataloader)
@@ -2,8 +2,8 @@ from sklearn.metrics import f1_score as f1
2
2
 
3
3
 
4
4
  class ReportsEnsembleF1:
5
- def __init__(self):
6
- super().__init__()
5
+ def __init__(self, *args, **kwargs):
6
+ super().__init__(*args, **kwargs)
7
7
 
8
8
  def f1(self) -> float:
9
9
  predictions, labels = self.metric_necessities()
@@ -2,8 +2,8 @@ from sklearn.metrics import cohen_kappa_score as kappa
2
2
 
3
3
 
4
4
  class ReportsEnsembleKappa:
5
- def __init__(self):
6
- super().__init__()
5
+ def __init__(self, *args, **kwargs):
6
+ super().__init__(*args, **kwargs)
7
7
 
8
8
  def accuracy(self) -> float:
9
9
  predictions, labels = self.metric_necessities()
@@ -1,6 +1,6 @@
1
1
  class ReportsEnsembleSize:
2
- def __init__(self):
3
- super().__init__()
2
+ def __init__(self, *args, **kwargs):
3
+ super().__init__(*args, **kwargs)
4
4
 
5
5
  @property
6
6
  def size(self):
@@ -2,14 +2,11 @@ import torch
2
2
 
3
3
 
4
4
  class ReportsLogits:
5
- def __init__(self):
6
- super().__init__()
5
+ def __init__(self, *args, **kwargs):
6
+ super().__init__(*args, **kwargs)
7
7
 
8
- def logits_average(self, x: torch.Tensor) -> torch.Tensor:
9
- return self.logits_sim(x) / len(self.factories)
10
-
11
- def logits_sum(self, x: torch.Tensor) -> torch.Tensor:
12
- batch_size = x.size(0)
8
+ def logits(self, x: torch.Tensor) -> torch.Tensor:
9
+ batch_size = x.shape[0]
13
10
  logits = torch.zeros((batch_size, self.num_classes))
14
11
  for factory, weight in zip(self.factories, self.weights):
15
12
  model = factory(num_classes=self.num_classes)
@@ -17,22 +14,6 @@ class ReportsLogits:
17
14
  logits += model(x)
18
15
  return logits
19
16
 
20
- def check_aggregation_strategy(self, aggregation: str):
21
- if aggregation not in ["mean", "average", "sum"]:
22
- raise ValueError(
23
- f"Ensemble aggregation strategy must be in [mean, average, sum], but found {aggregation}."
24
- )
25
-
26
- def logits(self, x: torch.Tensor, aggregation: str = "mean") -> torch.Tensor:
27
- self.check_aggregation_strategy(aggregation=aggregation)
28
- logits_handlers = {
29
- "mean": self.logits_average,
30
- "average": self.logits_average,
31
- "sum": self.logits_sum,
32
- }
33
- handler = logits_handlers.get(aggregation)
34
- return handler(x)
35
-
36
17
  @torch.no_grad()
37
18
  def logits_(self, *args, **kwargs):
38
19
  return self.logits(*args, **kwargs)
@@ -1,6 +1,6 @@
1
1
  class SavesEnsembleModels:
2
- def __init__(self):
3
- super().__init__()
2
+ def __init__(self, *args, **kwargs):
3
+ super().__init__(*args, **kwargs)
4
4
 
5
5
  def save(self):
6
6
  self.save_factories()
@@ -5,8 +5,8 @@ from ...vision import Model
5
5
 
6
6
 
7
7
  class StoresModels:
8
- def __init__(self):
9
- super().__init__()
8
+ def __init__(self, *args, **kwargs):
9
+ super().__init__(*args, **kwargs)
10
10
  self.factories: List[Type[torch.nn.Module]] = []
11
11
  self.weights: List[OrderedDict] = []
12
12
 
@@ -1,15 +1,14 @@
1
1
  from .sac import SoftActor, SoftCritic
2
2
  from .buffers import SoftActorCriticBuffer
3
- from ..core.concerns import TracksTime
4
3
 
5
4
 
6
- class SoftActorCritic(TracksTime):
5
+ class SoftActorCritic:
7
6
  def __init__(
8
7
  self,
9
8
  state_dimension: int,
10
9
  action_dimension: int,
11
10
  hidden_dimension: int = 256,
12
- buffer_capacity: int = 100_000,
11
+ buffer_capacity: int = 1_000_000,
13
12
  batch_size: int = 256,
14
13
  actor_lr: float = 0.0002,
15
14
  critic_lr: float = 0.0003,
@@ -20,13 +19,10 @@ class SoftActorCritic(TracksTime):
20
19
  gamma: float = 0.99,
21
20
  min_std: float = -20,
22
21
  max_std: float = 2,
23
- warmup: int = 20_000,
22
+ warmup: int = 10_000,
24
23
  ):
25
- super().__init__()
26
-
27
24
  self.batch_size: int = batch_size
28
25
  self.warmup: int = warmup
29
- self.tau: float = tau
30
26
 
31
27
  self.actor = SoftActor(
32
28
  state_dimension=state_dimension,
@@ -44,6 +40,7 @@ class SoftActorCritic(TracksTime):
44
40
  hidden_dimension=hidden_dimension,
45
41
  lr=critic_lr,
46
42
  weight_decay=critic_decay,
43
+ tau=tau,
47
44
  gamma=gamma,
48
45
  alpha=alpha,
49
46
  )
@@ -63,5 +60,5 @@ class SoftActorCritic(TracksTime):
63
60
  next_states=data.next_states,
64
61
  actor=self.actor,
65
62
  )
66
- self.actor.train(states=data.states, critic=self.critic)
67
- self.critic.update(tau=self.tau)
63
+ self.actor.train(states=data.states, critic_network=self.critic.network)
64
+ self.critic.update()
homa/rl/buffers/Buffer.py CHANGED
@@ -1,12 +1,10 @@
1
- from collections import deque
2
- from typing import Type
3
1
  from .concerns import ResetsCollection, HasRecordAlternatives
4
2
 
5
3
 
6
4
  class Buffer(ResetsCollection, HasRecordAlternatives):
7
5
  def __init__(self, capacity: int):
8
6
  self.capacity: int = capacity
9
- self.collection: Type[deque] = deque(maxlen=self.capacity)
7
+ self.reset()
10
8
 
11
9
  @property
12
10
  def size(self):
@@ -3,7 +3,6 @@ import random
3
3
  import torch
4
4
  from types import SimpleNamespace
5
5
  from .Buffer import Buffer
6
- from ...device import move
7
6
 
8
7
 
9
8
  class SoftActorCriticBuffer(Buffer):
@@ -23,7 +22,7 @@ class SoftActorCriticBuffer(Buffer):
23
22
  (state, action, reward, next_state, termination, probability)
24
23
  )
25
24
 
26
- def sample(self, k: int, as_tensor: bool = False, move_to_device: bool = True):
25
+ def sample(self, k: int, as_tensor: bool = False):
27
26
  batch = random.sample(self.collection, k)
28
27
  states, actions, rewards, next_states, terminations, probabilities = zip(*batch)
29
28
 
@@ -34,21 +33,14 @@ class SoftActorCriticBuffer(Buffer):
34
33
  terminations = numpy.array(terminations)
35
34
  probabilities = numpy.array(probabilities)
36
35
 
37
- # add one dimension to both rewards and terminations
38
- rewards = numpy.expand_dims(rewards, axis=-1)
39
- terminations = numpy.expand_dims(terminations, axis=-1)
40
-
41
36
  if as_tensor:
42
37
  states = torch.from_numpy(states).float()
43
- actions = torch.from_numpy(actions).float()
38
+ actions = torch.from_numpy(actions).long()
44
39
  rewards = torch.from_numpy(rewards).float()
45
40
  next_states = torch.from_numpy(next_states).float()
46
41
  terminations = torch.from_numpy(terminations).float()
47
42
  probabilities = torch.from_numpy(probabilities).float()
48
43
 
49
- if move_to_device:
50
- move(states, actions, rewards, next_states, terminations, probabilities)
51
-
52
44
  return SimpleNamespace(
53
45
  **{
54
46
  "states": states,
homa/rl/sac/SoftActor.py CHANGED
@@ -1,11 +1,9 @@
1
1
  import torch
2
2
  import numpy
3
- from .SoftCritic import SoftCritic
4
3
  from .modules import SoftActorModule
5
- from ...core.concerns import MovesNetworkToDevice
6
4
 
7
5
 
8
- class SoftActor(MovesNetworkToDevice):
6
+ class SoftActor:
9
7
  def __init__(
10
8
  self,
11
9
  state_dimension: int,
@@ -30,22 +28,23 @@ class SoftActor(MovesNetworkToDevice):
30
28
  self.network.parameters(), lr=lr, weight_decay=weight_decay
31
29
  )
32
30
 
33
- def train(self, states: torch.Tensor, critic: SoftCritic):
34
- self.network.train()
31
+ def train(self, states: torch.Tensor, critic_network: torch.nn.Module):
35
32
  self.optimizer.zero_grad()
36
- loss = self.loss(states=states, critic=critic)
33
+ loss = self.loss(states=states, critic_network=critic_network)
37
34
  loss.backward()
38
35
  self.optimizer.step()
39
36
 
40
- def loss(self, states: torch.Tensor, critic: SoftCritic) -> torch.Tensor:
37
+ def loss(
38
+ self, states: torch.Tensor, critic_network: torch.nn.Module
39
+ ) -> torch.Tensor:
41
40
  actions, probabilities = self.sample(states)
42
- q_alpha, q_beta = critic.network(states, actions)
41
+ q_alpha, q_beta = critic_network(states, actions)
43
42
  q = torch.min(q_alpha, q_beta)
44
43
  return (self.alpha * probabilities - q).mean()
45
44
 
46
45
  def process_state(self, state: numpy.ndarray | torch.Tensor) -> torch.Tensor:
47
46
  if isinstance(state, numpy.ndarray):
48
- state = torch.from_numpy(state).float()
47
+ state = torch.from_numpy(state)
49
48
 
50
49
  if state.ndim < 2:
51
50
  state = state.unsqueeze(0)
@@ -65,6 +64,6 @@ class SoftActor(MovesNetworkToDevice):
65
64
  action = torch.tanh(pre_tanh)
66
65
 
67
66
  probabilities = distribution.log_prob(pre_tanh).sum(dim=1, keepdim=True)
68
- probabilities -= torch.log(1 - action.pow(2) + 1e-6).sum(dim=1, keepdim=True)
67
+ correction = torch.log(1 - action.pow(2) + 1e-6).sum(dim=1, keepdim=True)
69
68
 
70
- return action, probabilities
69
+ return action, probabilities - correction
homa/rl/sac/SoftCritic.py CHANGED
@@ -1,12 +1,11 @@
1
1
  import torch
2
2
  from torch.nn.functional import mse_loss as mse
3
+ from typing import Type
3
4
  from .modules import DualSoftCriticModule
4
5
  from .SoftActor import SoftActor
5
- from ..utils import soft_update
6
- from ...core.concerns import MovesNetworkToDevice
7
6
 
8
7
 
9
- class SoftCritic(MovesNetworkToDevice):
8
+ class SoftCritic:
10
9
  def __init__(
11
10
  self,
12
11
  state_dimension: int,
@@ -14,9 +13,11 @@ class SoftCritic(MovesNetworkToDevice):
14
13
  action_dimension: int,
15
14
  lr: float,
16
15
  weight_decay: float,
16
+ tau: float,
17
17
  gamma: float,
18
18
  alpha: float,
19
19
  ):
20
+ self.tau: float = tau
20
21
  self.gamma: float = gamma
21
22
  self.alpha: float = alpha
22
23
 
@@ -30,10 +31,6 @@ class SoftCritic(MovesNetworkToDevice):
30
31
  hidden_dimension=hidden_dimension,
31
32
  action_dimension=action_dimension,
32
33
  )
33
-
34
- # copy source to target when initiated
35
- self.target.load_state_dict(self.network.state_dict())
36
-
37
34
  self.optimizer = torch.optim.AdamW(
38
35
  self.network.parameters(), lr=lr, weight_decay=weight_decay
39
36
  )
@@ -45,9 +42,8 @@ class SoftCritic(MovesNetworkToDevice):
45
42
  rewards: torch.Tensor,
46
43
  terminations: torch.Tensor,
47
44
  next_states: torch.Tensor,
48
- actor: SoftActor,
45
+ actor: torch.nn.Module,
49
46
  ):
50
- self.network.train()
51
47
  self.optimizer.zero_grad()
52
48
  loss = self.loss(
53
49
  states=states,
@@ -69,7 +65,7 @@ class SoftCritic(MovesNetworkToDevice):
69
65
  next_states: torch.Tensor,
70
66
  actor: torch.nn.Module,
71
67
  ):
72
- q_alpha, q_beta = self.network(states, actions)
68
+ q_alpha, q_beta = self.target(states, actions)
73
69
  target = self.calculate_target(
74
70
  rewards=rewards,
75
71
  terminations=terminations,
@@ -86,13 +82,19 @@ class SoftCritic(MovesNetworkToDevice):
86
82
  next_states: torch.Tensor,
87
83
  actor: SoftActor,
88
84
  ):
89
- termination_mask = 1 - terminations
90
85
  next_actions, next_probabilities = actor.sample(next_states)
91
86
  q_alpha, q_beta = self.target(next_states, next_actions)
92
87
  q = torch.min(q_alpha, q_beta)
93
- entropy_q = q - self.alpha * next_probabilities
94
- return rewards + self.gamma * termination_mask * entropy_q
88
+ termination_mask = 1 - terminations
89
+ entropy_q = q - self.alpha * next_probabilities * termination_mask
90
+ return rewards + self.gamma * entropy_q
91
+
92
+ def soft_update(
93
+ self, network: Type[torch.nn.Module], target: Type[torch.nn.Module]
94
+ ):
95
+ for s, t in zip(network.parameters(), target.parameters()):
96
+ t.data.copy_(self.tau * s.data + (1 - self.tau) * t.data)
95
97
 
96
- def update(self, tau: float):
97
- soft_update(network=self.network.alpha, target=self.target.alpha, tau=tau)
98
- soft_update(network=self.network.beta, target=self.target.beta, tau=tau)
98
+ def update(self):
99
+ self.soft_update(self.network.alpha, self.target.alpha)
100
+ self.soft_update(self.network.beta, self.target.beta)
@@ -30,6 +30,6 @@ class SoftActorModule(torch.nn.Module):
30
30
  def forward(self, state: torch.Tensor):
31
31
  features = self.phi(state)
32
32
  mean = self.mu(features)
33
- std = self.xi(features)
33
+ std = self.mu(features)
34
34
  std = std.clamp(self.min_std, self.max_std)
35
35
  return mean, std
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: homa
3
- Version: 0.3.2
3
+ Version: 0.3.11
4
4
  Summary: A curated list of machine learning and deep learning helpers.
5
5
  Author-email: Taha Shieenavaz <tahashieenavaz@gmail.com>
6
6
  Requires-Python: >=3.7
@@ -1,5 +1,5 @@
1
1
  homa/__init__.py,sha256=NBYFKizG8UASiz5HLsEBqzXNGlWr78xm4sLr5hxKvjU,46
2
- homa/device.py,sha256=J_XpsqXOOXG15ea_9M_W4abtr1DA7VLPFuURUa0f2Qw,424
2
+ homa/device.py,sha256=dpKI-ah_kPgNfFH_ism8YXHndEndGngBrTVnuZZ2J2I,408
3
3
  homa/settings.py,sha256=CPZDPvs1380O7SY7FcSKol8kBVFVVYFgSJl3YEyJuZ0,263
4
4
  homa/utils.py,sha256=dPp6TItJwWxBqxmkMzUuCtX_BzdPT-kMOZyXRGVMCbQ,70
5
5
  homa/activations/APLU.py,sha256=cUf6LUjY8TewXe_V1avO_7IcOtY66Hd6Dyk_1K4R3Ms,1555
@@ -73,21 +73,20 @@ homa/cli/namespaces/MakeNamespace.py,sha256=5G6LHk3lDkXROz7uq4jYE0DyO_V7JvnhJ33I
73
73
  homa/cli/namespaces/__init__.py,sha256=zAKUGPH4wcacxfH5Qvidp-uOuHdfzhan6kvVI6eMKA8,84
74
74
  homa/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
75
75
  homa/core/concerns/MovesNetworkToDevice.py,sha256=OPMvO7scsM6NNy_fM0cJdkRdoVc-b2j6l4bz88cBif0,348
76
- homa/core/concerns/TracksTime.py,sha256=atg7iUH5HKqKJd03s9eHsl18iUO_4fzxuYmXgNtqSBQ,129
77
- homa/core/concerns/__init__.py,sha256=6jL3_kiqmmMs8BV789ZBbwEYNQNAhq1otVOrDJJrSXo,90
76
+ homa/core/concerns/__init__.py,sha256=O9OXMIMYrkIgp11lAyEv-OgT3Wq0IvNdDVZr2bOmpQU,55
78
77
  homa/ensemble/Ensemble.py,sha256=mrqwbEm8OtiBmEgKuO6RzO1V8v80vrQFIJ4WHl8Yqgk,356
79
78
  homa/ensemble/__init__.py,sha256=1pk2W-NbgfDFh9WLKZVLUk2E3PTjVZ5Bap9dQEnrs9o,31
80
79
  homa/ensemble/utils.py,sha256=nn6eAgGW7ZafjjOVJWzGUWE0XYeyJAOMNEHm-lHxd6A,200
81
- homa/ensemble/concerns/CalculatesMetricNecessities.py,sha256=HgrLbz8O9grGZZ0LG82Au5lZwq2D1zixDRjegM4f8Wk,793
82
- homa/ensemble/concerns/PredictsProbabilities.py,sha256=WWUaNXQxCJQ_NrLgeTdw0OXEsDb5xU7899_2d3Pzaoc,408
83
- homa/ensemble/concerns/ReportsClassificationMetrics.py,sha256=S9IBH6O7dmHhQ4Mxf5c7JFirOsPokKZUyqfOmED13mM,437
84
- homa/ensemble/concerns/ReportsEnsembleAccuracy.py,sha256=UuaPQ7v2sCaWuo3xa4PaTZzyjciRhcIluhnt6Zla2Fo,348
85
- homa/ensemble/concerns/ReportsEnsembleF1.py,sha256=aXKBK2-dTB133Rjg-X2a4Khb1sTdxuffXukcSMZlkzM,264
86
- homa/ensemble/concerns/ReportsEnsembleKappa.py,sha256=AkWTVGuCeIanDusNdtJOHwTgSEh5RJeWwLTQJJqSEKE,268
87
- homa/ensemble/concerns/ReportsEnsembleSize.py,sha256=lRyHIrK_zr7pE5RlwuNLIkqXoEMoVhNmT1nYDgCaNVI,208
88
- homa/ensemble/concerns/ReportsLogits.py,sha256=yZobLvxPL6ep70uMFIEtz5-l4rlfaG7m9mti6jJD1E8,1338
89
- homa/ensemble/concerns/SavesEnsembleModels.py,sha256=d1DcZnzfJABEfxcnYy5tV9N7YOghzO_ZdCdU80VTcno,243
90
- homa/ensemble/concerns/StoresModels.py,sha256=dg-xP1C4A9K8DrUTnR4VfqWU9iNAdS_0DlQcRThDka8,931
80
+ homa/ensemble/concerns/CalculatesMetricNecessities.py,sha256=QccROg_FOp_X2T_lZDg8p1DMZhPYdO-7aEdnebRXMsY,825
81
+ homa/ensemble/concerns/PredictsProbabilities.py,sha256=7rmI66DzE7-QGoJgZEk-9fu5YQvJW-4ZnMn_dWEEhqU,440
82
+ homa/ensemble/concerns/ReportsClassificationMetrics.py,sha256=bg__cdCKp2U1H9qN1aOJH4BoX98oIvt8XaPDGApJhSM,395
83
+ homa/ensemble/concerns/ReportsEnsembleAccuracy.py,sha256=AX5X3VGOm7DfdonW0N7FFgUwEr7wnsojRSVEULEii7c,380
84
+ homa/ensemble/concerns/ReportsEnsembleF1.py,sha256=hdtdCQrWaFJNUn1KP9cAmi_q_EA4FYnpkBMlYLjzRZg,296
85
+ homa/ensemble/concerns/ReportsEnsembleKappa.py,sha256=ZRbtrFCTD84EDql6ZL1xeWtTLFxpO5Y5tQaUlR6_0jw,300
86
+ homa/ensemble/concerns/ReportsEnsembleSize.py,sha256=eIweQHpLcfGnNLwiMuTho-9rDgxV0xXGHPTOaEOABzw,240
87
+ homa/ensemble/concerns/ReportsLogits.py,sha256=sJZGJwTISZo2DFmJbI5zqhrt7CblNi09iGn1zaEk-ro,593
88
+ homa/ensemble/concerns/SavesEnsembleModels.py,sha256=VIXT9wJ8FiCspIvI2-F4WPa6mBBe9SWvMLFyad3TgRg,275
89
+ homa/ensemble/concerns/StoresModels.py,sha256=uAYbdUtadZsAJ9-Fj4jJFLWC23qfiXKo1mBm6-PZkN4,963
91
90
  homa/ensemble/concerns/__init__.py,sha256=IF5mHIgzCuCpA2EmpkctbjAr0kYW4P96v7RffK2V_iQ,548
92
91
  homa/graph/GraphAttention.py,sha256=oPXuc1s-3BXwGkHuomEIxnOcZSRBbL8b8fO0432RdDo,478
93
92
  homa/graph/__init__.py,sha256=NCtMUB-awe9UvkwDYqWXxTAZ1RW-AwSW1DD9X_kFkD0,43
@@ -100,13 +99,12 @@ homa/loss/__init__.py,sha256=4mPVzme2_-M64bgBu1cANIfBFAL0voa5I71-ceMr_qk,64
100
99
  homa/rl/DQN.py,sha256=PaNq9Z1K87IQ7Y7mhiJ1CE4TofgV7c7m1py8qT09vE4,20
101
100
  homa/rl/DRQN.py,sha256=zooojji9aeeubOP7cRPSHg31u2Assxk-qjXyGUWIO3A,49
102
101
  homa/rl/DiversityIsAllYouNeed.py,sha256=8yKzlVdLisForGyXqxaXUAWG_dozq7dNY8MBasCvniE,3322
103
- homa/rl/SoftActorCritic.py,sha256=0xQcjAJQAiBsPCl8RORHz02K7tPaBWoQv45Zd12Ud6Q,2044
102
+ homa/rl/SoftActorCritic.py,sha256=N8EsiYbsLH-dpT2EmqdYFG9KvHNfO3JX8SG2LPTy94s,1962
104
103
  homa/rl/__init__.py,sha256=EaNDkIzLH1Oy0Wc0aAyyVs4HVMcZS1tdHDh631LKSXs,146
105
- homa/rl/utils.py,sha256=IqbN5aDLwovocpPbxgywuetjz7GQwh9aJ4WFIOtLP3g,232
106
- homa/rl/buffers/Buffer.py,sha256=YCESh9tFxgWOLzGQj_IA0zLJoZWDmz6gCNu1iYsGp1s,388
104
+ homa/rl/buffers/Buffer.py,sha256=wOk8MH0Wf0cpvavpHIK2O7PrbGP6MwHTH5YFkq2Ints,288
107
105
  homa/rl/buffers/DiversityIsAllYouNeedBuffer.py,sha256=Nwcqs3Q10x6OKZ-zWug4IcBc6RR1TwEIybuFQOtmftA,1612
108
106
  homa/rl/buffers/ImageBuffer.py,sha256=HSmMt82hmkL3ooBYo7c6YUtTsMz9TAA8CvPh3y8z3yg,65
109
- homa/rl/buffers/SoftActorCriticBuffer.py,sha256=N2etaAOA4xkOBdybQX6RQf-H4ivFgCKMJM5QugM9CYc,2154
107
+ homa/rl/buffers/SoftActorCriticBuffer.py,sha256=iDC2C5XFvONT3f7YX_gYXQJGU9wz2usvPOVGbQUd22M,1796
110
108
  homa/rl/buffers/__init__.py,sha256=h1AkCHs6isXbNtxpaZfLp6YudHj1KlnOvURE64vhRa4,190
111
109
  homa/rl/buffers/concerns/HasRecordAlternatives.py,sha256=D5aVlPZlnGm0GyGtikKb4wZqyO6zpyqR1IOETmAgLx4,362
112
110
  homa/rl/buffers/concerns/ResetsCollection.py,sha256=bZ8q4czYXo1jMtVCnnlG69OgiJ0AqSGY6CiKzJC6xtQ,215
@@ -119,11 +117,11 @@ homa/rl/diayn/modules/ContinuousActorModule.py,sha256=yeC117I5gkXZSidQhjwakjiY7G
119
117
  homa/rl/diayn/modules/CriticModule.py,sha256=OUenwCG0dG4PnK7Iq-jy7oCTv_Cn9s7bXRpro6Pvb40,956
120
118
  homa/rl/diayn/modules/DiscriminatorModule.py,sha256=D58dKBv4f6gtrpqMKLK8XAZpiMqKfS4sG6s3QcF8iGE,891
121
119
  homa/rl/diayn/modules/__init__.py,sha256=1Pgjr4FT5WG-AMh26NPEfbf5pK6I02B1x8HYsgyUCJ4,149
122
- homa/rl/sac/SoftActor.py,sha256=fCsAp5_KxzFmGplJaiF-4Cvn5qVaKGa51gst9bzHs0w,2221
123
- homa/rl/sac/SoftCritic.py,sha256=OPJoYgvbyBfkPfAt6DNxFCTsNydsi2__2p_4MmWBxiA,3004
120
+ homa/rl/sac/SoftActor.py,sha256=CxR58IFrZ6xlmBj_gq_abZfgdzlVD71c6wA6wQiVL2c,2142
121
+ homa/rl/sac/SoftCritic.py,sha256=wFIunTgKGBy64Igu7zuvE2BvGz2e-DTplviLyq4tQ7M,3031
124
122
  homa/rl/sac/__init__.py,sha256=8EIkOcVvxN94gGzcZoX2XTnvTsHqW6yBaZ2RdFwIveM,68
125
123
  homa/rl/sac/modules/DualSoftCriticModule.py,sha256=Ax28i7U-KnP4QJig-AeeCfpPYNvTT3DfvRMJI-f-TGY,749
126
- homa/rl/sac/modules/SoftActorModule.py,sha256=LQ4z7s8mE3wwb1JgxPs0QvnriZULK3_ULdhkt60Ffpw,1152
124
+ homa/rl/sac/modules/SoftActorModule.py,sha256=AiWnsWkmQONjOAWAp06eO-lLWEYNJDmx8FSjPKTcjI0,1152
127
125
  homa/rl/sac/modules/SoftCriticModule.py,sha256=aOfhDZTB5og-BLTsmdBdIcRufygCJUas7P-ikBvWQ34,928
128
126
  homa/rl/sac/modules/__init__.py,sha256=h-22B5CAK1xhn75tolI5J5sQMxl--kOXbQ6r_JfHIOA,147
129
127
  homa/vision/Classifier.py,sha256=bAypqREQVuPamnc8hpbLCwmW9Uly3T1rvrlbMxXp1eA,61
@@ -144,8 +142,8 @@ homa/vision/concerns/__init__.py,sha256=mrw1YvN-GpQPvMwDF00KxnFkksPKo23RWM4KRioU
144
142
  homa/vision/modules/ResnetModule.py,sha256=eFudBnILD6OmgQtcW_CQQ8aZ62NEa4HyZ15-lobTtt0,712
145
143
  homa/vision/modules/SwinModule.py,sha256=3ZtUcfyJt0NMGmIlGpN35MIJG9QsgcLdFniZH7NxZQo,1227
146
144
  homa/vision/modules/__init__.py,sha256=zVMYB9IAO_xZylC1-N3p8ymHgEkAE2sBbuVz8K5Y1kk,74
147
- homa-0.3.2.dist-info/METADATA,sha256=awU-Sftb68ejizlVGd4_otzjFwpq4EIopenhQHoHlFA,1759
148
- homa-0.3.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
149
- homa-0.3.2.dist-info/entry_points.txt,sha256=tJZzjs-f2QvFe3ES8Qta8IE5sAbeE8-cyZ_UtbgqG4s,51
150
- homa-0.3.2.dist-info/top_level.txt,sha256=tmOfy2tuaAwc3W5-i6j61_vYJsXgR4ivBWkhJ3ZtJDc,5
151
- homa-0.3.2.dist-info/RECORD,,
145
+ homa-0.3.11.dist-info/METADATA,sha256=SvSxNXB1IsX3N5IfhOsnWYtvhjpfzauJPanVH7i5cRs,1760
146
+ homa-0.3.11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
147
+ homa-0.3.11.dist-info/entry_points.txt,sha256=tJZzjs-f2QvFe3ES8Qta8IE5sAbeE8-cyZ_UtbgqG4s,51
148
+ homa-0.3.11.dist-info/top_level.txt,sha256=tmOfy2tuaAwc3W5-i6j61_vYJsXgR4ivBWkhJ3ZtJDc,5
149
+ homa-0.3.11.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- class TracksTime:
2
- def __init__(self):
3
- super().__init__()
4
- self.t = 0
5
-
6
- def tick(self):
7
- self.t += 1
homa/rl/utils.py DELETED
@@ -1,7 +0,0 @@
1
- import torch
2
-
3
-
4
- @torch.no_grad()
5
- def soft_update(network: torch.nn.Module, target: torch.nn.Module, tau: float):
6
- for s, t in zip(network.parameters(), target.parameters()):
7
- t.data.copy_(tau * s.data + (1 - tau) * t.data)
File without changes