jupyterlab-mlflow 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jupyterlab-mlflow might be problematic. Click here for more details.

Files changed (23) hide show
  1. jupyterlab_mlflow/__init__.py +63 -0
  2. jupyterlab_mlflow/_version.py +6 -0
  3. jupyterlab_mlflow/post_install.py +44 -0
  4. jupyterlab_mlflow/schema/plugin.json +17 -0
  5. jupyterlab_mlflow/serverextension/__init__.py +69 -0
  6. jupyterlab_mlflow/serverextension/handlers.py +570 -0
  7. jupyterlab_mlflow/serverextension/mlflow_server.py +214 -0
  8. jupyterlab_mlflow-0.4.1.data/data/etc/jupyter/jupyter_server_config.d/jupyterlab_mlflow.json +8 -0
  9. jupyterlab_mlflow-0.4.1.data/data/share/jupyter/labextensions/jupyterlab-mlflow/install.json +12 -0
  10. jupyterlab_mlflow-0.4.1.data/data/share/jupyter/labextensions/jupyterlab-mlflow/package.json +103 -0
  11. jupyterlab_mlflow-0.4.1.data/data/share/jupyter/labextensions/jupyterlab-mlflow/schema/plugin.json +17 -0
  12. jupyterlab_mlflow-0.4.1.data/data/share/jupyter/labextensions/jupyterlab-mlflow/schemas/jupyterlab-mlflow/package.json.orig +98 -0
  13. jupyterlab_mlflow-0.4.1.data/data/share/jupyter/labextensions/jupyterlab-mlflow/schemas/jupyterlab-mlflow/plugin.json +17 -0
  14. jupyterlab_mlflow-0.4.1.data/data/share/jupyter/labextensions/jupyterlab-mlflow/static/218.47b1285b67dde3db8969.js +1 -0
  15. jupyterlab_mlflow-0.4.1.data/data/share/jupyter/labextensions/jupyterlab-mlflow/static/665.f3ea36ea04224fd9c2f3.js +1 -0
  16. jupyterlab_mlflow-0.4.1.data/data/share/jupyter/labextensions/jupyterlab-mlflow/static/remoteEntry.121dc9414dda869fb1a6.js +1 -0
  17. jupyterlab_mlflow-0.4.1.data/data/share/jupyter/labextensions/jupyterlab-mlflow/static/style.js +4 -0
  18. jupyterlab_mlflow-0.4.1.data/data/share/jupyter/labextensions/jupyterlab-mlflow/static/third-party-licenses.json +16 -0
  19. jupyterlab_mlflow-0.4.1.dist-info/METADATA +273 -0
  20. jupyterlab_mlflow-0.4.1.dist-info/RECORD +23 -0
  21. jupyterlab_mlflow-0.4.1.dist-info/WHEEL +4 -0
  22. jupyterlab_mlflow-0.4.1.dist-info/entry_points.txt +2 -0
  23. jupyterlab_mlflow-0.4.1.dist-info/licenses/LICENSE +30 -0
