experimaestro 2.0.0b4__py3-none-any.whl → 2.0.0b8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of experimaestro might be problematic. Click here for more details.

@@ -1,9 +1,14 @@
1
1
  import abc
2
2
  from enum import Enum
3
- import functools
4
3
  import logging
5
4
  import threading
6
- from typing import Set
5
+ from pathlib import Path
6
+ from typing import Callable, Optional, Set, TYPE_CHECKING
7
+
8
+ from experimaestro.scheduler.interfaces import BaseService
9
+
10
+ if TYPE_CHECKING:
11
+ from experimaestro.scheduler.experiment import Experiment
7
12
 
8
13
  logger = logging.getLogger(__name__)
9
14
 
@@ -28,7 +33,7 @@ class ServiceState(Enum):
28
33
  STOPPING = 3
29
34
 
30
35
 
31
- class Service:
36
+ class Service(BaseService):
32
37
  """An experiment service
33
38
 
34
39
  Services can be associated with an experiment. They send
@@ -46,32 +51,89 @@ class Service:
46
51
  self._listeners: Set[ServiceListener] = set()
47
52
  self._listeners_lock = threading.Lock()
48
53
 
54
+ def set_experiment(self, xp: "Experiment") -> None:
55
+ """Called when the service is added to an experiment.
56
+
57
+ Override this method to access the experiment context (e.g., workdir).
58
+ The default implementation does nothing.
59
+
60
+ Args:
61
+ xp: The experiment this service is being added to.
62
+ """
63
+ pass
64
+
49
65
  def state_dict(self) -> dict:
50
- """Return a dictionary representation for serialization.
66
+ """Return parameters needed to recreate this service.
67
+
68
+ Subclasses should override this to return constructor arguments.
69
+ Path values are automatically serialized and restored (with
70
+ translation for remote monitoring).
71
+
72
+ Example::
73
+
74
+ def state_dict(self):
75
+ return {
76
+ "log_dir": self.log_dir, # Path is auto-handled
77
+ "name": self.name,
78
+ }
79
+
80
+ Returns:
81
+ Dict with constructor kwargs (no need to include __class__).
82
+ """
83
+ return {}
51
84
 
52
- Subclasses should override this to include any parameters needed
53
- to recreate the service. The base implementation returns the
54
- class module and name.
85
+ def _full_state_dict(self) -> dict:
86
+ """Get complete state_dict including __class__ for serialization."""
87
+ d = self.state_dict()
88
+ d["__class__"] = f"{self.__class__.__module__}.{self.__class__.__name__}"
89
+ return d
90
+
91
+ @staticmethod
92
+ def serialize_state_dict(data: dict) -> dict:
93
+ """Serialize a state_dict, converting Path objects to serializable format.
94
+
95
+ This is called automatically when storing services. Path values are
96
+ converted to {"__path__": "/path/string"} format.
97
+
98
+ Args:
99
+ data: Raw state_dict from service (should include __class__)
55
100
 
56
101
  Returns:
57
- Dict with '__class__' key and any additional kwargs.
102
+ Serializable dictionary with paths converted
58
103
  """
59
- return {
60
- "__class__": f"{self.__class__.__module__}.{self.__class__.__name__}",
61
- }
104
+ result = {}
105
+ for k, v in data.items():
106
+ if isinstance(v, Path):
107
+ result[k] = {"__path__": str(v)}
108
+ else:
109
+ result[k] = v
110
+ return result
62
111
 
63
112
  @staticmethod
64
- def from_state_dict(data: dict) -> "Service":
113
+ def from_state_dict(
114
+ data: dict, path_translator: Optional[Callable[[str], Path]] = None
115
+ ) -> "Service":
65
116
  """Recreate a service from a state dictionary.
66
117
 
67
118
  Args:
68
- data: Dictionary from :meth:`state_dict`
119
+ data: Dictionary from :meth:`state_dict` (may be serialized)
120
+ path_translator: Optional function to translate remote paths to local.
121
+ Used by remote clients to map paths to local cache.
69
122
 
70
123
  Returns:
71
124
  A new Service instance, or raises if the class cannot be loaded.
125
+
126
+ Raises:
127
+ ValueError: If __unserializable__ is True or __class__ is missing
72
128
  """
73
129
  import importlib
74
130
 
131
+ # Check if service is marked as unserializable
132
+ if data.get("__unserializable__"):
133
+ raise ValueError(
134
+ f"Service cannot be recreated: {data.get('__reason__', 'unknown reason')}"
135
+ )
136
+
75
137
  class_path = data.get("__class__")
