jupyterlab-mlflow 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of jupyterlab-mlflow might be problematic. Click here for more details.
- jupyterlab_mlflow/__init__.py +63 -0
- jupyterlab_mlflow/_version.py +6 -0
- jupyterlab_mlflow/post_install.py +44 -0
- jupyterlab_mlflow/schema/plugin.json +17 -0
- jupyterlab_mlflow/serverextension/__init__.py +69 -0
- jupyterlab_mlflow/serverextension/handlers.py +570 -0
- jupyterlab_mlflow/serverextension/mlflow_server.py +214 -0
- jupyterlab_mlflow-0.4.1.data/data/etc/jupyter/jupyter_server_config.d/jupyterlab_mlflow.json +8 -0
- jupyterlab_mlflow-0.4.1.data/data/share/jupyter/labextensions/jupyterlab-mlflow/install.json +12 -0
- jupyterlab_mlflow-0.4.1.data/data/share/jupyter/labextensions/jupyterlab-mlflow/package.json +103 -0
- jupyterlab_mlflow-0.4.1.data/data/share/jupyter/labextensions/jupyterlab-mlflow/schema/plugin.json +17 -0
- jupyterlab_mlflow-0.4.1.data/data/share/jupyter/labextensions/jupyterlab-mlflow/schemas/jupyterlab-mlflow/package.json.orig +98 -0
- jupyterlab_mlflow-0.4.1.data/data/share/jupyter/labextensions/jupyterlab-mlflow/schemas/jupyterlab-mlflow/plugin.json +17 -0
- jupyterlab_mlflow-0.4.1.data/data/share/jupyter/labextensions/jupyterlab-mlflow/static/218.47b1285b67dde3db8969.js +1 -0
- jupyterlab_mlflow-0.4.1.data/data/share/jupyter/labextensions/jupyterlab-mlflow/static/665.f3ea36ea04224fd9c2f3.js +1 -0
- jupyterlab_mlflow-0.4.1.data/data/share/jupyter/labextensions/jupyterlab-mlflow/static/remoteEntry.121dc9414dda869fb1a6.js +1 -0
- jupyterlab_mlflow-0.4.1.data/data/share/jupyter/labextensions/jupyterlab-mlflow/static/style.js +4 -0
- jupyterlab_mlflow-0.4.1.data/data/share/jupyter/labextensions/jupyterlab-mlflow/static/third-party-licenses.json +16 -0
- jupyterlab_mlflow-0.4.1.dist-info/METADATA +273 -0
- jupyterlab_mlflow-0.4.1.dist-info/RECORD +23 -0
- jupyterlab_mlflow-0.4.1.dist-info/WHEEL +4 -0
- jupyterlab_mlflow-0.4.1.dist-info/entry_points.txt +2 -0
- jupyterlab_mlflow-0.4.1.dist-info/licenses/LICENSE +30 -0
|
@@ -0,0 +1,570 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MLflow API handlers for JupyterLab extension
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
from typing import Dict, Any, Optional
|
|
8
|
+
from urllib.parse import urlparse
|
|
9
|
+
|
|
10
|
+
import mlflow
|
|
11
|
+
from mlflow.tracking import MlflowClient
|
|
12
|
+
from mlflow.exceptions import MlflowException
|
|
13
|
+
from tornado import web
|
|
14
|
+
from tornado.web import RequestHandler
|
|
15
|
+
from .mlflow_server import start_mlflow_server, stop_mlflow_server, get_mlflow_server_status
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_mlflow_client(tracking_uri: Optional[str] = None) -> MlflowClient:
|
|
19
|
+
"""
|
|
20
|
+
Get MLflow client with tracking URI from settings or environment.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
tracking_uri : str, optional
|
|
25
|
+
MLflow tracking URI. If None, uses environment variable or default.
|
|
26
|
+
|
|
27
|
+
Returns
|
|
28
|
+
-------
|
|
29
|
+
MlflowClient
|
|
30
|
+
Configured MLflow client
|
|
31
|
+
"""
|
|
32
|
+
if tracking_uri:
|
|
33
|
+
mlflow.set_tracking_uri(tracking_uri)
|
|
34
|
+
elif os.environ.get("MLFLOW_TRACKING_URI"):
|
|
35
|
+
mlflow.set_tracking_uri(os.environ["MLFLOW_TRACKING_URI"])
|
|
36
|
+
|
|
37
|
+
return MlflowClient()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class HealthCheckHandler(RequestHandler):
|
|
41
|
+
"""Handler for extension health check - helps diagnose loading issues"""
|
|
42
|
+
|
|
43
|
+
def get(self):
|
|
44
|
+
"""Return extension status"""
|
|
45
|
+
self.write({
|
|
46
|
+
"status": "ok",
|
|
47
|
+
"extension": "jupyterlab-mlflow",
|
|
48
|
+
"message": "Server extension is loaded and responding"
|
|
49
|
+
})
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class MLflowBaseHandler(RequestHandler):
|
|
53
|
+
"""Base handler for MLflow API endpoints"""
|
|
54
|
+
|
|
55
|
+
def set_default_headers(self):
|
|
56
|
+
"""Set CORS headers"""
|
|
57
|
+
self.set_header("Access-Control-Allow-Origin", "*")
|
|
58
|
+
self.set_header("Access-Control-Allow-Headers", "Content-Type")
|
|
59
|
+
self.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
|
60
|
+
|
|
61
|
+
def options(self):
|
|
62
|
+
"""Handle OPTIONS request for CORS"""
|
|
63
|
+
self.set_status(204)
|
|
64
|
+
self.finish()
|
|
65
|
+
|
|
66
|
+
def get_tracking_uri(self) -> Optional[str]:
|
|
67
|
+
"""Get tracking URI from request or settings"""
|
|
68
|
+
tracking_uri = self.get_query_argument("tracking_uri", None)
|
|
69
|
+
if tracking_uri:
|
|
70
|
+
return tracking_uri
|
|
71
|
+
|
|
72
|
+
# Try to get from request body for POST requests
|
|
73
|
+
if self.request.method == "POST":
|
|
74
|
+
try:
|
|
75
|
+
body = json.loads(self.request.body.decode("utf-8"))
|
|
76
|
+
return body.get("tracking_uri")
|
|
77
|
+
except (json.JSONDecodeError, KeyError):
|
|
78
|
+
pass
|
|
79
|
+
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
def write_error(self, status_code: int, **kwargs):
|
|
83
|
+
"""Write error response"""
|
|
84
|
+
exc_info = kwargs.get("exc_info")
|
|
85
|
+
if exc_info:
|
|
86
|
+
exception = exc_info[1]
|
|
87
|
+
if isinstance(exception, MlflowException):
|
|
88
|
+
self.write({
|
|
89
|
+
"error": str(exception),
|
|
90
|
+
"status_code": status_code
|
|
91
|
+
})
|
|
92
|
+
return
|
|
93
|
+
|
|
94
|
+
self.write({
|
|
95
|
+
"error": f"HTTP {status_code}: {self._reason}",
|
|
96
|
+
"status_code": status_code
|
|
97
|
+
})
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class ExperimentsHandler(MLflowBaseHandler):
|
|
101
|
+
"""Handler for listing experiments"""
|
|
102
|
+
|
|
103
|
+
def get(self):
|
|
104
|
+
"""Get list of experiments"""
|
|
105
|
+
try:
|
|
106
|
+
tracking_uri = self.get_tracking_uri()
|
|
107
|
+
client = get_mlflow_client(tracking_uri)
|
|
108
|
+
|
|
109
|
+
experiments = client.search_experiments()
|
|
110
|
+
experiments_data = []
|
|
111
|
+
|
|
112
|
+
for exp in experiments:
|
|
113
|
+
experiments_data.append({
|
|
114
|
+
"experiment_id": exp.experiment_id,
|
|
115
|
+
"name": exp.name,
|
|
116
|
+
"artifact_location": exp.artifact_location,
|
|
117
|
+
"lifecycle_stage": exp.lifecycle_stage,
|
|
118
|
+
"tags": exp.tags or {}
|
|
119
|
+
})
|
|
120
|
+
|
|
121
|
+
self.write({"experiments": experiments_data})
|
|
122
|
+
except Exception as e:
|
|
123
|
+
self.set_status(500)
|
|
124
|
+
self.write({"error": str(e)})
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class ExperimentHandler(MLflowBaseHandler):
|
|
128
|
+
"""Handler for getting experiment details"""
|
|
129
|
+
|
|
130
|
+
def get(self, experiment_id: str):
|
|
131
|
+
"""Get experiment details"""
|
|
132
|
+
try:
|
|
133
|
+
tracking_uri = self.get_tracking_uri()
|
|
134
|
+
client = get_mlflow_client(tracking_uri)
|
|
135
|
+
|
|
136
|
+
experiment = client.get_experiment(experiment_id)
|
|
137
|
+
|
|
138
|
+
self.write({
|
|
139
|
+
"experiment_id": experiment.experiment_id,
|
|
140
|
+
"name": experiment.name,
|
|
141
|
+
"artifact_location": experiment.artifact_location,
|
|
142
|
+
"lifecycle_stage": experiment.lifecycle_stage,
|
|
143
|
+
"tags": experiment.tags or {}
|
|
144
|
+
})
|
|
145
|
+
except Exception as e:
|
|
146
|
+
self.set_status(500)
|
|
147
|
+
self.write({"error": str(e)})
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class RunsHandler(MLflowBaseHandler):
|
|
151
|
+
"""Handler for listing runs"""
|
|
152
|
+
|
|
153
|
+
def get(self, experiment_id: str):
|
|
154
|
+
"""Get list of runs for an experiment"""
|
|
155
|
+
try:
|
|
156
|
+
tracking_uri = self.get_tracking_uri()
|
|
157
|
+
client = get_mlflow_client(tracking_uri)
|
|
158
|
+
|
|
159
|
+
runs = client.search_runs(
|
|
160
|
+
experiment_ids=[experiment_id],
|
|
161
|
+
max_results=1000
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
runs_data = []
|
|
165
|
+
for run in runs:
|
|
166
|
+
runs_data.append({
|
|
167
|
+
"run_id": run.info.run_id,
|
|
168
|
+
"run_name": run.info.run_name,
|
|
169
|
+
"experiment_id": run.info.experiment_id,
|
|
170
|
+
"status": run.info.status,
|
|
171
|
+
"start_time": run.info.start_time,
|
|
172
|
+
"end_time": run.info.end_time,
|
|
173
|
+
"user_id": run.info.user_id,
|
|
174
|
+
"metrics": {k: v for k, v in run.data.metrics.items()},
|
|
175
|
+
"params": {k: v for k, v in run.data.params.items()},
|
|
176
|
+
"tags": {k: v for k, v in run.data.tags.items()},
|
|
177
|
+
"artifact_uri": run.info.artifact_uri
|
|
178
|
+
})
|
|
179
|
+
|
|
180
|
+
self.write({"runs": runs_data})
|
|
181
|
+
except Exception as e:
|
|
182
|
+
self.set_status(500)
|
|
183
|
+
self.write({"error": str(e)})
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class RunHandler(MLflowBaseHandler):
|
|
187
|
+
"""Handler for getting run details"""
|
|
188
|
+
|
|
189
|
+
def get(self, run_id: str):
|
|
190
|
+
"""Get run details"""
|
|
191
|
+
try:
|
|
192
|
+
tracking_uri = self.get_tracking_uri()
|
|
193
|
+
client = get_mlflow_client(tracking_uri)
|
|
194
|
+
|
|
195
|
+
run = client.get_run(run_id)
|
|
196
|
+
|
|
197
|
+
self.write({
|
|
198
|
+
"run_id": run.info.run_id,
|
|
199
|
+
"run_name": run.info.run_name,
|
|
200
|
+
"experiment_id": run.info.experiment_id,
|
|
201
|
+
"status": run.info.status,
|
|
202
|
+
"start_time": run.info.start_time,
|
|
203
|
+
"end_time": run.info.end_time,
|
|
204
|
+
"user_id": run.info.user_id,
|
|
205
|
+
"metrics": {k: v for k, v in run.data.metrics.items()},
|
|
206
|
+
"params": {k: v for k, v in run.data.params.items()},
|
|
207
|
+
"tags": {k: v for k, v in run.data.tags.items()},
|
|
208
|
+
"artifact_uri": run.info.artifact_uri
|
|
209
|
+
})
|
|
210
|
+
except Exception as e:
|
|
211
|
+
self.set_status(500)
|
|
212
|
+
self.write({"error": str(e)})
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class ArtifactsHandler(MLflowBaseHandler):
|
|
216
|
+
"""Handler for listing artifacts"""
|
|
217
|
+
|
|
218
|
+
def get(self, run_id: str):
|
|
219
|
+
"""Get list of artifacts for a run"""
|
|
220
|
+
try:
|
|
221
|
+
tracking_uri = self.get_tracking_uri()
|
|
222
|
+
client = get_mlflow_client(tracking_uri)
|
|
223
|
+
|
|
224
|
+
run = client.get_run(run_id)
|
|
225
|
+
artifact_uri = run.info.artifact_uri
|
|
226
|
+
|
|
227
|
+
# Get optional path parameter for listing artifacts in a directory
|
|
228
|
+
path = self.get_query_argument("path", None)
|
|
229
|
+
|
|
230
|
+
# List artifacts using MLflow client
|
|
231
|
+
artifacts = []
|
|
232
|
+
try:
|
|
233
|
+
if path:
|
|
234
|
+
artifact_list = client.list_artifacts(run_id, path)
|
|
235
|
+
else:
|
|
236
|
+
artifact_list = client.list_artifacts(run_id)
|
|
237
|
+
for artifact in artifact_list:
|
|
238
|
+
artifacts.append({
|
|
239
|
+
"path": artifact.path,
|
|
240
|
+
"is_dir": artifact.is_dir,
|
|
241
|
+
"file_size": artifact.file_size if hasattr(artifact, 'file_size') else None
|
|
242
|
+
})
|
|
243
|
+
except Exception as e:
|
|
244
|
+
# If list_artifacts fails, return basic info
|
|
245
|
+
self.log.warning(f"Could not list artifacts: {e}")
|
|
246
|
+
|
|
247
|
+
self.write({
|
|
248
|
+
"run_id": run_id,
|
|
249
|
+
"artifact_uri": artifact_uri,
|
|
250
|
+
"artifacts": artifacts
|
|
251
|
+
})
|
|
252
|
+
except Exception as e:
|
|
253
|
+
self.set_status(500)
|
|
254
|
+
self.write({"error": str(e)})
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
class ArtifactDownloadHandler(MLflowBaseHandler):
|
|
258
|
+
"""Handler for downloading artifacts"""
|
|
259
|
+
|
|
260
|
+
def get(self, run_id: str):
|
|
261
|
+
"""Download an artifact"""
|
|
262
|
+
try:
|
|
263
|
+
path = self.get_query_argument("path", "")
|
|
264
|
+
if not path:
|
|
265
|
+
self.set_status(400)
|
|
266
|
+
self.write({"error": "path parameter is required"})
|
|
267
|
+
return
|
|
268
|
+
|
|
269
|
+
tracking_uri = self.get_tracking_uri()
|
|
270
|
+
client = get_mlflow_client(tracking_uri)
|
|
271
|
+
|
|
272
|
+
# Check if artifact is a directory first
|
|
273
|
+
try:
|
|
274
|
+
artifacts = client.list_artifacts(run_id, path)
|
|
275
|
+
# If list_artifacts returns items, it's a directory
|
|
276
|
+
if artifacts:
|
|
277
|
+
self.set_status(400)
|
|
278
|
+
self.write({"error": f"'{path}' is a directory. Cannot download directories. Expand to see files inside."})
|
|
279
|
+
return
|
|
280
|
+
except Exception:
|
|
281
|
+
# If list_artifacts fails, try to download anyway
|
|
282
|
+
pass
|
|
283
|
+
|
|
284
|
+
# Check if the path itself is a directory by trying to list it
|
|
285
|
+
try:
|
|
286
|
+
# List parent directory to check if this path is a directory
|
|
287
|
+
parent_path = os.path.dirname(path) if os.path.dirname(path) else None
|
|
288
|
+
if parent_path:
|
|
289
|
+
parent_artifacts = client.list_artifacts(run_id, parent_path)
|
|
290
|
+
for art in parent_artifacts:
|
|
291
|
+
if art.path == path and art.is_dir:
|
|
292
|
+
self.set_status(400)
|
|
293
|
+
self.write({"error": f"'{path}' is a directory. Cannot download directories. Expand to see files inside."})
|
|
294
|
+
return
|
|
295
|
+
except Exception:
|
|
296
|
+
pass
|
|
297
|
+
|
|
298
|
+
# Download artifact
|
|
299
|
+
artifact_path = client.download_artifacts(run_id, path)
|
|
300
|
+
|
|
301
|
+
# Check if downloaded path is actually a directory
|
|
302
|
+
if os.path.isdir(artifact_path):
|
|
303
|
+
self.set_status(400)
|
|
304
|
+
self.write({"error": f"'{path}' is a directory. Cannot download directories. Expand to see files inside."})
|
|
305
|
+
return
|
|
306
|
+
|
|
307
|
+
# Read and return file content
|
|
308
|
+
with open(artifact_path, "rb") as f:
|
|
309
|
+
content = f.read()
|
|
310
|
+
|
|
311
|
+
# Determine content type
|
|
312
|
+
content_type = "application/octet-stream"
|
|
313
|
+
if path.endswith(".json"):
|
|
314
|
+
content_type = "application/json"
|
|
315
|
+
elif path.endswith(".csv"):
|
|
316
|
+
content_type = "text/csv"
|
|
317
|
+
elif path.endswith(".txt") or path.endswith(".log"):
|
|
318
|
+
content_type = "text/plain"
|
|
319
|
+
elif path.endswith(".png"):
|
|
320
|
+
content_type = "image/png"
|
|
321
|
+
elif path.endswith(".jpg") or path.endswith(".jpeg"):
|
|
322
|
+
content_type = "image/jpeg"
|
|
323
|
+
elif path.endswith(".html"):
|
|
324
|
+
content_type = "text/html"
|
|
325
|
+
|
|
326
|
+
self.set_header("Content-Type", content_type)
|
|
327
|
+
self.set_header("Content-Disposition", f'attachment; filename="{os.path.basename(path)}"')
|
|
328
|
+
self.write(content)
|
|
329
|
+
except Exception as e:
|
|
330
|
+
error_msg = str(e)
|
|
331
|
+
# Check if error is about directory
|
|
332
|
+
if "Is a directory" in error_msg or "[Errno 21]" in error_msg:
|
|
333
|
+
self.set_status(400)
|
|
334
|
+
self.write({"error": f"'{path}' is a directory. Cannot download directories. Expand to see files inside."})
|
|
335
|
+
else:
|
|
336
|
+
self.set_status(500)
|
|
337
|
+
self.write({"error": error_msg})
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
class ModelsHandler(MLflowBaseHandler):
|
|
341
|
+
"""Handler for listing models from model registry"""
|
|
342
|
+
|
|
343
|
+
def get(self):
|
|
344
|
+
"""Get list of registered models"""
|
|
345
|
+
try:
|
|
346
|
+
tracking_uri = self.get_tracking_uri()
|
|
347
|
+
client = get_mlflow_client(tracking_uri)
|
|
348
|
+
|
|
349
|
+
# Get all registered models
|
|
350
|
+
models = client.search_registered_models()
|
|
351
|
+
|
|
352
|
+
models_data = []
|
|
353
|
+
for model in models:
|
|
354
|
+
models_data.append({
|
|
355
|
+
"name": model.name,
|
|
356
|
+
"latest_versions": [
|
|
357
|
+
{
|
|
358
|
+
"version": v.version,
|
|
359
|
+
"stage": v.current_stage,
|
|
360
|
+
"status": v.status,
|
|
361
|
+
"run_id": v.run_id,
|
|
362
|
+
"creation_timestamp": v.creation_timestamp
|
|
363
|
+
}
|
|
364
|
+
for v in model.latest_versions
|
|
365
|
+
],
|
|
366
|
+
"creation_timestamp": model.creation_timestamp,
|
|
367
|
+
"last_updated_timestamp": model.last_updated_timestamp,
|
|
368
|
+
"description": model.description,
|
|
369
|
+
"tags": model.tags or {}
|
|
370
|
+
})
|
|
371
|
+
|
|
372
|
+
self.write({"models": models_data})
|
|
373
|
+
except Exception as e:
|
|
374
|
+
self.set_status(500)
|
|
375
|
+
self.write({"error": str(e)})
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
class ModelHandler(MLflowBaseHandler):
|
|
379
|
+
"""Handler for getting model details"""
|
|
380
|
+
|
|
381
|
+
def get(self, model_name: str):
|
|
382
|
+
"""Get model details"""
|
|
383
|
+
try:
|
|
384
|
+
tracking_uri = self.get_tracking_uri()
|
|
385
|
+
client = get_mlflow_client(tracking_uri)
|
|
386
|
+
|
|
387
|
+
model = client.get_registered_model(model_name)
|
|
388
|
+
|
|
389
|
+
# Get all versions
|
|
390
|
+
versions = []
|
|
391
|
+
for version in model.latest_versions:
|
|
392
|
+
versions.append({
|
|
393
|
+
"version": version.version,
|
|
394
|
+
"stage": version.current_stage,
|
|
395
|
+
"status": version.status,
|
|
396
|
+
"run_id": version.run_id,
|
|
397
|
+
"creation_timestamp": version.creation_timestamp,
|
|
398
|
+
"description": version.description
|
|
399
|
+
})
|
|
400
|
+
|
|
401
|
+
self.write({
|
|
402
|
+
"name": model.name,
|
|
403
|
+
"versions": versions,
|
|
404
|
+
"creation_timestamp": model.creation_timestamp,
|
|
405
|
+
"last_updated_timestamp": model.last_updated_timestamp,
|
|
406
|
+
"description": model.description,
|
|
407
|
+
"tags": model.tags or {}
|
|
408
|
+
})
|
|
409
|
+
except Exception as e:
|
|
410
|
+
self.set_status(500)
|
|
411
|
+
self.write({"error": str(e)})
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
class ConnectionTestHandler(MLflowBaseHandler):
|
|
415
|
+
"""Handler for testing MLflow connection"""
|
|
416
|
+
|
|
417
|
+
def post(self):
|
|
418
|
+
"""Test connection to MLflow server"""
|
|
419
|
+
try:
|
|
420
|
+
body = json.loads(self.request.body.decode("utf-8"))
|
|
421
|
+
tracking_uri = body.get("tracking_uri")
|
|
422
|
+
|
|
423
|
+
if not tracking_uri:
|
|
424
|
+
self.set_status(400)
|
|
425
|
+
self.write({"error": "tracking_uri is required"})
|
|
426
|
+
return
|
|
427
|
+
|
|
428
|
+
client = get_mlflow_client(tracking_uri)
|
|
429
|
+
|
|
430
|
+
# Try to list experiments to test connection
|
|
431
|
+
experiments = client.search_experiments(max_results=1)
|
|
432
|
+
|
|
433
|
+
self.write({
|
|
434
|
+
"success": True,
|
|
435
|
+
"message": "Connection successful",
|
|
436
|
+
"experiment_count": len(experiments)
|
|
437
|
+
})
|
|
438
|
+
except Exception as e:
|
|
439
|
+
self.set_status(500)
|
|
440
|
+
self.write({
|
|
441
|
+
"success": False,
|
|
442
|
+
"error": str(e)
|
|
443
|
+
})
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
class LocalMLflowServerHandler(MLflowBaseHandler):
|
|
447
|
+
"""Handler for managing local MLflow server"""
|
|
448
|
+
|
|
449
|
+
def get(self):
|
|
450
|
+
"""Get status of local MLflow server"""
|
|
451
|
+
try:
|
|
452
|
+
status = get_mlflow_server_status()
|
|
453
|
+
self.write(status)
|
|
454
|
+
except Exception as e:
|
|
455
|
+
self.set_status(500)
|
|
456
|
+
self.write({
|
|
457
|
+
"success": False,
|
|
458
|
+
"error": str(e)
|
|
459
|
+
})
|
|
460
|
+
|
|
461
|
+
def post(self):
|
|
462
|
+
"""Start local MLflow server"""
|
|
463
|
+
try:
|
|
464
|
+
body = json.loads(self.request.body.decode("utf-8"))
|
|
465
|
+
port = body.get("port", 5000)
|
|
466
|
+
tracking_uri = body.get("tracking_uri", "sqlite:///mlflow.db")
|
|
467
|
+
artifact_uri = body.get("artifact_uri")
|
|
468
|
+
backend_uri = body.get("backend_uri")
|
|
469
|
+
|
|
470
|
+
# Convert empty strings to None
|
|
471
|
+
if artifact_uri == "":
|
|
472
|
+
artifact_uri = None
|
|
473
|
+
if backend_uri == "":
|
|
474
|
+
backend_uri = None
|
|
475
|
+
|
|
476
|
+
# Log the request for debugging
|
|
477
|
+
import logging
|
|
478
|
+
logger = logging.getLogger(__name__)
|
|
479
|
+
logger.info(f"Starting local MLflow server: port={port}, tracking_uri={tracking_uri}, artifact_uri={artifact_uri}")
|
|
480
|
+
|
|
481
|
+
result = start_mlflow_server(
|
|
482
|
+
port=port,
|
|
483
|
+
tracking_uri=tracking_uri,
|
|
484
|
+
artifact_uri=artifact_uri,
|
|
485
|
+
backend_uri=backend_uri
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
if not result.get("success"):
|
|
489
|
+
logger.error(f"Failed to start MLflow server: {result.get('error', 'Unknown error')}")
|
|
490
|
+
self.set_status(500)
|
|
491
|
+
else:
|
|
492
|
+
logger.info(f"Successfully started MLflow server: {result.get('url')}")
|
|
493
|
+
|
|
494
|
+
self.write(result)
|
|
495
|
+
except json.JSONDecodeError as e:
|
|
496
|
+
self.set_status(400)
|
|
497
|
+
self.write({
|
|
498
|
+
"success": False,
|
|
499
|
+
"error": f"Invalid JSON in request body: {str(e)}"
|
|
500
|
+
})
|
|
501
|
+
except Exception as e:
|
|
502
|
+
import logging
|
|
503
|
+
logger = logging.getLogger(__name__)
|
|
504
|
+
logger.error(f"Error starting local MLflow server: {e}", exc_info=True)
|
|
505
|
+
self.set_status(500)
|
|
506
|
+
self.write({
|
|
507
|
+
"success": False,
|
|
508
|
+
"error": str(e)
|
|
509
|
+
})
|
|
510
|
+
|
|
511
|
+
def delete(self):
|
|
512
|
+
"""Stop local MLflow server"""
|
|
513
|
+
try:
|
|
514
|
+
result = stop_mlflow_server()
|
|
515
|
+
if not result.get("success"):
|
|
516
|
+
self.set_status(500)
|
|
517
|
+
self.write(result)
|
|
518
|
+
except Exception as e:
|
|
519
|
+
self.set_status(500)
|
|
520
|
+
self.write({
|
|
521
|
+
"success": False,
|
|
522
|
+
"error": str(e)
|
|
523
|
+
})
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
def setup_handlers(web_app):
|
|
527
|
+
"""Setup API handlers"""
|
|
528
|
+
import re
|
|
529
|
+
|
|
530
|
+
host_pattern = ".*$"
|
|
531
|
+
|
|
532
|
+
base_url = web_app.settings.get("base_url", "/")
|
|
533
|
+
# Ensure base_url ends with / for proper path joining
|
|
534
|
+
if not base_url.endswith("/"):
|
|
535
|
+
base_url = base_url + "/"
|
|
536
|
+
|
|
537
|
+
# Escape base_url for use in regex patterns (Tornado uses regex)
|
|
538
|
+
# This handles cases where base_url might contain special regex characters
|
|
539
|
+
escaped_base_url = re.escape(base_url)
|
|
540
|
+
|
|
541
|
+
# Remove leading / from mlflow/api paths since base_url already includes it
|
|
542
|
+
handlers = [
|
|
543
|
+
# Health check endpoint (for diagnosing loading issues)
|
|
544
|
+
(rf"{escaped_base_url}mlflow/api/health", HealthCheckHandler),
|
|
545
|
+
# Main API endpoints
|
|
546
|
+
(rf"{escaped_base_url}mlflow/api/experiments", ExperimentsHandler),
|
|
547
|
+
(rf"{escaped_base_url}mlflow/api/experiments/([^/]+)", ExperimentHandler),
|
|
548
|
+
(rf"{escaped_base_url}mlflow/api/experiments/([^/]+)/runs", RunsHandler),
|
|
549
|
+
(rf"{escaped_base_url}mlflow/api/runs/([^/]+)", RunHandler),
|
|
550
|
+
(rf"{escaped_base_url}mlflow/api/runs/([^/]+)/artifacts", ArtifactsHandler),
|
|
551
|
+
(rf"{escaped_base_url}mlflow/api/runs/([^/]+)/artifacts/download", ArtifactDownloadHandler),
|
|
552
|
+
(rf"{escaped_base_url}mlflow/api/models", ModelsHandler),
|
|
553
|
+
(rf"{escaped_base_url}mlflow/api/models/([^/]+)", ModelHandler),
|
|
554
|
+
(rf"{escaped_base_url}mlflow/api/connection/test", ConnectionTestHandler),
|
|
555
|
+
(rf"{escaped_base_url}mlflow/api/local-server", LocalMLflowServerHandler),
|
|
556
|
+
]
|
|
557
|
+
|
|
558
|
+
web_app.add_handlers(host_pattern, handlers)
|
|
559
|
+
|
|
560
|
+
# Log registered handlers for debugging
|
|
561
|
+
import logging
|
|
562
|
+
logger = logging.getLogger(__name__)
|
|
563
|
+
logger.info(f"✅ Registered jupyterlab-mlflow API handlers with base_url: {base_url}")
|
|
564
|
+
# Also print to stderr for visibility in managed environments
|
|
565
|
+
import sys
|
|
566
|
+
print(f"✅ jupyterlab-mlflow: Registered {len(handlers)} API handlers with base_url: {base_url}", file=sys.stderr)
|
|
567
|
+
for pattern, handler in handlers:
|
|
568
|
+
logger.debug(f" - {pattern} -> {handler.__name__}")
|
|
569
|
+
print(f" - {pattern} -> {handler.__name__}", file=sys.stderr)
|
|
570
|
+
|