@@ -0,0 +1,570 @@
1
+ """
2
+ MLflow API handlers for JupyterLab extension
3
+ """
4
+
5
+ import json
6
+ import os
7
+ from typing import Dict, Any, Optional
8
+ from urllib.parse import urlparse
9
+
10
+ import mlflow
11
+ from mlflow.tracking import MlflowClient
12
+ from mlflow.exceptions import MlflowException
13
+ from tornado import web
14
+ from tornado.web import RequestHandler
15
+ from .mlflow_server import start_mlflow_server, stop_mlflow_server, get_mlflow_server_status
16
+
17
+
18
+ def get_mlflow_client(tracking_uri: Optional[str] = None) -> MlflowClient:
19
+ """
20
+ Get MLflow client with tracking URI from settings or environment.
21
+
22
+ Parameters
23
+ ----------
24
+ tracking_uri : str, optional
25
+ MLflow tracking URI. If None, uses environment variable or default.
26
+
27
+ Returns
28
+ -------
29
+ MlflowClient
30
+ Configured MLflow client
31
+ """
32
+ if tracking_uri:
33
+ mlflow.set_tracking_uri(tracking_uri)
34
+ elif os.environ.get("MLFLOW_TRACKING_URI"):
35
+ mlflow.set_tracking_uri(os.environ["MLFLOW_TRACKING_URI"])
36
+
37
+ return MlflowClient()
38
+
39
+
40
+ class HealthCheckHandler(RequestHandler):
41
+ """Handler for extension health check - helps diagnose loading issues"""
42
+
43
+ def get(self):
44
+ """Return extension status"""
45
+ self.write({
46
+ "status": "ok",
47
+ "extension": "jupyterlab-mlflow",
48
+ "message": "Server extension is loaded and responding"
49
+ })
50
+
51
+
52
+ class MLflowBaseHandler(RequestHandler):
53
+ """Base handler for MLflow API endpoints"""
54
+
55
+ def set_default_headers(self):
56
+ """Set CORS headers"""
57
+ self.set_header("Access-Control-Allow-Origin", "*")
58
+ self.set_header("Access-Control-Allow-Headers", "Content-Type")
59
+ self.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
60
+
61
+ def options(self):
62
+ """Handle OPTIONS request for CORS"""
63
+ self.set_status(204)
64
+ self.finish()
65
+
66
+ def get_tracking_uri(self) -> Optional[str]:
67
+ """Get tracking URI from request or settings"""
68
+ tracking_uri = self.get_query_argument("tracking_uri", None)
69
+ if tracking_uri:
70
+ return tracking_uri
71
+
72
+ # Try to get from request body for POST requests
73
+ if self.request.method == "POST":
74
+ try:
75
+ body = json.loads(self.request.body.decode("utf-8"))
76
+ return body.get("tracking_uri")
77
+ except (json.JSONDecodeError, KeyError):
78
+ pass
79
+
80
+ return None
81
+
82
+ def write_error(self, status_code: int, **kwargs):
83
+ """Write error response"""
84
+ exc_info = kwargs.get("exc_info")
85
+ if exc_info:
86
+ exception = exc_info[1]
87
+ if isinstance(exception, MlflowException):
88
+ self.write({
89
+ "error": str(exception),
90
+ "status_code": status_code
91
+ })
92
+ return
93
+
94
+ self.write({
95
+ "error": f"HTTP {status_code}: {self._reason}",
96
+ "status_code": status_code
97
+ })
98
+
99
+
100
+ class ExperimentsHandler(MLflowBaseHandler):
101
+ """Handler for listing experiments"""
102
+
103
+ def get(self):
104
+ """Get list of experiments"""
105
+ try:
106
+ tracking_uri = self.get_tracking_uri()
107
+ client = get_mlflow_client(tracking_uri)
108
+
109
+ experiments = client.search_experiments()
110
+ experiments_data = []
111
+
112
+ for exp in experiments:
113
+ experiments_data.append({
114
+ "experiment_id": exp.experiment_id,
115
+ "name": exp.name,
116
+ "artifact_location": exp.artifact_location,
117
+ "lifecycle_stage": exp.lifecycle_stage,
118
+ "tags": exp.tags or {}
119
+ })
120
+
121
+ self.write({"experiments": experiments_data})
122
+ except Exception as e:
123
+ self.set_status(500)
124
+ self.write({"error": str(e)})
125
+
126
+
127
+ class ExperimentHandler(MLflowBaseHandler):
128
+ """Handler for getting experiment details"""
129
+
130
+ def get(self, experiment_id: str):
131
+ """Get experiment details"""
132
+ try:
133
+ tracking_uri = self.get_tracking_uri()
134
+ client = get_mlflow_client(tracking_uri)
135
+
136
+ experiment = client.get_experiment(experiment_id)
137
+
138
+ self.write({
139
+ "experiment_id": experiment.experiment_id,
140
+ "name": experiment.name,
141
+ "artifact_location": experiment.artifact_location,
142
+ "lifecycle_stage": experiment.lifecycle_stage,
143
+ "tags": experiment.tags or {}
144
+ })
145
+ except Exception as e:
146
+ self.set_status(500)
147
+ self.write({"error": str(e)})
148
+
149
+
150
+ class RunsHandler(MLflowBaseHandler):
151
+ """Handler for listing runs"""
152
+
153
+ def get(self, experiment_id: str):
154
+ """Get list of runs for an experiment"""
155
+ try:
156
+ tracking_uri = self.get_tracking_uri()
157
+ client = get_mlflow_client(tracking_uri)
158
+
159
+ runs = client.search_runs(
160
+ experiment_ids=[experiment_id],
161
+ max_results=1000
162
+ )
163
+
164
+ runs_data = []
165
+ for run in runs:
166
+ runs_data.append({
167
+ "run_id": run.info.run_id,
168
+ "run_name": run.info.run_name,
169
+ "experiment_id": run.info.experiment_id,
170
+ "status": run.info.status,
171
+ "start_time": run.info.start_time,
172
+ "end_time": run.info.end_time,
173
+ "user_id": run.info.user_id,
174
+ "metrics": {k: v for k, v in run.data.metrics.items()},
175
+ "params": {k: v for k, v in run.data.params.items()},
176
+ "tags": {k: v for k, v in run.data.tags.items()},
177
+ "artifact_uri": run.info.artifact_uri
178
+ })
179
+
180
+ self.write({"runs": runs_data})
181
+ except Exception as e:
182
+ self.set_status(500)
183
+ self.write({"error": str(e)})
184
+
185
+
186
+ class RunHandler(MLflowBaseHandler):
187
+ """Handler for getting run details"""
188
+
189
+ def get(self, run_id: str):
190
+ """Get run details"""
191
+ try:
192
+ tracking_uri = self.get_tracking_uri()
193
+ client = get_mlflow_client(tracking_uri)
194
+
195
+ run = client.get_run(run_id)
196
+
197
+ self.write({
198
+ "run_id": run.info.run_id,
199
+ "run_name": run.info.run_name,
200
+ "experiment_id": run.info.experiment_id,
201
+ "status": run.info.status,
202
+ "start_time": run.info.start_time,
203
+ "end_time": run.info.end_time,
204
+ "user_id": run.info.user_id,
205
+ "metrics": {k: v for k, v in run.data.metrics.items()},
206
+ "params": {k: v for k, v in run.data.params.items()},
207
+ "tags": {k: v for k, v in run.data.tags.items()},
208
+ "artifact_uri": run.info.artifact_uri
209
+ })
210
+ except Exception as e:
211
+ self.set_status(500)
212
+ self.write({"error": str(e)})
213
+
214
+
215
+ class ArtifactsHandler(MLflowBaseHandler):
216
+ """Handler for listing artifacts"""
217
+
218
+ def get(self, run_id: str):
219
+ """Get list of artifacts for a run"""
220
+ try:
221
+ tracking_uri = self.get_tracking_uri()
222
+ client = get_mlflow_client(tracking_uri)
223
+
224
+ run = client.get_run(run_id)
225
+ artifact_uri = run.info.artifact_uri
226
+
227
+ # Get optional path parameter for listing artifacts in a directory
228
+ path = self.get_query_argument("path", None)
229
+
230
+ # List artifacts using MLflow client
231
+ artifacts = []
232
+ try:
233
+ if path:
234
+ artifact_list = client.list_artifacts(run_id, path)
235
+ else:
236
+ artifact_list = client.list_artifacts(run_id)
237
+ for artifact in artifact_list:
238
+ artifacts.append({
239
+ "path": artifact.path,
240
+ "is_dir": artifact.is_dir,
241
+ "file_size": artifact.file_size if hasattr(artifact, 'file_size') else None
242
+ })
243
+ except Exception as e:
244
+ # If list_artifacts fails, return basic info
245
+ self.log.warning(f"Could not list artifacts: {e}")
246
+
247
+ self.write({
248
+ "run_id": run_id,
249
+ "artifact_uri": artifact_uri,
250
+ "artifacts": artifacts
251
+ })
252
+ except Exception as e:
253
+ self.set_status(500)
254
+ self.write({"error": str(e)})
255
+
256
+
257
+ class ArtifactDownloadHandler(MLflowBaseHandler):
258
+ """Handler for downloading artifacts"""
259
+
260
+ def get(self, run_id: str):
261
+ """Download an artifact"""
262
+ try:
263
+ path = self.get_query_argument("path", "")
264
+ if not path:
265
+ self.set_status(400)
266
+ self.write({"error": "path parameter is required"})
267
+ return
268
+
269
+ tracking_uri = self.get_tracking_uri()
270
+ client = get_mlflow_client(tracking_uri)
271
+
272
+ # Check if artifact is a directory first
273
+ try:
274
+ artifacts = client.list_artifacts(run_id, path)
275
+ # If list_artifacts returns items, it's a directory
276
+ if artifacts:
277
+ self.set_status(400)
278
+ self.write({"error": f"'{path}' is a directory. Cannot download directories. Expand to see files inside."})
279
+ return
280
+ except Exception:
281
+ # If list_artifacts fails, try to download anyway
282
+ pass
283
+
284
+ # Check if the path itself is a directory by trying to list it
285
+ try:
286
+ # List parent directory to check if this path is a directory
287
+ parent_path = os.path.dirname(path) if os.path.dirname(path) else None
288
+ if parent_path:
289
+ parent_artifacts = client.list_artifacts(run_id, parent_path)
290
+ for art in parent_artifacts:
291
+ if art.path == path and art.is_dir:
292
+ self.set_status(400)
293
+ self.write({"error": f"'{path}' is a directory. Cannot download directories. Expand to see files inside."})
294
+ return
295
+ except Exception:
296
+ pass
297
+
298
+ # Download artifact
299
+ artifact_path = client.download_artifacts(run_id, path)
300
+
301
+ # Check if downloaded path is actually a directory
302
+ if os.path.isdir(artifact_path):
303
+ self.set_status(400)
304
+ self.write({"error": f"'{path}' is a directory. Cannot download directories. Expand to see files inside."})
305
+ return
306
+
307
+ # Read and return file content
308
+ with open(artifact_path, "rb") as f:
309
+ content = f.read()
310
+
311
+ # Determine content type
312
+ content_type = "application/octet-stream"
313
+ if path.endswith(".json"):
314
+ content_type = "application/json"
315
+ elif path.endswith(".csv"):
316
+ content_type = "text/csv"
317
+ elif path.endswith(".txt") or path.endswith(".log"):
318
+ content_type = "text/plain"
319
+ elif path.endswith(".png"):
320
+ content_type = "image/png"
321
+ elif path.endswith(".jpg") or path.endswith(".jpeg"):
322
+ content_type = "image/jpeg"
323
+ elif path.endswith(".html"):
324
+ content_type = "text/html"
325
+
326
+ self.set_header("Content-Type", content_type)
327
+ self.set_header("Content-Disposition", f'attachment; filename="{os.path.basename(path)}"')
328
+ self.write(content)
329
+ except Exception as e:
330
+ error_msg = str(e)
331
+ # Check if error is about directory
332
+ if "Is a directory" in error_msg or "[Errno 21]" in error_msg:
333
+ self.set_status(400)
334
+ self.write({"error": f"'{path}' is a directory. Cannot download directories. Expand to see files inside."})
335
+ else:
336
+ self.set_status(500)
337
+ self.write({"error": error_msg})
338
+
339
+
340
+ class ModelsHandler(MLflowBaseHandler):
341
+ """Handler for listing models from model registry"""
342
+
343
+ def get(self):
344
+ """Get list of registered models"""
345
+ try:
346
+ tracking_uri = self.get_tracking_uri()
347
+ client = get_mlflow_client(tracking_uri)
348
+
349
+ # Get all registered models
350
+ models = client.search_registered_models()
351
+
352
+ models_data = []
353
+ for model in models:
354
+ models_data.append({
355
+ "name": model.name,
356
+ "latest_versions": [
357
+ {
358
+ "version": v.version,
359
+ "stage": v.current_stage,
360
+ "status": v.status,
361
+ "run_id": v.run_id,
362
+ "creation_timestamp": v.creation_timestamp
363
+ }
364
+ for v in model.latest_versions
365
+ ],
366
+ "creation_timestamp": model.creation_timestamp,
367
+ "last_updated_timestamp": model.last_updated_timestamp,
368
+ "description": model.description,
369
+ "tags": model.tags or {}
370
+ })
371
+
372
+ self.write({"models": models_data})
373
+ except Exception as e:
374
+ self.set_status(500)
375
+ self.write({"error": str(e)})
376
+
377
+
378
+ class ModelHandler(MLflowBaseHandler):
379
+ """Handler for getting model details"""
380
+
381
+ def get(self, model_name: str):
382
+ """Get model details"""
383
+ try:
384
+ tracking_uri = self.get_tracking_uri()
385
+ client = get_mlflow_client(tracking_uri)
386
+
387
+ model = client.get_registered_model(model_name)
388
+
389
+ # Get all versions
390
+ versions = []
391
+ for version in model.latest_versions:
392
+ versions.append({
393
+ "version": version.version,
394
+ "stage": version.current_stage,
395
+ "status": version.status,
396
+ "run_id": version.run_id,
397
+ "creation_timestamp": version.creation_timestamp,
398
+ "description": version.description
399
+ })
400
+
401
+ self.write({
402
+ "name": model.name,
403
+ "versions": versions,
404
+ "creation_timestamp": model.creation_timestamp,
405
+ "last_updated_timestamp": model.last_updated_timestamp,
406
+ "description": model.description,
407
+ "tags": model.tags or {}
408
+ })
409
+ except Exception as e:
410
+ self.set_status(500)
411
+ self.write({"error": str(e)})
412
+
413
+
414
+ class ConnectionTestHandler(MLflowBaseHandler):
415
+ """Handler for testing MLflow connection"""
416
+
417
+ def post(self):
418
+ """Test connection to MLflow server"""
419
+ try:
420
+ body = json.loads(self.request.body.decode("utf-8"))
421
+ tracking_uri = body.get("tracking_uri")
422
+
423
+ if not tracking_uri:
424
+ self.set_status(400)
425
+ self.write({"error": "tracking_uri is required"})
426
+ return
427
+
428
+ client = get_mlflow_client(tracking_uri)
429
+
430
+ # Try to list experiments to test connection
431
+ experiments = client.search_experiments(max_results=1)
432
+
433
+ self.write({
434
+ "success": True,
435
+ "message": "Connection successful",
436
+ "experiment_count": len(experiments)
437
+ })
438
+ except Exception as e:
439
+ self.set_status(500)
440
+ self.write({
441
+ "success": False,
442
+ "error": str(e)
443
+ })
444
+
445
+
446
+ class LocalMLflowServerHandler(MLflowBaseHandler):
447
+ """Handler for managing local MLflow server"""
448
+
449
+ def get(self):
450
+ """Get status of local MLflow server"""
451
+ try:
452
+ status = get_mlflow_server_status()
453
+ self.write(status)
454
+ except Exception as e:
455
+ self.set_status(500)
456
+ self.write({
457
+ "success": False,
458
+ "error": str(e)
459
+ })
460
+
461
+ def post(self):
462
+ """Start local MLflow server"""
463
+ try:
464
+ body = json.loads(self.request.body.decode("utf-8"))
465
+ port = body.get("port", 5000)
466
+ tracking_uri = body.get("tracking_uri", "sqlite:///mlflow.db")
467
+ artifact_uri = body.get("artifact_uri")
468
+ backend_uri = body.get("backend_uri")
469
+
470
+ # Convert empty strings to None
471
+ if artifact_uri == "":
472
+ artifact_uri = None
473
+ if backend_uri == "":
474
+ backend_uri = None
475
+
476
+ # Log the request for debugging
477
+ import logging
478
+ logger = logging.getLogger(__name__)
479
+ logger.info(f"Starting local MLflow server: port={port}, tracking_uri={tracking_uri}, artifact_uri={artifact_uri}")
480
+
481
+ result = start_mlflow_server(
482
+ port=port,
483
+ tracking_uri=tracking_uri,
484
+ artifact_uri=artifact_uri,
485
+ backend_uri=backend_uri
486
+ )
487
+
488
+ if not result.get("success"):
489
+ logger.error(f"Failed to start MLflow server: {result.get('error', 'Unknown error')}")
490
+ self.set_status(500)
491
+ else:
492
+ logger.info(f"Successfully started MLflow server: {result.get('url')}")
493
+
494
+ self.write(result)
495
+ except json.JSONDecodeError as e:
496
+ self.set_status(400)
497
+ self.write({
498
+ "success": False,
499
+ "error": f"Invalid JSON in request body: {str(e)}"
500
+ })
501
+ except Exception as e:
502
+ import logging
503
+ logger = logging.getLogger(__name__)
504
+ logger.error(f"Error starting local MLflow server: {e}", exc_info=True)
505
+ self.set_status(500)
506
+ self.write({
507
+ "success": False,
508
+ "error": str(e)
509
+ })
510
+
511
+ def delete(self):
512
+ """Stop local MLflow server"""
513
+ try:
514
+ result = stop_mlflow_server()
515
+ if not result.get("success"):
516
+ self.set_status(500)
517
+ self.write(result)
518
+ except Exception as e:
519
+ self.set_status(500)
520
+ self.write({
521
+ "success": False,
522
+ "error": str(e)
523
+ })
524
+
525
+
526
+ def setup_handlers(web_app):
527
+ """Setup API handlers"""
528
+ import re
529
+
530
+ host_pattern = ".*$"
531
+
532
+ base_url = web_app.settings.get("base_url", "/")
533
+ # Ensure base_url ends with / for proper path joining
534
+ if not base_url.endswith("/"):
535
+ base_url = base_url + "/"
536
+
537
+ # Escape base_url for use in regex patterns (Tornado uses regex)
538
+ # This handles cases where base_url might contain special regex characters
539
+ escaped_base_url = re.escape(base_url)
540
+
541
+ # Remove leading / from mlflow/api paths since base_url already includes it
542
+ handlers = [
543
+ # Health check endpoint (for diagnosing loading issues)
544
+ (rf"{escaped_base_url}mlflow/api/health", HealthCheckHandler),
545
+ # Main API endpoints
546
+ (rf"{escaped_base_url}mlflow/api/experiments", ExperimentsHandler),
547
+ (rf"{escaped_base_url}mlflow/api/experiments/([^/]+)", ExperimentHandler),
548
+ (rf"{escaped_base_url}mlflow/api/experiments/([^/]+)/runs", RunsHandler),
549
+ (rf"{escaped_base_url}mlflow/api/runs/([^/]+)", RunHandler),
550
+ (rf"{escaped_base_url}mlflow/api/runs/([^/]+)/artifacts", ArtifactsHandler),
551
+ (rf"{escaped_base_url}mlflow/api/runs/([^/]+)/artifacts/download", ArtifactDownloadHandler),
552
+ (rf"{escaped_base_url}mlflow/api/models", ModelsHandler),
553
+ (rf"{escaped_base_url}mlflow/api/models/([^/]+)", ModelHandler),
554
+ (rf"{escaped_base_url}mlflow/api/connection/test", ConnectionTestHandler),
555
+ (rf"{escaped_base_url}mlflow/api/local-server", LocalMLflowServerHandler),
556
+ ]
557
+
558
+ web_app.add_handlers(host_pattern, handlers)
559
+
560
+ # Log registered handlers for debugging
561
+ import logging
562
+ logger = logging.getLogger(__name__)
563
+ logger.info(f"✅ Registered jupyterlab-mlflow API handlers with base_url: {base_url}")
564
+ # Also print to stderr for visibility in managed environments
565
+ import sys
566
+ print(f"✅ jupyterlab-mlflow: Registered {len(handlers)} API handlers with base_url: {base_url}", file=sys.stderr)
567
+ for pattern, handler in handlers:
568
+ logger.debug(f" - {pattern} -> {handler.__name__}")
569
+ print(f" - {pattern} -> {handler.__name__}", file=sys.stderr)
570
+