76
138
  if not class_path:
77
139
  raise ValueError("Missing '__class__' in service state_dict")
@@ -80,8 +142,22 @@ class Service:
80
142
  module = importlib.import_module(module_name)
81
143
  cls = getattr(module, class_name)
82
144
 
83
- # Remove __class__ and pass remaining as kwargs
84
- kwargs = {k: v for k, v in data.items() if k != "__class__"}
145
+ # Build kwargs, detecting and translating paths automatically
146
+ kwargs = {}
147
+ for k, v in data.items():
148
+ if k.startswith("__"):
149
+ continue # Skip special keys
150
+ if isinstance(v, dict) and "__path__" in v:
151
+ # Serialized path - deserialize with optional translation
152
+ path_str = v["__path__"]
153
+ if path_translator:
154
+ kwargs[k] = path_translator(path_str)
155
+ else:
156
+ kwargs[k] = Path(path_str)
157
+ else:
158
+ kwargs[k] = v
159
+
160
+ logger.debug("Creating %s with kwargs: %s", cls.__name__, kwargs)
85
161
  return cls(**kwargs)
86
162
 
87
163
  def add_listener(self, listener: ServiceListener):
@@ -158,6 +234,8 @@ class WebService(Service):
158
234
  self.url = None
159
235
  self.thread = None
160
236
  self._stop_event = threading.Event()
237
+ self._start_lock = threading.Lock()
238
+ self._running_event: Optional[threading.Event] = None
161
239
 
162
240
  def should_stop(self) -> bool:
163
241
  """Check if the service should stop.
@@ -173,21 +251,46 @@ class WebService(Service):
173
251
  """Get the URL of this web service, starting it if needed.
174
252
 
175
253
  If the service is not running, this method will start it and
176
- block until the URL is available.
254
+ block until the URL is available. If the service is already
255
+ starting or running, returns the existing URL.
177
256
 
178
257
  :return: The URL where this service can be accessed
258
+ :raises RuntimeError: If called while service is stopping
179
259
  """
180
- if self.state == ServiceState.STOPPED:
181
- self._stop_event.clear()
182
- self.state = ServiceState.STARTING
183
- self.running = threading.Event()
184
- self.serve()
185
-
186
- # Wait until the server is ready
187
- self.running.wait()
188
- self.state = ServiceState.RUNNING
260
+ with self._start_lock:
261
+ if self.state == ServiceState.STOPPING:
262
+ raise RuntimeError("Cannot start service while it is stopping")
263
+
264
+ if self.state == ServiceState.RUNNING:
265
+ logger.debug("Service already running, returning existing URL")
266
+ return self.url
267
+
268
+ if self.state == ServiceState.STOPPED:
269
+ logger.info(
270
+ "Starting service %s (id=%s)", self.__class__.__name__, id(self)
271
+ )
272
+ self._stop_event.clear()
273
+ self.state = ServiceState.STARTING
274
+ self._running_event = threading.Event()
275
+ self.serve()
276
+ else:
277
+ logger.info(
278
+ "Service %s (id=%s) already starting, waiting for it",
279
+ self.__class__.__name__,
280
+ id(self),
281
+ )
282
+
283
+ # State is STARTING - wait for it to be ready
284
+ running_event = self._running_event
285
+
286
+ # Wait outside the lock to avoid blocking other callers
287
+ if running_event:
288
+ running_event.wait()
289
+ # Set state to RUNNING (this will notify listeners)
290
+ with self._start_lock:
291
+ if self.state == ServiceState.STARTING:
292
+ self.state = ServiceState.RUNNING
189
293
 
190
- # Returns the URL
191
294
  return self.url
192
295
 
193
296
  def stop(self, timeout: float = 2.0):
@@ -199,10 +302,21 @@ class WebService(Service):
199
302
 
200
303
  :param timeout: Seconds to wait for graceful shutdown before forcing
201
304
  """
202
- if self.state == ServiceState.STOPPED:
203
- return
305
+ with self._start_lock:
306
+ if self.state == ServiceState.STOPPED:
307
+ return
308
+
309
+ if self.state == ServiceState.STARTING:
310
+ # Wait for service to finish starting before stopping
311
+ running_event = self._running_event
312
+ else:
313
+ running_event = None
204
314
 
