jupyterlab-mlflow 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jupyterlab_mlflow/__init__.py +22 -0
- jupyterlab_mlflow/_version.py +6 -0
- jupyterlab_mlflow/schema/plugin.json +17 -0
- jupyterlab_mlflow/serverextension/__init__.py +29 -0
- jupyterlab_mlflow/serverextension/handlers.py +511 -0
- jupyterlab_mlflow/serverextension/mlflow_server.py +214 -0
- jupyterlab_mlflow-0.1.0.data/data/share/jupyter/labextensions/jupyterlab-mlflow/install.json +12 -0
- jupyterlab_mlflow-0.1.0.data/data/share/jupyter/labextensions/jupyterlab-mlflow/package.json +97 -0
- jupyterlab_mlflow-0.1.0.data/data/share/jupyter/labextensions/jupyterlab-mlflow/schema/plugin.json +17 -0
- jupyterlab_mlflow-0.1.0.data/data/share/jupyter/labextensions/jupyterlab-mlflow/schemas/jupyterlab-mlflow/package.json.orig +93 -0
- jupyterlab_mlflow-0.1.0.data/data/share/jupyter/labextensions/jupyterlab-mlflow/schemas/jupyterlab-mlflow/plugin.json +17 -0
- jupyterlab_mlflow-0.1.0.data/data/share/jupyter/labextensions/jupyterlab-mlflow/static/218.0cf25be5c060df009f4a.js +1 -0
- jupyterlab_mlflow-0.1.0.data/data/share/jupyter/labextensions/jupyterlab-mlflow/static/665.f3ea36ea04224fd9c2f3.js +1 -0
- jupyterlab_mlflow-0.1.0.data/data/share/jupyter/labextensions/jupyterlab-mlflow/static/remoteEntry.cd2e48e97a1a6275a623.js +1 -0
- jupyterlab_mlflow-0.1.0.data/data/share/jupyter/labextensions/jupyterlab-mlflow/static/style.js +4 -0
- jupyterlab_mlflow-0.1.0.data/data/share/jupyter/labextensions/jupyterlab-mlflow/static/third-party-licenses.json +16 -0
- jupyterlab_mlflow-0.1.0.dist-info/METADATA +157 -0
- jupyterlab_mlflow-0.1.0.dist-info/RECORD +20 -0
- jupyterlab_mlflow-0.1.0.dist-info/WHEEL +4 -0
- jupyterlab_mlflow-0.1.0.dist-info/licenses/LICENSE +30 -0
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""
|
|
2
|
+
JupyterLab MLflow Extension
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from ._version import __version__
|
|
6
|
+
|
|
7
|
+
def _jupyter_labextension_paths():
|
|
8
|
+
"""Called by Jupyter Lab Server to detect if it is a valid labextension and
|
|
9
|
+
to install the widget
|
|
10
|
+
|
|
11
|
+
Returns
|
|
12
|
+
=======
|
|
13
|
+
src: Source directory name to copy files from. Webpack outputs generated files
|
|
14
|
+
into this directory and Jupyter Lab copies from this directory during
|
|
15
|
+
widget installation
|
|
16
|
+
dest: Destination directory name to install to
|
|
17
|
+
"""
|
|
18
|
+
return [{
|
|
19
|
+
'src': 'labextension',
|
|
20
|
+
'dest': 'jupyterlab-mlflow'
|
|
21
|
+
}]
|
|
22
|
+
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
{
|
|
2
|
+
"jupyter.lab.setting-icon": "mlflow:icon",
|
|
3
|
+
"jupyter.lab.setting-icon-label": "MLflow",
|
|
4
|
+
"title": "MLflow",
|
|
5
|
+
"description": "MLflow extension settings",
|
|
6
|
+
"type": "object",
|
|
7
|
+
"properties": {
|
|
8
|
+
"mlflowTrackingUri": {
|
|
9
|
+
"type": "string",
|
|
10
|
+
"title": "MLflow Tracking URI",
|
|
11
|
+
"description": "URI of the MLflow tracking server (e.g., http://localhost:5000). Leave empty to use MLFLOW_TRACKING_URI environment variable.",
|
|
12
|
+
"default": ""
|
|
13
|
+
}
|
|
14
|
+
},
|
|
15
|
+
"additionalProperties": false
|
|
16
|
+
}
|
|
17
|
+
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
"""
|
|
2
|
+
JupyterLab MLflow Server Extension
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from .handlers import setup_handlers
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _jupyter_server_extension_points():
|
|
10
|
+
"""
|
|
11
|
+
Returns a list of dictionaries with metadata about
|
|
12
|
+
the server extension points.
|
|
13
|
+
"""
|
|
14
|
+
return [{
|
|
15
|
+
"module": "jupyterlab_mlflow.serverextension"
|
|
16
|
+
}]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _load_jupyter_server_extension(server_app):
|
|
20
|
+
"""Registers the API handler to receive HTTP requests from the frontend extension.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
server_app: jupyter_server.serverapp.ServerApp
|
|
25
|
+
Jupyter Server application instance
|
|
26
|
+
"""
|
|
27
|
+
setup_handlers(server_app.web_app)
|
|
28
|
+
server_app.log.info("Registered jupyterlab-mlflow server extension")
|
|
29
|
+
|
|
@@ -0,0 +1,511 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MLflow API handlers for JupyterLab extension
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
from typing import Dict, Any, Optional
|
|
8
|
+
from urllib.parse import urlparse
|
|
9
|
+
|
|
10
|
+
import mlflow
|
|
11
|
+
from mlflow.tracking import MlflowClient
|
|
12
|
+
from mlflow.exceptions import MlflowException
|
|
13
|
+
from tornado import web
|
|
14
|
+
from tornado.web import RequestHandler
|
|
15
|
+
from .mlflow_server import start_mlflow_server, stop_mlflow_server, get_mlflow_server_status
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_mlflow_client(tracking_uri: Optional[str] = None) -> MlflowClient:
|
|
19
|
+
"""
|
|
20
|
+
Get MLflow client with tracking URI from settings or environment.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
tracking_uri : str, optional
|
|
25
|
+
MLflow tracking URI. If None, uses environment variable or default.
|
|
26
|
+
|
|
27
|
+
Returns
|
|
28
|
+
-------
|
|
29
|
+
MlflowClient
|
|
30
|
+
Configured MLflow client
|
|
31
|
+
"""
|
|
32
|
+
if tracking_uri:
|
|
33
|
+
mlflow.set_tracking_uri(tracking_uri)
|
|
34
|
+
elif os.environ.get("MLFLOW_TRACKING_URI"):
|
|
35
|
+
mlflow.set_tracking_uri(os.environ["MLFLOW_TRACKING_URI"])
|
|
36
|
+
|
|
37
|
+
return MlflowClient()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class MLflowBaseHandler(RequestHandler):
|
|
41
|
+
"""Base handler for MLflow API endpoints"""
|
|
42
|
+
|
|
43
|
+
def set_default_headers(self):
|
|
44
|
+
"""Set CORS headers"""
|
|
45
|
+
self.set_header("Access-Control-Allow-Origin", "*")
|
|
46
|
+
self.set_header("Access-Control-Allow-Headers", "Content-Type")
|
|
47
|
+
self.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
|
48
|
+
|
|
49
|
+
def options(self):
|
|
50
|
+
"""Handle OPTIONS request for CORS"""
|
|
51
|
+
self.set_status(204)
|
|
52
|
+
self.finish()
|
|
53
|
+
|
|
54
|
+
def get_tracking_uri(self) -> Optional[str]:
|
|
55
|
+
"""Get tracking URI from request or settings"""
|
|
56
|
+
tracking_uri = self.get_query_argument("tracking_uri", None)
|
|
57
|
+
if tracking_uri:
|
|
58
|
+
return tracking_uri
|
|
59
|
+
|
|
60
|
+
# Try to get from request body for POST requests
|
|
61
|
+
if self.request.method == "POST":
|
|
62
|
+
try:
|
|
63
|
+
body = json.loads(self.request.body.decode("utf-8"))
|
|
64
|
+
return body.get("tracking_uri")
|
|
65
|
+
except (json.JSONDecodeError, KeyError):
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
return None
|
|
69
|
+
|
|
70
|
+
def write_error(self, status_code: int, **kwargs):
|
|
71
|
+
"""Write error response"""
|
|
72
|
+
exc_info = kwargs.get("exc_info")
|
|
73
|
+
if exc_info:
|
|
74
|
+
exception = exc_info[1]
|
|
75
|
+
if isinstance(exception, MlflowException):
|
|
76
|
+
self.write({
|
|
77
|
+
"error": str(exception),
|
|
78
|
+
"status_code": status_code
|
|
79
|
+
})
|
|
80
|
+
return
|
|
81
|
+
|
|
82
|
+
self.write({
|
|
83
|
+
"error": f"HTTP {status_code}: {self._reason}",
|
|
84
|
+
"status_code": status_code
|
|
85
|
+
})
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class ExperimentsHandler(MLflowBaseHandler):
|
|
89
|
+
"""Handler for listing experiments"""
|
|
90
|
+
|
|
91
|
+
def get(self):
|
|
92
|
+
"""Get list of experiments"""
|
|
93
|
+
try:
|
|
94
|
+
tracking_uri = self.get_tracking_uri()
|
|
95
|
+
client = get_mlflow_client(tracking_uri)
|
|
96
|
+
|
|
97
|
+
experiments = client.search_experiments()
|
|
98
|
+
experiments_data = []
|
|
99
|
+
|
|
100
|
+
for exp in experiments:
|
|
101
|
+
experiments_data.append({
|
|
102
|
+
"experiment_id": exp.experiment_id,
|
|
103
|
+
"name": exp.name,
|
|
104
|
+
"artifact_location": exp.artifact_location,
|
|
105
|
+
"lifecycle_stage": exp.lifecycle_stage,
|
|
106
|
+
"tags": exp.tags or {}
|
|
107
|
+
})
|
|
108
|
+
|
|
109
|
+
self.write({"experiments": experiments_data})
|
|
110
|
+
except Exception as e:
|
|
111
|
+
self.set_status(500)
|
|
112
|
+
self.write({"error": str(e)})
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class ExperimentHandler(MLflowBaseHandler):
|
|
116
|
+
"""Handler for getting experiment details"""
|
|
117
|
+
|
|
118
|
+
def get(self, experiment_id: str):
|
|
119
|
+
"""Get experiment details"""
|
|
120
|
+
try:
|
|
121
|
+
tracking_uri = self.get_tracking_uri()
|
|
122
|
+
client = get_mlflow_client(tracking_uri)
|
|
123
|
+
|
|
124
|
+
experiment = client.get_experiment(experiment_id)
|
|
125
|
+
|
|
126
|
+
self.write({
|
|
127
|
+
"experiment_id": experiment.experiment_id,
|
|
128
|
+
"name": experiment.name,
|
|
129
|
+
"artifact_location": experiment.artifact_location,
|
|
130
|
+
"lifecycle_stage": experiment.lifecycle_stage,
|
|
131
|
+
"tags": experiment.tags or {}
|
|
132
|
+
})
|
|
133
|
+
except Exception as e:
|
|
134
|
+
self.set_status(500)
|
|
135
|
+
self.write({"error": str(e)})
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class RunsHandler(MLflowBaseHandler):
|
|
139
|
+
"""Handler for listing runs"""
|
|
140
|
+
|
|
141
|
+
def get(self, experiment_id: str):
|
|
142
|
+
"""Get list of runs for an experiment"""
|
|
143
|
+
try:
|
|
144
|
+
tracking_uri = self.get_tracking_uri()
|
|
145
|
+
client = get_mlflow_client(tracking_uri)
|
|
146
|
+
|
|
147
|
+
runs = client.search_runs(
|
|
148
|
+
experiment_ids=[experiment_id],
|
|
149
|
+
max_results=1000
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
runs_data = []
|
|
153
|
+
for run in runs:
|
|
154
|
+
runs_data.append({
|
|
155
|
+
"run_id": run.info.run_id,
|
|
156
|
+
"run_name": run.info.run_name,
|
|
157
|
+
"experiment_id": run.info.experiment_id,
|
|
158
|
+
"status": run.info.status,
|
|
159
|
+
"start_time": run.info.start_time,
|
|
160
|
+
"end_time": run.info.end_time,
|
|
161
|
+
"user_id": run.info.user_id,
|
|
162
|
+
"metrics": {k: v for k, v in run.data.metrics.items()},
|
|
163
|
+
"params": {k: v for k, v in run.data.params.items()},
|
|
164
|
+
"tags": {k: v for k, v in run.data.tags.items()},
|
|
165
|
+
"artifact_uri": run.info.artifact_uri
|
|
166
|
+
})
|
|
167
|
+
|
|
168
|
+
self.write({"runs": runs_data})
|
|
169
|
+
except Exception as e:
|
|
170
|
+
self.set_status(500)
|
|
171
|
+
self.write({"error": str(e)})
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class RunHandler(MLflowBaseHandler):
|
|
175
|
+
"""Handler for getting run details"""
|
|
176
|
+
|
|
177
|
+
def get(self, run_id: str):
|
|
178
|
+
"""Get run details"""
|
|
179
|
+
try:
|
|
180
|
+
tracking_uri = self.get_tracking_uri()
|
|
181
|
+
client = get_mlflow_client(tracking_uri)
|
|
182
|
+
|
|
183
|
+
run = client.get_run(run_id)
|
|
184
|
+
|
|
185
|
+
self.write({
|
|
186
|
+
"run_id": run.info.run_id,
|
|
187
|
+
"run_name": run.info.run_name,
|
|
188
|
+
"experiment_id": run.info.experiment_id,
|
|
189
|
+
"status": run.info.status,
|
|
190
|
+
"start_time": run.info.start_time,
|
|
191
|
+
"end_time": run.info.end_time,
|
|
192
|
+
"user_id": run.info.user_id,
|
|
193
|
+
"metrics": {k: v for k, v in run.data.metrics.items()},
|
|
194
|
+
"params": {k: v for k, v in run.data.params.items()},
|
|
195
|
+
"tags": {k: v for k, v in run.data.tags.items()},
|
|
196
|
+
"artifact_uri": run.info.artifact_uri
|
|
197
|
+
})
|
|
198
|
+
except Exception as e:
|
|
199
|
+
self.set_status(500)
|
|
200
|
+
self.write({"error": str(e)})
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class ArtifactsHandler(MLflowBaseHandler):
|
|
204
|
+
"""Handler for listing artifacts"""
|
|
205
|
+
|
|
206
|
+
def get(self, run_id: str):
|
|
207
|
+
"""Get list of artifacts for a run"""
|
|
208
|
+
try:
|
|
209
|
+
tracking_uri = self.get_tracking_uri()
|
|
210
|
+
client = get_mlflow_client(tracking_uri)
|
|
211
|
+
|
|
212
|
+
run = client.get_run(run_id)
|
|
213
|
+
artifact_uri = run.info.artifact_uri
|
|
214
|
+
|
|
215
|
+
# Get optional path parameter for listing artifacts in a directory
|
|
216
|
+
path = self.get_query_argument("path", None)
|
|
217
|
+
|
|
218
|
+
# List artifacts using MLflow client
|
|
219
|
+
artifacts = []
|
|
220
|
+
try:
|
|
221
|
+
if path:
|
|
222
|
+
artifact_list = client.list_artifacts(run_id, path)
|
|
223
|
+
else:
|
|
224
|
+
artifact_list = client.list_artifacts(run_id)
|
|
225
|
+
for artifact in artifact_list:
|
|
226
|
+
artifacts.append({
|
|
227
|
+
"path": artifact.path,
|
|
228
|
+
"is_dir": artifact.is_dir,
|
|
229
|
+
"file_size": artifact.file_size if hasattr(artifact, 'file_size') else None
|
|
230
|
+
})
|
|
231
|
+
except Exception as e:
|
|
232
|
+
# If list_artifacts fails, return basic info
|
|
233
|
+
self.log.warning(f"Could not list artifacts: {e}")
|
|
234
|
+
|
|
235
|
+
self.write({
|
|
236
|
+
"run_id": run_id,
|
|
237
|
+
"artifact_uri": artifact_uri,
|
|
238
|
+
"artifacts": artifacts
|
|
239
|
+
})
|
|
240
|
+
except Exception as e:
|
|
241
|
+
self.set_status(500)
|
|
242
|
+
self.write({"error": str(e)})
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
class ArtifactDownloadHandler(MLflowBaseHandler):
|
|
246
|
+
"""Handler for downloading artifacts"""
|
|
247
|
+
|
|
248
|
+
def get(self, run_id: str):
|
|
249
|
+
"""Download an artifact"""
|
|
250
|
+
try:
|
|
251
|
+
path = self.get_query_argument("path", "")
|
|
252
|
+
if not path:
|
|
253
|
+
self.set_status(400)
|
|
254
|
+
self.write({"error": "path parameter is required"})
|
|
255
|
+
return
|
|
256
|
+
|
|
257
|
+
tracking_uri = self.get_tracking_uri()
|
|
258
|
+
client = get_mlflow_client(tracking_uri)
|
|
259
|
+
|
|
260
|
+
# Check if artifact is a directory first
|
|
261
|
+
try:
|
|
262
|
+
artifacts = client.list_artifacts(run_id, path)
|
|
263
|
+
# If list_artifacts returns items, it's a directory
|
|
264
|
+
if artifacts:
|
|
265
|
+
self.set_status(400)
|
|
266
|
+
self.write({"error": f"'{path}' is a directory. Cannot download directories. Expand to see files inside."})
|
|
267
|
+
return
|
|
268
|
+
except Exception:
|
|
269
|
+
# If list_artifacts fails, try to download anyway
|
|
270
|
+
pass
|
|
271
|
+
|
|
272
|
+
# Check if the path itself is a directory by trying to list it
|
|
273
|
+
try:
|
|
274
|
+
# List parent directory to check if this path is a directory
|
|
275
|
+
parent_path = os.path.dirname(path) if os.path.dirname(path) else None
|
|
276
|
+
if parent_path:
|
|
277
|
+
parent_artifacts = client.list_artifacts(run_id, parent_path)
|
|
278
|
+
for art in parent_artifacts:
|
|
279
|
+
if art.path == path and art.is_dir:
|
|
280
|
+
self.set_status(400)
|
|
281
|
+
self.write({"error": f"'{path}' is a directory. Cannot download directories. Expand to see files inside."})
|
|
282
|
+
return
|
|
283
|
+
except Exception:
|
|
284
|
+
pass
|
|
285
|
+
|
|
286
|
+
# Download artifact
|
|
287
|
+
artifact_path = client.download_artifacts(run_id, path)
|
|
288
|
+
|
|
289
|
+
# Check if downloaded path is actually a directory
|
|
290
|
+
if os.path.isdir(artifact_path):
|
|
291
|
+
self.set_status(400)
|
|
292
|
+
self.write({"error": f"'{path}' is a directory. Cannot download directories. Expand to see files inside."})
|
|
293
|
+
return
|
|
294
|
+
|
|
295
|
+
# Read and return file content
|
|
296
|
+
with open(artifact_path, "rb") as f:
|
|
297
|
+
content = f.read()
|
|
298
|
+
|
|
299
|
+
# Determine content type
|
|
300
|
+
content_type = "application/octet-stream"
|
|
301
|
+
if path.endswith(".json"):
|
|
302
|
+
content_type = "application/json"
|
|
303
|
+
elif path.endswith(".csv"):
|
|
304
|
+
content_type = "text/csv"
|
|
305
|
+
elif path.endswith(".txt") or path.endswith(".log"):
|
|
306
|
+
content_type = "text/plain"
|
|
307
|
+
elif path.endswith(".png"):
|
|
308
|
+
content_type = "image/png"
|
|
309
|
+
elif path.endswith(".jpg") or path.endswith(".jpeg"):
|
|
310
|
+
content_type = "image/jpeg"
|
|
311
|
+
elif path.endswith(".html"):
|
|
312
|
+
content_type = "text/html"
|
|
313
|
+
|
|
314
|
+
self.set_header("Content-Type", content_type)
|
|
315
|
+
self.set_header("Content-Disposition", f'attachment; filename="{os.path.basename(path)}"')
|
|
316
|
+
self.write(content)
|
|
317
|
+
except Exception as e:
|
|
318
|
+
error_msg = str(e)
|
|
319
|
+
# Check if error is about directory
|
|
320
|
+
if "Is a directory" in error_msg or "[Errno 21]" in error_msg:
|
|
321
|
+
self.set_status(400)
|
|
322
|
+
self.write({"error": f"'{path}' is a directory. Cannot download directories. Expand to see files inside."})
|
|
323
|
+
else:
|
|
324
|
+
self.set_status(500)
|
|
325
|
+
self.write({"error": error_msg})
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
class ModelsHandler(MLflowBaseHandler):
|
|
329
|
+
"""Handler for listing models from model registry"""
|
|
330
|
+
|
|
331
|
+
def get(self):
|
|
332
|
+
"""Get list of registered models"""
|
|
333
|
+
try:
|
|
334
|
+
tracking_uri = self.get_tracking_uri()
|
|
335
|
+
client = get_mlflow_client(tracking_uri)
|
|
336
|
+
|
|
337
|
+
# Get all registered models
|
|
338
|
+
models = client.search_registered_models()
|
|
339
|
+
|
|
340
|
+
models_data = []
|
|
341
|
+
for model in models:
|
|
342
|
+
models_data.append({
|
|
343
|
+
"name": model.name,
|
|
344
|
+
"latest_versions": [
|
|
345
|
+
{
|
|
346
|
+
"version": v.version,
|
|
347
|
+
"stage": v.current_stage,
|
|
348
|
+
"status": v.status,
|
|
349
|
+
"run_id": v.run_id,
|
|
350
|
+
"creation_timestamp": v.creation_timestamp
|
|
351
|
+
}
|
|
352
|
+
for v in model.latest_versions
|
|
353
|
+
],
|
|
354
|
+
"creation_timestamp": model.creation_timestamp,
|
|
355
|
+
"last_updated_timestamp": model.last_updated_timestamp,
|
|
356
|
+
"description": model.description,
|
|
357
|
+
"tags": model.tags or {}
|
|
358
|
+
})
|
|
359
|
+
|
|
360
|
+
self.write({"models": models_data})
|
|
361
|
+
except Exception as e:
|
|
362
|
+
self.set_status(500)
|
|
363
|
+
self.write({"error": str(e)})
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
class ModelHandler(MLflowBaseHandler):
|
|
367
|
+
"""Handler for getting model details"""
|
|
368
|
+
|
|
369
|
+
def get(self, model_name: str):
|
|
370
|
+
"""Get model details"""
|
|
371
|
+
try:
|
|
372
|
+
tracking_uri = self.get_tracking_uri()
|
|
373
|
+
client = get_mlflow_client(tracking_uri)
|
|
374
|
+
|
|
375
|
+
model = client.get_registered_model(model_name)
|
|
376
|
+
|
|
377
|
+
# Get all versions
|
|
378
|
+
versions = []
|
|
379
|
+
for version in model.latest_versions:
|
|
380
|
+
versions.append({
|
|
381
|
+
"version": version.version,
|
|
382
|
+
"stage": version.current_stage,
|
|
383
|
+
"status": version.status,
|
|
384
|
+
"run_id": version.run_id,
|
|
385
|
+
"creation_timestamp": version.creation_timestamp,
|
|
386
|
+
"description": version.description
|
|
387
|
+
})
|
|
388
|
+
|
|
389
|
+
self.write({
|
|
390
|
+
"name": model.name,
|
|
391
|
+
"versions": versions,
|
|
392
|
+
"creation_timestamp": model.creation_timestamp,
|
|
393
|
+
"last_updated_timestamp": model.last_updated_timestamp,
|
|
394
|
+
"description": model.description,
|
|
395
|
+
"tags": model.tags or {}
|
|
396
|
+
})
|
|
397
|
+
except Exception as e:
|
|
398
|
+
self.set_status(500)
|
|
399
|
+
self.write({"error": str(e)})
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
class ConnectionTestHandler(MLflowBaseHandler):
|
|
403
|
+
"""Handler for testing MLflow connection"""
|
|
404
|
+
|
|
405
|
+
def post(self):
|
|
406
|
+
"""Test connection to MLflow server"""
|
|
407
|
+
try:
|
|
408
|
+
body = json.loads(self.request.body.decode("utf-8"))
|
|
409
|
+
tracking_uri = body.get("tracking_uri")
|
|
410
|
+
|
|
411
|
+
if not tracking_uri:
|
|
412
|
+
self.set_status(400)
|
|
413
|
+
self.write({"error": "tracking_uri is required"})
|
|
414
|
+
return
|
|
415
|
+
|
|
416
|
+
client = get_mlflow_client(tracking_uri)
|
|
417
|
+
|
|
418
|
+
# Try to list experiments to test connection
|
|
419
|
+
experiments = client.search_experiments(max_results=1)
|
|
420
|
+
|
|
421
|
+
self.write({
|
|
422
|
+
"success": True,
|
|
423
|
+
"message": "Connection successful",
|
|
424
|
+
"experiment_count": len(experiments)
|
|
425
|
+
})
|
|
426
|
+
except Exception as e:
|
|
427
|
+
self.set_status(500)
|
|
428
|
+
self.write({
|
|
429
|
+
"success": False,
|
|
430
|
+
"error": str(e)
|
|
431
|
+
})
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
class LocalMLflowServerHandler(MLflowBaseHandler):
|
|
435
|
+
"""Handler for managing local MLflow server"""
|
|
436
|
+
|
|
437
|
+
def get(self):
|
|
438
|
+
"""Get status of local MLflow server"""
|
|
439
|
+
try:
|
|
440
|
+
status = get_mlflow_server_status()
|
|
441
|
+
self.write(status)
|
|
442
|
+
except Exception as e:
|
|
443
|
+
self.set_status(500)
|
|
444
|
+
self.write({
|
|
445
|
+
"success": False,
|
|
446
|
+
"error": str(e)
|
|
447
|
+
})
|
|
448
|
+
|
|
449
|
+
def post(self):
|
|
450
|
+
"""Start local MLflow server"""
|
|
451
|
+
try:
|
|
452
|
+
body = json.loads(self.request.body.decode("utf-8"))
|
|
453
|
+
port = body.get("port", 5000)
|
|
454
|
+
tracking_uri = body.get("tracking_uri", "sqlite:///mlflow.db")
|
|
455
|
+
artifact_uri = body.get("artifact_uri")
|
|
456
|
+
backend_uri = body.get("backend_uri")
|
|
457
|
+
|
|
458
|
+
result = start_mlflow_server(
|
|
459
|
+
port=port,
|
|
460
|
+
tracking_uri=tracking_uri,
|
|
461
|
+
artifact_uri=artifact_uri,
|
|
462
|
+
backend_uri=backend_uri
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
if not result.get("success"):
|
|
466
|
+
self.set_status(500)
|
|
467
|
+
|
|
468
|
+
self.write(result)
|
|
469
|
+
except Exception as e:
|
|
470
|
+
self.set_status(500)
|
|
471
|
+
self.write({
|
|
472
|
+
"success": False,
|
|
473
|
+
"error": str(e)
|
|
474
|
+
})
|
|
475
|
+
|
|
476
|
+
def delete(self):
|
|
477
|
+
"""Stop local MLflow server"""
|
|
478
|
+
try:
|
|
479
|
+
result = stop_mlflow_server()
|
|
480
|
+
if not result.get("success"):
|
|
481
|
+
self.set_status(500)
|
|
482
|
+
self.write(result)
|
|
483
|
+
except Exception as e:
|
|
484
|
+
self.set_status(500)
|
|
485
|
+
self.write({
|
|
486
|
+
"success": False,
|
|
487
|
+
"error": str(e)
|
|
488
|
+
})
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
def setup_handlers(web_app):
|
|
492
|
+
"""Setup API handlers"""
|
|
493
|
+
host_pattern = ".*$"
|
|
494
|
+
|
|
495
|
+
base_url = web_app.settings.get("base_url", "/")
|
|
496
|
+
|
|
497
|
+
handlers = [
|
|
498
|
+
(f"{base_url}mlflow/api/experiments", ExperimentsHandler),
|
|
499
|
+
(f"{base_url}mlflow/api/experiments/([^/]+)", ExperimentHandler),
|
|
500
|
+
(f"{base_url}mlflow/api/experiments/([^/]+)/runs", RunsHandler),
|
|
501
|
+
(f"{base_url}mlflow/api/runs/([^/]+)", RunHandler),
|
|
502
|
+
(f"{base_url}mlflow/api/runs/([^/]+)/artifacts", ArtifactsHandler),
|
|
503
|
+
(f"{base_url}mlflow/api/runs/([^/]+)/artifacts/download", ArtifactDownloadHandler),
|
|
504
|
+
(f"{base_url}mlflow/api/models", ModelsHandler),
|
|
505
|
+
(f"{base_url}mlflow/api/models/([^/]+)", ModelHandler),
|
|
506
|
+
(f"{base_url}mlflow/api/connection/test", ConnectionTestHandler),
|
|
507
|
+
(f"{base_url}mlflow/api/local-server", LocalMLflowServerHandler),
|
|
508
|
+
]
|
|
509
|
+
|
|
510
|
+
web_app.add_handlers(host_pattern, handlers)
|
|
511
|
+
|