expt-logger 0.1.0.dev20__tar.gz → 0.1.0.dev22__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/PKG-INFO +1 -1
- {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/pyproject.toml +1 -1
- {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/src/expt_logger/__init__.py +5 -1
- {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/src/expt_logger/buffer.py +16 -4
- {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/src/expt_logger/client.py +5 -7
- {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/src/expt_logger/run.py +23 -7
- {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/src/expt_logger/types.py +7 -0
- {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/tests/test_buffer.py +73 -27
- {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/tests/test_client.py +12 -3
- {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/tests/test_client_integration.py +2 -2
- {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/tests/test_integration_e2e.py +51 -7
- {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/tests/test_run.py +13 -27
- {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/.gitignore +0 -0
- {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/.pre-commit-config.yaml +0 -0
- {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/DEVELOPMENT.md +0 -0
- {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/README.md +0 -0
- {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/src/expt_logger/config.py +0 -0
- {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/src/expt_logger/env.py +0 -0
- {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/src/expt_logger/exceptions.py +0 -0
- {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/src/expt_logger/py.typed +0 -0
- {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/src/expt_logger/validation.py +0 -0
- {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/tests/conftest.py +0 -0
- {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/tests/test_config.py +0 -0
- {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/tests/test_env.py +0 -0
- {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/tests/test_exceptions.py +0 -0
- {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/tests/test_global_api.py +0 -0
- {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/tests/test_validation.py +0 -0
|
@@ -159,19 +159,23 @@ def log_rollout(
|
|
|
159
159
|
def log_environment(
|
|
160
160
|
rollout_id: str,
|
|
161
161
|
content: str,
|
|
162
|
+
k: int | None = None,
|
|
163
|
+
commit: bool = True,
|
|
162
164
|
) -> None:
|
|
163
165
|
"""Log an environment log entry associated with a rollout.
|
|
164
166
|
|
|
165
167
|
Args:
|
|
166
168
|
rollout_id: ID of the rollout this log entry is associated with
|
|
167
169
|
content: Log content string
|
|
170
|
+
k: If set, commit only when the buffer has more than k env log entries.
|
|
171
|
+
commit: Whether to flush buffer after logging
|
|
168
172
|
|
|
169
173
|
Raises:
|
|
170
174
|
RuntimeError: If no active run exists
|
|
171
175
|
"""
|
|
172
176
|
if _active_run is None:
|
|
173
177
|
raise RuntimeError("No active run. Call init() first.")
|
|
174
|
-
_active_run.log_environment(rollout_id=rollout_id, content=content)
|
|
178
|
+
_active_run.log_environment(rollout_id=rollout_id, content=content, k=k, commit=commit)
|
|
175
179
|
|
|
176
180
|
|
|
177
181
|
def log_error(
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
|
|
5
|
-
from expt_logger.types import ErrorItem, RolloutItem, ScalarItem
|
|
5
|
+
from expt_logger.types import EnvLogItem, ErrorItem, RolloutItem, ScalarItem
|
|
6
6
|
|
|
7
7
|
logger = logging.getLogger(__name__)
|
|
8
8
|
|
|
@@ -20,6 +20,7 @@ class Buffer:
|
|
|
20
20
|
self._scalars: dict[str, float] = {} # full_key (mode/metric) -> value
|
|
21
21
|
self._rollouts: list[RolloutItem] = []
|
|
22
22
|
self._errors: list[ErrorItem] = []
|
|
23
|
+
self._env_logs: list[EnvLogItem] = []
|
|
23
24
|
|
|
24
25
|
def add_scalar(self, name: str, value: float, mode: str | None = None) -> None:
|
|
25
26
|
"""Add a scalar metric to the buffer.
|
|
@@ -88,7 +89,16 @@ class Buffer:
|
|
|
88
89
|
"""
|
|
89
90
|
self._errors.append(error)
|
|
90
91
|
|
|
91
|
-
def
|
|
92
|
+
def add_env_log(self, rollout_id: str, content: str) -> None:
|
|
93
|
+
"""Add an environment log entry to the buffer.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
rollout_id: ID of the rollout this log entry is associated with
|
|
97
|
+
content: Log content string
|
|
98
|
+
"""
|
|
99
|
+
self._env_logs.append({"rollout_id": rollout_id, "content": content})
|
|
100
|
+
|
|
101
|
+
def get_and_clear(self) -> tuple[list[ScalarItem], list[RolloutItem], list[ErrorItem], list[EnvLogItem]]:
|
|
92
102
|
"""Get all buffered data and clear the buffer.
|
|
93
103
|
|
|
94
104
|
Returns:
|
|
@@ -110,12 +120,14 @@ class Buffer:
|
|
|
110
120
|
|
|
111
121
|
rollouts = self._rollouts.copy()
|
|
112
122
|
errors = self._errors.copy()
|
|
123
|
+
env_logs = self._env_logs.copy()
|
|
113
124
|
|
|
114
125
|
self._scalars.clear()
|
|
115
126
|
self._rollouts.clear()
|
|
116
127
|
self._errors.clear()
|
|
128
|
+
self._env_logs.clear()
|
|
117
129
|
|
|
118
|
-
return scalar_items, rollouts, errors
|
|
130
|
+
return scalar_items, rollouts, errors, env_logs
|
|
119
131
|
|
|
120
132
|
def is_empty(self) -> bool:
|
|
121
133
|
"""Check if buffer has any data.
|
|
@@ -123,4 +135,4 @@ class Buffer:
|
|
|
123
135
|
Returns:
|
|
124
136
|
True if buffer is empty, False otherwise
|
|
125
137
|
"""
|
|
126
|
-
return not self._scalars and not self._rollouts and not self._errors
|
|
138
|
+
return not self._scalars and not self._rollouts and not self._errors and not self._env_logs
|
|
@@ -7,7 +7,7 @@ from typing import Any
|
|
|
7
7
|
import requests
|
|
8
8
|
|
|
9
9
|
from expt_logger.exceptions import APIError, AuthenticationError
|
|
10
|
-
from expt_logger.types import ErrorItem, RolloutItem, ScalarItem, ScalarValue
|
|
10
|
+
from expt_logger.types import EnvLogItem, ErrorItem, RolloutItem, ScalarItem, ScalarValue
|
|
11
11
|
|
|
12
12
|
logger = logging.getLogger(__name__)
|
|
13
13
|
|
|
@@ -283,21 +283,19 @@ class APIClient:
|
|
|
283
283
|
def log_env_logs(
|
|
284
284
|
self,
|
|
285
285
|
experiment_id: str,
|
|
286
|
-
|
|
287
|
-
content: str,
|
|
286
|
+
logs: list[EnvLogItem],
|
|
288
287
|
) -> None:
|
|
289
|
-
"""Log
|
|
288
|
+
"""Log environment log entries for an experiment.
|
|
290
289
|
|
|
291
290
|
Args:
|
|
292
291
|
experiment_id: Experiment ID
|
|
293
|
-
|
|
294
|
-
content: Log content string
|
|
292
|
+
logs: List of env log entries with rollout_id and content
|
|
295
293
|
|
|
296
294
|
Raises:
|
|
297
295
|
APIError: If request fails
|
|
298
296
|
"""
|
|
299
297
|
url = f"{self.base_url}/api/experiments/{experiment_id}/env-logs"
|
|
300
|
-
payload = {"rolloutId": rollout_id, "content": content}
|
|
298
|
+
payload = {"logs": [{"rolloutId": log["rollout_id"], "content": log["content"]} for log in logs]}
|
|
301
299
|
self._request("POST", url, json=payload)
|
|
302
300
|
|
|
303
301
|
def get_env_logs(
|
|
@@ -82,7 +82,9 @@ class Run:
|
|
|
82
82
|
resolved_experiment_id = get_experiment_id(experiment_id, is_main_process=False)
|
|
83
83
|
if resolved_experiment_id is not None:
|
|
84
84
|
self._experiment_id = resolved_experiment_id
|
|
85
|
-
|
|
85
|
+
# Do not set the config if already attached to an experiment to avoid
|
|
86
|
+
# overwriting existing settings
|
|
87
|
+
self._validate_and_attach_experiment()
|
|
86
88
|
logger.info(
|
|
87
89
|
f"[expt_logger] Attached to experiment ID: {self._experiment_id} (subprocess)"
|
|
88
90
|
)
|
|
@@ -291,12 +293,16 @@ class Run:
|
|
|
291
293
|
self,
|
|
292
294
|
rollout_id: str,
|
|
293
295
|
content: str,
|
|
296
|
+
k: int | None = None,
|
|
297
|
+
commit: bool = True,
|
|
294
298
|
) -> None:
|
|
295
299
|
"""Log an environment log entry associated with a rollout.
|
|
296
300
|
|
|
297
301
|
Args:
|
|
298
302
|
rollout_id: ID of the rollout this log entry is associated with
|
|
299
303
|
content: Log content string
|
|
304
|
+
k: If set, commit when the buffer has more than k elements
|
|
305
|
+
commit: Whether to flush buffer after logging
|
|
300
306
|
"""
|
|
301
307
|
env_cmd: LogEnvironmentCommand = {
|
|
302
308
|
"rollout_id": rollout_id,
|
|
@@ -305,7 +311,12 @@ class Run:
|
|
|
305
311
|
try:
|
|
306
312
|
self._queue.put_nowait(("log_environment", env_cmd))
|
|
307
313
|
except Full:
|
|
308
|
-
logger.warning(
|
|
314
|
+
logger.warning(
|
|
315
|
+
f"Command queue full, dropping environment log for rollout: {rollout_id}"
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
if commit or (k is not None and len(self._buffer._env_logs) > k):
|
|
319
|
+
self.commit()
|
|
309
320
|
|
|
310
321
|
def log_error(
|
|
311
322
|
self,
|
|
@@ -569,10 +580,7 @@ class Run:
|
|
|
569
580
|
Args:
|
|
570
581
|
data: Log environment command data with rollout_id and content
|
|
571
582
|
"""
|
|
572
|
-
|
|
573
|
-
self._client.log_env_logs(self._experiment_id, data["rollout_id"], data["content"])
|
|
574
|
-
except Exception as e:
|
|
575
|
-
logger.error(f"Error logging environment log: {e}", exc_info=True)
|
|
583
|
+
self._buffer.add_env_log(data["rollout_id"], data["content"])
|
|
576
584
|
|
|
577
585
|
def _handle_log_error(self, data: LogErrorCommand) -> None:
|
|
578
586
|
"""Handle log_error command.
|
|
@@ -619,7 +627,7 @@ class Run:
|
|
|
619
627
|
if self._buffer.is_empty():
|
|
620
628
|
return
|
|
621
629
|
|
|
622
|
-
scalars, rollouts, errors = self._buffer.get_and_clear()
|
|
630
|
+
scalars, rollouts, errors, env_logs = self._buffer.get_and_clear()
|
|
623
631
|
|
|
624
632
|
# Send scalars if any
|
|
625
633
|
if scalars:
|
|
@@ -644,3 +652,11 @@ class Run:
|
|
|
644
652
|
logger.debug(f"Flushed {len(errors)} errors at step {self._step}")
|
|
645
653
|
except Exception as e:
|
|
646
654
|
logger.error(f"Error logging errors: {e}", exc_info=True)
|
|
655
|
+
|
|
656
|
+
# Send env logs if any
|
|
657
|
+
if env_logs:
|
|
658
|
+
try:
|
|
659
|
+
self._client.log_env_logs(self._experiment_id, env_logs)
|
|
660
|
+
logger.debug(f"Flushed {len(env_logs)} environment logs")
|
|
661
|
+
except Exception as e:
|
|
662
|
+
logger.error(f"Error logging environment logs: {e}", exc_info=True)
|
|
@@ -49,6 +49,13 @@ class ErrorItem(TypedDict):
|
|
|
49
49
|
traceback: str | None
|
|
50
50
|
|
|
51
51
|
|
|
52
|
+
class EnvLogItem(TypedDict):
|
|
53
|
+
"""An environment log item for API submission."""
|
|
54
|
+
|
|
55
|
+
rollout_id: str
|
|
56
|
+
content: str
|
|
57
|
+
|
|
58
|
+
|
|
52
59
|
class ScalarValue(TypedDict):
|
|
53
60
|
"""A single scalar value at a specific step (used in GET responses)."""
|
|
54
61
|
|
|
@@ -16,17 +16,18 @@ def buffer():
|
|
|
16
16
|
def test_buffer_initialization(buffer):
|
|
17
17
|
"""Test buffer starts empty."""
|
|
18
18
|
assert buffer.is_empty()
|
|
19
|
-
scalars, rollouts, errors = buffer.get_and_clear()
|
|
19
|
+
scalars, rollouts, errors, env_logs = buffer.get_and_clear()
|
|
20
20
|
assert scalars == []
|
|
21
21
|
assert rollouts == []
|
|
22
22
|
assert errors == []
|
|
23
|
+
assert env_logs == []
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
def test_add_scalar_simple(buffer):
|
|
26
27
|
"""Test adding a simple scalar metric."""
|
|
27
28
|
buffer.add_scalar("loss", 0.5)
|
|
28
29
|
|
|
29
|
-
scalars, rollouts, errors = buffer.get_and_clear()
|
|
30
|
+
scalars, rollouts, errors, _ = buffer.get_and_clear()
|
|
30
31
|
assert len(scalars) == 1
|
|
31
32
|
assert scalars[0] == {"step": 0, "mode": "train", "name": "loss", "value": 0.5}
|
|
32
33
|
assert rollouts == []
|
|
@@ -37,7 +38,7 @@ def test_add_scalar_with_mode_in_key(buffer):
|
|
|
37
38
|
"""Test adding scalar with mode already in key."""
|
|
38
39
|
buffer.add_scalar("train/loss", 0.5)
|
|
39
40
|
|
|
40
|
-
scalars, rollouts, errors = buffer.get_and_clear()
|
|
41
|
+
scalars, rollouts, errors, _ = buffer.get_and_clear()
|
|
41
42
|
assert len(scalars) == 1
|
|
42
43
|
assert scalars[0] == {"step": 0, "mode": "train", "name": "loss", "value": 0.5}
|
|
43
44
|
|
|
@@ -46,7 +47,7 @@ def test_add_scalar_with_mode_parameter(buffer):
|
|
|
46
47
|
"""Test adding scalar with explicit mode parameter."""
|
|
47
48
|
buffer.add_scalar("loss", 0.5, mode="eval")
|
|
48
49
|
|
|
49
|
-
scalars, rollouts, errors = buffer.get_and_clear()
|
|
50
|
+
scalars, rollouts, errors, _ = buffer.get_and_clear()
|
|
50
51
|
assert len(scalars) == 1
|
|
51
52
|
assert scalars[0] == {"step": 0, "mode": "eval", "name": "loss", "value": 0.5}
|
|
52
53
|
|
|
@@ -57,7 +58,7 @@ def test_add_scalar_mode_conflict_key_wins(buffer, caplog):
|
|
|
57
58
|
buffer.add_scalar("train/loss", 0.5, mode="eval")
|
|
58
59
|
|
|
59
60
|
# Should use "eval" mode from parameter, with full "train/loss" as the name
|
|
60
|
-
scalars, _, _ = buffer.get_and_clear()
|
|
61
|
+
scalars, _, _, _ = buffer.get_and_clear()
|
|
61
62
|
assert len(scalars) == 1
|
|
62
63
|
assert scalars[0] == {"step": 0, "mode": "eval", "name": "train/loss", "value": 0.5}
|
|
63
64
|
|
|
@@ -70,7 +71,7 @@ def test_add_scalar_mode_no_conflict(buffer, caplog):
|
|
|
70
71
|
with caplog.at_level(logging.WARNING):
|
|
71
72
|
buffer.add_scalar("train/loss", 0.5, mode="train")
|
|
72
73
|
|
|
73
|
-
scalars, _, _ = buffer.get_and_clear()
|
|
74
|
+
scalars, _, _, _ = buffer.get_and_clear()
|
|
74
75
|
assert len(scalars) == 1
|
|
75
76
|
assert scalars[0] == {"step": 0, "mode": "train", "name": "loss", "value": 0.5}
|
|
76
77
|
|
|
@@ -84,7 +85,7 @@ def test_add_multiple_scalars(buffer):
|
|
|
84
85
|
buffer.add_scalar("accuracy", 0.9)
|
|
85
86
|
buffer.add_scalar("eval/loss", 0.6)
|
|
86
87
|
|
|
87
|
-
scalars, _, _ = buffer.get_and_clear()
|
|
88
|
+
scalars, _, _, _ = buffer.get_and_clear()
|
|
88
89
|
assert len(scalars) == 3
|
|
89
90
|
# Convert to dict for easier comparison
|
|
90
91
|
scalars_dict = {f"{s['mode']}/{s['name']}": s["value"] for s in scalars}
|
|
@@ -101,7 +102,7 @@ def test_add_scalar_last_write_wins(buffer, caplog):
|
|
|
101
102
|
buffer.add_scalar("loss", 0.5)
|
|
102
103
|
buffer.add_scalar("loss", 0.3) # Overwrites
|
|
103
104
|
|
|
104
|
-
scalars, _, _ = buffer.get_and_clear()
|
|
105
|
+
scalars, _, _, _ = buffer.get_and_clear()
|
|
105
106
|
assert len(scalars) == 1
|
|
106
107
|
assert scalars[0] == {"step": 0, "mode": "train", "name": "loss", "value": 0.3}
|
|
107
108
|
|
|
@@ -117,7 +118,7 @@ def test_add_scalar_last_write_wins_different_modes(buffer):
|
|
|
117
118
|
buffer.add_scalar("loss", 0.5, mode="train")
|
|
118
119
|
buffer.add_scalar("loss", 0.6, mode="eval")
|
|
119
120
|
|
|
120
|
-
scalars, _, _ = buffer.get_and_clear()
|
|
121
|
+
scalars, _, _, _ = buffer.get_and_clear()
|
|
121
122
|
assert len(scalars) == 2
|
|
122
123
|
# Convert to dict for easier comparison
|
|
123
124
|
scalars_dict = {f"{s['mode']}/{s['name']}": s["value"] for s in scalars}
|
|
@@ -139,7 +140,7 @@ def test_add_rollout(buffer):
|
|
|
139
140
|
|
|
140
141
|
buffer.add_rollout(rollout)
|
|
141
142
|
|
|
142
|
-
scalars, rollouts, errors = buffer.get_and_clear()
|
|
143
|
+
scalars, rollouts, errors, _ = buffer.get_and_clear()
|
|
143
144
|
assert scalars == []
|
|
144
145
|
assert len(rollouts) == 1
|
|
145
146
|
assert rollouts[0] == rollout
|
|
@@ -165,7 +166,7 @@ def test_add_multiple_rollouts(buffer):
|
|
|
165
166
|
buffer.add_rollout(rollout1)
|
|
166
167
|
buffer.add_rollout(rollout2)
|
|
167
168
|
|
|
168
|
-
_, rollouts, _ = buffer.get_and_clear()
|
|
169
|
+
_, rollouts, _, _ = buffer.get_and_clear()
|
|
169
170
|
assert len(rollouts) == 2
|
|
170
171
|
assert rollouts[0] == rollout1
|
|
171
172
|
assert rollouts[1] == rollout2
|
|
@@ -185,7 +186,7 @@ def test_mixed_scalars_and_rollouts(buffer):
|
|
|
185
186
|
)
|
|
186
187
|
buffer.add_scalar("accuracy", 0.9)
|
|
187
188
|
|
|
188
|
-
scalars, rollouts, errors = buffer.get_and_clear()
|
|
189
|
+
scalars, rollouts, errors, _ = buffer.get_and_clear()
|
|
189
190
|
assert len(scalars) == 2
|
|
190
191
|
scalars_dict = {f"{s['mode']}/{s['name']}": s["value"] for s in scalars}
|
|
191
192
|
assert scalars_dict == {"train/loss": 0.5, "train/accuracy": 0.9}
|
|
@@ -208,7 +209,7 @@ def test_get_and_clear_empties_buffer(buffer):
|
|
|
208
209
|
assert not buffer.is_empty()
|
|
209
210
|
|
|
210
211
|
# First call returns data
|
|
211
|
-
scalars1, rollouts1, errors1 = buffer.get_and_clear()
|
|
212
|
+
scalars1, rollouts1, errors1, _ = buffer.get_and_clear()
|
|
212
213
|
assert len(scalars1) > 0
|
|
213
214
|
assert len(rollouts1) > 0
|
|
214
215
|
|
|
@@ -216,7 +217,7 @@ def test_get_and_clear_empties_buffer(buffer):
|
|
|
216
217
|
assert buffer.is_empty()
|
|
217
218
|
|
|
218
219
|
# Second call returns empty
|
|
219
|
-
scalars2, rollouts2, errors2 = buffer.get_and_clear()
|
|
220
|
+
scalars2, rollouts2, errors2, _ = buffer.get_and_clear()
|
|
220
221
|
assert scalars2 == []
|
|
221
222
|
assert rollouts2 == []
|
|
222
223
|
assert errors2 == []
|
|
@@ -226,13 +227,13 @@ def test_get_and_clear_returns_copy(buffer):
|
|
|
226
227
|
"""Test that get_and_clear returns a copy, not reference."""
|
|
227
228
|
buffer.add_scalar("loss", 0.5)
|
|
228
229
|
|
|
229
|
-
scalars1, _, _ = buffer.get_and_clear()
|
|
230
|
+
scalars1, _, _, _ = buffer.get_and_clear()
|
|
230
231
|
# Modify returned list
|
|
231
232
|
scalars1.append({"step": 999, "mode": "test", "name": "modified", "value": 999})
|
|
232
233
|
|
|
233
234
|
# Add new data
|
|
234
235
|
buffer.add_scalar("accuracy", 0.9)
|
|
235
|
-
scalars2, _, _ = buffer.get_and_clear()
|
|
236
|
+
scalars2, _, _, _ = buffer.get_and_clear()
|
|
236
237
|
|
|
237
238
|
# Should not contain the modification
|
|
238
239
|
assert len(scalars2) == 1
|
|
@@ -274,7 +275,7 @@ def test_metric_key_with_multiple_slashes(buffer):
|
|
|
274
275
|
# Edge case: what if key has multiple slashes?
|
|
275
276
|
buffer.add_scalar("train/sub/metric", 0.5)
|
|
276
277
|
|
|
277
|
-
scalars, _, _ = buffer.get_and_clear()
|
|
278
|
+
scalars, _, _, _ = buffer.get_and_clear()
|
|
278
279
|
# Should split on first slash only
|
|
279
280
|
assert len(scalars) == 1
|
|
280
281
|
assert scalars[0] == {"step": 0, "mode": "train", "name": "sub/metric", "value": 0.5}
|
|
@@ -285,7 +286,7 @@ def test_default_mode_is_train(buffer):
|
|
|
285
286
|
buffer.add_scalar("loss", 0.5)
|
|
286
287
|
buffer.add_scalar("accuracy", 0.9)
|
|
287
288
|
|
|
288
|
-
scalars, _, _ = buffer.get_and_clear()
|
|
289
|
+
scalars, _, _, _ = buffer.get_and_clear()
|
|
289
290
|
assert all(s["mode"] == "train" for s in scalars)
|
|
290
291
|
|
|
291
292
|
|
|
@@ -293,7 +294,7 @@ def test_mode_provided_strips_matching_prefix(buffer):
|
|
|
293
294
|
"""Test that when mode is provided and matches key prefix, the prefix is stripped."""
|
|
294
295
|
buffer.add_scalar("train/loss", 0.5, mode="train")
|
|
295
296
|
|
|
296
|
-
scalars, _, _ = buffer.get_and_clear()
|
|
297
|
+
scalars, _, _, _ = buffer.get_and_clear()
|
|
297
298
|
# Should be train/loss (prefix stripped)
|
|
298
299
|
assert len(scalars) == 1
|
|
299
300
|
assert scalars[0] == {"step": 0, "mode": "train", "name": "loss", "value": 0.5}
|
|
@@ -303,7 +304,7 @@ def test_mode_provided_keeps_mismatched_prefix(buffer):
|
|
|
303
304
|
"""Test that when mode is provided and doesn't match key prefix, full key is kept."""
|
|
304
305
|
buffer.add_scalar("train/loss", 0.5, mode="eval")
|
|
305
306
|
|
|
306
|
-
scalars, _, _ = buffer.get_and_clear()
|
|
307
|
+
scalars, _, _, _ = buffer.get_and_clear()
|
|
307
308
|
# Should be eval/train/loss (full key kept as name)
|
|
308
309
|
assert len(scalars) == 1
|
|
309
310
|
assert scalars[0] == {"step": 0, "mode": "eval", "name": "train/loss", "value": 0.5}
|
|
@@ -313,7 +314,7 @@ def test_mode_not_provided_extracts_from_key(buffer):
|
|
|
313
314
|
"""Test that when mode is not provided, it's extracted from key prefix."""
|
|
314
315
|
buffer.add_scalar("eval/accuracy", 0.95)
|
|
315
316
|
|
|
316
|
-
scalars, _, _ = buffer.get_and_clear()
|
|
317
|
+
scalars, _, _, _ = buffer.get_and_clear()
|
|
317
318
|
# Should be eval/accuracy (mode extracted, name is accuracy)
|
|
318
319
|
assert len(scalars) == 1
|
|
319
320
|
assert scalars[0] == {"step": 0, "mode": "eval", "name": "accuracy", "value": 0.95}
|
|
@@ -323,7 +324,7 @@ def test_mode_not_provided_simple_key_defaults_train(buffer):
|
|
|
323
324
|
"""Test that simple keys without mode default to train mode."""
|
|
324
325
|
buffer.add_scalar("loss", 0.5)
|
|
325
326
|
|
|
326
|
-
scalars, _, _ = buffer.get_and_clear()
|
|
327
|
+
scalars, _, _, _ = buffer.get_and_clear()
|
|
327
328
|
# Should be train/loss (default mode)
|
|
328
329
|
assert len(scalars) == 1
|
|
329
330
|
assert scalars[0] == {"step": 0, "mode": "train", "name": "loss", "value": 0.5}
|
|
@@ -342,7 +343,7 @@ def test_add_error(buffer):
|
|
|
342
343
|
|
|
343
344
|
buffer.add_error(error)
|
|
344
345
|
|
|
345
|
-
scalars, rollouts, errors = buffer.get_and_clear()
|
|
346
|
+
scalars, rollouts, errors, _ = buffer.get_and_clear()
|
|
346
347
|
assert scalars == []
|
|
347
348
|
assert rollouts == []
|
|
348
349
|
assert len(errors) == 1
|
|
@@ -369,7 +370,7 @@ def test_add_multiple_errors(buffer):
|
|
|
369
370
|
buffer.add_error(error1)
|
|
370
371
|
buffer.add_error(error2)
|
|
371
372
|
|
|
372
|
-
_, _, errors = buffer.get_and_clear()
|
|
373
|
+
_, _, errors, _ = buffer.get_and_clear()
|
|
373
374
|
assert len(errors) == 2
|
|
374
375
|
assert errors[0] == error1
|
|
375
376
|
assert errors[1] == error2
|
|
@@ -398,7 +399,7 @@ def test_mixed_scalars_rollouts_and_errors(buffer):
|
|
|
398
399
|
)
|
|
399
400
|
buffer.add_scalar("accuracy", 0.9)
|
|
400
401
|
|
|
401
|
-
scalars, rollouts, errors = buffer.get_and_clear()
|
|
402
|
+
scalars, rollouts, errors, _ = buffer.get_and_clear()
|
|
402
403
|
assert len(scalars) == 2
|
|
403
404
|
assert len(rollouts) == 1
|
|
404
405
|
assert len(errors) == 1
|
|
@@ -438,13 +439,58 @@ def test_get_and_clear_empties_errors(buffer):
|
|
|
438
439
|
assert not buffer.is_empty()
|
|
439
440
|
|
|
440
441
|
# First call returns data
|
|
441
|
-
_, _, errors1 = buffer.get_and_clear()
|
|
442
|
+
_, _, errors1, _ = buffer.get_and_clear()
|
|
442
443
|
assert len(errors1) > 0
|
|
443
444
|
|
|
444
445
|
# Buffer should now be empty
|
|
445
446
|
assert buffer.is_empty()
|
|
446
447
|
|
|
447
448
|
# Second call returns empty
|
|
448
|
-
_, _, errors2 = buffer.get_and_clear()
|
|
449
|
+
_, _, errors2, _ = buffer.get_and_clear()
|
|
449
450
|
assert errors2 == []
|
|
450
451
|
|
|
452
|
+
|
|
453
|
+
def test_add_env_log(buffer):
|
|
454
|
+
"""Test adding an environment log entry."""
|
|
455
|
+
buffer.add_env_log("rollout-1", "step 1 observation")
|
|
456
|
+
|
|
457
|
+
_, _, _, env_logs = buffer.get_and_clear()
|
|
458
|
+
assert len(env_logs) == 1
|
|
459
|
+
assert env_logs[0] == {"rollout_id": "rollout-1", "content": "step 1 observation"}
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
def test_add_multiple_env_logs(buffer):
|
|
463
|
+
"""Test adding multiple environment log entries."""
|
|
464
|
+
buffer.add_env_log("rollout-1", "step 1")
|
|
465
|
+
buffer.add_env_log("rollout-1", "step 2")
|
|
466
|
+
buffer.add_env_log("rollout-2", "other rollout")
|
|
467
|
+
|
|
468
|
+
_, _, _, env_logs = buffer.get_and_clear()
|
|
469
|
+
assert len(env_logs) == 3
|
|
470
|
+
assert env_logs[0] == {"rollout_id": "rollout-1", "content": "step 1"}
|
|
471
|
+
assert env_logs[1] == {"rollout_id": "rollout-1", "content": "step 2"}
|
|
472
|
+
assert env_logs[2] == {"rollout_id": "rollout-2", "content": "other rollout"}
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
def test_is_empty_with_only_env_logs(buffer):
|
|
476
|
+
"""Test is_empty with only env logs."""
|
|
477
|
+
assert buffer.is_empty()
|
|
478
|
+
|
|
479
|
+
buffer.add_env_log("rollout-1", "some content")
|
|
480
|
+
assert not buffer.is_empty()
|
|
481
|
+
|
|
482
|
+
buffer.get_and_clear()
|
|
483
|
+
assert buffer.is_empty()
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
def test_get_and_clear_empties_env_logs(buffer):
|
|
487
|
+
"""Test that get_and_clear empties env logs from buffer."""
|
|
488
|
+
buffer.add_env_log("rollout-1", "content")
|
|
489
|
+
|
|
490
|
+
_, _, _, env_logs1 = buffer.get_and_clear()
|
|
491
|
+
assert len(env_logs1) == 1
|
|
492
|
+
|
|
493
|
+
assert buffer.is_empty()
|
|
494
|
+
|
|
495
|
+
_, _, _, env_logs2 = buffer.get_and_clear()
|
|
496
|
+
assert env_logs2 == []
|
|
@@ -295,15 +295,24 @@ def test_log_rollouts(client):
|
|
|
295
295
|
|
|
296
296
|
|
|
297
297
|
def test_log_env_logs(client):
|
|
298
|
-
"""Test logging environment logs."""
|
|
298
|
+
"""Test logging environment logs as a batch."""
|
|
299
|
+
logs = [
|
|
300
|
+
{"rollout_id": "rollout-456", "content": "step 1 observation"},
|
|
301
|
+
{"rollout_id": "rollout-456", "content": "step 2 observation"},
|
|
302
|
+
]
|
|
299
303
|
with patch.object(client, "_request") as mock_request:
|
|
300
|
-
client.log_env_logs("exp-123",
|
|
304
|
+
client.log_env_logs("exp-123", logs)
|
|
301
305
|
|
|
302
306
|
mock_request.assert_called_once()
|
|
303
307
|
call_args = mock_request.call_args
|
|
304
308
|
assert call_args[0][0] == "POST"
|
|
305
309
|
assert call_args[0][1] == "https://test.example.com/api/experiments/exp-123/env-logs"
|
|
306
|
-
assert call_args[1]["json"] == {
|
|
310
|
+
assert call_args[1]["json"] == {
|
|
311
|
+
"logs": [
|
|
312
|
+
{"rolloutId": "rollout-456", "content": "step 1 observation"},
|
|
313
|
+
{"rolloutId": "rollout-456", "content": "step 2 observation"},
|
|
314
|
+
]
|
|
315
|
+
}
|
|
307
316
|
|
|
308
317
|
|
|
309
318
|
def test_log_errors(client):
|
|
@@ -176,7 +176,7 @@ class TestAPIClientIntegration:
|
|
|
176
176
|
response = requests.get(
|
|
177
177
|
f"{client.base_url}/api/experiments/{experiment_id}/errors",
|
|
178
178
|
params={"mode": "train"},
|
|
179
|
-
headers={"
|
|
179
|
+
headers={"Authorization": f"Bearer {client.api_key}"},
|
|
180
180
|
)
|
|
181
181
|
|
|
182
182
|
assert response.status_code == 200
|
|
@@ -370,7 +370,7 @@ class TestAPIClientIntegration:
|
|
|
370
370
|
errors_response = requests.get(
|
|
371
371
|
f"{client.base_url}/api/experiments/{experiment_id}/errors",
|
|
372
372
|
params={"mode": "train"},
|
|
373
|
-
headers={"
|
|
373
|
+
headers={"Authorization": f"Bearer {client.api_key}"},
|
|
374
374
|
)
|
|
375
375
|
|
|
376
376
|
assert errors_response.status_code == 200
|
|
@@ -1096,7 +1096,7 @@ class TestEnvironmentLogs:
|
|
|
1096
1096
|
)
|
|
1097
1097
|
cleanup_experiments.append(run._experiment_id)
|
|
1098
1098
|
|
|
1099
|
-
expt_logger.log_environment("00000000-0000-0000-0000-000000000001", "observation: step=1, reward=0.9")
|
|
1099
|
+
expt_logger.log_environment("00000000-0000-0000-0000-000000000001", "observation: step=1, reward=0.9", k=-1)
|
|
1100
1100
|
time.sleep(0.5)
|
|
1101
1101
|
|
|
1102
1102
|
env_logs = fetch_env_logs(run._experiment_id, "00000000-0000-0000-0000-000000000001")
|
|
@@ -1120,9 +1120,9 @@ class TestEnvironmentLogs:
|
|
|
1120
1120
|
)
|
|
1121
1121
|
cleanup_experiments.append(run._experiment_id)
|
|
1122
1122
|
|
|
1123
|
-
expt_logger.log_environment("00000000-0000-0000-0000-000000000002", "step 1: action=left")
|
|
1124
|
-
expt_logger.log_environment("00000000-0000-0000-0000-000000000002", "step 2: action=right")
|
|
1125
|
-
expt_logger.log_environment("00000000-0000-0000-0000-000000000002", "step 3: action=jump")
|
|
1123
|
+
expt_logger.log_environment("00000000-0000-0000-0000-000000000002", "step 1: action=left", commit=True)
|
|
1124
|
+
expt_logger.log_environment("00000000-0000-0000-0000-000000000002", "step 2: action=right", commit=True)
|
|
1125
|
+
expt_logger.log_environment("00000000-0000-0000-0000-000000000002", "step 3: action=jump", commit=True)
|
|
1126
1126
|
time.sleep(0.5)
|
|
1127
1127
|
|
|
1128
1128
|
env_logs = fetch_env_logs(run._experiment_id, "00000000-0000-0000-0000-000000000002")
|
|
@@ -1147,8 +1147,8 @@ class TestEnvironmentLogs:
|
|
|
1147
1147
|
)
|
|
1148
1148
|
cleanup_experiments.append(run._experiment_id)
|
|
1149
1149
|
|
|
1150
|
-
expt_logger.log_environment("00000000-0000-0000-0000-00000000003a", "log for rollout A")
|
|
1151
|
-
expt_logger.log_environment("00000000-0000-0000-0000-00000000003b", "log for rollout B")
|
|
1150
|
+
expt_logger.log_environment("00000000-0000-0000-0000-00000000003a", "log for rollout A", commit=True)
|
|
1151
|
+
expt_logger.log_environment("00000000-0000-0000-0000-00000000003b", "log for rollout B", commit=True)
|
|
1152
1152
|
time.sleep(0.5)
|
|
1153
1153
|
|
|
1154
1154
|
logs_a = fetch_env_logs(run._experiment_id, "00000000-0000-0000-0000-00000000003a")
|
|
@@ -1160,6 +1160,50 @@ class TestEnvironmentLogs:
|
|
|
1160
1160
|
|
|
1161
1161
|
expt_logger.end()
|
|
1162
1162
|
|
|
1163
|
+
def test_log_environment_k_threshold_batching(
|
|
1164
|
+
self,
|
|
1165
|
+
shared_api_key: str,
|
|
1166
|
+
base_url: str,
|
|
1167
|
+
cleanup_experiments: list[str],
|
|
1168
|
+
fetch_env_logs,
|
|
1169
|
+
) -> None:
|
|
1170
|
+
"""Test that k parameter batches env logs and flushes only when threshold is exceeded.
|
|
1171
|
+
|
|
1172
|
+
With k=2, the flush triggers when the main thread calls log_environment and
|
|
1173
|
+
finds > 2 entries already in the worker's buffer. So we need k+2 total calls:
|
|
1174
|
+
the first k+1 accumulate in the buffer, and the (k+2)th call detects the overflow
|
|
1175
|
+
and triggers a commit.
|
|
1176
|
+
"""
|
|
1177
|
+
k = 2
|
|
1178
|
+
rollout_id = "00000000-0000-0000-0000-00000000005a"
|
|
1179
|
+
run = expt_logger.init(
|
|
1180
|
+
name="test-env-log-k-threshold",
|
|
1181
|
+
api_key=shared_api_key,
|
|
1182
|
+
base_url=base_url,
|
|
1183
|
+
)
|
|
1184
|
+
cleanup_experiments.append(run._experiment_id)
|
|
1185
|
+
|
|
1186
|
+
# Log k+1 items without committing; sleep to let the worker process them into the buffer
|
|
1187
|
+
for i in range(k + 1):
|
|
1188
|
+
expt_logger.log_environment(rollout_id, f"step {i}", commit=False, k=k)
|
|
1189
|
+
time.sleep(0.3)
|
|
1190
|
+
|
|
1191
|
+
# Nothing should be on the server yet — threshold not exceeded in the main thread's view
|
|
1192
|
+
env_logs = fetch_env_logs(run._experiment_id, rollout_id)
|
|
1193
|
+
assert len(env_logs) == 0, f"Expected 0 logs before threshold, got {len(env_logs)}"
|
|
1194
|
+
|
|
1195
|
+
# One more call: the main thread now sees k+1 items in the buffer (> k) and commits
|
|
1196
|
+
expt_logger.log_environment(rollout_id, f"step {k + 1}", commit=False, k=k)
|
|
1197
|
+
time.sleep(0.5)
|
|
1198
|
+
|
|
1199
|
+
# All k+2 logs should now be on the server
|
|
1200
|
+
env_logs = fetch_env_logs(run._experiment_id, rollout_id)
|
|
1201
|
+
assert len(env_logs) == k + 2
|
|
1202
|
+
contents = {log["content"] for log in env_logs}
|
|
1203
|
+
assert contents == {f"step {i}" for i in range(k + 2)}
|
|
1204
|
+
|
|
1205
|
+
expt_logger.end()
|
|
1206
|
+
|
|
1163
1207
|
def test_log_environment_multiline_content(
|
|
1164
1208
|
self,
|
|
1165
1209
|
shared_api_key: str,
|
|
@@ -1176,7 +1220,7 @@ class TestEnvironmentLogs:
|
|
|
1176
1220
|
cleanup_experiments.append(run._experiment_id)
|
|
1177
1221
|
|
|
1178
1222
|
content = "obs: {x: 1.0, y: 2.0}\nreward: 0.5\ndone: false\ninfo: {step: 10}"
|
|
1179
|
-
expt_logger.log_environment("00000000-0000-0000-0000-000000000004", content)
|
|
1223
|
+
expt_logger.log_environment("00000000-0000-0000-0000-000000000004", content, commit=True)
|
|
1180
1224
|
time.sleep(0.5)
|
|
1181
1225
|
|
|
1182
1226
|
env_logs = fetch_env_logs(run._experiment_id, "00000000-0000-0000-0000-000000000004")
|
|
@@ -1486,47 +1486,37 @@ def test_log_error_with_invalid_mode_empty(mock_client):
|
|
|
1486
1486
|
|
|
1487
1487
|
|
|
1488
1488
|
def test_log_environment_calls_client(mock_client):
|
|
1489
|
-
"""Test log_environment()
|
|
1489
|
+
"""Test log_environment() flushes to client on commit."""
|
|
1490
1490
|
_, client_instance = mock_client
|
|
1491
1491
|
|
|
1492
1492
|
run = Run(name="test-run", api_key="test-key", base_url="https://test.example.com")
|
|
1493
1493
|
|
|
1494
1494
|
run.log_environment("rollout-abc", "step 1 observation")
|
|
1495
|
-
|
|
1496
|
-
# Give worker time to process
|
|
1497
|
-
time.sleep(0.1)
|
|
1495
|
+
run.end()
|
|
1498
1496
|
|
|
1499
1497
|
assert client_instance.log_env_logs.called
|
|
1500
1498
|
call_args = client_instance.log_env_logs.call_args
|
|
1501
1499
|
assert call_args[0][0] == "test-exp-id"
|
|
1502
|
-
assert call_args[0][1] == "rollout-abc"
|
|
1503
|
-
assert call_args[0][2] == "step 1 observation"
|
|
1504
|
-
|
|
1505
|
-
run.end()
|
|
1500
|
+
assert call_args[0][1] == [{"rollout_id": "rollout-abc", "content": "step 1 observation"}]
|
|
1506
1501
|
|
|
1507
1502
|
|
|
1508
1503
|
def test_log_environment_multiple_calls(mock_client):
|
|
1509
|
-
"""Test log_environment()
|
|
1504
|
+
"""Test multiple log_environment() calls are batched into a single client call."""
|
|
1510
1505
|
_, client_instance = mock_client
|
|
1511
1506
|
|
|
1512
1507
|
run = Run(name="test-run", api_key="test-key", base_url="https://test.example.com")
|
|
1513
1508
|
|
|
1514
1509
|
run.log_environment("rollout-1", "log content 1")
|
|
1515
1510
|
run.log_environment("rollout-2", "log content 2")
|
|
1516
|
-
|
|
1517
|
-
# Give worker time to process
|
|
1518
|
-
time.sleep(0.1)
|
|
1519
|
-
|
|
1520
|
-
assert client_instance.log_env_logs.call_count == 2
|
|
1521
|
-
first_call = client_instance.log_env_logs.call_args_list[0]
|
|
1522
|
-
second_call = client_instance.log_env_logs.call_args_list[1]
|
|
1523
|
-
assert first_call[0][1] == "rollout-1"
|
|
1524
|
-
assert first_call[0][2] == "log content 1"
|
|
1525
|
-
assert second_call[0][1] == "rollout-2"
|
|
1526
|
-
assert second_call[0][2] == "log content 2"
|
|
1527
|
-
|
|
1528
1511
|
run.end()
|
|
1529
1512
|
|
|
1513
|
+
assert client_instance.log_env_logs.call_count == 1
|
|
1514
|
+
call_args = client_instance.log_env_logs.call_args
|
|
1515
|
+
assert call_args[0][1] == [
|
|
1516
|
+
{"rollout_id": "rollout-1", "content": "log content 1"},
|
|
1517
|
+
{"rollout_id": "rollout-2", "content": "log content 2"},
|
|
1518
|
+
]
|
|
1519
|
+
|
|
1530
1520
|
|
|
1531
1521
|
def test_log_environment_queue_full_handling(mock_client):
|
|
1532
1522
|
"""Test that queue full is handled gracefully for environment logs."""
|
|
@@ -1556,12 +1546,8 @@ def test_log_environment_api_error_is_logged(mock_client):
|
|
|
1556
1546
|
|
|
1557
1547
|
with patch("expt_logger.run.logger") as mock_logger:
|
|
1558
1548
|
run.log_environment("rollout-abc", "some content")
|
|
1559
|
-
|
|
1560
|
-
# Give worker time to process and handle error
|
|
1561
|
-
time.sleep(0.2)
|
|
1549
|
+
run.end()
|
|
1562
1550
|
|
|
1563
1551
|
assert run._worker_thread is not None
|
|
1564
|
-
assert run._worker_thread.is_alive()
|
|
1552
|
+
assert not run._worker_thread.is_alive()
|
|
1565
1553
|
assert mock_logger.error.called
|
|
1566
|
-
|
|
1567
|
-
run.end()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|