205
- self.state = ServiceState.STOPPING
315
+ self.state = ServiceState.STOPPING
316
+
317
+ # Wait for starting to complete if needed (outside lock to avoid deadlock)
318
+ if running_event is not None:
319
+ running_event.wait()
206
320
 
207
321
  # Signal the service to stop
208
322
  self._stop_event.set()
@@ -215,8 +329,10 @@ class WebService(Service):
215
329
  if self.thread.is_alive():
216
330
  self._force_stop_thread()
217
331
 
218
- self.url = None
219
- self.state = ServiceState.STOPPED
332
+ with self._start_lock:
333
+ self.url = None
334
+ self._running_event = None
335
+ self.state = ServiceState.STOPPED
220
336
 
221
337
  def _force_stop_thread(self):
222
338
  """Attempt to forcefully stop the service thread.
@@ -254,12 +370,22 @@ class WebService(Service):
254
370
  This method creates a daemon thread that calls :meth:`_serve`.
255
371
  """
256
372
  self.thread = threading.Thread(
257
- target=functools.partial(self._serve, self.running),
373
+ target=self._serve_wrapper,
258
374
  name=f"service[{self.id}]",
259
375
  )
260
376
  self.thread.daemon = True
261
377
  self.thread.start()
262
378
 
379
+ def _serve_wrapper(self):
380
+ """Wrapper for _serve that handles state transitions."""
381
+ running_event = self._running_event
382
+ try:
383
+ self._serve(running_event)
384
+ finally:
385
+ # Ensure the event is set even if _serve fails
386
+ if running_event and not running_event.is_set():
387
+ running_event.set()
388
+
263
389
  @abc.abstractmethod
264
390
  def _serve(self, running: threading.Event):
265
391
  """Start the web server (implement in subclasses).
@@ -14,7 +14,9 @@ Key design:
14
14
  - Database instance is passed explicitly to avoid global state
15
15
  """
16
16
 
17
+ import logging
17
18
  from pathlib import Path
