experimaestro 2.0.0b4__py3-none-any.whl → 2.0.0b8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of experimaestro might be problematic. Click here for more details.
- experimaestro/cli/__init__.py +177 -31
- experimaestro/experiments/cli.py +6 -2
- experimaestro/scheduler/base.py +21 -0
- experimaestro/scheduler/experiment.py +64 -34
- experimaestro/scheduler/interfaces.py +27 -0
- experimaestro/scheduler/remote/__init__.py +31 -0
- experimaestro/scheduler/remote/client.py +874 -0
- experimaestro/scheduler/remote/protocol.py +467 -0
- experimaestro/scheduler/remote/server.py +423 -0
- experimaestro/scheduler/remote/sync.py +144 -0
- experimaestro/scheduler/services.py +158 -32
- experimaestro/scheduler/state_db.py +58 -9
- experimaestro/scheduler/state_provider.py +512 -91
- experimaestro/scheduler/state_sync.py +65 -8
- experimaestro/tests/test_cli_jobs.py +3 -3
- experimaestro/tests/test_remote_state.py +671 -0
- experimaestro/tests/test_state_db.py +8 -8
- experimaestro/tui/app.py +100 -8
- experimaestro/version.py +2 -2
- {experimaestro-2.0.0b4.dist-info → experimaestro-2.0.0b8.dist-info}/METADATA +4 -4
- {experimaestro-2.0.0b4.dist-info → experimaestro-2.0.0b8.dist-info}/RECORD +24 -18
- {experimaestro-2.0.0b4.dist-info → experimaestro-2.0.0b8.dist-info}/WHEEL +0 -0
- {experimaestro-2.0.0b4.dist-info → experimaestro-2.0.0b8.dist-info}/entry_points.txt +0 -0
- {experimaestro-2.0.0b4.dist-info → experimaestro-2.0.0b8.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,671 @@
|
|
|
1
|
+
"""Tests for SSH-based remote state provider
|
|
2
|
+
|
|
3
|
+
Tests cover:
|
|
4
|
+
- Protocol serialization/deserialization
|
|
5
|
+
- Server request handling
|
|
6
|
+
- Client-server communication (using pipes instead of SSH)
|
|
7
|
+
- File synchronization logic
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import io
|
|
11
|
+
import json
|
|
12
|
+
from datetime import datetime
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from unittest.mock import MagicMock, patch
|
|
15
|
+
|
|
16
|
+
import pytest
|
|
17
|
+
|
|
18
|
+
from experimaestro.scheduler.remote.protocol import (
|
|
19
|
+
JSONRPC_VERSION,
|
|
20
|
+
RPCMethod,
|
|
21
|
+
NotificationMethod,
|
|
22
|
+
RPCRequest,
|
|
23
|
+
RPCResponse,
|
|
24
|
+
RPCNotification,
|
|
25
|
+
RPCError,
|
|
26
|
+
parse_message,
|
|
27
|
+
create_request,
|
|
28
|
+
create_success_response,
|
|
29
|
+
create_error_response,
|
|
30
|
+
create_notification,
|
|
31
|
+
serialize_datetime,
|
|
32
|
+
deserialize_datetime,
|
|
33
|
+
serialize_job,
|
|
34
|
+
serialize_experiment,
|
|
35
|
+
serialize_run,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# =============================================================================
|
|
40
|
+
# Protocol Tests
|
|
41
|
+
# =============================================================================
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class TestProtocolMessages:
|
|
45
|
+
"""Test JSON-RPC message creation and parsing"""
|
|
46
|
+
|
|
47
|
+
def test_request_creation(self):
|
|
48
|
+
"""Test creating a JSON-RPC request"""
|
|
49
|
+
req_json = create_request(RPCMethod.GET_EXPERIMENTS, {"since": None}, 1)
|
|
50
|
+
data = json.loads(req_json)
|
|
51
|
+
|
|
52
|
+
assert data["jsonrpc"] == JSONRPC_VERSION
|
|
53
|
+
assert data["method"] == "get_experiments"
|
|
54
|
+
assert data["params"] == {"since": None}
|
|
55
|
+
assert data["id"] == 1
|
|
56
|
+
|
|
57
|
+
def test_request_parsing(self):
|
|
58
|
+
"""Test parsing a JSON-RPC request"""
|
|
59
|
+
req_json = '{"jsonrpc": "2.0", "method": "get_jobs", "params": {"experiment_id": "exp1"}, "id": 42}'
|
|
60
|
+
msg = parse_message(req_json)
|
|
61
|
+
|
|
62
|
+
assert isinstance(msg, RPCRequest)
|
|
63
|
+
assert msg.method == "get_jobs"
|
|
64
|
+
assert msg.params == {"experiment_id": "exp1"}
|
|
65
|
+
assert msg.id == 42
|
|
66
|
+
|
|
67
|
+
def test_response_creation(self):
|
|
68
|
+
"""Test creating a JSON-RPC response"""
|
|
69
|
+
resp_json = create_success_response(1, [{"id": "test"}])
|
|
70
|
+
data = json.loads(resp_json)
|
|
71
|
+
|
|
72
|
+
assert data["jsonrpc"] == JSONRPC_VERSION
|
|
73
|
+
assert data["id"] == 1
|
|
74
|
+
assert data["result"] == [{"id": "test"}]
|
|
75
|
+
assert "error" not in data
|
|
76
|
+
|
|
77
|
+
def test_error_response_creation(self):
|
|
78
|
+
"""Test creating a JSON-RPC error response"""
|
|
79
|
+
resp_json = create_error_response(
|
|
80
|
+
1, -32600, "Invalid request", {"detail": "missing method"}
|
|
81
|
+
)
|
|
82
|
+
data = json.loads(resp_json)
|
|
83
|
+
|
|
84
|
+
assert data["jsonrpc"] == JSONRPC_VERSION
|
|
85
|
+
assert data["id"] == 1
|
|
86
|
+
assert data["error"]["code"] == -32600
|
|
87
|
+
assert data["error"]["message"] == "Invalid request"
|
|
88
|
+
assert data["error"]["data"] == {"detail": "missing method"}
|
|
89
|
+
|
|
90
|
+
def test_response_parsing(self):
|
|
91
|
+
"""Test parsing a JSON-RPC response"""
|
|
92
|
+
resp_json = '{"jsonrpc": "2.0", "id": 1, "result": {"success": true}}'
|
|
93
|
+
msg = parse_message(resp_json)
|
|
94
|
+
|
|
95
|
+
assert isinstance(msg, RPCResponse)
|
|
96
|
+
assert msg.id == 1
|
|
97
|
+
assert msg.result == {"success": True}
|
|
98
|
+
assert msg.error is None
|
|
99
|
+
|
|
100
|
+
def test_error_response_parsing(self):
|
|
101
|
+
"""Test parsing a JSON-RPC error response"""
|
|
102
|
+
resp_json = '{"jsonrpc": "2.0", "id": 1, "error": {"code": -32601, "message": "Method not found"}}'
|
|
103
|
+
msg = parse_message(resp_json)
|
|
104
|
+
|
|
105
|
+
assert isinstance(msg, RPCResponse)
|
|
106
|
+
assert msg.id == 1
|
|
107
|
+
assert msg.result is None
|
|
108
|
+
assert msg.error.code == -32601
|
|
109
|
+
assert msg.error.message == "Method not found"
|
|
110
|
+
|
|
111
|
+
def test_notification_creation(self):
|
|
112
|
+
"""Test creating a JSON-RPC notification"""
|
|
113
|
+
notif_json = create_notification(
|
|
114
|
+
NotificationMethod.JOB_UPDATED, {"job_id": "job1", "state": "running"}
|
|
115
|
+
)
|
|
116
|
+
data = json.loads(notif_json)
|
|
117
|
+
|
|
118
|
+
assert data["jsonrpc"] == JSONRPC_VERSION
|
|
119
|
+
assert data["method"] == "notification.job_updated"
|
|
120
|
+
assert data["params"] == {"job_id": "job1", "state": "running"}
|
|
121
|
+
assert "id" not in data
|
|
122
|
+
|
|
123
|
+
def test_notification_parsing(self):
|
|
124
|
+
"""Test parsing a JSON-RPC notification"""
|
|
125
|
+
notif_json = '{"jsonrpc": "2.0", "method": "notification.shutdown", "params": {"reason": "test"}}'
|
|
126
|
+
msg = parse_message(notif_json)
|
|
127
|
+
|
|
128
|
+
assert isinstance(msg, RPCNotification)
|
|
129
|
+
assert msg.method == "notification.shutdown"
|
|
130
|
+
assert msg.params == {"reason": "test"}
|
|
131
|
+
|
|
132
|
+
def test_parse_invalid_json(self):
|
|
133
|
+
"""Test parsing invalid JSON raises ValueError"""
|
|
134
|
+
with pytest.raises(ValueError, match="Invalid JSON"):
|
|
135
|
+
parse_message("not valid json")
|
|
136
|
+
|
|
137
|
+
def test_parse_missing_version(self):
|
|
138
|
+
"""Test parsing message without jsonrpc version raises ValueError"""
|
|
139
|
+
with pytest.raises(ValueError, match="Invalid or missing jsonrpc version"):
|
|
140
|
+
parse_message('{"method": "test", "id": 1}')
|
|
141
|
+
|
|
142
|
+
def test_parse_wrong_version(self):
|
|
143
|
+
"""Test parsing message with wrong jsonrpc version raises ValueError"""
|
|
144
|
+
with pytest.raises(ValueError, match="Invalid or missing jsonrpc version"):
|
|
145
|
+
parse_message('{"jsonrpc": "1.0", "method": "test", "id": 1}')
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class TestDatetimeSerialization:
|
|
149
|
+
"""Test datetime serialization helpers"""
|
|
150
|
+
|
|
151
|
+
def test_serialize_none(self):
|
|
152
|
+
"""Test serializing None"""
|
|
153
|
+
assert serialize_datetime(None) is None
|
|
154
|
+
|
|
155
|
+
def test_serialize_datetime(self):
|
|
156
|
+
"""Test serializing datetime object"""
|
|
157
|
+
dt = datetime(2024, 1, 15, 10, 30, 0)
|
|
158
|
+
result = serialize_datetime(dt)
|
|
159
|
+
assert result == "2024-01-15T10:30:00"
|
|
160
|
+
|
|
161
|
+
def test_serialize_timestamp(self):
|
|
162
|
+
"""Test serializing Unix timestamp"""
|
|
163
|
+
# 2024-01-01 00:00:00 UTC (adjusted for local timezone)
|
|
164
|
+
result = serialize_datetime(1704067200.0)
|
|
165
|
+
assert "2024-01-01" in result
|
|
166
|
+
|
|
167
|
+
def test_serialize_string_passthrough(self):
|
|
168
|
+
"""Test that strings pass through unchanged"""
|
|
169
|
+
result = serialize_datetime("2024-01-15T10:30:00")
|
|
170
|
+
assert result == "2024-01-15T10:30:00"
|
|
171
|
+
|
|
172
|
+
def test_deserialize_none(self):
|
|
173
|
+
"""Test deserializing None"""
|
|
174
|
+
assert deserialize_datetime(None) is None
|
|
175
|
+
|
|
176
|
+
def test_deserialize_iso_string(self):
|
|
177
|
+
"""Test deserializing ISO format string"""
|
|
178
|
+
result = deserialize_datetime("2024-01-15T10:30:00")
|
|
179
|
+
assert result == datetime(2024, 1, 15, 10, 30, 0)
|
|
180
|
+
|
|
181
|
+
def test_roundtrip(self):
|
|
182
|
+
"""Test datetime serialization roundtrip"""
|
|
183
|
+
original = datetime(2024, 6, 15, 14, 30, 45)
|
|
184
|
+
serialized = serialize_datetime(original)
|
|
185
|
+
deserialized = deserialize_datetime(serialized)
|
|
186
|
+
assert deserialized == original
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class TestJobSerialization:
|
|
190
|
+
"""Test job serialization"""
|
|
191
|
+
|
|
192
|
+
def test_serialize_mock_job(self):
|
|
193
|
+
"""Test serializing a MockJob-like object"""
|
|
194
|
+
from experimaestro.scheduler.state_provider import MockJob
|
|
195
|
+
|
|
196
|
+
job = MockJob(
|
|
197
|
+
identifier="job123",
|
|
198
|
+
task_id="task.MyTask",
|
|
199
|
+
locator="job123",
|
|
200
|
+
path=Path("/tmp/jobs/job123"),
|
|
201
|
+
state="running",
|
|
202
|
+
submittime=1704067200.0,
|
|
203
|
+
starttime=1704067300.0,
|
|
204
|
+
endtime=None,
|
|
205
|
+
progress=[],
|
|
206
|
+
tags={"tag1": "value1"},
|
|
207
|
+
experiment_id="exp1",
|
|
208
|
+
run_id="run1",
|
|
209
|
+
updated_at="2024-01-01T00:00:00",
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
result = serialize_job(job)
|
|
213
|
+
|
|
214
|
+
assert result["identifier"] == "job123"
|
|
215
|
+
assert result["task_id"] == "task.MyTask"
|
|
216
|
+
assert result["path"] == "/tmp/jobs/job123"
|
|
217
|
+
# State is serialized from JobState enum - case may vary
|
|
218
|
+
assert result["state"].upper() == "RUNNING"
|
|
219
|
+
assert result["tags"] == {"tag1": "value1"}
|
|
220
|
+
assert result["experiment_id"] == "exp1"
|
|
221
|
+
assert result["run_id"] == "run1"
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class TestExperimentSerialization:
|
|
225
|
+
"""Test experiment serialization"""
|
|
226
|
+
|
|
227
|
+
def test_serialize_mock_experiment(self):
|
|
228
|
+
"""Test serializing a MockExperiment-like object"""
|
|
229
|
+
from experimaestro.scheduler.state_provider import MockExperiment
|
|
230
|
+
|
|
231
|
+
exp = MockExperiment(
|
|
232
|
+
workdir=Path("/tmp/xp/myexp"),
|
|
233
|
+
current_run_id="run_20240101",
|
|
234
|
+
total_jobs=10,
|
|
235
|
+
finished_jobs=5,
|
|
236
|
+
failed_jobs=1,
|
|
237
|
+
updated_at="2024-01-01T12:00:00",
|
|
238
|
+
started_at=1704067200.0,
|
|
239
|
+
ended_at=None,
|
|
240
|
+
hostname="server1",
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
result = serialize_experiment(exp)
|
|
244
|
+
|
|
245
|
+
assert result["experiment_id"] == "myexp"
|
|
246
|
+
assert result["workdir"] == "/tmp/xp/myexp"
|
|
247
|
+
assert result["current_run_id"] == "run_20240101"
|
|
248
|
+
assert result["total_jobs"] == 10
|
|
249
|
+
assert result["finished_jobs"] == 5
|
|
250
|
+
assert result["failed_jobs"] == 1
|
|
251
|
+
assert result["hostname"] == "server1"
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class TestRunSerialization:
|
|
255
|
+
"""Test run serialization"""
|
|
256
|
+
|
|
257
|
+
def test_serialize_run_dict(self):
|
|
258
|
+
"""Test serializing a run dictionary"""
|
|
259
|
+
run_dict = {
|
|
260
|
+
"run_id": "run_20240101",
|
|
261
|
+
"experiment_id": "exp1",
|
|
262
|
+
"hostname": "server1",
|
|
263
|
+
"started_at": "2024-01-01T10:00:00",
|
|
264
|
+
"ended_at": None,
|
|
265
|
+
"status": "active",
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
result = serialize_run(run_dict)
|
|
269
|
+
|
|
270
|
+
assert result["run_id"] == "run_20240101"
|
|
271
|
+
assert result["experiment_id"] == "exp1"
|
|
272
|
+
assert result["hostname"] == "server1"
|
|
273
|
+
assert result["status"] == "active"
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
# =============================================================================
|
|
277
|
+
# Server Tests
|
|
278
|
+
# =============================================================================
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
class TestServerRequestHandling:
|
|
282
|
+
"""Test server request handling with mocked state provider"""
|
|
283
|
+
|
|
284
|
+
@pytest.fixture
|
|
285
|
+
def mock_state_provider(self):
|
|
286
|
+
"""Create a mock state provider"""
|
|
287
|
+
provider = MagicMock()
|
|
288
|
+
provider.workspace_path = Path("/tmp/workspace")
|
|
289
|
+
provider.get_experiments.return_value = []
|
|
290
|
+
provider.get_experiment.return_value = None
|
|
291
|
+
provider.get_experiment_runs.return_value = []
|
|
292
|
+
provider.get_jobs.return_value = []
|
|
293
|
+
provider.get_job.return_value = None
|
|
294
|
+
provider.get_all_jobs.return_value = []
|
|
295
|
+
provider.get_services.return_value = []
|
|
296
|
+
provider.get_last_sync_time.return_value = None
|
|
297
|
+
return provider
|
|
298
|
+
|
|
299
|
+
@pytest.fixture
|
|
300
|
+
def server_with_mock(self, mock_state_provider, tmp_path):
|
|
301
|
+
"""Create a server with mocked state provider"""
|
|
302
|
+
from experimaestro.scheduler.remote.server import SSHStateProviderServer
|
|
303
|
+
|
|
304
|
+
# Create workspace directory
|
|
305
|
+
workspace = tmp_path / "workspace"
|
|
306
|
+
workspace.mkdir()
|
|
307
|
+
(workspace / ".experimaestro").mkdir()
|
|
308
|
+
|
|
309
|
+
server = SSHStateProviderServer(workspace)
|
|
310
|
+
server._state_provider = mock_state_provider
|
|
311
|
+
return server
|
|
312
|
+
|
|
313
|
+
def test_handle_get_experiments(self, server_with_mock, mock_state_provider):
|
|
314
|
+
"""Test handling get_experiments request"""
|
|
315
|
+
from experimaestro.scheduler.state_provider import MockExperiment
|
|
316
|
+
|
|
317
|
+
mock_exp = MockExperiment(
|
|
318
|
+
workdir=Path("/tmp/xp/exp1"),
|
|
319
|
+
current_run_id="run1",
|
|
320
|
+
total_jobs=5,
|
|
321
|
+
finished_jobs=3,
|
|
322
|
+
failed_jobs=0,
|
|
323
|
+
updated_at="2024-01-01T00:00:00",
|
|
324
|
+
)
|
|
325
|
+
mock_state_provider.get_experiments.return_value = [mock_exp]
|
|
326
|
+
|
|
327
|
+
result = server_with_mock._handle_get_experiments({"since": None})
|
|
328
|
+
|
|
329
|
+
assert len(result) == 1
|
|
330
|
+
assert result[0]["experiment_id"] == "exp1"
|
|
331
|
+
mock_state_provider.get_experiments.assert_called_once_with(since=None)
|
|
332
|
+
|
|
333
|
+
def test_handle_get_jobs(self, server_with_mock, mock_state_provider):
|
|
334
|
+
"""Test handling get_jobs request"""
|
|
335
|
+
from experimaestro.scheduler.state_provider import MockJob
|
|
336
|
+
|
|
337
|
+
mock_job = MockJob(
|
|
338
|
+
identifier="job1",
|
|
339
|
+
task_id="task.Test",
|
|
340
|
+
locator="job1",
|
|
341
|
+
path=Path("/tmp/jobs/job1"),
|
|
342
|
+
state="done",
|
|
343
|
+
submittime=None,
|
|
344
|
+
starttime=None,
|
|
345
|
+
endtime=None,
|
|
346
|
+
progress=[],
|
|
347
|
+
tags={},
|
|
348
|
+
experiment_id="exp1",
|
|
349
|
+
run_id="run1",
|
|
350
|
+
updated_at="",
|
|
351
|
+
)
|
|
352
|
+
mock_state_provider.get_jobs.return_value = [mock_job]
|
|
353
|
+
|
|
354
|
+
result = server_with_mock._handle_get_jobs(
|
|
355
|
+
{
|
|
356
|
+
"experiment_id": "exp1",
|
|
357
|
+
"run_id": "run1",
|
|
358
|
+
}
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
assert len(result) == 1
|
|
362
|
+
assert result[0]["identifier"] == "job1"
|
|
363
|
+
|
|
364
|
+
def test_handle_get_sync_info(self, server_with_mock):
|
|
365
|
+
"""Test handling get_sync_info request"""
|
|
366
|
+
result = server_with_mock._handle_get_sync_info({})
|
|
367
|
+
|
|
368
|
+
assert "workspace_path" in result
|
|
369
|
+
assert "last_sync_time" in result
|
|
370
|
+
|
|
371
|
+
def test_handle_unknown_method(self, server_with_mock):
|
|
372
|
+
"""Test that unknown methods are not in handlers"""
|
|
373
|
+
assert "unknown_method" not in server_with_mock._handlers
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
# =============================================================================
|
|
377
|
+
# Client-Server Integration Tests (using pipes)
|
|
378
|
+
# =============================================================================
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
class TestClientServerIntegration:
|
|
382
|
+
"""Integration tests using pipes instead of SSH"""
|
|
383
|
+
|
|
384
|
+
@pytest.fixture
|
|
385
|
+
def pipe_pair(self):
|
|
386
|
+
"""Create a pair of pipes for client-server communication"""
|
|
387
|
+
# Server reads from client_to_server, writes to server_to_client
|
|
388
|
+
# Client reads from server_to_client, writes to client_to_server
|
|
389
|
+
client_to_server_r, client_to_server_w = io.BytesIO(), io.BytesIO()
|
|
390
|
+
server_to_client_r, server_to_client_w = io.BytesIO(), io.BytesIO()
|
|
391
|
+
|
|
392
|
+
return {
|
|
393
|
+
"server_stdin": client_to_server_r,
|
|
394
|
+
"server_stdout": server_to_client_w,
|
|
395
|
+
"client_stdin": client_to_server_w,
|
|
396
|
+
"client_stdout": server_to_client_r,
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
def test_request_response_cycle(self):
|
|
400
|
+
"""Test a complete request-response cycle"""
|
|
401
|
+
# Simulate server response
|
|
402
|
+
request = create_request(RPCMethod.GET_EXPERIMENTS, {"since": None}, 1)
|
|
403
|
+
response = create_success_response(1, [])
|
|
404
|
+
|
|
405
|
+
# Parse request
|
|
406
|
+
req_msg = parse_message(request)
|
|
407
|
+
assert isinstance(req_msg, RPCRequest)
|
|
408
|
+
assert req_msg.method == "get_experiments"
|
|
409
|
+
|
|
410
|
+
# Parse response
|
|
411
|
+
resp_msg = parse_message(response)
|
|
412
|
+
assert isinstance(resp_msg, RPCResponse)
|
|
413
|
+
assert resp_msg.result == []
|
|
414
|
+
|
|
415
|
+
def test_notification_handling(self):
|
|
416
|
+
"""Test notification message handling"""
|
|
417
|
+
notification = create_notification(
|
|
418
|
+
NotificationMethod.JOB_UPDATED,
|
|
419
|
+
{
|
|
420
|
+
"job_id": "job1",
|
|
421
|
+
"experiment_id": "exp1",
|
|
422
|
+
"run_id": "run1",
|
|
423
|
+
"state": "running",
|
|
424
|
+
},
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
msg = parse_message(notification)
|
|
428
|
+
assert isinstance(msg, RPCNotification)
|
|
429
|
+
assert msg.method == "notification.job_updated"
|
|
430
|
+
assert msg.params["job_id"] == "job1"
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
# =============================================================================
|
|
434
|
+
# Client Tests
|
|
435
|
+
# =============================================================================
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
class TestClientDataConversion:
|
|
439
|
+
"""Test client data conversion methods"""
|
|
440
|
+
|
|
441
|
+
@pytest.fixture
|
|
442
|
+
def client(self, tmp_path):
|
|
443
|
+
"""Create a client instance with mocked temp directory"""
|
|
444
|
+
from experimaestro.scheduler.remote.client import SSHStateProviderClient
|
|
445
|
+
|
|
446
|
+
client = SSHStateProviderClient(
|
|
447
|
+
host="testhost",
|
|
448
|
+
remote_workspace="/remote/workspace",
|
|
449
|
+
)
|
|
450
|
+
# Manually set up temp directory for testing (normally done in connect())
|
|
451
|
+
client._temp_dir = str(tmp_path)
|
|
452
|
+
client.local_cache_dir = tmp_path
|
|
453
|
+
client.workspace_path = tmp_path
|
|
454
|
+
|
|
455
|
+
return client
|
|
456
|
+
|
|
457
|
+
def test_dict_to_job(self, client, tmp_path):
|
|
458
|
+
"""Test converting dictionary to MockJob"""
|
|
459
|
+
job_dict = {
|
|
460
|
+
"identifier": "job123",
|
|
461
|
+
"task_id": "task.MyTask",
|
|
462
|
+
"locator": "job123",
|
|
463
|
+
"path": "/remote/workspace/jobs/job123",
|
|
464
|
+
"state": "running",
|
|
465
|
+
"submittime": "2024-01-01T10:00:00",
|
|
466
|
+
"starttime": "2024-01-01T10:01:00",
|
|
467
|
+
"endtime": None,
|
|
468
|
+
"progress": [],
|
|
469
|
+
"tags": {"key": "value"},
|
|
470
|
+
"experiment_id": "exp1",
|
|
471
|
+
"run_id": "run1",
|
|
472
|
+
}
|
|
473
|
+
|
|
474
|
+
job = client._dict_to_job(job_dict)
|
|
475
|
+
|
|
476
|
+
assert job.identifier == "job123"
|
|
477
|
+
assert job.task_id == "task.MyTask"
|
|
478
|
+
# Path should be mapped to local cache
|
|
479
|
+
assert job.path == tmp_path / "jobs/job123"
|
|
480
|
+
assert job.tags == {"key": "value"}
|
|
481
|
+
|
|
482
|
+
def test_dict_to_experiment(self, client, tmp_path):
|
|
483
|
+
"""Test converting dictionary to MockExperiment"""
|
|
484
|
+
exp_dict = {
|
|
485
|
+
"experiment_id": "myexp",
|
|
486
|
+
"workdir": "/remote/workspace/xp/myexp",
|
|
487
|
+
"current_run_id": "run1",
|
|
488
|
+
"total_jobs": 10,
|
|
489
|
+
"finished_jobs": 5,
|
|
490
|
+
"failed_jobs": 1,
|
|
491
|
+
"updated_at": "2024-01-01T12:00:00",
|
|
492
|
+
"hostname": "server1",
|
|
493
|
+
}
|
|
494
|
+
|
|
495
|
+
exp = client._dict_to_experiment(exp_dict)
|
|
496
|
+
|
|
497
|
+
assert exp.experiment_id == "myexp"
|
|
498
|
+
# Path should be mapped to local cache
|
|
499
|
+
assert exp.workdir == tmp_path / "xp/myexp"
|
|
500
|
+
assert exp.total_jobs == 10
|
|
501
|
+
assert exp.hostname == "server1"
|
|
502
|
+
|
|
503
|
+
def test_path_mapping_outside_workspace(self, client, tmp_path):
|
|
504
|
+
"""Test path mapping for paths outside remote workspace"""
|
|
505
|
+
job_dict = {
|
|
506
|
+
"identifier": "job123",
|
|
507
|
+
"task_id": "task.MyTask",
|
|
508
|
+
"locator": "job123",
|
|
509
|
+
"path": "/other/path/job123", # Not under remote_workspace
|
|
510
|
+
"state": "done",
|
|
511
|
+
"progress": [],
|
|
512
|
+
"tags": {},
|
|
513
|
+
}
|
|
514
|
+
|
|
515
|
+
job = client._dict_to_job(job_dict)
|
|
516
|
+
|
|
517
|
+
# Path outside workspace should be kept as-is
|
|
518
|
+
assert job.path == Path("/other/path/job123")
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
# =============================================================================
|
|
522
|
+
# Synchronizer Tests
|
|
523
|
+
# =============================================================================
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
class TestRemoteFileSynchronizer:
|
|
527
|
+
"""Test file synchronization logic"""
|
|
528
|
+
|
|
529
|
+
@pytest.fixture
|
|
530
|
+
def synchronizer(self, tmp_path):
|
|
531
|
+
"""Create a synchronizer instance"""
|
|
532
|
+
from experimaestro.scheduler.remote.sync import RemoteFileSynchronizer
|
|
533
|
+
|
|
534
|
+
local_cache = tmp_path / "cache"
|
|
535
|
+
local_cache.mkdir()
|
|
536
|
+
|
|
537
|
+
return RemoteFileSynchronizer(
|
|
538
|
+
host="testhost",
|
|
539
|
+
remote_workspace=Path("/remote/workspace"),
|
|
540
|
+
local_cache=local_cache,
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
def test_get_local_path(self, synchronizer):
|
|
544
|
+
"""Test mapping remote path to local path"""
|
|
545
|
+
remote_path = "/remote/workspace/xp/exp1/jobs.jsonl"
|
|
546
|
+
local_path = synchronizer.get_local_path(remote_path)
|
|
547
|
+
|
|
548
|
+
assert local_path == synchronizer.local_cache / "xp/exp1/jobs.jsonl"
|
|
549
|
+
|
|
550
|
+
def test_get_local_path_outside_workspace(self, synchronizer):
|
|
551
|
+
"""Test mapping path outside workspace"""
|
|
552
|
+
remote_path = "/other/path/file.txt"
|
|
553
|
+
local_path = synchronizer.get_local_path(remote_path)
|
|
554
|
+
|
|
555
|
+
# Should return the original path
|
|
556
|
+
assert local_path == Path("/other/path/file.txt")
|
|
557
|
+
|
|
558
|
+
@patch("subprocess.run")
|
|
559
|
+
def test_rsync_command_construction(self, mock_run, synchronizer):
|
|
560
|
+
"""Test that rsync command is constructed correctly"""
|
|
561
|
+
mock_run.return_value = MagicMock(returncode=0)
|
|
562
|
+
|
|
563
|
+
synchronizer._rsync(
|
|
564
|
+
"testhost:/remote/workspace/logs/",
|
|
565
|
+
str(synchronizer.local_cache / "logs") + "/",
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
mock_run.assert_called_once()
|
|
569
|
+
cmd = mock_run.call_args[0][0]
|
|
570
|
+
|
|
571
|
+
assert "rsync" in cmd
|
|
572
|
+
assert "--inplace" in cmd
|
|
573
|
+
assert "--delete" in cmd
|
|
574
|
+
assert "-L" in cmd
|
|
575
|
+
assert "-a" in cmd
|
|
576
|
+
assert "-z" in cmd
|
|
577
|
+
assert "-v" in cmd
|
|
578
|
+
assert "testhost:/remote/workspace/logs/" in cmd
|
|
579
|
+
|
|
580
|
+
|
|
581
|
+
# =============================================================================
|
|
582
|
+
# Version and Temp Directory Tests
|
|
583
|
+
# =============================================================================
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
class TestVersionStripping:
|
|
587
|
+
"""Test version string manipulation"""
|
|
588
|
+
|
|
589
|
+
def test_strip_dev_version(self):
|
|
590
|
+
"""Test stripping .devN suffix from versions"""
|
|
591
|
+
from experimaestro.scheduler.remote.client import _strip_dev_version
|
|
592
|
+
|
|
593
|
+
assert _strip_dev_version("2.0.0b3.dev8") == "2.0.0b3"
|
|
594
|
+
assert _strip_dev_version("1.2.3.dev1") == "1.2.3"
|
|
595
|
+
assert _strip_dev_version("1.2.3") == "1.2.3"
|
|
596
|
+
assert _strip_dev_version("2.0.0a1.dev100") == "2.0.0a1"
|
|
597
|
+
assert _strip_dev_version("0.1.0.dev0") == "0.1.0"
|
|
598
|
+
|
|
599
|
+
def test_strip_dev_preserves_prerelease(self):
|
|
600
|
+
"""Test that pre-release tags are preserved"""
|
|
601
|
+
from experimaestro.scheduler.remote.client import _strip_dev_version
|
|
602
|
+
|
|
603
|
+
assert _strip_dev_version("1.0.0a1.dev5") == "1.0.0a1"
|
|
604
|
+
assert _strip_dev_version("1.0.0b2.dev3") == "1.0.0b2"
|
|
605
|
+
assert _strip_dev_version("1.0.0rc1.dev1") == "1.0.0rc1"
|
|
606
|
+
|
|
607
|
+
|
|
608
|
+
class TestTempDirectory:
|
|
609
|
+
"""Test temporary directory handling for client cache"""
|
|
610
|
+
|
|
611
|
+
def test_client_temp_dir_not_created_until_connect(self):
|
|
612
|
+
"""Test that temp directory is not created until connect() is called"""
|
|
613
|
+
from experimaestro.scheduler.remote.client import SSHStateProviderClient
|
|
614
|
+
|
|
615
|
+
client = SSHStateProviderClient(
|
|
616
|
+
host="testhost",
|
|
617
|
+
remote_workspace="/remote",
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
# Before connect, temp_dir should be None
|
|
621
|
+
assert client._temp_dir is None
|
|
622
|
+
assert client.local_cache_dir is None
|
|
623
|
+
|
|
624
|
+
|
|
625
|
+
# =============================================================================
|
|
626
|
+
# Error Handling Tests
|
|
627
|
+
# =============================================================================
|
|
628
|
+
|
|
629
|
+
|
|
630
|
+
class TestErrorHandling:
|
|
631
|
+
"""Test error handling in protocol and server"""
|
|
632
|
+
|
|
633
|
+
def test_rpc_error_creation(self):
|
|
634
|
+
"""Test creating RPC error objects"""
|
|
635
|
+
error = RPCError(
|
|
636
|
+
code=-32600, message="Invalid request", data={"field": "method"}
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
assert error.code == -32600
|
|
640
|
+
assert error.message == "Invalid request"
|
|
641
|
+
assert error.data == {"field": "method"}
|
|
642
|
+
|
|
643
|
+
error_dict = error.to_dict()
|
|
644
|
+
assert error_dict["code"] == -32600
|
|
645
|
+
assert error_dict["message"] == "Invalid request"
|
|
646
|
+
assert error_dict["data"] == {"field": "method"}
|
|
647
|
+
|
|
648
|
+
def test_rpc_error_from_dict(self):
|
|
649
|
+
"""Test creating RPC error from dictionary"""
|
|
650
|
+
error = RPCError.from_dict(
|
|
651
|
+
{
|
|
652
|
+
"code": -32601,
|
|
653
|
+
"message": "Method not found",
|
|
654
|
+
}
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
assert error.code == -32601
|
|
658
|
+
assert error.message == "Method not found"
|
|
659
|
+
assert error.data is None
|
|
660
|
+
|
|
661
|
+
def test_response_with_error(self):
|
|
662
|
+
"""Test response with error is parsed correctly"""
|
|
663
|
+
resp = RPCResponse(
|
|
664
|
+
id=1,
|
|
665
|
+
error=RPCError(code=-32600, message="Invalid request"),
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
resp_dict = resp.to_dict()
|
|
669
|
+
assert "error" in resp_dict
|
|
670
|
+
assert resp_dict["error"]["code"] == -32600
|
|
671
|
+
assert "result" not in resp_dict
|