jupyterlab-mlflow 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (20) hide show
  1. jupyterlab_mlflow/__init__.py +22 -0
  2. jupyterlab_mlflow/_version.py +6 -0
  3. jupyterlab_mlflow/schema/plugin.json +17 -0
  4. jupyterlab_mlflow/serverextension/__init__.py +29 -0
  5. jupyterlab_mlflow/serverextension/handlers.py +511 -0
  6. jupyterlab_mlflow/serverextension/mlflow_server.py +214 -0
  7. jupyterlab_mlflow-0.1.0.data/data/share/jupyter/labextensions/jupyterlab-mlflow/install.json +12 -0
  8. jupyterlab_mlflow-0.1.0.data/data/share/jupyter/labextensions/jupyterlab-mlflow/package.json +97 -0
  9. jupyterlab_mlflow-0.1.0.data/data/share/jupyter/labextensions/jupyterlab-mlflow/schema/plugin.json +17 -0
  10. jupyterlab_mlflow-0.1.0.data/data/share/jupyter/labextensions/jupyterlab-mlflow/schemas/jupyterlab-mlflow/package.json.orig +93 -0
  11. jupyterlab_mlflow-0.1.0.data/data/share/jupyter/labextensions/jupyterlab-mlflow/schemas/jupyterlab-mlflow/plugin.json +17 -0
  12. jupyterlab_mlflow-0.1.0.data/data/share/jupyter/labextensions/jupyterlab-mlflow/static/218.0cf25be5c060df009f4a.js +1 -0
  13. jupyterlab_mlflow-0.1.0.data/data/share/jupyter/labextensions/jupyterlab-mlflow/static/665.f3ea36ea04224fd9c2f3.js +1 -0
  14. jupyterlab_mlflow-0.1.0.data/data/share/jupyter/labextensions/jupyterlab-mlflow/static/remoteEntry.cd2e48e97a1a6275a623.js +1 -0
  15. jupyterlab_mlflow-0.1.0.data/data/share/jupyter/labextensions/jupyterlab-mlflow/static/style.js +4 -0
  16. jupyterlab_mlflow-0.1.0.data/data/share/jupyter/labextensions/jupyterlab-mlflow/static/third-party-licenses.json +16 -0
  17. jupyterlab_mlflow-0.1.0.dist-info/METADATA +157 -0
  18. jupyterlab_mlflow-0.1.0.dist-info/RECORD +20 -0
  19. jupyterlab_mlflow-0.1.0.dist-info/WHEEL +4 -0
  20. jupyterlab_mlflow-0.1.0.dist-info/licenses/LICENSE +30 -0
