expt-logger 0.1.0.dev20__tar.gz → 0.1.0.dev22__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/PKG-INFO +1 -1
  2. {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/pyproject.toml +1 -1
  3. {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/src/expt_logger/__init__.py +5 -1
  4. {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/src/expt_logger/buffer.py +16 -4
  5. {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/src/expt_logger/client.py +5 -7
  6. {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/src/expt_logger/run.py +23 -7
  7. {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/src/expt_logger/types.py +7 -0
  8. {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/tests/test_buffer.py +73 -27
  9. {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/tests/test_client.py +12 -3
  10. {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/tests/test_client_integration.py +2 -2
  11. {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/tests/test_integration_e2e.py +51 -7
  12. {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/tests/test_run.py +13 -27
  13. {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/.gitignore +0 -0
  14. {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/.pre-commit-config.yaml +0 -0
  15. {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/DEVELOPMENT.md +0 -0
  16. {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/README.md +0 -0
  17. {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/src/expt_logger/config.py +0 -0
  18. {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/src/expt_logger/env.py +0 -0
  19. {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/src/expt_logger/exceptions.py +0 -0
  20. {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/src/expt_logger/py.typed +0 -0
  21. {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/src/expt_logger/validation.py +0 -0
  22. {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/tests/conftest.py +0 -0
  23. {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/tests/test_config.py +0 -0
  24. {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/tests/test_env.py +0 -0
  25. {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/tests/test_exceptions.py +0 -0
  26. {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/tests/test_global_api.py +0 -0
  27. {expt_logger-0.1.0.dev20 → expt_logger-0.1.0.dev22}/tests/test_validation.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: expt-logger
3
- Version: 0.1.0.dev20
3
+ Version: 0.1.0.dev22
4
4
  Summary: Simple experiment logging library
5
5
  Requires-Python: >=3.10
6
6
  Requires-Dist: requests>=2.31.0
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "expt-logger"
3
- version = "0.1.0.dev20"
3
+ version = "0.1.0.dev22"
4
4
  description = "Simple experiment logging library"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.10"
@@ -159,19 +159,23 @@ def log_rollout(
159
159
  def log_environment(
160
160
  rollout_id: str,
161
161
  content: str,
162
+ k: int | None = None,
163
+ commit: bool = True,
162
164
  ) -> None:
163
165
  """Log an environment log entry associated with a rollout.
164
166
 
165
167
  Args:
166
168
  rollout_id: ID of the rollout this log entry is associated with
167
169
  content: Log content string
170
+ k: If set, commit only when the buffer has more than k env log entries.
171
+ commit: Whether to flush buffer after logging
168
172
 
169
173
  Raises:
170
174
  RuntimeError: If no active run exists
171
175
  """
172
176
  if _active_run is None:
173
177
  raise RuntimeError("No active run. Call init() first.")
174
- _active_run.log_environment(rollout_id=rollout_id, content=content)
178
+ _active_run.log_environment(rollout_id=rollout_id, content=content, k=k, commit=commit)
175
179
 
176
180
 
177
181
  def log_error(
@@ -2,7 +2,7 @@
2
2
 
3
3
  import logging
4
4
 
5
- from expt_logger.types import ErrorItem, RolloutItem, ScalarItem
5
+ from expt_logger.types import EnvLogItem, ErrorItem, RolloutItem, ScalarItem
6
6
 
7
7
  logger = logging.getLogger(__name__)
8
8
 
@@ -20,6 +20,7 @@ class Buffer:
20
20
  self._scalars: dict[str, float] = {} # full_key (mode/metric) -> value
21
21
  self._rollouts: list[RolloutItem] = []
22
22
  self._errors: list[ErrorItem] = []
23
+ self._env_logs: list[EnvLogItem] = []
23
24
 
24
25
  def add_scalar(self, name: str, value: float, mode: str | None = None) -> None:
25
26
  """Add a scalar metric to the buffer.
@@ -88,7 +89,16 @@ class Buffer:
88
89
  """
89
90
  self._errors.append(error)
90
91
 
91
- def get_and_clear(self) -> tuple[list[ScalarItem], list[RolloutItem], list[ErrorItem]]:
92
+ def add_env_log(self, rollout_id: str, content: str) -> None:
93
+ """Add an environment log entry to the buffer.
94
+
95
+ Args:
96
+ rollout_id: ID of the rollout this log entry is associated with
97
+ content: Log content string
98
+ """
99
+ self._env_logs.append({"rollout_id": rollout_id, "content": content})
100
+
101
+ def get_and_clear(self) -> tuple[list[ScalarItem], list[RolloutItem], list[ErrorItem], list[EnvLogItem]]:
92
102
  """Get all buffered data and clear the buffer.
93
103
 
94
104
  Returns:
@@ -110,12 +120,14 @@ class Buffer:
110
120
 
111
121
  rollouts = self._rollouts.copy()
112
122
  errors = self._errors.copy()
123
+ env_logs = self._env_logs.copy()
113
124
 
114
125
  self._scalars.clear()
115
126
  self._rollouts.clear()
116
127
  self._errors.clear()
128
+ self._env_logs.clear()
117
129
 
118
- return scalar_items, rollouts, errors
130
+ return scalar_items, rollouts, errors, env_logs
119
131
 
120
132
  def is_empty(self) -> bool:
121
133
  """Check if buffer has any data.
@@ -123,4 +135,4 @@ class Buffer:
123
135
  Returns:
124
136
  True if buffer is empty, False otherwise
125
137
  """
126
- return not self._scalars and not self._rollouts and not self._errors
138
+ return not self._scalars and not self._rollouts and not self._errors and not self._env_logs
@@ -7,7 +7,7 @@ from typing import Any
7
7
  import requests
8
8
 
9
9
  from expt_logger.exceptions import APIError, AuthenticationError
10
- from expt_logger.types import ErrorItem, RolloutItem, ScalarItem, ScalarValue
10
+ from expt_logger.types import EnvLogItem, ErrorItem, RolloutItem, ScalarItem, ScalarValue
11
11
 
12
12
  logger = logging.getLogger(__name__)
13
13
 
@@ -283,21 +283,19 @@ class APIClient:
283
283
  def log_env_logs(
284
284
  self,
285
285
  experiment_id: str,
286
- rollout_id: str,
287
- content: str,
286
+ logs: list[EnvLogItem],
288
287
  ) -> None:
289
- """Log an environment log entry for an experiment.
288
+ """Log environment log entries for an experiment.
290
289
 
291
290
  Args:
292
291
  experiment_id: Experiment ID
293
- rollout_id: Rollout ID this log entry is associated with
294
- content: Log content string
292
+ logs: List of env log entries with rollout_id and content
295
293
 
296
294
  Raises:
297
295
  APIError: If request fails
298
296
  """
299
297
  url = f"{self.base_url}/api/experiments/{experiment_id}/env-logs"
300
- payload = {"rolloutId": rollout_id, "content": content}
298
+ payload = {"logs": [{"rolloutId": log["rollout_id"], "content": log["content"]} for log in logs]}
301
299
  self._request("POST", url, json=payload)
302
300
 
303
301
  def get_env_logs(
@@ -82,7 +82,9 @@ class Run:
82
82
  resolved_experiment_id = get_experiment_id(experiment_id, is_main_process=False)
83
83
  if resolved_experiment_id is not None:
84
84
  self._experiment_id = resolved_experiment_id
85
- self._validate_and_attach_experiment(config=config)
85
+ # Do not set the config if already attached to an experiment to avoid
86
+ # overwriting existing settings
87
+ self._validate_and_attach_experiment()
86
88
  logger.info(
87
89
  f"[expt_logger] Attached to experiment ID: {self._experiment_id} (subprocess)"
88
90
  )
@@ -291,12 +293,16 @@ class Run:
291
293
  self,
292
294
  rollout_id: str,
293
295
  content: str,
296
+ k: int | None = None,
297
+ commit: bool = True,
294
298
  ) -> None:
295
299
  """Log an environment log entry associated with a rollout.
296
300
 
297
301
  Args:
298
302
  rollout_id: ID of the rollout this log entry is associated with
299
303
  content: Log content string
304
+ k: If set, commit when the buffer has more than k elements
305
+ commit: Whether to flush buffer after logging
300
306
  """
301
307
  env_cmd: LogEnvironmentCommand = {
302
308
  "rollout_id": rollout_id,
@@ -305,7 +311,12 @@ class Run:
305
311
  try:
306
312
  self._queue.put_nowait(("log_environment", env_cmd))
307
313
  except Full:
308
- logger.warning(f"Command queue full, dropping environment log for rollout: {rollout_id}")
314
+ logger.warning(
315
+ f"Command queue full, dropping environment log for rollout: {rollout_id}"
316
+ )
317
+
318
+ if commit or (k is not None and len(self._buffer._env_logs) > k):
319
+ self.commit()
309
320
 
310
321
  def log_error(
311
322
  self,
@@ -569,10 +580,7 @@ class Run:
569
580
  Args:
570
581
  data: Log environment command data with rollout_id and content
571
582
  """
572
- try:
573
- self._client.log_env_logs(self._experiment_id, data["rollout_id"], data["content"])
574
- except Exception as e:
575
- logger.error(f"Error logging environment log: {e}", exc_info=True)
583
+ self._buffer.add_env_log(data["rollout_id"], data["content"])
576
584
 
577
585
  def _handle_log_error(self, data: LogErrorCommand) -> None:
578
586
  """Handle log_error command.
@@ -619,7 +627,7 @@ class Run:
619
627
  if self._buffer.is_empty():
620
628
  return
621
629
 
622
- scalars, rollouts, errors = self._buffer.get_and_clear()
630
+ scalars, rollouts, errors, env_logs = self._buffer.get_and_clear()
623
631
 
624
632
  # Send scalars if any
625
633
  if scalars:
@@ -644,3 +652,11 @@ class Run:
644
652
  logger.debug(f"Flushed {len(errors)} errors at step {self._step}")
645
653
  except Exception as e:
646
654
  logger.error(f"Error logging errors: {e}", exc_info=True)
655
+
656
+ # Send env logs if any
657
+ if env_logs:
658
+ try:
659
+ self._client.log_env_logs(self._experiment_id, env_logs)
660
+ logger.debug(f"Flushed {len(env_logs)} environment logs")
661
+ except Exception as e:
662
+ logger.error(f"Error logging environment logs: {e}", exc_info=True)
@@ -49,6 +49,13 @@ class ErrorItem(TypedDict):
49
49
  traceback: str | None
50
50
 
51
51
 
52
+ class EnvLogItem(TypedDict):
53
+ """An environment log item for API submission."""
54
+
55
+ rollout_id: str
56
+ content: str
57
+
58
+
52
59
  class ScalarValue(TypedDict):
53
60
  """A single scalar value at a specific step (used in GET responses)."""
54
61
 
@@ -16,17 +16,18 @@ def buffer():
16
16
  def test_buffer_initialization(buffer):
17
17
  """Test buffer starts empty."""
18
18
  assert buffer.is_empty()
19
- scalars, rollouts, errors = buffer.get_and_clear()
19
+ scalars, rollouts, errors, env_logs = buffer.get_and_clear()
20
20
  assert scalars == []
21
21
  assert rollouts == []
22
22
  assert errors == []
23
+ assert env_logs == []
23
24
 
24
25
 
25
26
  def test_add_scalar_simple(buffer):
26
27
  """Test adding a simple scalar metric."""
27
28
  buffer.add_scalar("loss", 0.5)
28
29
 
29
- scalars, rollouts, errors = buffer.get_and_clear()
30
+ scalars, rollouts, errors, _ = buffer.get_and_clear()
30
31
  assert len(scalars) == 1
31
32
  assert scalars[0] == {"step": 0, "mode": "train", "name": "loss", "value": 0.5}
32
33
  assert rollouts == []
@@ -37,7 +38,7 @@ def test_add_scalar_with_mode_in_key(buffer):
37
38
  """Test adding scalar with mode already in key."""
38
39
  buffer.add_scalar("train/loss", 0.5)
39
40
 
40
- scalars, rollouts, errors = buffer.get_and_clear()
41
+ scalars, rollouts, errors, _ = buffer.get_and_clear()
41
42
  assert len(scalars) == 1
42
43
  assert scalars[0] == {"step": 0, "mode": "train", "name": "loss", "value": 0.5}
43
44
 
@@ -46,7 +47,7 @@ def test_add_scalar_with_mode_parameter(buffer):
46
47
  """Test adding scalar with explicit mode parameter."""
47
48
  buffer.add_scalar("loss", 0.5, mode="eval")
48
49
 
49
- scalars, rollouts, errors = buffer.get_and_clear()
50
+ scalars, rollouts, errors, _ = buffer.get_and_clear()
50
51
  assert len(scalars) == 1
51
52
  assert scalars[0] == {"step": 0, "mode": "eval", "name": "loss", "value": 0.5}
52
53
 
@@ -57,7 +58,7 @@ def test_add_scalar_mode_conflict_key_wins(buffer, caplog):
57
58
  buffer.add_scalar("train/loss", 0.5, mode="eval")
58
59
 
59
60
  # Should use "eval" mode from parameter, with full "train/loss" as the name
60
- scalars, _, _ = buffer.get_and_clear()
61
+ scalars, _, _, _ = buffer.get_and_clear()
61
62
  assert len(scalars) == 1
62
63
  assert scalars[0] == {"step": 0, "mode": "eval", "name": "train/loss", "value": 0.5}
63
64
 
@@ -70,7 +71,7 @@ def test_add_scalar_mode_no_conflict(buffer, caplog):
70
71
  with caplog.at_level(logging.WARNING):
71
72
  buffer.add_scalar("train/loss", 0.5, mode="train")
72
73
 
73
- scalars, _, _ = buffer.get_and_clear()
74
+ scalars, _, _, _ = buffer.get_and_clear()
74
75
  assert len(scalars) == 1
75
76
  assert scalars[0] == {"step": 0, "mode": "train", "name": "loss", "value": 0.5}
76
77
 
@@ -84,7 +85,7 @@ def test_add_multiple_scalars(buffer):
84
85
  buffer.add_scalar("accuracy", 0.9)
85
86
  buffer.add_scalar("eval/loss", 0.6)
86
87
 
87
- scalars, _, _ = buffer.get_and_clear()
88
+ scalars, _, _, _ = buffer.get_and_clear()
88
89
  assert len(scalars) == 3
89
90
  # Convert to dict for easier comparison
90
91
  scalars_dict = {f"{s['mode']}/{s['name']}": s["value"] for s in scalars}
@@ -101,7 +102,7 @@ def test_add_scalar_last_write_wins(buffer, caplog):
101
102
  buffer.add_scalar("loss", 0.5)
102
103
  buffer.add_scalar("loss", 0.3) # Overwrites
103
104
 
104
- scalars, _, _ = buffer.get_and_clear()
105
+ scalars, _, _, _ = buffer.get_and_clear()
105
106
  assert len(scalars) == 1
106
107
  assert scalars[0] == {"step": 0, "mode": "train", "name": "loss", "value": 0.3}
107
108
 
@@ -117,7 +118,7 @@ def test_add_scalar_last_write_wins_different_modes(buffer):
117
118
  buffer.add_scalar("loss", 0.5, mode="train")
118
119
  buffer.add_scalar("loss", 0.6, mode="eval")
119
120
 
120
- scalars, _, _ = buffer.get_and_clear()
121
+ scalars, _, _, _ = buffer.get_and_clear()
121
122
  assert len(scalars) == 2
122
123
  # Convert to dict for easier comparison
123
124
  scalars_dict = {f"{s['mode']}/{s['name']}": s["value"] for s in scalars}
@@ -139,7 +140,7 @@ def test_add_rollout(buffer):
139
140
 
140
141
  buffer.add_rollout(rollout)
141
142
 
142
- scalars, rollouts, errors = buffer.get_and_clear()
143
+ scalars, rollouts, errors, _ = buffer.get_and_clear()
143
144
  assert scalars == []
144
145
  assert len(rollouts) == 1
145
146
  assert rollouts[0] == rollout
@@ -165,7 +166,7 @@ def test_add_multiple_rollouts(buffer):
165
166
  buffer.add_rollout(rollout1)
166
167
  buffer.add_rollout(rollout2)
167
168
 
168
- _, rollouts, _ = buffer.get_and_clear()
169
+ _, rollouts, _, _ = buffer.get_and_clear()
169
170
  assert len(rollouts) == 2
170
171
  assert rollouts[0] == rollout1
171
172
  assert rollouts[1] == rollout2
@@ -185,7 +186,7 @@ def test_mixed_scalars_and_rollouts(buffer):
185
186
  )
186
187
  buffer.add_scalar("accuracy", 0.9)
187
188
 
188
- scalars, rollouts, errors = buffer.get_and_clear()
189
+ scalars, rollouts, errors, _ = buffer.get_and_clear()
189
190
  assert len(scalars) == 2
190
191
  scalars_dict = {f"{s['mode']}/{s['name']}": s["value"] for s in scalars}
191
192
  assert scalars_dict == {"train/loss": 0.5, "train/accuracy": 0.9}
@@ -208,7 +209,7 @@ def test_get_and_clear_empties_buffer(buffer):
208
209
  assert not buffer.is_empty()
209
210
 
210
211
  # First call returns data
211
- scalars1, rollouts1, errors1 = buffer.get_and_clear()
212
+ scalars1, rollouts1, errors1, _ = buffer.get_and_clear()
212
213
  assert len(scalars1) > 0
213
214
  assert len(rollouts1) > 0
214
215
 
@@ -216,7 +217,7 @@ def test_get_and_clear_empties_buffer(buffer):
216
217
  assert buffer.is_empty()
217
218
 
218
219
  # Second call returns empty
219
- scalars2, rollouts2, errors2 = buffer.get_and_clear()
220
+ scalars2, rollouts2, errors2, _ = buffer.get_and_clear()
220
221
  assert scalars2 == []
221
222
  assert rollouts2 == []
222
223
  assert errors2 == []
@@ -226,13 +227,13 @@ def test_get_and_clear_returns_copy(buffer):
226
227
  """Test that get_and_clear returns a copy, not reference."""
227
228
  buffer.add_scalar("loss", 0.5)
228
229
 
229
- scalars1, _, _ = buffer.get_and_clear()
230
+ scalars1, _, _, _ = buffer.get_and_clear()
230
231
  # Modify returned list
231
232
  scalars1.append({"step": 999, "mode": "test", "name": "modified", "value": 999})
232
233
 
233
234
  # Add new data
234
235
  buffer.add_scalar("accuracy", 0.9)
235
- scalars2, _, _ = buffer.get_and_clear()
236
+ scalars2, _, _, _ = buffer.get_and_clear()
236
237
 
237
238
  # Should not contain the modification
238
239
  assert len(scalars2) == 1
@@ -274,7 +275,7 @@ def test_metric_key_with_multiple_slashes(buffer):
274
275
  # Edge case: what if key has multiple slashes?
275
276
  buffer.add_scalar("train/sub/metric", 0.5)
276
277
 
277
- scalars, _, _ = buffer.get_and_clear()
278
+ scalars, _, _, _ = buffer.get_and_clear()
278
279
  # Should split on first slash only
279
280
  assert len(scalars) == 1
280
281
  assert scalars[0] == {"step": 0, "mode": "train", "name": "sub/metric", "value": 0.5}
@@ -285,7 +286,7 @@ def test_default_mode_is_train(buffer):
285
286
  buffer.add_scalar("loss", 0.5)
286
287
  buffer.add_scalar("accuracy", 0.9)
287
288
 
288
- scalars, _, _ = buffer.get_and_clear()
289
+ scalars, _, _, _ = buffer.get_and_clear()
289
290
  assert all(s["mode"] == "train" for s in scalars)
290
291
 
291
292
 
@@ -293,7 +294,7 @@ def test_mode_provided_strips_matching_prefix(buffer):
293
294
  """Test that when mode is provided and matches key prefix, the prefix is stripped."""
294
295
  buffer.add_scalar("train/loss", 0.5, mode="train")
295
296
 
296
- scalars, _, _ = buffer.get_and_clear()
297
+ scalars, _, _, _ = buffer.get_and_clear()
297
298
  # Should be train/loss (prefix stripped)
298
299
  assert len(scalars) == 1
299
300
  assert scalars[0] == {"step": 0, "mode": "train", "name": "loss", "value": 0.5}
@@ -303,7 +304,7 @@ def test_mode_provided_keeps_mismatched_prefix(buffer):
303
304
  """Test that when mode is provided and doesn't match key prefix, full key is kept."""
304
305
  buffer.add_scalar("train/loss", 0.5, mode="eval")
305
306
 
306
- scalars, _, _ = buffer.get_and_clear()
307
+ scalars, _, _, _ = buffer.get_and_clear()
307
308
  # Should be eval/train/loss (full key kept as name)
308
309
  assert len(scalars) == 1
309
310
  assert scalars[0] == {"step": 0, "mode": "eval", "name": "train/loss", "value": 0.5}
@@ -313,7 +314,7 @@ def test_mode_not_provided_extracts_from_key(buffer):
313
314
  """Test that when mode is not provided, it's extracted from key prefix."""
314
315
  buffer.add_scalar("eval/accuracy", 0.95)
315
316
 
316
- scalars, _, _ = buffer.get_and_clear()
317
+ scalars, _, _, _ = buffer.get_and_clear()
317
318
  # Should be eval/accuracy (mode extracted, name is accuracy)
318
319
  assert len(scalars) == 1
319
320
  assert scalars[0] == {"step": 0, "mode": "eval", "name": "accuracy", "value": 0.95}
@@ -323,7 +324,7 @@ def test_mode_not_provided_simple_key_defaults_train(buffer):
323
324
  """Test that simple keys without mode default to train mode."""
324
325
  buffer.add_scalar("loss", 0.5)
325
326
 
326
- scalars, _, _ = buffer.get_and_clear()
327
+ scalars, _, _, _ = buffer.get_and_clear()
327
328
  # Should be train/loss (default mode)
328
329
  assert len(scalars) == 1
329
330
  assert scalars[0] == {"step": 0, "mode": "train", "name": "loss", "value": 0.5}
@@ -342,7 +343,7 @@ def test_add_error(buffer):
342
343
 
343
344
  buffer.add_error(error)
344
345
 
345
- scalars, rollouts, errors = buffer.get_and_clear()
346
+ scalars, rollouts, errors, _ = buffer.get_and_clear()
346
347
  assert scalars == []
347
348
  assert rollouts == []
348
349
  assert len(errors) == 1
@@ -369,7 +370,7 @@ def test_add_multiple_errors(buffer):
369
370
  buffer.add_error(error1)
370
371
  buffer.add_error(error2)
371
372
 
372
- _, _, errors = buffer.get_and_clear()
373
+ _, _, errors, _ = buffer.get_and_clear()
373
374
  assert len(errors) == 2
374
375
  assert errors[0] == error1
375
376
  assert errors[1] == error2
@@ -398,7 +399,7 @@ def test_mixed_scalars_rollouts_and_errors(buffer):
398
399
  )
399
400
  buffer.add_scalar("accuracy", 0.9)
400
401
 
401
- scalars, rollouts, errors = buffer.get_and_clear()
402
+ scalars, rollouts, errors, _ = buffer.get_and_clear()
402
403
  assert len(scalars) == 2
403
404
  assert len(rollouts) == 1
404
405
  assert len(errors) == 1
@@ -438,13 +439,58 @@ def test_get_and_clear_empties_errors(buffer):
438
439
  assert not buffer.is_empty()
439
440
 
440
441
  # First call returns data
441
- _, _, errors1 = buffer.get_and_clear()
442
+ _, _, errors1, _ = buffer.get_and_clear()
442
443
  assert len(errors1) > 0
443
444
 
444
445
  # Buffer should now be empty
445
446
  assert buffer.is_empty()
446
447
 
447
448
  # Second call returns empty
448
- _, _, errors2 = buffer.get_and_clear()
449
+ _, _, errors2, _ = buffer.get_and_clear()
449
450
  assert errors2 == []
450
451
 
452
+
453
+ def test_add_env_log(buffer):
454
+ """Test adding an environment log entry."""
455
+ buffer.add_env_log("rollout-1", "step 1 observation")
456
+
457
+ _, _, _, env_logs = buffer.get_and_clear()
458
+ assert len(env_logs) == 1
459
+ assert env_logs[0] == {"rollout_id": "rollout-1", "content": "step 1 observation"}
460
+
461
+
462
+ def test_add_multiple_env_logs(buffer):
463
+ """Test adding multiple environment log entries."""
464
+ buffer.add_env_log("rollout-1", "step 1")
465
+ buffer.add_env_log("rollout-1", "step 2")
466
+ buffer.add_env_log("rollout-2", "other rollout")
467
+
468
+ _, _, _, env_logs = buffer.get_and_clear()
469
+ assert len(env_logs) == 3
470
+ assert env_logs[0] == {"rollout_id": "rollout-1", "content": "step 1"}
471
+ assert env_logs[1] == {"rollout_id": "rollout-1", "content": "step 2"}
472
+ assert env_logs[2] == {"rollout_id": "rollout-2", "content": "other rollout"}
473
+
474
+
475
+ def test_is_empty_with_only_env_logs(buffer):
476
+ """Test is_empty with only env logs."""
477
+ assert buffer.is_empty()
478
+
479
+ buffer.add_env_log("rollout-1", "some content")
480
+ assert not buffer.is_empty()
481
+
482
+ buffer.get_and_clear()
483
+ assert buffer.is_empty()
484
+
485
+
486
+ def test_get_and_clear_empties_env_logs(buffer):
487
+ """Test that get_and_clear empties env logs from buffer."""
488
+ buffer.add_env_log("rollout-1", "content")
489
+
490
+ _, _, _, env_logs1 = buffer.get_and_clear()
491
+ assert len(env_logs1) == 1
492
+
493
+ assert buffer.is_empty()
494
+
495
+ _, _, _, env_logs2 = buffer.get_and_clear()
496
+ assert env_logs2 == []
@@ -295,15 +295,24 @@ def test_log_rollouts(client):
295
295
 
296
296
 
297
297
  def test_log_env_logs(client):
298
- """Test logging environment logs."""
298
+ """Test logging environment logs as a batch."""
299
+ logs = [
300
+ {"rollout_id": "rollout-456", "content": "step 1 observation"},
301
+ {"rollout_id": "rollout-456", "content": "step 2 observation"},
302
+ ]
299
303
  with patch.object(client, "_request") as mock_request:
300
- client.log_env_logs("exp-123", "rollout-456", "step 1 observation")
304
+ client.log_env_logs("exp-123", logs)
301
305
 
302
306
  mock_request.assert_called_once()
303
307
  call_args = mock_request.call_args
304
308
  assert call_args[0][0] == "POST"
305
309
  assert call_args[0][1] == "https://test.example.com/api/experiments/exp-123/env-logs"
306
- assert call_args[1]["json"] == {"rolloutId": "rollout-456", "content": "step 1 observation"}
310
+ assert call_args[1]["json"] == {
311
+ "logs": [
312
+ {"rolloutId": "rollout-456", "content": "step 1 observation"},
313
+ {"rolloutId": "rollout-456", "content": "step 2 observation"},
314
+ ]
315
+ }
307
316
 
308
317
 
309
318
  def test_log_errors(client):
@@ -176,7 +176,7 @@ class TestAPIClientIntegration:
176
176
  response = requests.get(
177
177
  f"{client.base_url}/api/experiments/{experiment_id}/errors",
178
178
  params={"mode": "train"},
179
- headers={"x-api-key": client.api_key},
179
+ headers={"Authorization": f"Bearer {client.api_key}"},
180
180
  )
181
181
 
182
182
  assert response.status_code == 200
@@ -370,7 +370,7 @@ class TestAPIClientIntegration:
370
370
  errors_response = requests.get(
371
371
  f"{client.base_url}/api/experiments/{experiment_id}/errors",
372
372
  params={"mode": "train"},
373
- headers={"x-api-key": client.api_key},
373
+ headers={"Authorization": f"Bearer {client.api_key}"},
374
374
  )
375
375
 
376
376
  assert errors_response.status_code == 200
@@ -1096,7 +1096,7 @@ class TestEnvironmentLogs:
1096
1096
  )
1097
1097
  cleanup_experiments.append(run._experiment_id)
1098
1098
 
1099
- expt_logger.log_environment("00000000-0000-0000-0000-000000000001", "observation: step=1, reward=0.9")
1099
+ expt_logger.log_environment("00000000-0000-0000-0000-000000000001", "observation: step=1, reward=0.9", k=-1)
1100
1100
  time.sleep(0.5)
1101
1101
 
1102
1102
  env_logs = fetch_env_logs(run._experiment_id, "00000000-0000-0000-0000-000000000001")
@@ -1120,9 +1120,9 @@ class TestEnvironmentLogs:
1120
1120
  )
1121
1121
  cleanup_experiments.append(run._experiment_id)
1122
1122
 
1123
- expt_logger.log_environment("00000000-0000-0000-0000-000000000002", "step 1: action=left")
1124
- expt_logger.log_environment("00000000-0000-0000-0000-000000000002", "step 2: action=right")
1125
- expt_logger.log_environment("00000000-0000-0000-0000-000000000002", "step 3: action=jump")
1123
+ expt_logger.log_environment("00000000-0000-0000-0000-000000000002", "step 1: action=left", commit=True)
1124
+ expt_logger.log_environment("00000000-0000-0000-0000-000000000002", "step 2: action=right", commit=True)
1125
+ expt_logger.log_environment("00000000-0000-0000-0000-000000000002", "step 3: action=jump", commit=True)
1126
1126
  time.sleep(0.5)
1127
1127
 
1128
1128
  env_logs = fetch_env_logs(run._experiment_id, "00000000-0000-0000-0000-000000000002")
@@ -1147,8 +1147,8 @@ class TestEnvironmentLogs:
1147
1147
  )
1148
1148
  cleanup_experiments.append(run._experiment_id)
1149
1149
 
1150
- expt_logger.log_environment("00000000-0000-0000-0000-00000000003a", "log for rollout A")
1151
- expt_logger.log_environment("00000000-0000-0000-0000-00000000003b", "log for rollout B")
1150
+ expt_logger.log_environment("00000000-0000-0000-0000-00000000003a", "log for rollout A", commit=True)
1151
+ expt_logger.log_environment("00000000-0000-0000-0000-00000000003b", "log for rollout B", commit=True)
1152
1152
  time.sleep(0.5)
1153
1153
 
1154
1154
  logs_a = fetch_env_logs(run._experiment_id, "00000000-0000-0000-0000-00000000003a")
@@ -1160,6 +1160,50 @@ class TestEnvironmentLogs:
1160
1160
 
1161
1161
  expt_logger.end()
1162
1162
 
1163
+ def test_log_environment_k_threshold_batching(
1164
+ self,
1165
+ shared_api_key: str,
1166
+ base_url: str,
1167
+ cleanup_experiments: list[str],
1168
+ fetch_env_logs,
1169
+ ) -> None:
1170
+ """Test that k parameter batches env logs and flushes only when threshold is exceeded.
1171
+
1172
+ With k=2, the flush triggers when the main thread calls log_environment and
1173
+ finds > 2 entries already in the worker's buffer. So we need k+2 total calls:
1174
+ the first k+1 accumulate in the buffer, and the (k+2)th call detects the overflow
1175
+ and triggers a commit.
1176
+ """
1177
+ k = 2
1178
+ rollout_id = "00000000-0000-0000-0000-00000000005a"
1179
+ run = expt_logger.init(
1180
+ name="test-env-log-k-threshold",
1181
+ api_key=shared_api_key,
1182
+ base_url=base_url,
1183
+ )
1184
+ cleanup_experiments.append(run._experiment_id)
1185
+
1186
+ # Log k+1 items without committing; sleep to let the worker process them into the buffer
1187
+ for i in range(k + 1):
1188
+ expt_logger.log_environment(rollout_id, f"step {i}", commit=False, k=k)
1189
+ time.sleep(0.3)
1190
+
1191
+ # Nothing should be on the server yet — threshold not exceeded in the main thread's view
1192
+ env_logs = fetch_env_logs(run._experiment_id, rollout_id)
1193
+ assert len(env_logs) == 0, f"Expected 0 logs before threshold, got {len(env_logs)}"
1194
+
1195
+ # One more call: the main thread now sees k+1 items in the buffer (> k) and commits
1196
+ expt_logger.log_environment(rollout_id, f"step {k + 1}", commit=False, k=k)
1197
+ time.sleep(0.5)
1198
+
1199
+ # All k+2 logs should now be on the server
1200
+ env_logs = fetch_env_logs(run._experiment_id, rollout_id)
1201
+ assert len(env_logs) == k + 2
1202
+ contents = {log["content"] for log in env_logs}
1203
+ assert contents == {f"step {i}" for i in range(k + 2)}
1204
+
1205
+ expt_logger.end()
1206
+
1163
1207
  def test_log_environment_multiline_content(
1164
1208
  self,
1165
1209
  shared_api_key: str,
@@ -1176,7 +1220,7 @@ class TestEnvironmentLogs:
1176
1220
  cleanup_experiments.append(run._experiment_id)
1177
1221
 
1178
1222
  content = "obs: {x: 1.0, y: 2.0}\nreward: 0.5\ndone: false\ninfo: {step: 10}"
1179
- expt_logger.log_environment("00000000-0000-0000-0000-000000000004", content)
1223
+ expt_logger.log_environment("00000000-0000-0000-0000-000000000004", content, commit=True)
1180
1224
  time.sleep(0.5)
1181
1225
 
1182
1226
  env_logs = fetch_env_logs(run._experiment_id, "00000000-0000-0000-0000-000000000004")
@@ -1486,47 +1486,37 @@ def test_log_error_with_invalid_mode_empty(mock_client):
1486
1486
 
1487
1487
 
1488
1488
  def test_log_environment_calls_client(mock_client):
1489
- """Test log_environment() calls client with correct args."""
1489
+ """Test log_environment() flushes to client on commit."""
1490
1490
  _, client_instance = mock_client
1491
1491
 
1492
1492
  run = Run(name="test-run", api_key="test-key", base_url="https://test.example.com")
1493
1493
 
1494
1494
  run.log_environment("rollout-abc", "step 1 observation")
1495
-
1496
- # Give worker time to process
1497
- time.sleep(0.1)
1495
+ run.end()
1498
1496
 
1499
1497
  assert client_instance.log_env_logs.called
1500
1498
  call_args = client_instance.log_env_logs.call_args
1501
1499
  assert call_args[0][0] == "test-exp-id"
1502
- assert call_args[0][1] == "rollout-abc"
1503
- assert call_args[0][2] == "step 1 observation"
1504
-
1505
- run.end()
1500
+ assert call_args[0][1] == [{"rollout_id": "rollout-abc", "content": "step 1 observation"}]
1506
1501
 
1507
1502
 
1508
1503
  def test_log_environment_multiple_calls(mock_client):
1509
- """Test log_environment() can be called multiple times for different rollouts."""
1504
+ """Test multiple log_environment() calls are batched into a single client call."""
1510
1505
  _, client_instance = mock_client
1511
1506
 
1512
1507
  run = Run(name="test-run", api_key="test-key", base_url="https://test.example.com")
1513
1508
 
1514
1509
  run.log_environment("rollout-1", "log content 1")
1515
1510
  run.log_environment("rollout-2", "log content 2")
1516
-
1517
- # Give worker time to process
1518
- time.sleep(0.1)
1519
-
1520
- assert client_instance.log_env_logs.call_count == 2
1521
- first_call = client_instance.log_env_logs.call_args_list[0]
1522
- second_call = client_instance.log_env_logs.call_args_list[1]
1523
- assert first_call[0][1] == "rollout-1"
1524
- assert first_call[0][2] == "log content 1"
1525
- assert second_call[0][1] == "rollout-2"
1526
- assert second_call[0][2] == "log content 2"
1527
-
1528
1511
  run.end()
1529
1512
 
1513
+ assert client_instance.log_env_logs.call_count == 1
1514
+ call_args = client_instance.log_env_logs.call_args
1515
+ assert call_args[0][1] == [
1516
+ {"rollout_id": "rollout-1", "content": "log content 1"},
1517
+ {"rollout_id": "rollout-2", "content": "log content 2"},
1518
+ ]
1519
+
1530
1520
 
1531
1521
  def test_log_environment_queue_full_handling(mock_client):
1532
1522
  """Test that queue full is handled gracefully for environment logs."""
@@ -1556,12 +1546,8 @@ def test_log_environment_api_error_is_logged(mock_client):
1556
1546
 
1557
1547
  with patch("expt_logger.run.logger") as mock_logger:
1558
1548
  run.log_environment("rollout-abc", "some content")
1559
-
1560
- # Give worker time to process and handle error
1561
- time.sleep(0.2)
1549
+ run.end()
1562
1550
 
1563
1551
  assert run._worker_thread is not None
1564
- assert run._worker_thread.is_alive()
1552
+ assert not run._worker_thread.is_alive()
1565
1553
  assert mock_logger.error.called
1566
-
1567
- run.end()