19
+ from typing import Tuple
18
20
  from peewee import (
19
21
  Model,
20
22
  SqliteDatabase,
@@ -30,6 +32,11 @@ from peewee import (
30
32
  from datetime import datetime
31
33
  import fasteners
32
34
 
35
+ logger = logging.getLogger("xpm.state_db")
36
+
37
+ # Database schema version - increment when schema changes require resync
38
+ CURRENT_DB_VERSION = 3
39
+
33
40
 
34
41
  class BaseModel(Model):
35
42
  """Base model for workspace database tables
@@ -84,6 +91,7 @@ class ExperimentRunModel(BaseModel):
84
91
  started_at: When this run started
85
92
  ended_at: When this run completed (null if still active)
86
93
  status: Run status (active, completed, failed, abandoned)
94
+ hostname: Host where the experiment was launched (null for old runs)
87
95
  """
88
96
 
89
97
  experiment_id = CharField(index=True)
@@ -91,6 +99,7 @@ class ExperimentRunModel(BaseModel):
91
99
  started_at = DateTimeField(default=datetime.now)
92
100
  ended_at = DateTimeField(null=True)
93
101
  status = CharField(default="active", index=True)
102
+ hostname = CharField(null=True)
94
103
 
95
104
  class Meta:
96
105
  table_name = "experiment_runs"
@@ -108,11 +117,13 @@ class WorkspaceSyncMetadata(BaseModel):
108
117
  id: Always "workspace" (single row table)
109
118
  last_sync_time: When last sync completed
110
119
  sync_interval_minutes: Minimum interval between syncs
120
+ db_version: Schema version for migration detection
111
121
  """
112
122
 
113
123
  id = CharField(primary_key=True, default="workspace")
114
124
  last_sync_time = DateTimeField(null=True)
115
125
  sync_interval_minutes = IntegerField(default=5)
126
+ db_version = IntegerField(default=1)
116
127
 
117
128
  class Meta:
118
129
  table_name = "workspace_sync_metadata"
@@ -219,26 +230,23 @@ class ServiceModel(BaseModel):
219
230
  """Service information linked to specific experiment run
220
231
 
221
232
  Services are tied to a specific run of an experiment via (experiment_id, run_id).
233
+ Services are only added or removed, not updated - state is managed at runtime.
222
234
 
223
235
  Fields:
224
236
  service_id: Unique identifier for the service
225
237
  experiment_id: ID of the experiment this service belongs to
226
238
  run_id: ID of the run this service belongs to
227
239
  description: Human-readable description
228
- state: Service state (e.g., "running", "stopped")
229
240
  state_dict: JSON serialized state_dict for service recreation
230
- created_at: When service was created
231
- updated_at: Timestamp of last update
241
+ created_at: When service was registered
232
242
  """
233
243
 
234
244
  service_id = CharField()
235
245
  experiment_id = CharField(index=True)
236
246
  run_id = CharField(index=True)
237
247
  description = TextField(default="")
238
- state = CharField()
239
248
  state_dict = TextField(default="{}") # JSON for service recreation
240
249
  created_at = DateTimeField(default=datetime.now)
241
- updated_at = DateTimeField(default=datetime.now)
242
250
 
243
251
  class Meta:
244
252
  table_name = "services"
@@ -311,7 +319,7 @@ ALL_MODELS = [
311
319
 
312
320
  def initialize_workspace_database(
313
321
  db_path: Path, read_only: bool = False
314
- ) -> SqliteDatabase:
322
+ ) -> Tuple[SqliteDatabase, bool]:
315
323
  """Initialize a workspace database connection with proper configuration
316
324
 
317
325
  Creates and configures a SQLite database connection for the workspace.
@@ -325,7 +333,9 @@ def initialize_workspace_database(
325
333
  read_only: If True, open database in read-only mode
326
334
 
327
335
  Returns:
328
- Configured SqliteDatabase instance
336
+ Tuple of (SqliteDatabase instance, needs_resync flag)
337
+ The needs_resync flag is True when the database schema version is outdated
338
+ and a full resync from disk is required.
329
339
  """
330
340
  # Ensure parent directory exists (unless read-only)
331
341
  if not read_only:
@@ -336,6 +346,8 @@ def initialize_workspace_database(
336
346
  lock_path = db_path.parent / f".{db_path.name}.init.lock"
337
347
  lock = fasteners.InterProcessLock(str(lock_path))
338
348
 
349
+ needs_resync = False
350
+
339
351
  # Acquire lock (blocking) - only one process can initialize at a time
340
352
  with lock:
341
353
  # Create database connection
@@ -364,18 +376,55 @@ def initialize_workspace_database(
364
376
  if not read_only:
365
377
  db.create_tables(ALL_MODELS, safe=True)
366
378
 
379
+ # Check database version for migration - use raw SQL since column may not exist
380
+ current_version = 0
381
+ try:
382
+ cursor = db.execute_sql(
383
+ "SELECT db_version FROM workspace_sync_metadata WHERE id='workspace'"
384
+ )
385
+ row = cursor.fetchone()
386
+ if row is not None:
387
+ current_version = row[0]
388
+ if current_version < CURRENT_DB_VERSION:
389
+ needs_resync = True
390
+ except OperationalError:
391
+ # Column doesn't exist - add it and trigger resync
392
+ needs_resync = True
393
+ try:
394
+ db.execute_sql(
395
+ "ALTER TABLE workspace_sync_metadata "
396
+ "ADD COLUMN db_version INTEGER DEFAULT 1"
397
+ )
398
+ except OperationalError:
399
+ pass # Column may already exist
400
+
401
+ # Run schema migrations for older databases
402
+ if current_version < 2:
403
+ # Migration v1 -> v2: Add hostname column to experiment_runs table
404
+ try:
405
+ db.execute_sql(
406
+ "ALTER TABLE experiment_runs ADD COLUMN hostname VARCHAR(255) NULL"
407
+ )
408
+ logger.info("Added hostname column to experiment_runs table")
409
+ except OperationalError:
410
+ pass # Column already exists
411
+
367
412
  # Initialize WorkspaceSyncMetadata with default row if not exists
368
413
  # Use try/except to handle race condition (shouldn't happen with lock, but be safe)
369
414
  try:
370
415
  WorkspaceSyncMetadata.get_or_create(
371
416
  id="workspace",
372
- defaults={"last_sync_time": None, "sync_interval_minutes": 5},
417
+ defaults={
418
+ "last_sync_time": None,
419
+ "sync_interval_minutes": 5,
420
+ "db_version": 1,
421
+ },
373
422
  )
374
423
  except (IntegrityError, OperationalError):
375
424
  # If get_or_create fails, the row likely already exists
376
425
  pass
377
426
 
378
- return db
427
+ return db, needs_resync
379
428
 
380
429
 
381
430
  def close_workspace_database(db: SqliteDatabase):