agilerl 2.4.1.dev3__py3-none-any.whl → 2.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agilerl/algorithms/bc_lm.py +2 -2
- agilerl/algorithms/core/base.py +4 -2
- agilerl/algorithms/dqn.py +5 -25
- agilerl/algorithms/ilql.py +5 -5
- agilerl/algorithms/ippo.py +1 -1
- agilerl/networks/q_networks.py +2 -2
- agilerl/protocols.py +2 -2
- agilerl/training/train_off_policy.py +1 -1
- agilerl/utils/algo_utils.py +2 -2
- agilerl/utils/llm_utils.py +0 -3
- agilerl/wrappers/agent.py +2 -2
- {agilerl-2.4.1.dev3.dist-info → agilerl-2.4.2.dist-info}/METADATA +42 -44
- {agilerl-2.4.1.dev3.dist-info → agilerl-2.4.2.dist-info}/RECORD +19 -19
- {agilerl-2.4.1.dev3.dist-info → agilerl-2.4.2.dist-info}/WHEEL +1 -1
- {agilerl-2.4.1.dev3.dist-info → agilerl-2.4.2.dist-info}/licenses/LICENSE +0 -0
agilerl/algorithms/bc_lm.py
CHANGED
|
@@ -55,7 +55,7 @@ class BC_LM(nn.Module):
|
|
|
55
55
|
prefix_embs: Optional[torch.Tensor] = None,
|
|
56
56
|
prefix_attn_mask: Optional[torch.Tensor] = None,
|
|
57
57
|
remove_prefix_position_embs: bool = False,
|
|
58
|
-
**kwargs
|
|
58
|
+
**kwargs,
|
|
59
59
|
):
|
|
60
60
|
# tokens – b,t
|
|
61
61
|
# attn_mask – b,t
|
|
@@ -83,7 +83,7 @@ class BC_LM(nn.Module):
|
|
|
83
83
|
tok_emb=input_embeddings,
|
|
84
84
|
attn_mask=input_attn_mask,
|
|
85
85
|
pos=position_ids,
|
|
86
|
-
**kwargs
|
|
86
|
+
**kwargs,
|
|
87
87
|
)
|
|
88
88
|
return model_outputs, model_past_key_values
|
|
89
89
|
|
agilerl/algorithms/core/base.py
CHANGED
|
@@ -652,6 +652,9 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
|
|
|
652
652
|
:type training: bool
|
|
653
653
|
"""
|
|
654
654
|
self.training = training
|
|
655
|
+
for name, network in self.evolvable_attributes(networks_only=True).items():
|
|
656
|
+
if "actor" in name:
|
|
657
|
+
network.train(mode=training)
|
|
655
658
|
|
|
656
659
|
def get_lr_names(self) -> list[str]:
|
|
657
660
|
"""Returns the learning rates of the algorithm."""
|
|
@@ -2063,8 +2066,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2063
2066
|
accelerator: Optional[Accelerator] = None,
|
|
2064
2067
|
) -> None:
|
|
2065
2068
|
raise NotImplementedError(
|
|
2066
|
-
"The load class method is not supported for this algorithm class."
|
|
2067
|
-
"""
|
|
2069
|
+
"The load class method is not supported for this algorithm class." """
|
|
2068
2070
|
To load a saved LLM, please load the model as follows, and then re-instantiate the GRPO
|
|
2069
2071
|
class, using the pre-trained model.
|
|
2070
2072
|
|
agilerl/algorithms/dqn.py
CHANGED
|
@@ -7,7 +7,6 @@ import torch._dynamo
|
|
|
7
7
|
import torch.nn as nn
|
|
8
8
|
import torch.optim as optim
|
|
9
9
|
from gymnasium import spaces
|
|
10
|
-
from tensordict import TensorDict, from_module
|
|
11
10
|
from tensordict.nn import CudaGraphModule
|
|
12
11
|
|
|
13
12
|
from agilerl.algorithms.core import OptimizerWrapper, RLAlgorithm
|
|
@@ -144,8 +143,8 @@ class DQN(RLAlgorithm):
|
|
|
144
143
|
self.actor = create_actor()
|
|
145
144
|
self.actor_target = create_actor()
|
|
146
145
|
|
|
147
|
-
#
|
|
148
|
-
self.
|
|
146
|
+
# Initialize target network (same pattern as DDPG; post-mutation sync via reinit_shared_networks)
|
|
147
|
+
self.actor_target.load_state_dict(self.actor.state_dict())
|
|
149
148
|
|
|
150
149
|
# Initialize optimizer with OptimizerWrapper
|
|
151
150
|
self.optimizer = OptimizerWrapper(
|
|
@@ -172,7 +171,7 @@ class DQN(RLAlgorithm):
|
|
|
172
171
|
self.update = CudaGraphModule(self.update)
|
|
173
172
|
self._get_action = CudaGraphModule(self._get_action)
|
|
174
173
|
|
|
175
|
-
# Register DQN network groups
|
|
174
|
+
# Register DQN network groups
|
|
176
175
|
self.register_network_group(
|
|
177
176
|
NetworkGroup(
|
|
178
177
|
eval_network=self.actor,
|
|
@@ -180,27 +179,6 @@ class DQN(RLAlgorithm):
|
|
|
180
179
|
policy=True,
|
|
181
180
|
)
|
|
182
181
|
)
|
|
183
|
-
self.register_mutation_hook(self.init_hook)
|
|
184
|
-
|
|
185
|
-
def init_hook(self) -> None:
|
|
186
|
-
"""Resets module parameters for the detached and target networks."""
|
|
187
|
-
param_vals: TensorDict = from_module(self.actor).detach()
|
|
188
|
-
|
|
189
|
-
# NOTE: This removes the target params from the computation graph which
|
|
190
|
-
# reduces memory overhead and speeds up training, however these won't
|
|
191
|
-
# appear in the modules parameters
|
|
192
|
-
target_params: TensorDict = param_vals.clone().lock_()
|
|
193
|
-
|
|
194
|
-
# This hook is prompted after performing architecture mutations on policy / evaluation
|
|
195
|
-
# networks, which will fail since the target network is a shared network that won't be
|
|
196
|
-
# reintiialized until the end. We can bypass the error safely for this reason.
|
|
197
|
-
try:
|
|
198
|
-
target_params.to_module(self.actor_target)
|
|
199
|
-
except KeyError:
|
|
200
|
-
pass
|
|
201
|
-
finally:
|
|
202
|
-
self.param_vals = param_vals
|
|
203
|
-
self.target_params = target_params
|
|
204
182
|
|
|
205
183
|
def get_action(
|
|
206
184
|
self,
|
|
@@ -260,8 +238,10 @@ class DQN(RLAlgorithm):
|
|
|
260
238
|
:return: Selected action(s) as tensor
|
|
261
239
|
:rtype: torch.Tensor
|
|
262
240
|
"""
|
|
241
|
+
self.actor.eval()
|
|
263
242
|
with torch.no_grad():
|
|
264
243
|
q_values = self.actor(obs)
|
|
244
|
+
self.actor.train()
|
|
265
245
|
|
|
266
246
|
# Masked random actions
|
|
267
247
|
masked_random_values = torch.rand_like(q_values) * action_mask
|
agilerl/algorithms/ilql.py
CHANGED
|
@@ -331,7 +331,7 @@ class ILQL(nn.Module):
|
|
|
331
331
|
tok_emb=input_embeddings,
|
|
332
332
|
attn_mask=input_attn_mask,
|
|
333
333
|
pos=position_ids,
|
|
334
|
-
**qv_kwargs
|
|
334
|
+
**qv_kwargs,
|
|
335
335
|
)
|
|
336
336
|
hidden_states = model_hidden_states[-1][:, prefix_t:, :]
|
|
337
337
|
|
|
@@ -345,7 +345,7 @@ class ILQL(nn.Module):
|
|
|
345
345
|
tok_emb=target_input_embeddings,
|
|
346
346
|
attn_mask=input_attn_mask,
|
|
347
347
|
pos=position_ids,
|
|
348
|
-
**target_kwargs
|
|
348
|
+
**target_kwargs,
|
|
349
349
|
)
|
|
350
350
|
target_hidden_states = target_hidden_states[-1][:, prefix_t:, :]
|
|
351
351
|
|
|
@@ -373,7 +373,7 @@ class ILQL(nn.Module):
|
|
|
373
373
|
tok_emb=policy_input_embeddings,
|
|
374
374
|
attn_mask=input_attn_mask,
|
|
375
375
|
pos=position_ids,
|
|
376
|
-
**policy_kwargs
|
|
376
|
+
**policy_kwargs,
|
|
377
377
|
)
|
|
378
378
|
else:
|
|
379
379
|
(
|
|
@@ -385,7 +385,7 @@ class ILQL(nn.Module):
|
|
|
385
385
|
tok_emb=policy_input_embeddings,
|
|
386
386
|
attn_mask=input_attn_mask,
|
|
387
387
|
pos=position_ids,
|
|
388
|
-
**policy_kwargs
|
|
388
|
+
**policy_kwargs,
|
|
389
389
|
)
|
|
390
390
|
policy_hidden_states = policy_hidden_states[-1][:, prefix_t:, :]
|
|
391
391
|
|
|
@@ -626,7 +626,7 @@ class ILQL(nn.Module):
|
|
|
626
626
|
qv_kwargs=None,
|
|
627
627
|
policy_kwargs=None,
|
|
628
628
|
target_kwargs=None,
|
|
629
|
-
**kwargs
|
|
629
|
+
**kwargs,
|
|
630
630
|
):
|
|
631
631
|
prepared_inputs = self.prepare_inputs(items)
|
|
632
632
|
tokens, attn_mask = prepared_inputs["tokens"], prepared_inputs["attn_mask"]
|
agilerl/algorithms/ippo.py
CHANGED
|
@@ -671,7 +671,7 @@ class IPPO(MultiAgentRLAlgorithm):
|
|
|
671
671
|
:param action_space: Action space for the agent
|
|
672
672
|
:type action_space: gymnasium.spaces
|
|
673
673
|
"""
|
|
674
|
-
|
|
674
|
+
states, actions, log_probs, rewards, dones, values, next_state, next_done = (
|
|
675
675
|
experiences
|
|
676
676
|
)
|
|
677
677
|
|
agilerl/networks/q_networks.py
CHANGED
|
@@ -248,7 +248,7 @@ class RainbowQNetwork(EvolvableNetwork):
|
|
|
248
248
|
num_atoms=self.num_atoms,
|
|
249
249
|
support=self.support,
|
|
250
250
|
device=self.device,
|
|
251
|
-
**net_config
|
|
251
|
+
**net_config,
|
|
252
252
|
)
|
|
253
253
|
|
|
254
254
|
def forward(
|
|
@@ -279,7 +279,7 @@ class RainbowQNetwork(EvolvableNetwork):
|
|
|
279
279
|
num_atoms=self.num_atoms,
|
|
280
280
|
support=self.support,
|
|
281
281
|
device=self.device,
|
|
282
|
-
**self.head_net.net_config
|
|
282
|
+
**self.head_net.net_config,
|
|
283
283
|
)
|
|
284
284
|
|
|
285
285
|
self.head_net = EvolvableModule.preserve_parameters(self.head_net, head_net)
|
agilerl/protocols.py
CHANGED
|
@@ -385,7 +385,7 @@ class PreTrainedModelProtocol(Protocol):
|
|
|
385
385
|
input_ids: torch.Tensor,
|
|
386
386
|
attention_mask: Optional[torch.Tensor] = None,
|
|
387
387
|
generation_config: Optional["GenerationConfigProtocol"] = None,
|
|
388
|
-
**kwargs: Any
|
|
388
|
+
**kwargs: Any,
|
|
389
389
|
) -> torch.Tensor: ...
|
|
390
390
|
def forward(self, *args: Any, **kwargs: Any) -> Any: ...
|
|
391
391
|
def parameters(self) -> Generator: ...
|
|
@@ -416,7 +416,7 @@ class PeftModelProtocol(Protocol):
|
|
|
416
416
|
input_ids: torch.Tensor,
|
|
417
417
|
attention_mask: Optional[torch.Tensor] = None,
|
|
418
418
|
generation_config: Optional["GenerationConfigProtocol"] = None,
|
|
419
|
-
**kwargs: Any
|
|
419
|
+
**kwargs: Any,
|
|
420
420
|
) -> torch.Tensor: ...
|
|
421
421
|
def forward(self, *args: Any, **kwargs: Any) -> Any: ...
|
|
422
422
|
def parameters(self) -> Generator: ...
|
|
@@ -145,7 +145,7 @@ def train_off_policy(
|
|
|
145
145
|
assert isinstance(max_steps, int), "Number of steps must be an integer."
|
|
146
146
|
assert isinstance(evo_steps, int), "Evolution frequency must be an integer."
|
|
147
147
|
assert isinstance(eps_start, float), "Starting epsilon must be a float."
|
|
148
|
-
assert isinstance(eps_end, float), "Final value of
|
|
148
|
+
assert isinstance(eps_end, float), "Final value of epsilon must be a float."
|
|
149
149
|
assert isinstance(eps_decay, float), "Epsilon decay rate must be a float."
|
|
150
150
|
if target is not None:
|
|
151
151
|
assert isinstance(
|
agilerl/utils/algo_utils.py
CHANGED
|
@@ -6,7 +6,7 @@ from collections import OrderedDict, defaultdict
|
|
|
6
6
|
from dataclasses import dataclass
|
|
7
7
|
from functools import singledispatch
|
|
8
8
|
from numbers import Number
|
|
9
|
-
from typing import Any, Optional, Union
|
|
9
|
+
from typing import Any, ForwardRef, Optional, Union
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import torch
|
|
@@ -50,7 +50,7 @@ if HAS_LLM_DEPENDENCIES:
|
|
|
50
50
|
|
|
51
51
|
PreTrainedModelType = Union[PeftModel, PreTrainedModel]
|
|
52
52
|
else:
|
|
53
|
-
PreTrainedModelType = Union["PeftModel", "PreTrainedModel"]
|
|
53
|
+
PreTrainedModelType = Union[ForwardRef("PeftModel"), ForwardRef("PreTrainedModel")]
|
|
54
54
|
|
|
55
55
|
|
|
56
56
|
def check_supported_space(observation_space: GymSpaceType) -> None:
|
agilerl/utils/llm_utils.py
CHANGED
|
@@ -105,7 +105,6 @@ class HuggingFaceGym(gym.Env, ABC):
|
|
|
105
105
|
accelerator: Accelerator | None = None,
|
|
106
106
|
seed: int = 42,
|
|
107
107
|
) -> None:
|
|
108
|
-
|
|
109
108
|
self.name = train_dataset.info.dataset_name
|
|
110
109
|
self.tokenizer = tokenizer
|
|
111
110
|
self.data_batch_size_per_gpu = data_batch_size_per_gpu
|
|
@@ -431,7 +430,6 @@ class ReasoningGym(HuggingFaceGym):
|
|
|
431
430
|
"""
|
|
432
431
|
|
|
433
432
|
def collate_fn(batch):
|
|
434
|
-
|
|
435
433
|
questions = [item["question"] for item in batch]
|
|
436
434
|
answers = [item["answer"] for item in batch]
|
|
437
435
|
|
|
@@ -551,7 +549,6 @@ class PreferenceGym(HuggingFaceGym):
|
|
|
551
549
|
"""
|
|
552
550
|
|
|
553
551
|
def collate_fn(batch: list[dict[str, str]]) -> dict[str, str]:
|
|
554
|
-
|
|
555
552
|
prompts = [item["prompt"] for item in batch]
|
|
556
553
|
chosen = [item["chosen"] for item in batch]
|
|
557
554
|
rejected = [item["rejected"] for item in batch]
|
agilerl/wrappers/agent.py
CHANGED
|
@@ -597,8 +597,8 @@ class AsyncAgentsWrapper(AgentWrapper[MultiAgentRLAlgorithm]):
|
|
|
597
597
|
:return: Learning information
|
|
598
598
|
:rtype: Any
|
|
599
599
|
"""
|
|
600
|
-
|
|
601
|
-
|
|
600
|
+
states, actions, log_probs, rewards, dones, values, next_state, next_done = map(
|
|
601
|
+
self.stack_experiences, experiences
|
|
602
602
|
)
|
|
603
603
|
|
|
604
604
|
# Handle case where we haven't collected a next state for each sub-agent
|
|
@@ -1,52 +1,51 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: agilerl
|
|
3
|
-
Version: 2.4.
|
|
3
|
+
Version: 2.4.2
|
|
4
4
|
Summary: AgileRL is a deep reinforcement learning library focused on improving RL development through RLOps.
|
|
5
|
-
|
|
5
|
+
Author-email: Nick Ustaran-Anderegg <dev@agilerl.com>
|
|
6
|
+
License-Expression: Apache-2.0
|
|
6
7
|
License-File: LICENSE
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
Requires-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
8
|
+
Requires-Python: <3.13,>=3.10
|
|
9
|
+
Requires-Dist: accelerate~=1.7.0
|
|
10
|
+
Requires-Dist: dill~=0.3.7
|
|
11
|
+
Requires-Dist: fastrand~=1.3.0
|
|
12
|
+
Requires-Dist: flatten-dict~=0.4.2
|
|
13
|
+
Requires-Dist: google-cloud-storage~=2.5.0
|
|
14
|
+
Requires-Dist: gymnasium~=1.0.0
|
|
15
|
+
Requires-Dist: h5py~=3.8.0
|
|
16
|
+
Requires-Dist: hydra-core~=1.3.2
|
|
17
|
+
Requires-Dist: jax[cpu]~=0.4.31
|
|
18
|
+
Requires-Dist: matplotlib<3.10,~=3.9.4
|
|
19
|
+
Requires-Dist: minari[all]==0.5.2
|
|
20
|
+
Requires-Dist: numpy~=1.26.4
|
|
21
|
+
Requires-Dist: omegaconf~=2.3.0
|
|
22
|
+
Requires-Dist: packaging>=20.0
|
|
23
|
+
Requires-Dist: pandas~=2.2.3
|
|
24
|
+
Requires-Dist: pettingzoo~=1.23.1
|
|
25
|
+
Requires-Dist: pre-commit~=3.4.0
|
|
26
|
+
Requires-Dist: pygame~=2.6.0
|
|
27
|
+
Requires-Dist: pymunk~=6.2.0
|
|
28
|
+
Requires-Dist: redis~=4.4.4
|
|
29
|
+
Requires-Dist: scipy~=1.12.0
|
|
30
|
+
Requires-Dist: supersuit~=3.9.0
|
|
31
|
+
Requires-Dist: tensordict~=0.8
|
|
32
|
+
Requires-Dist: termcolor~=1.1.0
|
|
33
|
+
Requires-Dist: torch==2.7.1
|
|
34
|
+
Requires-Dist: tqdm~=4.66.4
|
|
35
|
+
Requires-Dist: ucimlrepo~=0.0.3
|
|
36
|
+
Requires-Dist: wandb~=0.17.6
|
|
15
37
|
Provides-Extra: all
|
|
38
|
+
Requires-Dist: datasets==4.4.1; extra == 'all'
|
|
39
|
+
Requires-Dist: deepspeed~=0.17.1; extra == 'all'
|
|
40
|
+
Requires-Dist: peft~=0.18.0; extra == 'all'
|
|
41
|
+
Requires-Dist: transformers~=4.57.1; extra == 'all'
|
|
42
|
+
Requires-Dist: vllm~=0.10.0; extra == 'all'
|
|
16
43
|
Provides-Extra: llm
|
|
17
|
-
Requires-Dist:
|
|
18
|
-
Requires-Dist:
|
|
19
|
-
Requires-Dist:
|
|
20
|
-
Requires-Dist:
|
|
21
|
-
Requires-Dist:
|
|
22
|
-
Requires-Dist: fastrand (>=1.3.0,<2.0.0)
|
|
23
|
-
Requires-Dist: flatten_dict (>=0.4.2,<0.5.0)
|
|
24
|
-
Requires-Dist: google-cloud-storage (>=2.5.0,<3.0.0)
|
|
25
|
-
Requires-Dist: gymnasium (>=1.0.0,<2.0.0)
|
|
26
|
-
Requires-Dist: h5py (>=3.8.0,<4.0.0)
|
|
27
|
-
Requires-Dist: hydra-core (>=1.3.2,<2.0.0)
|
|
28
|
-
Requires-Dist: jax[cpu] (>=0.4.31,<0.5.0)
|
|
29
|
-
Requires-Dist: matplotlib (>=3.9.4,<3.10.0)
|
|
30
|
-
Requires-Dist: minari[all] (==0.5.2)
|
|
31
|
-
Requires-Dist: numpy (>=1.26.4,<2.0.0)
|
|
32
|
-
Requires-Dist: omegaconf (>=2.3.0,<3.0.0)
|
|
33
|
-
Requires-Dist: packaging (>=20.0)
|
|
34
|
-
Requires-Dist: pandas (>=2.2.3,<3.0.0)
|
|
35
|
-
Requires-Dist: peft (>=0.18.0,<0.19.0) ; extra == "llm" or extra == "all"
|
|
36
|
-
Requires-Dist: pettingzoo (>=1.23.1,<2.0.0)
|
|
37
|
-
Requires-Dist: pre-commit (>=3.4.0,<4.0.0)
|
|
38
|
-
Requires-Dist: pygame (>=2.6.0,<3.0.0)
|
|
39
|
-
Requires-Dist: pymunk (>=6.2.0,<7.0.0)
|
|
40
|
-
Requires-Dist: redis (>=4.4.4,<5.0.0)
|
|
41
|
-
Requires-Dist: scipy (>=1.12.0,<2.0.0)
|
|
42
|
-
Requires-Dist: tensordict (>=0.8,<0.9)
|
|
43
|
-
Requires-Dist: termcolor (>=1.1.0,<2.0.0)
|
|
44
|
-
Requires-Dist: torch (==2.7.1)
|
|
45
|
-
Requires-Dist: tqdm (>=4.66.4,<5.0.0)
|
|
46
|
-
Requires-Dist: transformers (>=4.57.1,<5.0.0) ; extra == "llm" or extra == "all"
|
|
47
|
-
Requires-Dist: ucimlrepo (>=0.0.3,<0.0.4)
|
|
48
|
-
Requires-Dist: vllm (==0.10.0) ; extra == "llm" or extra == "all"
|
|
49
|
-
Requires-Dist: wandb (>=0.17.6,<0.18.0)
|
|
44
|
+
Requires-Dist: datasets==4.4.1; extra == 'llm'
|
|
45
|
+
Requires-Dist: deepspeed~=0.17.1; extra == 'llm'
|
|
46
|
+
Requires-Dist: peft~=0.18.0; extra == 'llm'
|
|
47
|
+
Requires-Dist: transformers~=4.57.1; extra == 'llm'
|
|
48
|
+
Requires-Dist: vllm~=0.10.0; extra == 'llm'
|
|
50
49
|
Description-Content-Type: text/markdown
|
|
51
50
|
|
|
52
51
|
<p align="center">
|
|
@@ -363,4 +362,3 @@ title = {{AgileRL}},
|
|
|
363
362
|
url = {https://github.com/AgileRL/AgileRL}
|
|
364
363
|
}
|
|
365
364
|
```
|
|
366
|
-
|
|
@@ -1,24 +1,26 @@
|
|
|
1
1
|
agilerl/__init__.py,sha256=0hZjnAULURFWpshG_mhNdaHhf8nlc7h2sR7CLEqup54,572
|
|
2
|
+
agilerl/protocols.py,sha256=AEOOsCc4zbYWqAfuZDb1Eki0Cu3QLTB42NU3-kNDZXI,14054
|
|
3
|
+
agilerl/typing.py,sha256=JtLhZMNyFzrnSeos6ltWyD_8yWFkc8Zx-OIC3d1CPQc,5442
|
|
2
4
|
agilerl/algorithms/__init__.py,sha256=5N4DqCEETuFBlhnzf7XEQzIClRXX9e-FxQqQHgLh3Es,661
|
|
3
|
-
agilerl/algorithms/bc_lm.py,sha256=
|
|
4
|
-
agilerl/algorithms/core/__init__.py,sha256=kKGnzj4TGRZKk2J6jcaKkK3s1LjCYu979o8u8OJUZjI,268
|
|
5
|
-
agilerl/algorithms/core/base.py,sha256=LeFN0l17oCUxp23zFayq8tr9RFbSw--68TPa1FwobuA,121970
|
|
6
|
-
agilerl/algorithms/core/optimizer_wrapper.py,sha256=UQTlnv-mbNGlQ3RX9ocHtczXhTZq1MBKO6OdoQ879uM,13086
|
|
7
|
-
agilerl/algorithms/core/registry.py,sha256=ndaw9U814tHrPBhEPO9kLIDNKmLStTwLXPsnu-nnj8c,19991
|
|
5
|
+
agilerl/algorithms/bc_lm.py,sha256=aL1ibo8Itv--A4yaW5I55fbx6sKRtRqqTZw3UDTAl-s,22948
|
|
8
6
|
agilerl/algorithms/cqn.py,sha256=3zE6LPWPV8ut5hLPllw3yhY_amonbiSmbBXJU0-7Zo4,12583
|
|
9
7
|
agilerl/algorithms/ddpg.py,sha256=uau1E37D9SARlf_bTswfZQGQRobh9tOcB6hoRpszx_g,21365
|
|
10
8
|
agilerl/algorithms/dpo.py,sha256=kN2wp2Ms_2sFiJcmqpVPxG4XHoJis6l6BQlSCsj07pk,15777
|
|
11
|
-
agilerl/algorithms/dqn.py,sha256=
|
|
9
|
+
agilerl/algorithms/dqn.py,sha256=3WYga_sVDflP1xVUJ2u-24jcmC_a5F0EXfthDJ5fbpQ,16210
|
|
12
10
|
agilerl/algorithms/dqn_rainbow.py,sha256=HyP-jkiVOkBUJmvpUlrB6VHo8m-AO2Z84M3Zb_ZP6fQ,20483
|
|
13
11
|
agilerl/algorithms/grpo.py,sha256=9VvRf4jQNDOfUlkKDZBNiiBACUybgeOxSQgnszjm2BM,19237
|
|
14
|
-
agilerl/algorithms/ilql.py,sha256=
|
|
15
|
-
agilerl/algorithms/ippo.py,sha256=
|
|
12
|
+
agilerl/algorithms/ilql.py,sha256=yQ6v6Y7n4JtsknCyhXOoJWMu-jbZX8CsLoitsEG2_YY,79849
|
|
13
|
+
agilerl/algorithms/ippo.py,sha256=2JBPYnXGBxVbgkvy5BEa_m3Y4knKuIMA0EFNR3YADsQ,39083
|
|
16
14
|
agilerl/algorithms/maddpg.py,sha256=qVXDyb_W51lZtvst4K3yiosSy58BEBYbck8wF8CViBA,33908
|
|
17
15
|
agilerl/algorithms/matd3.py,sha256=n17y6PvM51r290Def_QeFT4p7TMo54MIDLN30XqlMk8,37926
|
|
18
16
|
agilerl/algorithms/neural_ts_bandit.py,sha256=jL_5mnExjMZdiIdwMWXT1XH-hWtaIiSokxi_n_qGTDY,11790
|
|
19
17
|
agilerl/algorithms/neural_ucb_bandit.py,sha256=wwo2sUNkIFtDDEOHIOp9aWhf5oeO9goi9p48tdH1Uno,11960
|
|
20
18
|
agilerl/algorithms/ppo.py,sha256=yAkgZT7WbZKn2oq62DFDPcfAmnRHomVPm4yNlI9-B-c,53025
|
|
21
19
|
agilerl/algorithms/td3.py,sha256=gFYlwwxYQgaWGDT5a-c3AOwI5WGQv4J4eeBotw1-fZY,23017
|
|
20
|
+
agilerl/algorithms/core/__init__.py,sha256=kKGnzj4TGRZKk2J6jcaKkK3s1LjCYu979o8u8OJUZjI,268
|
|
21
|
+
agilerl/algorithms/core/base.py,sha256=R0GyAIC3CaRzNKkufU_omfKzpasGefxCFNq5yvvwQ78,122119
|
|
22
|
+
agilerl/algorithms/core/optimizer_wrapper.py,sha256=UQTlnv-mbNGlQ3RX9ocHtczXhTZq1MBKO6OdoQ879uM,13086
|
|
23
|
+
agilerl/algorithms/core/registry.py,sha256=ndaw9U814tHrPBhEPO9kLIDNKmLStTwLXPsnu-nnj8c,19991
|
|
22
24
|
agilerl/components/__init__.py,sha256=cc3bYeOdsNp-Puj_4_Ukj3kwmEqUqFeUo-5dZ3tP47o,292
|
|
23
25
|
agilerl/components/data.py,sha256=KiXS4OPgC0VpM9cP8HMDoDhxYX-khL9vKEi2pYIWd7E,3832
|
|
24
26
|
agilerl/components/multi_agent_replay_buffer.py,sha256=VfT90DhlrgMomzW_8Nw5zQrD4908hFLMgg0kXpy1ZHE,8604
|
|
@@ -53,9 +55,8 @@ agilerl/networks/base.py,sha256=Lkhj0yujVgNDKXKn_ea0hBFDSyStDwL5AuEyFnyGjmE,2163
|
|
|
53
55
|
agilerl/networks/custom_modules.py,sha256=n6WR5DsXBQwQtvQk6lHiXP-DR-Ma6lGjrOtySSrIAiA,6843
|
|
54
56
|
agilerl/networks/distributions.py,sha256=mzntWgwoEdZKAspInbmvfc6_0rGuPdquqQyQkVSWvoo,18252
|
|
55
57
|
agilerl/networks/distributions_experimental.py,sha256=K6_EYflAlR6qRouRr6SJXnT19w7QhOA1bwN7kCl3DJ8,18890
|
|
56
|
-
agilerl/networks/q_networks.py,sha256=
|
|
58
|
+
agilerl/networks/q_networks.py,sha256=pgX7lg-_bo724A7BC-b0ViLgXhWA5xbvoLnRnn8f9sU,17286
|
|
57
59
|
agilerl/networks/value_networks.py,sha256=ZLX5vQIxeV65uxOzv2r5QMxF_-fzFT8N1et3lHdQP7E,4630
|
|
58
|
-
agilerl/protocols.py,sha256=SQ8T79jmZAqlm2fJ1Qo0kefU5w2c4Mh_wUk9RtiPego,14052
|
|
59
60
|
agilerl/rollouts/__init__.py,sha256=dGR9BnXliQI6yvXPwecV7g5TCtCEPbyIB-W1a5evBBY,130
|
|
60
61
|
agilerl/rollouts/on_policy.py,sha256=VOxUjwzyYngzrTEW9asXsgz1O6lRTUn_PijmjqtzGwQ,8036
|
|
61
62
|
agilerl/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -63,16 +64,15 @@ agilerl/training/train_bandits.py,sha256=pi6GFQrGsFkqgD8V69ayVlzcNUPMIF3PYaBEgJU
|
|
|
63
64
|
agilerl/training/train_llm.py,sha256=1KxiJQGPLCVxMqsVNUWzsGHEwDL9ehvQ7a3gEELr2zM,27602
|
|
64
65
|
agilerl/training/train_multi_agent_off_policy.py,sha256=p1VOBDqyt14LD5HUQ-YF5m2jce_LphgYa38DP4asY30,23349
|
|
65
66
|
agilerl/training/train_multi_agent_on_policy.py,sha256=WDtUTpIpPuQpPdZN-1H_gwqHICyPRLfWIJeyYtClQKc,24427
|
|
66
|
-
agilerl/training/train_off_policy.py,sha256=
|
|
67
|
+
agilerl/training/train_off_policy.py,sha256=iyMHnFrOjjuPxcIesrg9WFRmDFxXXI1guqeXMVb5XXg,23511
|
|
67
68
|
agilerl/training/train_offline.py,sha256=qAlr3lGQf7EfSSmTtmohi80rUN4HMha955q3pae6TCY,13406
|
|
68
69
|
agilerl/training/train_on_policy.py,sha256=iQEIHq_JgBIBH2GPJeLN6QmPRho-_beUdro1H9DPkUA,19360
|
|
69
|
-
agilerl/typing.py,sha256=JtLhZMNyFzrnSeos6ltWyD_8yWFkc8Zx-OIC3d1CPQc,5442
|
|
70
70
|
agilerl/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
71
|
-
agilerl/utils/algo_utils.py,sha256=
|
|
71
|
+
agilerl/utils/algo_utils.py,sha256=tFH89AD63qsN80xeyn9m85qRfjD12VN0nCan2_qRjT4,60412
|
|
72
72
|
agilerl/utils/cache.py,sha256=8Q1SYbTxQYzIn40UMy32EWMvtgaduY1k5jqwPihxJ_Q,3418
|
|
73
73
|
agilerl/utils/evolvable_networks.py,sha256=cIJHzadFOaK0aAqwn96HvnuH4atLBxrQ3cwpR1nxvUo,23265
|
|
74
74
|
agilerl/utils/ilql_utils.py,sha256=dU_vbwOB6VsODGGu_hOyDN_xRtFKVhZbxMISFlAUM5s,2293
|
|
75
|
-
agilerl/utils/llm_utils.py,sha256=
|
|
75
|
+
agilerl/utils/llm_utils.py,sha256=Rdwfo3L3TDyqfk4QfHrRRZ4-r8nblvMXp3-Qnf9W5k8,26591
|
|
76
76
|
agilerl/utils/log_utils.py,sha256=OIhj86V97-ijlUENic2WKIWipB5ITJyBIGM_ZPZg5Vo,4401
|
|
77
77
|
agilerl/utils/minari_utils.py,sha256=WNFzt9ZQuvWy3w84MFhhGkA0e9MAgc4KSI_cmPgFTBo,5109
|
|
78
78
|
agilerl/utils/probe_envs.py,sha256=q2uyPQW7mbo9x4c_Yq9vi2Yu1X9qyLm43adET9SFf9Y,39796
|
|
@@ -84,12 +84,12 @@ agilerl/vector/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
|
84
84
|
agilerl/vector/pz_async_vec_env.py,sha256=uj9TyCn0SWksTUOW84RGspMkXqdGG-wjr86w08uCMb0,36742
|
|
85
85
|
agilerl/vector/pz_vec_env.py,sha256=sFVqm8eecxVHahTpFZEE3fvyZrmp2vMu0GECik8el6M,5978
|
|
86
86
|
agilerl/wrappers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
87
|
-
agilerl/wrappers/agent.py,sha256=
|
|
87
|
+
agilerl/wrappers/agent.py,sha256=VyWju26YIfN4g8NCxTM-zMc3IGf5YKCXtyoeHynfEXE,23158
|
|
88
88
|
agilerl/wrappers/learning.py,sha256=nSVMg6eUBWn13NNdIFgCEHj31CaN_dGryQa13SmMvBw,2774
|
|
89
89
|
agilerl/wrappers/make_evolvable.py,sha256=sb9oAorGAayrD_6lNbyvHhefA_RKO4bSSNjqS6u9UhI,51079
|
|
90
90
|
agilerl/wrappers/pettingzoo_wrappers.py,sha256=Pw8VzabxfYCw5ad15y5J3rAH1teA6nVVo0RHCTTdOPQ,2063
|
|
91
91
|
agilerl/wrappers/utils.py,sha256=pENFH2AxsXd22s8HGUeM-jRowC0tmjHLWjqDwIq12l8,2194
|
|
92
|
-
agilerl-2.4.
|
|
93
|
-
agilerl-2.4.
|
|
94
|
-
agilerl-2.4.
|
|
95
|
-
agilerl-2.4.
|
|
92
|
+
agilerl-2.4.2.dist-info/METADATA,sha256=ohWA3cJL3JmVZr5svRWRslLWbCX5Lbx46ceqL9MWwgQ,20164
|
|
93
|
+
agilerl-2.4.2.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
94
|
+
agilerl-2.4.2.dist-info/licenses/LICENSE,sha256=vPX_VnIseflXJ30mQvwbXZoe208EtIr9ZVrl6cfdQNs,11720
|
|
95
|
+
agilerl-2.4.2.dist-info/RECORD,,
|
|
File without changes
|