agentmesh-lightning 2.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent_lightning_gov/__init__.py +36 -0
- agent_lightning_gov/emitter.py +314 -0
- agent_lightning_gov/environment.py +330 -0
- agent_lightning_gov/reward.py +339 -0
- agent_lightning_gov/runner.py +344 -0
- agentmesh_lightning-2.3.0.dist-info/METADATA +206 -0
- agentmesh_lightning-2.3.0.dist-info/RECORD +10 -0
- agentmesh_lightning-2.3.0.dist-info/WHEEL +5 -0
- agentmesh_lightning-2.3.0.dist-info/licenses/LICENSE +21 -0
- agentmesh_lightning-2.3.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
"""
|
|
4
|
+
Agent-Lightning Governance Integration
|
|
5
|
+
=======================================
|
|
6
|
+
|
|
7
|
+
Provides kernel-level safety during Agent-Lightning RL training.
|
|
8
|
+
|
|
9
|
+
Key components:
|
|
10
|
+
- GovernedRunner: Agent-Lightning runner with policy enforcement
|
|
11
|
+
- PolicyReward: Convert policy violations to RL penalties
|
|
12
|
+
- FlightRecorderEmitter: Export audit logs to LightningStore
|
|
13
|
+
- GovernedEnvironment: Training environment with governance constraints
|
|
14
|
+
|
|
15
|
+
Example:
|
|
16
|
+
>>> from agent_lightning_gov import GovernedRunner, PolicyReward
|
|
17
|
+
>>> from agent_os import KernelSpace
|
|
18
|
+
>>> from agent_os.policies import SQLPolicy
|
|
19
|
+
>>>
|
|
20
|
+
>>> kernel = KernelSpace(policy=SQLPolicy())
|
|
21
|
+
>>> runner = GovernedRunner(kernel)
|
|
22
|
+
>>> reward_fn = PolicyReward(kernel, base_reward_fn=accuracy)
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
from .emitter import FlightRecorderEmitter
|
|
26
|
+
from .environment import GovernedEnvironment
|
|
27
|
+
from .reward import PolicyReward, policy_penalty
|
|
28
|
+
from .runner import GovernedRunner
|
|
29
|
+
|
|
30
|
+
__all__ = [
|
|
31
|
+
"GovernedRunner",
|
|
32
|
+
"PolicyReward",
|
|
33
|
+
"policy_penalty",
|
|
34
|
+
"FlightRecorderEmitter",
|
|
35
|
+
"GovernedEnvironment",
|
|
36
|
+
]
|
|
@@ -0,0 +1,314 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
"""
|
|
4
|
+
FlightRecorderEmitter - Export Audit Logs to LightningStore
|
|
5
|
+
============================================================
|
|
6
|
+
|
|
7
|
+
Adapts Agent OS Flight Recorder logs to Agent-Lightning's
|
|
8
|
+
span format for unified training and audit trail.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import json
|
|
14
|
+
import logging
|
|
15
|
+
from collections.abc import Iterator
|
|
16
|
+
from dataclasses import dataclass, field
|
|
17
|
+
from datetime import datetime, timezone
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class LightningSpan:
|
|
25
|
+
"""
|
|
26
|
+
Span compatible with Agent-Lightning's LightningStore.
|
|
27
|
+
|
|
28
|
+
Maps Agent OS Flight Recorder entries to the span format
|
|
29
|
+
expected by Agent-Lightning's training and analysis tools.
|
|
30
|
+
"""
|
|
31
|
+
span_id: str
|
|
32
|
+
trace_id: str
|
|
33
|
+
name: str
|
|
34
|
+
start_time: datetime
|
|
35
|
+
end_time: datetime | None = None
|
|
36
|
+
attributes: dict[str, Any] = field(default_factory=dict)
|
|
37
|
+
events: list[dict[str, Any]] = field(default_factory=list)
|
|
38
|
+
|
|
39
|
+
def to_dict(self) -> dict[str, Any]:
|
|
40
|
+
"""Convert to dictionary for serialization."""
|
|
41
|
+
return {
|
|
42
|
+
"span_id": self.span_id,
|
|
43
|
+
"trace_id": self.trace_id,
|
|
44
|
+
"name": self.name,
|
|
45
|
+
"start_time": self.start_time.isoformat(),
|
|
46
|
+
"end_time": self.end_time.isoformat() if self.end_time else None,
|
|
47
|
+
"attributes": self.attributes,
|
|
48
|
+
"events": self.events,
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
def to_json(self) -> str:
|
|
52
|
+
"""Convert to JSON string."""
|
|
53
|
+
return json.dumps(self.to_dict())
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class FlightRecorderEmitter:
|
|
57
|
+
"""
|
|
58
|
+
Emits Agent OS Flight Recorder entries to Agent-Lightning.
|
|
59
|
+
|
|
60
|
+
This adapter enables:
|
|
61
|
+
1. Complete audit trail from training to production
|
|
62
|
+
2. RL algorithms learning from policy violations
|
|
63
|
+
3. Compliance-friendly training logs
|
|
64
|
+
|
|
65
|
+
Example:
|
|
66
|
+
>>> from agent_os import FlightRecorder
|
|
67
|
+
>>> from agent_os.integrations.agent_lightning import FlightRecorderEmitter
|
|
68
|
+
>>>
|
|
69
|
+
>>> recorder = FlightRecorder()
|
|
70
|
+
>>> emitter = FlightRecorderEmitter(recorder)
|
|
71
|
+
>>>
|
|
72
|
+
>>> # Emit all logs to LightningStore
|
|
73
|
+
>>> emitter.emit_to_store(store)
|
|
74
|
+
>>>
|
|
75
|
+
>>> # Or stream continuously
|
|
76
|
+
>>> async for span in emitter.stream():
|
|
77
|
+
... store.emit_span(span)
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
def __init__(
|
|
81
|
+
self,
|
|
82
|
+
flight_recorder: Any, # FlightRecorder
|
|
83
|
+
*,
|
|
84
|
+
include_policy_checks: bool = True,
|
|
85
|
+
include_signals: bool = True,
|
|
86
|
+
include_tool_calls: bool = True,
|
|
87
|
+
trace_id_prefix: str = "agentos",
|
|
88
|
+
):
|
|
89
|
+
"""
|
|
90
|
+
Initialize the emitter.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
flight_recorder: Agent OS FlightRecorder instance
|
|
94
|
+
include_policy_checks: Include policy check spans
|
|
95
|
+
include_signals: Include signal dispatch spans
|
|
96
|
+
include_tool_calls: Include tool call spans
|
|
97
|
+
trace_id_prefix: Prefix for generated trace IDs
|
|
98
|
+
"""
|
|
99
|
+
self.recorder = flight_recorder
|
|
100
|
+
self.include_policy_checks = include_policy_checks
|
|
101
|
+
self.include_signals = include_signals
|
|
102
|
+
self.include_tool_calls = include_tool_calls
|
|
103
|
+
self.trace_id_prefix = trace_id_prefix
|
|
104
|
+
|
|
105
|
+
self._emitted_count = 0
|
|
106
|
+
self._last_position = 0
|
|
107
|
+
|
|
108
|
+
def _convert_entry(self, entry: Any) -> LightningSpan | None:
|
|
109
|
+
"""Convert a Flight Recorder entry to a Lightning span."""
|
|
110
|
+
entry_type = getattr(entry, 'type', getattr(entry, 'entry_type', 'unknown'))
|
|
111
|
+
|
|
112
|
+
# Filter by entry type
|
|
113
|
+
if entry_type == 'policy_check' and not self.include_policy_checks:
|
|
114
|
+
return None
|
|
115
|
+
if entry_type == 'signal' and not self.include_signals:
|
|
116
|
+
return None
|
|
117
|
+
if entry_type == 'tool_call' and not self.include_tool_calls:
|
|
118
|
+
return None
|
|
119
|
+
|
|
120
|
+
# Extract common fields
|
|
121
|
+
span_id = getattr(entry, 'id', getattr(entry, 'entry_id', str(self._emitted_count)))
|
|
122
|
+
timestamp = getattr(entry, 'timestamp', datetime.now(timezone.utc))
|
|
123
|
+
agent_id = getattr(entry, 'agent_id', 'unknown')
|
|
124
|
+
|
|
125
|
+
# Build attributes
|
|
126
|
+
attributes = {
|
|
127
|
+
"agent_os.entry_type": entry_type,
|
|
128
|
+
"agent_os.agent_id": agent_id,
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
# Add type-specific attributes
|
|
132
|
+
if entry_type == 'policy_check':
|
|
133
|
+
attributes.update({
|
|
134
|
+
"agent_os.policy_name": getattr(entry, 'policy_name', 'unknown'),
|
|
135
|
+
"agent_os.policy_result": getattr(entry, 'result', 'unknown'),
|
|
136
|
+
"agent_os.policy_violated": getattr(entry, 'violated', False),
|
|
137
|
+
})
|
|
138
|
+
elif entry_type == 'signal':
|
|
139
|
+
attributes.update({
|
|
140
|
+
"agent_os.signal_type": getattr(entry, 'signal', 'unknown'),
|
|
141
|
+
"agent_os.signal_target": getattr(entry, 'target', 'unknown'),
|
|
142
|
+
})
|
|
143
|
+
elif entry_type == 'tool_call':
|
|
144
|
+
attributes.update({
|
|
145
|
+
"agent_os.tool_name": getattr(entry, 'tool_name', 'unknown'),
|
|
146
|
+
"agent_os.tool_args": str(getattr(entry, 'args', {}))[:1000], # Truncate
|
|
147
|
+
"agent_os.tool_result": str(getattr(entry, 'result', None))[:1000],
|
|
148
|
+
})
|
|
149
|
+
|
|
150
|
+
# Copy any additional attributes
|
|
151
|
+
if hasattr(entry, 'metadata'):
|
|
152
|
+
for key, value in entry.metadata.items():
|
|
153
|
+
attributes[f"agent_os.{key}"] = value
|
|
154
|
+
|
|
155
|
+
return LightningSpan(
|
|
156
|
+
span_id=str(span_id),
|
|
157
|
+
trace_id=f"{self.trace_id_prefix}-{agent_id}",
|
|
158
|
+
name=f"agent_os.{entry_type}",
|
|
159
|
+
start_time=timestamp,
|
|
160
|
+
end_time=timestamp,
|
|
161
|
+
attributes=attributes,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
def get_spans(self) -> list[LightningSpan]:
|
|
165
|
+
"""
|
|
166
|
+
Get all Flight Recorder entries as Lightning spans.
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
List of LightningSpan objects
|
|
170
|
+
"""
|
|
171
|
+
spans = []
|
|
172
|
+
|
|
173
|
+
# Get entries from recorder
|
|
174
|
+
if hasattr(self.recorder, 'get_entries'):
|
|
175
|
+
entries = self.recorder.get_entries()
|
|
176
|
+
elif hasattr(self.recorder, 'entries'):
|
|
177
|
+
entries = self.recorder.entries
|
|
178
|
+
elif hasattr(self.recorder, 'get_logs'):
|
|
179
|
+
entries = self.recorder.get_logs()
|
|
180
|
+
else:
|
|
181
|
+
logger.warning("Flight recorder has no recognized entry accessor")
|
|
182
|
+
return []
|
|
183
|
+
|
|
184
|
+
for entry in entries:
|
|
185
|
+
span = self._convert_entry(entry)
|
|
186
|
+
if span:
|
|
187
|
+
spans.append(span)
|
|
188
|
+
self._emitted_count += 1
|
|
189
|
+
|
|
190
|
+
return spans
|
|
191
|
+
|
|
192
|
+
def get_new_spans(self) -> list[LightningSpan]:
|
|
193
|
+
"""
|
|
194
|
+
Get only new entries since last call.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
List of new LightningSpan objects
|
|
198
|
+
"""
|
|
199
|
+
all_spans = self.get_spans()
|
|
200
|
+
new_spans = all_spans[self._last_position:]
|
|
201
|
+
self._last_position = len(all_spans)
|
|
202
|
+
return new_spans
|
|
203
|
+
|
|
204
|
+
async def stream(self) -> Iterator[LightningSpan]:
|
|
205
|
+
"""
|
|
206
|
+
Stream spans as they are recorded.
|
|
207
|
+
|
|
208
|
+
Yields:
|
|
209
|
+
LightningSpan objects as they become available
|
|
210
|
+
"""
|
|
211
|
+
import asyncio
|
|
212
|
+
|
|
213
|
+
while True:
|
|
214
|
+
new_spans = self.get_new_spans()
|
|
215
|
+
for span in new_spans:
|
|
216
|
+
yield span
|
|
217
|
+
|
|
218
|
+
# Wait before checking again
|
|
219
|
+
await asyncio.sleep(0.1)
|
|
220
|
+
|
|
221
|
+
def emit_to_store(self, store: Any) -> int:
|
|
222
|
+
"""
|
|
223
|
+
Emit all spans to a LightningStore.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
store: Agent-Lightning LightningStore instance
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
Number of spans emitted
|
|
230
|
+
"""
|
|
231
|
+
spans = self.get_spans()
|
|
232
|
+
|
|
233
|
+
for span in spans:
|
|
234
|
+
try:
|
|
235
|
+
if hasattr(store, 'emit_span'):
|
|
236
|
+
store.emit_span(span.to_dict())
|
|
237
|
+
elif hasattr(store, 'add_span'):
|
|
238
|
+
store.add_span(span.to_dict())
|
|
239
|
+
else:
|
|
240
|
+
logger.warning("Store has no recognized span emitter")
|
|
241
|
+
break
|
|
242
|
+
except Exception as e:
|
|
243
|
+
logger.error(f"Failed to emit span: {e}")
|
|
244
|
+
|
|
245
|
+
logger.info(f"Emitted {len(spans)} spans to LightningStore")
|
|
246
|
+
return len(spans)
|
|
247
|
+
|
|
248
|
+
def export_to_file(self, filepath: str) -> int:
|
|
249
|
+
"""
|
|
250
|
+
Export spans to a JSON file.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
filepath: Path to output file
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
Number of spans exported
|
|
257
|
+
"""
|
|
258
|
+
spans = self.get_spans()
|
|
259
|
+
|
|
260
|
+
with open(filepath, 'w') as f:
|
|
261
|
+
json.dump([s.to_dict() for s in spans], f, indent=2)
|
|
262
|
+
|
|
263
|
+
logger.info(f"Exported {len(spans)} spans to {filepath}")
|
|
264
|
+
return len(spans)
|
|
265
|
+
|
|
266
|
+
def get_violation_summary(self) -> dict[str, Any]:
|
|
267
|
+
"""
|
|
268
|
+
Get summary of policy violations from recorded entries.
|
|
269
|
+
|
|
270
|
+
Returns:
|
|
271
|
+
Summary dictionary with violation statistics
|
|
272
|
+
"""
|
|
273
|
+
spans = self.get_spans()
|
|
274
|
+
|
|
275
|
+
violations = [
|
|
276
|
+
s for s in spans
|
|
277
|
+
if s.attributes.get("agent_os.policy_violated", False)
|
|
278
|
+
]
|
|
279
|
+
|
|
280
|
+
policies_violated = {}
|
|
281
|
+
for v in violations:
|
|
282
|
+
policy = v.attributes.get("agent_os.policy_name", "unknown")
|
|
283
|
+
policies_violated[policy] = policies_violated.get(policy, 0) + 1
|
|
284
|
+
|
|
285
|
+
return {
|
|
286
|
+
"total_entries": len(spans),
|
|
287
|
+
"total_violations": len(violations),
|
|
288
|
+
"violation_rate": len(violations) / max(len(spans), 1),
|
|
289
|
+
"policies_violated": policies_violated,
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
def get_stats(self) -> dict[str, Any]:
|
|
293
|
+
"""Get emitter statistics."""
|
|
294
|
+
return {
|
|
295
|
+
"emitted_count": self._emitted_count,
|
|
296
|
+
"last_position": self._last_position,
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def create_emitter(
|
|
301
|
+
flight_recorder: Any,
|
|
302
|
+
**kwargs: Any,
|
|
303
|
+
) -> FlightRecorderEmitter:
|
|
304
|
+
"""
|
|
305
|
+
Factory function to create a FlightRecorderEmitter.
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
flight_recorder: Agent OS FlightRecorder
|
|
309
|
+
**kwargs: Additional configuration
|
|
310
|
+
|
|
311
|
+
Returns:
|
|
312
|
+
Configured FlightRecorderEmitter
|
|
313
|
+
"""
|
|
314
|
+
return FlightRecorderEmitter(flight_recorder, **kwargs)
|
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
"""
|
|
4
|
+
GovernedEnvironment - Training Environment with Governance
|
|
5
|
+
==========================================================
|
|
6
|
+
|
|
7
|
+
Wraps an Agent OS kernel as a training environment for
|
|
8
|
+
Agent-Lightning's RL algorithms.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import logging
|
|
14
|
+
from dataclasses import dataclass, field
|
|
15
|
+
from datetime import datetime, timezone
|
|
16
|
+
from typing import Any, Callable, Generic, TypeVar
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
T_state = TypeVar("T_state")
|
|
21
|
+
T_action = TypeVar("T_action")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class EnvironmentConfig:
|
|
26
|
+
"""Configuration for governed training environment."""
|
|
27
|
+
|
|
28
|
+
# Maximum steps per episode
|
|
29
|
+
max_steps: int = 100
|
|
30
|
+
|
|
31
|
+
# Penalty for policy violations
|
|
32
|
+
violation_penalty: float = -10.0
|
|
33
|
+
|
|
34
|
+
# Terminate episode on critical violation
|
|
35
|
+
terminate_on_critical: bool = True
|
|
36
|
+
|
|
37
|
+
# Reward shaping
|
|
38
|
+
step_penalty: float = -0.1 # Small penalty per step to encourage efficiency
|
|
39
|
+
success_bonus: float = 10.0
|
|
40
|
+
|
|
41
|
+
# Reset behavior
|
|
42
|
+
reset_kernel_state: bool = True
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class EnvironmentState:
|
|
47
|
+
"""State of the governed environment."""
|
|
48
|
+
|
|
49
|
+
step_count: int = 0
|
|
50
|
+
total_reward: float = 0.0
|
|
51
|
+
violations: list = field(default_factory=list)
|
|
52
|
+
terminated: bool = False
|
|
53
|
+
truncated: bool = False
|
|
54
|
+
info: dict = field(default_factory=dict)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class GovernedEnvironment(Generic[T_state, T_action]):
|
|
58
|
+
"""
|
|
59
|
+
RL training environment with Agent OS governance.
|
|
60
|
+
|
|
61
|
+
This environment wraps an Agent OS kernel and can be used
|
|
62
|
+
directly with Agent-Lightning or other RL frameworks.
|
|
63
|
+
|
|
64
|
+
The environment:
|
|
65
|
+
1. Enforces policies on each action
|
|
66
|
+
2. Converts violations to negative rewards
|
|
67
|
+
3. Optionally terminates on critical violations
|
|
68
|
+
4. Tracks compliance metrics during training
|
|
69
|
+
|
|
70
|
+
Example:
|
|
71
|
+
>>> from agent_os import KernelSpace
|
|
72
|
+
>>> from agent_os.policies import SQLPolicy
|
|
73
|
+
>>>
|
|
74
|
+
>>> kernel = KernelSpace(policy=SQLPolicy())
|
|
75
|
+
>>> env = GovernedEnvironment(kernel)
|
|
76
|
+
>>>
|
|
77
|
+
>>> state = env.reset()
|
|
78
|
+
>>> while not env.terminated:
|
|
79
|
+
... action = agent.get_action(state)
|
|
80
|
+
... state, reward, terminated, truncated, info = env.step(action)
|
|
81
|
+
|
|
82
|
+
Compatible with:
|
|
83
|
+
- Agent-Lightning trainers
|
|
84
|
+
- OpenAI Gym / Gymnasium
|
|
85
|
+
- Stable Baselines3
|
|
86
|
+
- Any environment with step/reset interface
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
kernel: Any, # KernelSpace
|
|
92
|
+
*,
|
|
93
|
+
task_generator: Callable[[], T_state] | None = None,
|
|
94
|
+
reward_fn: Callable[[T_state, T_action, Any], float] | None = None,
|
|
95
|
+
config: EnvironmentConfig | None = None,
|
|
96
|
+
):
|
|
97
|
+
"""
|
|
98
|
+
Initialize the governed environment.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
kernel: Agent OS KernelSpace with loaded policies
|
|
102
|
+
task_generator: Optional function to generate initial states
|
|
103
|
+
reward_fn: Optional custom reward function
|
|
104
|
+
config: Environment configuration
|
|
105
|
+
"""
|
|
106
|
+
self.kernel = kernel
|
|
107
|
+
self.task_generator = task_generator
|
|
108
|
+
self.reward_fn = reward_fn or self._default_reward
|
|
109
|
+
self.config = config or EnvironmentConfig()
|
|
110
|
+
|
|
111
|
+
# Current episode state
|
|
112
|
+
self._state = EnvironmentState()
|
|
113
|
+
self._current_task: T_state | None = None
|
|
114
|
+
self._current_violations: list = []
|
|
115
|
+
|
|
116
|
+
# Metrics
|
|
117
|
+
self._total_episodes = 0
|
|
118
|
+
self._total_steps = 0
|
|
119
|
+
self._total_violations = 0
|
|
120
|
+
self._successful_episodes = 0
|
|
121
|
+
|
|
122
|
+
# Set up kernel hooks
|
|
123
|
+
self._setup_hooks()
|
|
124
|
+
|
|
125
|
+
logger.info("GovernedEnvironment initialized")
|
|
126
|
+
|
|
127
|
+
def _setup_hooks(self) -> None:
|
|
128
|
+
"""Set up hooks to capture violations."""
|
|
129
|
+
if hasattr(self.kernel, 'on_policy_violation'):
|
|
130
|
+
self.kernel.on_policy_violation(self._handle_violation)
|
|
131
|
+
|
|
132
|
+
def _handle_violation(
|
|
133
|
+
self,
|
|
134
|
+
policy_name: str,
|
|
135
|
+
description: str,
|
|
136
|
+
severity: str,
|
|
137
|
+
blocked: bool,
|
|
138
|
+
) -> None:
|
|
139
|
+
"""Handle policy violation during step."""
|
|
140
|
+
violation = {
|
|
141
|
+
"policy": policy_name,
|
|
142
|
+
"description": description,
|
|
143
|
+
"severity": severity,
|
|
144
|
+
"blocked": blocked,
|
|
145
|
+
"step": self._state.step_count,
|
|
146
|
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
147
|
+
}
|
|
148
|
+
self._current_violations.append(violation)
|
|
149
|
+
self._state.violations.append(violation)
|
|
150
|
+
self._total_violations += 1
|
|
151
|
+
|
|
152
|
+
def _default_reward(
|
|
153
|
+
self,
|
|
154
|
+
state: T_state,
|
|
155
|
+
action: T_action,
|
|
156
|
+
result: Any,
|
|
157
|
+
) -> float:
|
|
158
|
+
"""Default reward function."""
|
|
159
|
+
# Base reward for task completion
|
|
160
|
+
if result is not None:
|
|
161
|
+
return 1.0
|
|
162
|
+
return 0.0
|
|
163
|
+
|
|
164
|
+
def reset(
|
|
165
|
+
self,
|
|
166
|
+
*,
|
|
167
|
+
seed: int | None = None,
|
|
168
|
+
options: dict[str, Any] | None = None,
|
|
169
|
+
) -> tuple[T_state, dict[str, Any]]:
|
|
170
|
+
"""
|
|
171
|
+
Reset environment for new episode.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
seed: Random seed (for compatibility)
|
|
175
|
+
options: Additional options
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
Tuple of (initial_state, info_dict)
|
|
179
|
+
"""
|
|
180
|
+
# Reset episode state
|
|
181
|
+
self._state = EnvironmentState()
|
|
182
|
+
self._current_violations = []
|
|
183
|
+
self._total_episodes += 1
|
|
184
|
+
|
|
185
|
+
# Reset kernel state if configured
|
|
186
|
+
if self.config.reset_kernel_state and hasattr(self.kernel, 'reset'):
|
|
187
|
+
self.kernel.reset()
|
|
188
|
+
|
|
189
|
+
# Generate initial task
|
|
190
|
+
if self.task_generator:
|
|
191
|
+
self._current_task = self.task_generator()
|
|
192
|
+
else:
|
|
193
|
+
self._current_task = None
|
|
194
|
+
|
|
195
|
+
info = {
|
|
196
|
+
"episode": self._total_episodes,
|
|
197
|
+
"kernel_policies": self._get_policy_names(),
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
return self._current_task, info
|
|
201
|
+
|
|
202
|
+
def step(
|
|
203
|
+
self,
|
|
204
|
+
action: T_action,
|
|
205
|
+
) -> tuple[T_state, float, bool, bool, dict[str, Any]]:
|
|
206
|
+
"""
|
|
207
|
+
Execute one step in the environment.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
action: Agent's action
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
Tuple of (next_state, reward, terminated, truncated, info)
|
|
214
|
+
"""
|
|
215
|
+
self._current_violations = []
|
|
216
|
+
self._state.step_count += 1
|
|
217
|
+
self._total_steps += 1
|
|
218
|
+
|
|
219
|
+
# Execute action through kernel
|
|
220
|
+
try:
|
|
221
|
+
if hasattr(self.kernel, 'execute'):
|
|
222
|
+
result = self.kernel.execute(action)
|
|
223
|
+
else:
|
|
224
|
+
result = action # No kernel execution, passthrough
|
|
225
|
+
|
|
226
|
+
success = True
|
|
227
|
+
except Exception as e:
|
|
228
|
+
logger.error(f"Step failed: {e}")
|
|
229
|
+
result = None
|
|
230
|
+
success = False
|
|
231
|
+
|
|
232
|
+
# Calculate reward
|
|
233
|
+
reward = self.reward_fn(self._current_task, action, result)
|
|
234
|
+
|
|
235
|
+
# Apply step penalty
|
|
236
|
+
reward += self.config.step_penalty
|
|
237
|
+
|
|
238
|
+
# Apply violation penalties
|
|
239
|
+
for violation in self._current_violations:
|
|
240
|
+
penalty = self.config.violation_penalty
|
|
241
|
+
if violation["severity"] == "critical":
|
|
242
|
+
penalty *= 10
|
|
243
|
+
elif violation["severity"] == "high":
|
|
244
|
+
penalty *= 5
|
|
245
|
+
reward += penalty
|
|
246
|
+
|
|
247
|
+
self._state.total_reward += reward
|
|
248
|
+
|
|
249
|
+
# Check termination conditions
|
|
250
|
+
terminated = False
|
|
251
|
+
truncated = False
|
|
252
|
+
|
|
253
|
+
# Terminate on critical violation if configured
|
|
254
|
+
if self.config.terminate_on_critical:
|
|
255
|
+
if any(v["severity"] == "critical" for v in self._current_violations):
|
|
256
|
+
terminated = True
|
|
257
|
+
logger.info("Episode terminated due to critical violation")
|
|
258
|
+
|
|
259
|
+
# Truncate on max steps
|
|
260
|
+
if self._state.step_count >= self.config.max_steps:
|
|
261
|
+
truncated = True
|
|
262
|
+
|
|
263
|
+
# Mark success
|
|
264
|
+
if success and not self._current_violations:
|
|
265
|
+
reward += self.config.success_bonus
|
|
266
|
+
self._successful_episodes += 1
|
|
267
|
+
|
|
268
|
+
self._state.terminated = terminated
|
|
269
|
+
self._state.truncated = truncated
|
|
270
|
+
|
|
271
|
+
info = {
|
|
272
|
+
"violations": self._current_violations,
|
|
273
|
+
"step": self._state.step_count,
|
|
274
|
+
"total_reward": self._state.total_reward,
|
|
275
|
+
"success": success,
|
|
276
|
+
}
|
|
277
|
+
self._state.info = info
|
|
278
|
+
|
|
279
|
+
return self._current_task, reward, terminated, truncated, info
|
|
280
|
+
|
|
281
|
+
def _get_policy_names(self) -> list[str]:
|
|
282
|
+
"""Get names of loaded policies."""
|
|
283
|
+
if hasattr(self.kernel, 'get_policies'):
|
|
284
|
+
return [p.name for p in self.kernel.get_policies()]
|
|
285
|
+
if hasattr(self.kernel, 'policies'):
|
|
286
|
+
return [getattr(p, 'name', str(p)) for p in self.kernel.policies]
|
|
287
|
+
return []
|
|
288
|
+
|
|
289
|
+
@property
|
|
290
|
+
def terminated(self) -> bool:
|
|
291
|
+
"""Whether current episode is terminated."""
|
|
292
|
+
return self._state.terminated or self._state.truncated
|
|
293
|
+
|
|
294
|
+
def get_metrics(self) -> dict[str, Any]:
|
|
295
|
+
"""Get environment metrics."""
|
|
296
|
+
return {
|
|
297
|
+
"total_episodes": self._total_episodes,
|
|
298
|
+
"total_steps": self._total_steps,
|
|
299
|
+
"total_violations": self._total_violations,
|
|
300
|
+
"successful_episodes": self._successful_episodes,
|
|
301
|
+
"success_rate": self._successful_episodes / max(self._total_episodes, 1),
|
|
302
|
+
"violations_per_episode": self._total_violations / max(self._total_episodes, 1),
|
|
303
|
+
"steps_per_episode": self._total_steps / max(self._total_episodes, 1),
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
def close(self) -> None:
|
|
307
|
+
"""Clean up environment resources."""
|
|
308
|
+
logger.info(f"Environment closed. Metrics: {self.get_metrics()}")
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def create_governed_env(
|
|
312
|
+
kernel: Any,
|
|
313
|
+
**kwargs: Any,
|
|
314
|
+
) -> GovernedEnvironment:
|
|
315
|
+
"""
|
|
316
|
+
Factory function to create a GovernedEnvironment.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
kernel: Agent OS KernelSpace
|
|
320
|
+
**kwargs: Environment configuration
|
|
321
|
+
|
|
322
|
+
Returns:
|
|
323
|
+
Configured GovernedEnvironment
|
|
324
|
+
"""
|
|
325
|
+
config = EnvironmentConfig()
|
|
326
|
+
for key, value in kwargs.items():
|
|
327
|
+
if hasattr(config, key):
|
|
328
|
+
setattr(config, key, value)
|
|
329
|
+
|
|
330
|
+
return GovernedEnvironment(kernel, config=config)
|