saferl-lite 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
agents/__init__.py ADDED
File without changes
@@ -0,0 +1,82 @@
1
+ # agents/constrained_dqn.py
2
+
3
+ import numpy as np
4
+ import random
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.optim as optim
8
+ from collections import deque
9
+ from agents.constraints import Constraint
10
+
11
+
12
+ class DQNetwork(nn.Module):
13
+ def __init__(self, input_dim, output_dim):
14
+ super().__init__()
15
+ self.net = nn.Sequential(
16
+ nn.Linear(input_dim, 128),
17
+ nn.ReLU(),
18
+ nn.Linear(128, 128),
19
+ nn.ReLU(),
20
+ nn.Linear(128, output_dim),
21
+ )
22
+
23
+ def forward(self, x):
24
+ return self.net(x)
25
+
26
+
27
+ class ConstrainedDQNAgent:
28
+ def __init__(self, state_dim, action_dim, constraint: Constraint = None):
29
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ self.q_net = DQNetwork(state_dim, action_dim).to(self.device)
31
+ self.target_net = DQNetwork(state_dim, action_dim).to(self.device)
32
+ self.target_net.load_state_dict(self.q_net.state_dict())
33
+
34
+ self.optimizer = optim.Adam(self.q_net.parameters(), lr=1e-3)
35
+ self.memory = deque(maxlen=10000)
36
+ self.gamma = 0.99
37
+ self.batch_size = 64
38
+
39
+ self.constraint = constraint
40
+ self.action_dim = action_dim
41
+
42
+ def select_action(self, state, epsilon):
43
+ if random.random() < epsilon:
44
+ return random.randint(0, self.action_dim - 1)
45
+ state = torch.tensor(state, dtype=torch.float32).to(self.device)
46
+ with torch.no_grad():
47
+ q_values = self.q_net(state)
48
+ return q_values.argmax().item()
49
+
50
+ def store_transition(self, s, a, r, s_next, done):
51
+ self.memory.append((s, a, r, s_next, done))
52
+
53
+ def update(self):
54
+ if len(self.memory) < self.batch_size:
55
+ return
56
+ batch = random.sample(self.memory, self.batch_size)
57
+ s, a, r, s_next, done = zip(*batch)
58
+
59
+ s = torch.tensor(s, dtype=torch.float32).to(self.device)
60
+ a = torch.tensor(a).to(self.device)
61
+ r = torch.tensor(r, dtype=torch.float32).to(self.device)
62
+ s_next = torch.tensor(s_next, dtype=torch.float32).to(self.device)
63
+ done = torch.tensor(done, dtype=torch.float32).to(self.device)
64
+
65
+ q_values = self.q_net(s).gather(1, a.unsqueeze(1)).squeeze()
66
+ with torch.no_grad():
67
+ target_q = r + self.gamma * self.target_net(s_next).max(1)[0] * (1 - done)
68
+
69
+ loss = nn.functional.mse_loss(q_values, target_q)
70
+ self.optimizer.zero_grad()
71
+ loss.backward()
72
+ self.optimizer.step()
73
+
74
+ def apply_constraint(self, state, action, reward):
75
+ if self.constraint:
76
+ penalty = self.constraint.compute_penalty(state, action, reward)
77
+ return reward - penalty, penalty
78
+ return reward, 0.0
79
+
80
+ def reset_constraints(self):
81
+ if self.constraint:
82
+ self.constraint.reset()
agents/constraints.py ADDED
@@ -0,0 +1,39 @@
1
+ # agents/constraints.py
2
+
3
+ from abc import ABC, abstractmethod
4
+
5
+
6
+ class Constraint(ABC):
7
+ """Abstract base class for constraints in SafeRL agents."""
8
+
9
+ @abstractmethod
10
+ def compute_penalty(self, state, action, reward) -> float:
11
+ """Return penalty value for a given step."""
12
+ pass
13
+
14
+ def reset(self):
15
+ """Optional: reset internal counters for new episode."""
16
+ pass
17
+
18
+
19
+ class ActionBudgetConstraint(Constraint):
20
+ def __init__(self, max_actions: int):
21
+ self.max_actions = max_actions
22
+ self.counter = 0
23
+
24
+ def compute_penalty(self, state, action, reward) -> float:
25
+ self.counter += 1
26
+ return 1.0 if self.counter > self.max_actions else 0.0
27
+
28
+ def reset(self):
29
+ self.counter = 0
30
+
31
+
32
+ class EnergyPenaltyConstraint(Constraint):
33
+ def __init__(self, energy_fn, max_energy):
34
+ self.energy_fn = energy_fn
35
+ self.max_energy = max_energy
36
+
37
+ def compute_penalty(self, state, action, reward) -> float:
38
+ energy = self.energy_fn(state, action)
39
+ return max(0.0, energy - self.max_energy)
envs/__init__.py ADDED
File without changes
envs/wrappers.py ADDED
@@ -0,0 +1,58 @@
1
+ # envs/wrappers.py
2
+
3
+ import gymnasium as gym
4
+ import numpy as np
5
+
6
+
7
+ class SafeEnvWrapper(gym.Wrapper):
8
+ def __init__(self, env, max_force: float = None, max_energy: float = None):
9
+ super().__init__(env)
10
+ self.max_force = max_force
11
+ self.max_energy = max_energy
12
+ self.violation_log = []
13
+ self.episode_log = []
14
+
15
+ def reset(self, **kwargs):
16
+ obs, info = self.env.reset(**kwargs)
17
+ self.violation_log.clear()
18
+ self.episode_log.clear()
19
+ return obs, info
20
+
21
+ def step(self, action):
22
+ obs, reward, terminated, truncated, info = self.env.step(action)
23
+ done = terminated or truncated
24
+
25
+ violation = 0.0
26
+
27
+ # Constraint: limit max force (CartPole)
28
+ if self.max_force is not None:
29
+ force = self._get_force(action)
30
+ if abs(force) > self.max_force:
31
+ violation = 1.0
32
+ reward -= 1.0 # penalize
33
+
34
+ # Note: remove _get_energy for now (MountainCar support will come later)
35
+
36
+ self.violation_log.append(violation)
37
+ self.episode_log.append(
38
+ {
39
+ "obs": obs,
40
+ "action": action,
41
+ "reward": reward,
42
+ "violation": violation,
43
+ }
44
+ )
45
+
46
+ return obs, reward, terminated, truncated, info
47
+
48
+ def _get_force(self, action):
49
+ # Assumes CartPole: force is ยฑ10
50
+ return 10.0 if action == 1 else -10.0
51
+
52
+ def _get_energy(self, prev_obs, action, next_obs):
53
+ if prev_obs is None:
54
+ return 0.0
55
+ # Simplified energy calculation: KE + PE
56
+ velocity = next_obs[1]
57
+ height = np.cos(3 * next_obs[0]) # approximates potential
58
+ return 0.5 * velocity**2 + 9.8 * height
File without changes
@@ -0,0 +1,23 @@
1
+ # explainability/saliency.py
2
+
3
+ import torch
4
+ from captum.attr import Saliency
5
+
6
+
7
+ class SaliencyExplainer:
8
+ def __init__(self, model, device="cpu"):
9
+ self.model = model
10
+ self.device = device
11
+ self.saliency = Saliency(self.model)
12
+
13
+ def explain(self, state_tensor, target_action: int):
14
+ """
15
+ state_tensor: 1D torch tensor (state) on correct device
16
+ target_action: int, index of the action you want to explain
17
+ Returns: 1D saliency values (array) for input features
18
+ """
19
+ state_tensor = state_tensor.unsqueeze(
20
+ 0
21
+ ).requires_grad_() # Shape: [1, input_dim]
22
+ attr = self.saliency.attribute(state_tensor, target=target_action)
23
+ return attr.squeeze().detach().cpu().numpy()
@@ -0,0 +1,31 @@
1
+ # explainability/shap_explainer.py
2
+
3
+ import shap
4
+ import numpy as np
5
+ import torch
6
+
7
+
8
+ class SHAPExplainer:
9
+ def __init__(self, model, input_dim, device="cpu"):
10
+ """
11
+ model: a function that maps np.array -> Q-values
12
+ input_dim: size of observation space
13
+ """
14
+ self.input_dim = input_dim
15
+ self.device = device
16
+
17
+ def model_wrapper(x_np):
18
+ x_tensor = torch.tensor(x_np, dtype=torch.float32).to(self.device)
19
+ with torch.no_grad():
20
+ return model(x_tensor).cpu().numpy()
21
+
22
+ self.explainer = shap.Explainer(
23
+ model_wrapper, shap.maskers.Independent(np.zeros((1, input_dim)))
24
+ )
25
+
26
+ def explain(self, state):
27
+ """
28
+ state: np.array of shape (input_dim,)
29
+ Returns: SHAP values for each input dimension
30
+ """
31
+ return self.explainer(np.array([state]))
@@ -0,0 +1,139 @@
1
+ Metadata-Version: 2.4
2
+ Name: saferl-lite
3
+ Version: 0.1.0
4
+ Summary: A lightweight, explainable, and constrained reinforcement learning toolkit.
5
+ Home-page: https://github.com/satyamcser/saferl-lite
6
+ Author: Satyam Mishra
7
+ Author-email: satyam@example.com
8
+ Project-URL: Documentation, https://satyamcser.github.io/saferl-lite/
9
+ Project-URL: Source, https://github.com/satyamcser/saferl-lite
10
+ Project-URL: Bug Tracker, https://github.com/satyamcser/saferl-lite/issues
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: License :: OSI Approved :: MIT License
13
+ Classifier: Operating System :: OS Independent
14
+ Classifier: Intended Audience :: Science/Research
15
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
16
+ Requires-Python: >=3.8
17
+ Description-Content-Type: text/markdown
18
+ License-File: LICENSE
19
+ Requires-Dist: gym
20
+ Requires-Dist: gymnasium
21
+ Requires-Dist: numpy
22
+ Requires-Dist: torch
23
+ Requires-Dist: matplotlib
24
+ Requires-Dist: seaborn
25
+ Requires-Dist: pre-commit
26
+ Requires-Dist: flake8
27
+ Requires-Dist: pyyaml
28
+ Requires-Dist: shap
29
+ Requires-Dist: captum
30
+ Requires-Dist: typer
31
+ Requires-Dist: scikit-learn
32
+ Requires-Dist: pandas
33
+ Requires-Dist: pytest
34
+ Requires-Dist: pytest-cov
35
+ Requires-Dist: coverage
36
+ Requires-Dist: mkdocs
37
+ Requires-Dist: wandb
38
+ Requires-Dist: mkdocs>=1.5
39
+ Requires-Dist: mkdocs-material>=9.5
40
+ Requires-Dist: mkdocstrings[python]
41
+ Dynamic: author
42
+ Dynamic: author-email
43
+ Dynamic: classifier
44
+ Dynamic: description
45
+ Dynamic: description-content-type
46
+ Dynamic: home-page
47
+ Dynamic: license-file
48
+ Dynamic: project-url
49
+ Dynamic: requires-dist
50
+ Dynamic: requires-python
51
+ Dynamic: summary
52
+
53
+ # ๐Ÿ” SafeRL-Lite
54
+
55
+ A **lightweight, explainable, and modular** Python library for **Constrained Reinforcement Learning (Safe RL)** with real-time **SHAP & saliency-based explainability**, custom metrics, and Gym-compatible wrappers.
56
+
57
+ <p align="center">
58
+ <img src="https://img.shields.io/github/license/satyamcser/saferl-lite?style=flat-square">
59
+ <img src="https://img.shields.io/github/stars/satyamcser/saferl-lite?style=flat-square">
60
+ <img src="https://img.shields.io/pypi/v/saferl-lite?style=flat-square">
61
+ <img src="https://img.shields.io/github/actions/workflow/status/satyamcser/saferl-lite/ci.yml?branch=main&style=flat-square">
62
+ </p>
63
+
64
+ ---
65
+
66
+ ## ๐ŸŒŸ Overview
67
+
68
+ **SafeRL-Lite** empowers reinforcement learning agents to act under **safety constraints**, while remaining **interpretable** and **modular** for fast experimentation. It wraps standard Gym environments and DQN-based agents with:
69
+
70
+ - โœ… Safety constraint logic
71
+ - ๐Ÿ” Visual explainability (SHAP, saliency maps)
72
+ - ๐Ÿ“Š Violation and reward tracking
73
+ - ๐Ÿงช Built-in testing and evaluations
74
+
75
+ ---
76
+
77
+ ## ๐Ÿ”ง Installation
78
+
79
+ > ๐Ÿ“ฆ PyPI (coming soon)
80
+ ```bash
81
+ pip install saferl-lite
82
+ ```
83
+
84
+ ## ๐Ÿ› ๏ธ From source:
85
+
86
+ ```bash
87
+ git clone https://github.com/satyamcser/saferl-lite.git
88
+ cd saferl-lite
89
+ pip install -e .
90
+ ```
91
+
92
+ ## ๐Ÿš€ Quickstart
93
+ Train a constrained DQN agent with saliency-based explainability:
94
+
95
+ ```bash
96
+ python train.py --env CartPole-v1 --constraint pole_angle --explain shap
97
+ ```
98
+
99
+ ๐Ÿ”น This:
100
+
101
+ - Adds a pole-angle constraint wrapper to the Gym env
102
+
103
+ - Logs violations
104
+
105
+ - Displays SHAP or saliency explanations for agent decisions
106
+
107
+ ## ๐Ÿง  Features
108
+ #### โœ… Constrained RL
109
+ - Add custom constraints via wrapper or logic class
110
+
111
+ - Violation logging and reward shaping
112
+
113
+ - Safe vs unsafe episode tracking
114
+
115
+ #### ๐Ÿ” Explainability
116
+ - SaliencyExplainer โ€” gradient-based visual heatmaps
117
+
118
+ - SHAPExplainer โ€” feature contribution values per decision
119
+
120
+ - Compatible with any PyTorch-based agent
121
+
122
+ #### ๐Ÿ“Š Metrics
123
+ - Constraint violation rate
124
+
125
+ - Episode reward
126
+
127
+ - Cumulative safe reward
128
+
129
+ - Action entropy & temporal behavior stats
130
+
131
+ #### ๐Ÿ“š Modularity
132
+ - Swap out agents, constraints, evaluators, or explainers
133
+
134
+ - Supports Gym environments
135
+
136
+ - Configurable training pipeline
137
+
138
+ ## ๐Ÿ“œ Citation
139
+ Coming soon after arXiv/preprint release.
@@ -0,0 +1,13 @@
1
+ agents/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ agents/constrained_dqn.py,sha256=dkVfgxBUEpT1gh4L2PJBXvwqGbIGFj8mYPFnacqNBgU,2814
3
+ agents/constraints.py,sha256=en8uB2gluDI6JDsz96lM0yUnFA-0DB9m_a3ycrVky8c,1075
4
+ envs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ envs/wrappers.py,sha256=rfk3cfsTsfD8NqUjEcJ-o7XGMmkBBHt5kfaCiE3AgAw,1749
6
+ explainability/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
+ explainability/saliency.py,sha256=EpvrpkRZWqYqd3lkRIkfIbJ0pw7G_hJ8GEiVfgPo88U,767
8
+ explainability/shap_explainer.py,sha256=Tj-fP947z8ixFdWRXHdR6D3a_wtznGN5x-DomU34xbc,883
9
+ saferl_lite-0.1.0.dist-info/licenses/LICENSE,sha256=WRhQPkdFDzbMFEhvoaq9gSNnbsy0lhSC8tFH3stLntY,1070
10
+ saferl_lite-0.1.0.dist-info/METADATA,sha256=k9EwE0Clqv-yIANmGdhJPemW4EhBI9kqAnw6xc74WJE,3868
11
+ saferl_lite-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
12
+ saferl_lite-0.1.0.dist-info/top_level.txt,sha256=f1IuezLA5sRnSuKZbl-VrS_Hh9pekOW2smLrpJLuiGg,27
13
+ saferl_lite-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Satyam Mishra
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,3 @@
1
+ agents
2
+ envs
3
+ explainability