@@ -0,0 +1,22 @@
1
+ """
2
+ JupyterLab MLflow Extension
3
+ """
4
+
5
+ from ._version import __version__
6
+
7
+ def _jupyter_labextension_paths():
8
+ """Called by Jupyter Lab Server to detect if it is a valid labextension and
9
+ to install the widget
10
+
11
+ Returns
12
+ =======
13
+ src: Source directory name to copy files from. Webpack outputs generated files
14
+ into this directory and Jupyter Lab copies from this directory during
15
+ widget installation
16
+ dest: Destination directory name to install to
17
+ """
18
+ return [{
19
+ 'src': 'labextension',
20
+ 'dest': 'jupyterlab-mlflow'
21
+ }]
22
+
@@ -0,0 +1,6 @@
1
+ """
2
+ Version information for jupyterlab-mlflow
3
+ """
4
+
5
+ __version__ = "0.1.0"
6
+
@@ -0,0 +1,17 @@
1
+ {
2
+ "jupyter.lab.setting-icon": "mlflow:icon",
3
+ "jupyter.lab.setting-icon-label": "MLflow",
4
+ "title": "MLflow",
5
+ "description": "MLflow extension settings",
6
+ "type": "object",
7
+ "properties": {
8
+ "mlflowTrackingUri": {
9
+ "type": "string",
10
+ "title": "MLflow Tracking URI",
11
+ "description": "URI of the MLflow tracking server (e.g., http://localhost:5000). Leave empty to use MLFLOW_TRACKING_URI environment variable.",
12
+ "default": ""
13
+ }
14
+ },
15
+ "additionalProperties": false
16
+ }
17
+
@@ -0,0 +1,29 @@
1
+ """
2
+ JupyterLab MLflow Server Extension
3
+ """
4
+
5
+ import os
6
+ from .handlers import setup_handlers
7
+
8
+
9
+ def _jupyter_server_extension_points():
10
+ """
11
+ Returns a list of dictionaries with metadata about
12
+ the server extension points.
13
+ """
14
+ return [{
15
+ "module": "jupyterlab_mlflow.serverextension"
16
+ }]
17
+
18
+
19
+ def _load_jupyter_server_extension(server_app):
20
+ """Registers the API handler to receive HTTP requests from the frontend extension.
21
+
22
+ Parameters
23
+ ----------
24
+ server_app: jupyter_server.serverapp.ServerApp
25
+ Jupyter Server application instance
26
+ """
27
+ setup_handlers(server_app.web_app)
28
+ server_app.log.info("Registered jupyterlab-mlflow server extension")
29
+
@@ -0,0 +1,511 @@
1
+ """
2
+ MLflow API handlers for JupyterLab extension
3
+ """
4
+
5
+ import json
6
+ import os
7
+ from typing import Dict, Any, Optional
8
+ from urllib.parse import urlparse
9
+
10
+ import mlflow
11
+ from mlflow.tracking import MlflowClient
12
+ from mlflow.exceptions import MlflowException
13
+ from tornado import web
14
+ from tornado.web import RequestHandler
15
+ from .mlflow_server import start_mlflow_server, stop_mlflow_server, get_mlflow_server_status
16
+
17
+
18
+ def get_mlflow_client(tracking_uri: Optional[str] = None) -> MlflowClient:
19
+ """
20
+ Get MLflow client with tracking URI from settings or environment.
21
+
22
+ Parameters
23
+ ----------
24
+ tracking_uri : str, optional
25
+ MLflow tracking URI. If None, uses environment variable or default.
26
+
27
+ Returns
28
+ -------
29
+ MlflowClient
30
+ Configured MLflow client
31
+ """
32
+ if tracking_uri:
33
+ mlflow.set_tracking_uri(tracking_uri)
34
+ elif os.environ.get("MLFLOW_TRACKING_URI"):
35
+ mlflow.set_tracking_uri(os.environ["MLFLOW_TRACKING_URI"])
36
+
37
+ return MlflowClient()
38
+
39
+
40
+ class MLflowBaseHandler(RequestHandler):
41
+ """Base handler for MLflow API endpoints"""
42
+
43
+ def set_default_headers(self):
44
+ """Set CORS headers"""
45
+ self.set_header("Access-Control-Allow-Origin", "*")
46
+ self.set_header("Access-Control-Allow-Headers", "Content-Type")
47
+ self.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
48
+
49
+ def options(self):
50
+ """Handle OPTIONS request for CORS"""
51
+ self.set_status(204)
52
+ self.finish()
53
+
54
+ def get_tracking_uri(self) -> Optional[str]:
55
+ """Get tracking URI from request or settings"""
56
+ tracking_uri = self.get_query_argument("tracking_uri", None)
57
+ if tracking_uri:
58
+ return tracking_uri
59
+
60
+ # Try to get from request body for POST requests
61
+ if self.request.method == "POST":
62
+ try:
63
+ body = json.loads(self.request.body.decode("utf-8"))
64
+ return body.get("tracking_uri")
65
+ except (json.JSONDecodeError, KeyError):
66
+ pass
67
+
68
+ return None
69
+
70
+ def write_error(self, status_code: int, **kwargs):
71
+ """Write error response"""
72
+ exc_info = kwargs.get("exc_info")
73
+ if exc_info:
74
+ exception = exc_info[1]
75
+ if isinstance(exception, MlflowException):
76
+ self.write({
77
+ "error": str(exception),
78
+ "status_code": status_code
79
+ })
80
+ return
81
+
82
+ self.write({
83
+ "error": f"HTTP {status_code}: {self._reason}",
84
+ "status_code": status_code
85
+ })
86
+
87
+
88
+ class ExperimentsHandler(MLflowBaseHandler):
89
+ """Handler for listing experiments"""
90
+
91
+ def get(self):
92
+ """Get list of experiments"""
93
+ try:
94
+ tracking_uri = self.get_tracking_uri()
95
+ client = get_mlflow_client(tracking_uri)
96
+
97
+ experiments = client.search_experiments()
98
+ experiments_data = []
99
+
100
+ for exp in experiments:
101
+ experiments_data.append({
102
+ "experiment_id": exp.experiment_id,
103
+ "name": exp.name,
104
+ "artifact_location": exp.artifact_location,
105
+ "lifecycle_stage": exp.lifecycle_stage,
106
+ "tags": exp.tags or {}
107
+ })
108
+
109
+ self.write({"experiments": experiments_data})
110
+ except Exception as e:
111
+ self.set_status(500)
112
+ self.write({"error": str(e)})
113
+
114
+
115
+ class ExperimentHandler(MLflowBaseHandler):
116
+ """Handler for getting experiment details"""
117
+
118
+ def get(self, experiment_id: str):
119
+ """Get experiment details"""
120
+ try:
121
+ tracking_uri = self.get_tracking_uri()
122
+ client = get_mlflow_client(tracking_uri)
123
+
124
+ experiment = client.get_experiment(experiment_id)
125
+
126
+ self.write({
127
+ "experiment_id": experiment.experiment_id,
128
+ "name": experiment.name,
129
+ "artifact_location": experiment.artifact_location,
130
+ "lifecycle_stage": experiment.lifecycle_stage,
131
+ "tags": experiment.tags or {}
132
+ })
133
+ except Exception as e:
134
+ self.set_status(500)
135
+ self.write({"error": str(e)})
136
+
137
+
138
+ class RunsHandler(MLflowBaseHandler):
139
+ """Handler for listing runs"""
140
+
141
+ def get(self, experiment_id: str):
142
+ """Get list of runs for an experiment"""
143
+ try:
144
+ tracking_uri = self.get_tracking_uri()
145
+ client = get_mlflow_client(tracking_uri)
146
+
147
+ runs = client.search_runs(
148
+ experiment_ids=[experiment_id],
149
+ max_results=1000
150
+ )
151
+
152
+ runs_data = []
153
+ for run in runs:
154
+ runs_data.append({
155
+ "run_id": run.info.run_id,
156
+ "run_name": run.info.run_name,
157
+ "experiment_id": run.info.experiment_id,
158
+ "status": run.info.status,
159
+ "start_time": run.info.start_time,
160
+ "end_time": run.info.end_time,
161
+ "user_id": run.info.user_id,
162
+ "metrics": {k: v for k, v in run.data.metrics.items()},
163
+ "params": {k: v for k, v in run.data.params.items()},
164
+ "tags": {k: v for k, v in run.data.tags.items()},
165
+ "artifact_uri": run.info.artifact_uri
166
+ })
167
+
168
+ self.write({"runs": runs_data})
169
+ except Exception as e:
170
+ self.set_status(500)
171
+ self.write({"error": str(e)})
172
+
173
+
174
+ class RunHandler(MLflowBaseHandler):
175
+ """Handler for getting run details"""
176
+
177
+ def get(self, run_id: str):
178
+ """Get run details"""
179
+ try:
180
+ tracking_uri = self.get_tracking_uri()
181
+ client = get_mlflow_client(tracking_uri)
182
+
183
+ run = client.get_run(run_id)
184
+
185
+ self.write({
186
+ "run_id": run.info.run_id,
187
+ "run_name": run.info.run_name,
188
+ "experiment_id": run.info.experiment_id,
189
+ "status": run.info.status,
190
+ "start_time": run.info.start_time,
191
+ "end_time": run.info.end_time,
192
+ "user_id": run.info.user_id,
193
+ "metrics": {k: v for k, v in run.data.metrics.items()},
194
+ "params": {k: v for k, v in run.data.params.items()},
195
+ "tags": {k: v for k, v in run.data.tags.items()},
196
+ "artifact_uri": run.info.artifact_uri
197
+ })
198
+ except Exception as e:
199
+ self.set_status(500)
200
+ self.write({"error": str(e)})
201
+
202
+
203
+ class ArtifactsHandler(MLflowBaseHandler):
204
+ """Handler for listing artifacts"""
205
+
206
+ def get(self, run_id: str):
207
+ """Get list of artifacts for a run"""
208
+ try:
209
+ tracking_uri = self.get_tracking_uri()
210
+ client = get_mlflow_client(tracking_uri)
211
+
212
+ run = client.get_run(run_id)
213
+ artifact_uri = run.info.artifact_uri
214
+
215
+ # Get optional path parameter for listing artifacts in a directory
216
+ path = self.get_query_argument("path", None)
217
+
218
+ # List artifacts using MLflow client
219
+ artifacts = []
220
+ try:
221
+ if path:
222
+ artifact_list = client.list_artifacts(run_id, path)
223
+ else:
224
+ artifact_list = client.list_artifacts(run_id)
225
+ for artifact in artifact_list:
226
+ artifacts.append({
227
+ "path": artifact.path,
228
+ "is_dir": artifact.is_dir,
229
+ "file_size": artifact.file_size if hasattr(artifact, 'file_size') else None
230
+ })
231
+ except Exception as e:
232
+ # If list_artifacts fails, return basic info
233
+ self.log.warning(f"Could not list artifacts: {e}")
234
+
235
+ self.write({
236
+ "run_id": run_id,
237
+ "artifact_uri": artifact_uri,
238
+ "artifacts": artifacts
239
+ })
240
+ except Exception as e:
241
+ self.set_status(500)
242
+ self.write({"error": str(e)})
243
+
244
+
245
+ class ArtifactDownloadHandler(MLflowBaseHandler):
246
+ """Handler for downloading artifacts"""
247
+
248
+ def get(self, run_id: str):
249
+ """Download an artifact"""
250
+ try:
251
+ path = self.get_query_argument("path", "")
252
+ if not path:
253
+ self.set_status(400)
254
+ self.write({"error": "path parameter is required"})
255
+ return
256
+
257
+ tracking_uri = self.get_tracking_uri()
258
+ client = get_mlflow_client(tracking_uri)
259
+
260
+ # Check if artifact is a directory first
261
+ try:
262
+ artifacts = client.list_artifacts(run_id, path)
263
+ # If list_artifacts returns items, it's a directory
264
+ if artifacts:
265
+ self.set_status(400)
266
+ self.write({"error": f"'{path}' is a directory. Cannot download directories. Expand to see files inside."})
267
+ return
268
+ except Exception:
269
+ # If list_artifacts fails, try to download anyway
270
+ pass
271
+
272
+ # Check if the path itself is a directory by trying to list it
273
+ try:
274
+ # List parent directory to check if this path is a directory
275
+ parent_path = os.path.dirname(path) if os.path.dirname(path) else None
276
+ if parent_path:
277
+ parent_artifacts = client.list_artifacts(run_id, parent_path)
278
+ for art in parent_artifacts:
279
+ if art.path == path and art.is_dir:
280
+ self.set_status(400)
281
+ self.write({"error": f"'{path}' is a directory. Cannot download directories. Expand to see files inside."})
282
+ return
283
+ except Exception:
284
+ pass
285
+
286
+ # Download artifact
287
+ artifact_path = client.download_artifacts(run_id, path)
288
+
289
+ # Check if downloaded path is actually a directory
290
+ if os.path.isdir(artifact_path):
291
+ self.set_status(400)
292
+ self.write({"error": f"'{path}' is a directory. Cannot download directories. Expand to see files inside."})
293
+ return
294
+
295
+ # Read and return file content
296
+ with open(artifact_path, "rb") as f:
297
+ content = f.read()
298
+
299
+ # Determine content type
300
+ content_type = "application/octet-stream"
301
+ if path.endswith(".json"):
302
+ content_type = "application/json"
303
+ elif path.endswith(".csv"):
304
+ content_type = "text/csv"
305
+ elif path.endswith(".txt") or path.endswith(".log"):
306
+ content_type = "text/plain"
307
+ elif path.endswith(".png"):
308
+ content_type = "image/png"
309
+ elif path.endswith(".jpg") or path.endswith(".jpeg"):
310
+ content_type = "image/jpeg"
311
+ elif path.endswith(".html"):
312
+ content_type = "text/html"
313
+
314
+ self.set_header("Content-Type", content_type)
315
+ self.set_header("Content-Disposition", f'attachment; filename="{os.path.basename(path)}"')
316
+ self.write(content)
317
+ except Exception as e:
318
+ error_msg = str(e)
319
+ # Check if error is about directory
320
+ if "Is a directory" in error_msg or "[Errno 21]" in error_msg:
321
+ self.set_status(400)
322
+ self.write({"error": f"'{path}' is a directory. Cannot download directories. Expand to see files inside."})
323
+ else:
324
+ self.set_status(500)
325
+ self.write({"error": error_msg})
326
+
327
+
328
+ class ModelsHandler(MLflowBaseHandler):
329
+ """Handler for listing models from model registry"""
330
+
331
+ def get(self):
332
+ """Get list of registered models"""
333
+ try:
334
+ tracking_uri = self.get_tracking_uri()
335
+ client = get_mlflow_client(tracking_uri)
336
+
337
+ # Get all registered models
338
+ models = client.search_registered_models()
339
+
340
+ models_data = []
341
+ for model in models:
342
+ models_data.append({
343
+ "name": model.name,
344
+ "latest_versions": [
345
+ {
346
+ "version": v.version,
347
+ "stage": v.current_stage,
348
+ "status": v.status,
349
+ "run_id": v.run_id,
350
+ "creation_timestamp": v.creation_timestamp
351
+ }
352
+ for v in model.latest_versions
353
+ ],
354
+ "creation_timestamp": model.creation_timestamp,
355
+ "last_updated_timestamp": model.last_updated_timestamp,
356
+ "description": model.description,
357
+ "tags": model.tags or {}
358
+ })
359
+
360
+ self.write({"models": models_data})
361
+ except Exception as e:
362
+ self.set_status(500)
363
+ self.write({"error": str(e)})
364
+
365
+
366
+ class ModelHandler(MLflowBaseHandler):
367
+ """Handler for getting model details"""
368
+
369
+ def get(self, model_name: str):
370
+ """Get model details"""
371
+ try:
372
+ tracking_uri = self.get_tracking_uri()
373
+ client = get_mlflow_client(tracking_uri)
374
+
375
+ model = client.get_registered_model(model_name)
376
+
377
+ # Get all versions
378
+ versions = []
379
+ for version in model.latest_versions:
380
+ versions.append({
381
+ "version": version.version,
382
+ "stage": version.current_stage,
383
+ "status": version.status,
384
+ "run_id": version.run_id,
385
+ "creation_timestamp": version.creation_timestamp,
386
+ "description": version.description
387
+ })
388
+
389
+ self.write({
390
+ "name": model.name,
391
+ "versions": versions,
392
+ "creation_timestamp": model.creation_timestamp,
393
+ "last_updated_timestamp": model.last_updated_timestamp,
394
+ "description": model.description,
395
+ "tags": model.tags or {}
396
+ })
397
+ except Exception as e:
398
+ self.set_status(500)
399
+ self.write({"error": str(e)})
400
+
401
+
402
+ class ConnectionTestHandler(MLflowBaseHandler):
403
+ """Handler for testing MLflow connection"""
404
+
405
+ def post(self):
406
+ """Test connection to MLflow server"""
407
+ try:
408
+ body = json.loads(self.request.body.decode("utf-8"))
409
+ tracking_uri = body.get("tracking_uri")
410
+
411
+ if not tracking_uri:
412
+ self.set_status(400)
413
+ self.write({"error": "tracking_uri is required"})
414
+ return
415
+
416
+ client = get_mlflow_client(tracking_uri)
417
+
418
+ # Try to list experiments to test connection
419
+ experiments = client.search_experiments(max_results=1)
420
+
421
+ self.write({
422
+ "success": True,
423
+ "message": "Connection successful",
424
+ "experiment_count": len(experiments)
425
+ })
426
+ except Exception as e:
427
+ self.set_status(500)
428
+ self.write({
429
+ "success": False,
430
+ "error": str(e)
431
+ })
432
+
433
+
434
+ class LocalMLflowServerHandler(MLflowBaseHandler):
435
+ """Handler for managing local MLflow server"""
436
+
437
+ def get(self):
438
+ """Get status of local MLflow server"""
439
+ try:
440
+ status = get_mlflow_server_status()
441
+ self.write(status)
442
+ except Exception as e:
443
+ self.set_status(500)
444
+ self.write({
445
+ "success": False,
446
+ "error": str(e)
447
+ })
448
+
449
+ def post(self):
450
+ """Start local MLflow server"""
451
+ try:
452
+ body = json.loads(self.request.body.decode("utf-8"))
453
+ port = body.get("port", 5000)
454
+ tracking_uri = body.get("tracking_uri", "sqlite:///mlflow.db")
455
+ artifact_uri = body.get("artifact_uri")
456
+ backend_uri = body.get("backend_uri")
457
+
458
+ result = start_mlflow_server(
459
+ port=port,
460
+ tracking_uri=tracking_uri,
461
+ artifact_uri=artifact_uri,
462
+ backend_uri=backend_uri
463
+ )
464
+
465
+ if not result.get("success"):
466
+ self.set_status(500)
467
+
468
+ self.write(result)
469
+ except Exception as e:
470
+ self.set_status(500)
471
+ self.write({
472
+ "success": False,
473
+ "error": str(e)
474
+ })
475
+
476
+ def delete(self):
477
+ """Stop local MLflow server"""
478
+ try:
479
+ result = stop_mlflow_server()
480
+ if not result.get("success"):
481
+ self.set_status(500)
482
+ self.write(result)
483
+ except Exception as e:
484
+ self.set_status(500)
485
+ self.write({
486
+ "success": False,
487
+ "error": str(e)
488
+ })
489
+
490
+
491
+ def setup_handlers(web_app):
492
+ """Setup API handlers"""
493
+ host_pattern = ".*$"
494
+
495
+ base_url = web_app.settings.get("base_url", "/")
496
+
497
+ handlers = [
498
+ (f"{base_url}mlflow/api/experiments", ExperimentsHandler),
499
+ (f"{base_url}mlflow/api/experiments/([^/]+)", ExperimentHandler),
500
+ (f"{base_url}mlflow/api/experiments/([^/]+)/runs", RunsHandler),
501
+ (f"{base_url}mlflow/api/runs/([^/]+)", RunHandler),
502
+ (f"{base_url}mlflow/api/runs/([^/]+)/artifacts", ArtifactsHandler),
503
+ (f"{base_url}mlflow/api/runs/([^/]+)/artifacts/download", ArtifactDownloadHandler),
504
+ (f"{base_url}mlflow/api/models", ModelsHandler),
505
+ (f"{base_url}mlflow/api/models/([^/]+)", ModelHandler),
506
+ (f"{base_url}mlflow/api/connection/test", ConnectionTestHandler),
507
+ (f"{base_url}mlflow/api/local-server", LocalMLflowServerHandler),
508
+ ]
509
+
510
+ web_app.add_handlers(host_pattern, handlers)
511
+