mlrun 1.4.0rc25__py3-none-any.whl → 1.5.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlrun might be problematic. Click here for more details.
- mlrun/__init__.py +2 -35
- mlrun/__main__.py +3 -41
- mlrun/api/api/api.py +6 -0
- mlrun/api/api/endpoints/feature_store.py +0 -4
- mlrun/api/api/endpoints/files.py +14 -2
- mlrun/api/api/endpoints/frontend_spec.py +2 -1
- mlrun/api/api/endpoints/functions.py +95 -59
- mlrun/api/api/endpoints/grafana_proxy.py +9 -9
- mlrun/api/api/endpoints/logs.py +17 -3
- mlrun/api/api/endpoints/model_endpoints.py +3 -2
- mlrun/api/api/endpoints/pipelines.py +1 -5
- mlrun/api/api/endpoints/projects.py +88 -0
- mlrun/api/api/endpoints/runs.py +48 -6
- mlrun/api/api/endpoints/submit.py +2 -1
- mlrun/api/api/endpoints/workflows.py +355 -0
- mlrun/api/api/utils.py +3 -4
- mlrun/api/crud/__init__.py +1 -0
- mlrun/api/crud/client_spec.py +6 -2
- mlrun/api/crud/feature_store.py +5 -0
- mlrun/api/crud/model_monitoring/__init__.py +1 -0
- mlrun/api/crud/model_monitoring/deployment.py +497 -0
- mlrun/api/crud/model_monitoring/grafana.py +96 -42
- mlrun/api/crud/model_monitoring/helpers.py +159 -0
- mlrun/api/crud/model_monitoring/model_endpoints.py +202 -476
- mlrun/api/crud/notifications.py +9 -4
- mlrun/api/crud/pipelines.py +6 -11
- mlrun/api/crud/projects.py +2 -2
- mlrun/api/crud/runtime_resources.py +4 -3
- mlrun/api/crud/runtimes/nuclio/helpers.py +5 -1
- mlrun/api/crud/secrets.py +21 -0
- mlrun/api/crud/workflows.py +352 -0
- mlrun/api/db/base.py +16 -1
- mlrun/api/db/init_db.py +2 -4
- mlrun/api/db/session.py +1 -1
- mlrun/api/db/sqldb/db.py +129 -31
- mlrun/api/db/sqldb/models/models_mysql.py +15 -1
- mlrun/api/db/sqldb/models/models_sqlite.py +16 -2
- mlrun/api/launcher.py +38 -6
- mlrun/api/main.py +3 -2
- mlrun/api/rundb/__init__.py +13 -0
- mlrun/{db → api/rundb}/sqldb.py +36 -84
- mlrun/api/runtime_handlers/__init__.py +56 -0
- mlrun/api/runtime_handlers/base.py +1247 -0
- mlrun/api/runtime_handlers/daskjob.py +209 -0
- mlrun/api/runtime_handlers/kubejob.py +37 -0
- mlrun/api/runtime_handlers/mpijob.py +147 -0
- mlrun/api/runtime_handlers/remotesparkjob.py +29 -0
- mlrun/api/runtime_handlers/sparkjob.py +148 -0
- mlrun/api/schemas/__init__.py +17 -6
- mlrun/api/utils/builder.py +1 -4
- mlrun/api/utils/clients/chief.py +14 -0
- mlrun/api/utils/clients/iguazio.py +33 -33
- mlrun/api/utils/clients/nuclio.py +2 -2
- mlrun/api/utils/periodic.py +9 -2
- mlrun/api/utils/projects/follower.py +14 -7
- mlrun/api/utils/projects/leader.py +2 -1
- mlrun/api/utils/projects/remotes/nop_follower.py +2 -2
- mlrun/api/utils/projects/remotes/nop_leader.py +2 -2
- mlrun/api/utils/runtimes/__init__.py +14 -0
- mlrun/api/utils/runtimes/nuclio.py +43 -0
- mlrun/api/utils/scheduler.py +98 -15
- mlrun/api/utils/singletons/db.py +5 -1
- mlrun/api/utils/singletons/project_member.py +4 -1
- mlrun/api/utils/singletons/scheduler.py +1 -1
- mlrun/artifacts/base.py +6 -6
- mlrun/artifacts/dataset.py +4 -4
- mlrun/artifacts/manager.py +2 -3
- mlrun/artifacts/model.py +2 -2
- mlrun/artifacts/plots.py +8 -8
- mlrun/common/db/__init__.py +14 -0
- mlrun/common/helpers.py +37 -0
- mlrun/{mlutils → common/model_monitoring}/__init__.py +3 -2
- mlrun/common/model_monitoring/helpers.py +69 -0
- mlrun/common/schemas/__init__.py +13 -1
- mlrun/common/schemas/auth.py +4 -1
- mlrun/common/schemas/client_spec.py +1 -1
- mlrun/common/schemas/function.py +17 -0
- mlrun/common/schemas/model_monitoring/__init__.py +48 -0
- mlrun/common/{model_monitoring.py → schemas/model_monitoring/constants.py} +11 -23
- mlrun/common/schemas/model_monitoring/grafana.py +55 -0
- mlrun/common/schemas/{model_endpoints.py → model_monitoring/model_endpoints.py} +32 -65
- mlrun/common/schemas/notification.py +1 -0
- mlrun/common/schemas/object.py +4 -0
- mlrun/common/schemas/project.py +1 -0
- mlrun/common/schemas/regex.py +1 -1
- mlrun/common/schemas/runs.py +1 -8
- mlrun/common/schemas/schedule.py +1 -8
- mlrun/common/schemas/workflow.py +54 -0
- mlrun/config.py +45 -42
- mlrun/datastore/__init__.py +21 -0
- mlrun/datastore/base.py +1 -1
- mlrun/datastore/datastore.py +9 -0
- mlrun/datastore/dbfs_store.py +168 -0
- mlrun/datastore/helpers.py +18 -0
- mlrun/datastore/sources.py +1 -0
- mlrun/datastore/store_resources.py +2 -5
- mlrun/datastore/v3io.py +1 -2
- mlrun/db/__init__.py +4 -68
- mlrun/db/base.py +12 -0
- mlrun/db/factory.py +65 -0
- mlrun/db/httpdb.py +175 -20
- mlrun/db/nopdb.py +4 -2
- mlrun/execution.py +4 -2
- mlrun/feature_store/__init__.py +1 -0
- mlrun/feature_store/api.py +1 -2
- mlrun/feature_store/common.py +2 -1
- mlrun/feature_store/feature_set.py +1 -11
- mlrun/feature_store/feature_vector.py +340 -2
- mlrun/feature_store/ingestion.py +5 -10
- mlrun/feature_store/retrieval/base.py +118 -104
- mlrun/feature_store/retrieval/dask_merger.py +17 -10
- mlrun/feature_store/retrieval/job.py +4 -1
- mlrun/feature_store/retrieval/local_merger.py +18 -18
- mlrun/feature_store/retrieval/spark_merger.py +21 -14
- mlrun/feature_store/retrieval/storey_merger.py +22 -16
- mlrun/kfpops.py +3 -9
- mlrun/launcher/base.py +57 -53
- mlrun/launcher/client.py +5 -4
- mlrun/launcher/factory.py +24 -13
- mlrun/launcher/local.py +6 -6
- mlrun/launcher/remote.py +4 -4
- mlrun/lists.py +0 -11
- mlrun/model.py +11 -17
- mlrun/model_monitoring/__init__.py +2 -22
- mlrun/model_monitoring/features_drift_table.py +1 -1
- mlrun/model_monitoring/helpers.py +22 -210
- mlrun/model_monitoring/model_endpoint.py +1 -1
- mlrun/model_monitoring/model_monitoring_batch.py +127 -50
- mlrun/model_monitoring/prometheus.py +219 -0
- mlrun/model_monitoring/stores/__init__.py +16 -11
- mlrun/model_monitoring/stores/kv_model_endpoint_store.py +95 -23
- mlrun/model_monitoring/stores/models/mysql.py +47 -29
- mlrun/model_monitoring/stores/models/sqlite.py +47 -29
- mlrun/model_monitoring/stores/sql_model_endpoint_store.py +31 -19
- mlrun/model_monitoring/{stream_processing_fs.py → stream_processing.py} +206 -64
- mlrun/model_monitoring/tracking_policy.py +104 -0
- mlrun/package/packager.py +6 -8
- mlrun/package/packagers/default_packager.py +121 -10
- mlrun/package/packagers/numpy_packagers.py +1 -1
- mlrun/platforms/__init__.py +0 -2
- mlrun/platforms/iguazio.py +0 -56
- mlrun/projects/pipelines.py +53 -159
- mlrun/projects/project.py +10 -37
- mlrun/render.py +1 -1
- mlrun/run.py +8 -124
- mlrun/runtimes/__init__.py +6 -42
- mlrun/runtimes/base.py +29 -1249
- mlrun/runtimes/daskjob.py +2 -198
- mlrun/runtimes/funcdoc.py +0 -9
- mlrun/runtimes/function.py +25 -29
- mlrun/runtimes/kubejob.py +5 -29
- mlrun/runtimes/local.py +1 -1
- mlrun/runtimes/mpijob/__init__.py +2 -2
- mlrun/runtimes/mpijob/abstract.py +10 -1
- mlrun/runtimes/mpijob/v1.py +0 -76
- mlrun/runtimes/mpijob/v1alpha1.py +1 -74
- mlrun/runtimes/nuclio.py +3 -2
- mlrun/runtimes/pod.py +28 -18
- mlrun/runtimes/remotesparkjob.py +1 -15
- mlrun/runtimes/serving.py +14 -6
- mlrun/runtimes/sparkjob/__init__.py +0 -1
- mlrun/runtimes/sparkjob/abstract.py +4 -131
- mlrun/runtimes/utils.py +0 -26
- mlrun/serving/routers.py +7 -7
- mlrun/serving/server.py +11 -8
- mlrun/serving/states.py +7 -1
- mlrun/serving/v2_serving.py +6 -6
- mlrun/utils/helpers.py +23 -42
- mlrun/utils/notifications/notification/__init__.py +4 -0
- mlrun/utils/notifications/notification/webhook.py +61 -0
- mlrun/utils/notifications/notification_pusher.py +5 -25
- mlrun/utils/regex.py +7 -2
- mlrun/utils/version/version.json +2 -2
- {mlrun-1.4.0rc25.dist-info → mlrun-1.5.0rc2.dist-info}/METADATA +26 -25
- {mlrun-1.4.0rc25.dist-info → mlrun-1.5.0rc2.dist-info}/RECORD +180 -158
- {mlrun-1.4.0rc25.dist-info → mlrun-1.5.0rc2.dist-info}/WHEEL +1 -1
- mlrun/mlutils/data.py +0 -160
- mlrun/mlutils/models.py +0 -78
- mlrun/mlutils/plots.py +0 -902
- mlrun/utils/model_monitoring.py +0 -249
- /mlrun/{api/db/sqldb/session.py → common/db/sql_session.py} +0 -0
- {mlrun-1.4.0rc25.dist-info → mlrun-1.5.0rc2.dist-info}/LICENSE +0 -0
- {mlrun-1.4.0rc25.dist-info → mlrun-1.5.0rc2.dist-info}/entry_points.txt +0 -0
- {mlrun-1.4.0rc25.dist-info → mlrun-1.5.0rc2.dist-info}/top_level.txt +0 -0
mlrun/mlutils/plots.py
DELETED
|
@@ -1,902 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 Iguazio
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
#
|
|
15
|
-
from itertools import cycle
|
|
16
|
-
from typing import List
|
|
17
|
-
|
|
18
|
-
import matplotlib.pyplot as plt
|
|
19
|
-
import numpy as np
|
|
20
|
-
import pandas as pd
|
|
21
|
-
import seaborn as sns
|
|
22
|
-
from deprecated import deprecated
|
|
23
|
-
from scikitplot.metrics import plot_calibration_curve
|
|
24
|
-
from scipy import interp
|
|
25
|
-
from sklearn import metrics
|
|
26
|
-
from sklearn.calibration import calibration_curve
|
|
27
|
-
from sklearn.metrics import confusion_matrix as sklearn_confusion_matrix
|
|
28
|
-
from sklearn.preprocessing import LabelBinarizer
|
|
29
|
-
|
|
30
|
-
from ..artifacts import PlotArtifact
|
|
31
|
-
|
|
32
|
-
# TODO: remove mlutils in 1.5.0
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
@deprecated(
|
|
36
|
-
version="1.3.0",
|
|
37
|
-
reason="'mlrun.mlutils' will be removed in 1.5.0, use 'mlrun.framework' instead",
|
|
38
|
-
category=FutureWarning,
|
|
39
|
-
)
|
|
40
|
-
def gcf_clear(plt):
|
|
41
|
-
"""Utility to clear matplotlib figure
|
|
42
|
-
Run this inside every plot method before calling any matplotlib
|
|
43
|
-
methods
|
|
44
|
-
:param plot: matloblib figure object
|
|
45
|
-
"""
|
|
46
|
-
plt.cla()
|
|
47
|
-
plt.clf()
|
|
48
|
-
plt.close()
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
@deprecated(
|
|
52
|
-
version="1.3.0",
|
|
53
|
-
reason="'mlrun.mlutils' will be removed in 1.5.0, use 'mlrun.framework' instead",
|
|
54
|
-
category=FutureWarning,
|
|
55
|
-
)
|
|
56
|
-
def feature_importances(model, header):
|
|
57
|
-
"""Display estimated feature importances
|
|
58
|
-
Only works for models with attribute 'feature_importances_`
|
|
59
|
-
:param model: fitted model
|
|
60
|
-
:param header: feature labels
|
|
61
|
-
"""
|
|
62
|
-
if not hasattr(model, "feature_importances_"):
|
|
63
|
-
raise Exception(
|
|
64
|
-
"feature importances are only available for some models, if you got "
|
|
65
|
-
"here then please make sure to check your estimated model for a "
|
|
66
|
-
"`feature_importances_` attribute before calling this method"
|
|
67
|
-
)
|
|
68
|
-
|
|
69
|
-
# create a feature importance table with desired labels
|
|
70
|
-
zipped = zip(model.feature_importances_, header)
|
|
71
|
-
feature_imp = pd.DataFrame(sorted(zipped), columns=["freq", "feature"]).sort_values(
|
|
72
|
-
by="freq", ascending=False
|
|
73
|
-
)
|
|
74
|
-
|
|
75
|
-
plt.clf() # gcf_clear(plt)
|
|
76
|
-
plt.figure()
|
|
77
|
-
sns.barplot(x="freq", y="feature", data=feature_imp)
|
|
78
|
-
plt.title("features")
|
|
79
|
-
plt.tight_layout()
|
|
80
|
-
|
|
81
|
-
return (
|
|
82
|
-
PlotArtifact(
|
|
83
|
-
"feature-importances", body=plt.gcf(), title="Feature Importances"
|
|
84
|
-
),
|
|
85
|
-
feature_imp,
|
|
86
|
-
)
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
@deprecated(
|
|
90
|
-
version="1.3.0",
|
|
91
|
-
reason="'mlrun.mlutils' will be removed in 1.5.0, use 'mlrun.framework' instead",
|
|
92
|
-
category=FutureWarning,
|
|
93
|
-
)
|
|
94
|
-
def plot_importance(
|
|
95
|
-
context, model, key: str = "feature-importances", plots_dest: str = "plots"
|
|
96
|
-
):
|
|
97
|
-
"""Display estimated feature importances
|
|
98
|
-
Only works for models with attribute 'feature_importances_`
|
|
99
|
-
|
|
100
|
-
**legacy version please deprecate in functions and demos**
|
|
101
|
-
|
|
102
|
-
:param context: function context
|
|
103
|
-
:param model: fitted model
|
|
104
|
-
:param key: key of feature importances plot and table in artifact
|
|
105
|
-
store
|
|
106
|
-
:param plots_dest: subfolder in artifact store
|
|
107
|
-
"""
|
|
108
|
-
if not hasattr(model, "feature_importances_"):
|
|
109
|
-
raise Exception("feature importaces are only available for some models")
|
|
110
|
-
|
|
111
|
-
# create a feature importance table with desired labels
|
|
112
|
-
zipped = zip(model.feature_importances_, context.header)
|
|
113
|
-
feature_imp = pd.DataFrame(sorted(zipped), columns=["freq", "feature"]).sort_values(
|
|
114
|
-
by="freq", ascending=False
|
|
115
|
-
)
|
|
116
|
-
|
|
117
|
-
gcf_clear(plt)
|
|
118
|
-
plt.figure(figsize=(20, 10))
|
|
119
|
-
sns.barplot(x="freq", y="feature", data=feature_imp)
|
|
120
|
-
plt.title("features")
|
|
121
|
-
plt.tight_layout()
|
|
122
|
-
|
|
123
|
-
fname = f"{plots_dest}/{key}.html"
|
|
124
|
-
context.log_artifact(PlotArtifact(key, body=plt.gcf()), local_path=fname)
|
|
125
|
-
|
|
126
|
-
# feature importances are also saved as a csv table (generally small):
|
|
127
|
-
fname = key + "-tbl.csv"
|
|
128
|
-
return context.log_dataset(key + "-tbl", df=feature_imp, local_path=fname)
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
@deprecated(
|
|
132
|
-
version="1.3.0",
|
|
133
|
-
reason="'mlrun.mlutils' will be removed in 1.5.0, use 'mlrun.framework' instead",
|
|
134
|
-
category=FutureWarning,
|
|
135
|
-
)
|
|
136
|
-
def learning_curves(model):
|
|
137
|
-
"""model class dependent
|
|
138
|
-
|
|
139
|
-
WIP
|
|
140
|
-
|
|
141
|
-
get training history plots for xgboost, lightgbm
|
|
142
|
-
|
|
143
|
-
returns list of PlotArtifacts, can be empty if no history
|
|
144
|
-
is found
|
|
145
|
-
"""
|
|
146
|
-
plots = []
|
|
147
|
-
|
|
148
|
-
# do this here and not in the call to learning_curve plots,
|
|
149
|
-
# this is default approach for xgboost and lightgbm
|
|
150
|
-
if hasattr(model, "evals_result"):
|
|
151
|
-
results = model.evals_result()
|
|
152
|
-
train_set = list(results.items())[0]
|
|
153
|
-
valid_set = list(results.items())[1]
|
|
154
|
-
|
|
155
|
-
learning_curves = pd.DataFrame(
|
|
156
|
-
{
|
|
157
|
-
"train_error": train_set[1]["error"],
|
|
158
|
-
"train_auc": train_set[1]["auc"],
|
|
159
|
-
"valid_error": valid_set[1]["error"],
|
|
160
|
-
"valid_auc": valid_set[1]["auc"],
|
|
161
|
-
}
|
|
162
|
-
)
|
|
163
|
-
|
|
164
|
-
plt.clf() # gcf_clear(plt)
|
|
165
|
-
fig, ax = plt.subplots()
|
|
166
|
-
plt.xlabel("# training examples")
|
|
167
|
-
plt.ylabel("auc")
|
|
168
|
-
plt.title("learning curve - auc")
|
|
169
|
-
ax.plot(learning_curves.train_auc, label="train")
|
|
170
|
-
ax.plot(learning_curves.valid_auc, label="valid")
|
|
171
|
-
ax.legend(loc="lower left")
|
|
172
|
-
plots.append(PlotArtifact("learning curve - auc", body=plt.gcf()))
|
|
173
|
-
|
|
174
|
-
plt.clf() # gcf_clear(plt)
|
|
175
|
-
fig, ax = plt.subplots()
|
|
176
|
-
plt.xlabel("# training examples")
|
|
177
|
-
plt.ylabel("error rate")
|
|
178
|
-
plt.title("learning curve - error")
|
|
179
|
-
ax.plot(learning_curves.train_error, label="train")
|
|
180
|
-
ax.plot(learning_curves.valid_error, label="valid")
|
|
181
|
-
ax.legend(loc="lower left")
|
|
182
|
-
plots.append(PlotArtifact("learning curve - taoot", body=plt.gcf()))
|
|
183
|
-
|
|
184
|
-
# elif some other model history api...
|
|
185
|
-
|
|
186
|
-
return plots
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
@deprecated(
|
|
190
|
-
version="1.3.0",
|
|
191
|
-
reason="'mlrun.mlutils' will be removed in 1.5.0, use 'mlrun.framework' instead",
|
|
192
|
-
category=FutureWarning,
|
|
193
|
-
)
|
|
194
|
-
def confusion_matrix(model, xtest, ytest, cmap="Blues"):
|
|
195
|
-
cmd = metrics.plot_confusion_matrix(
|
|
196
|
-
model,
|
|
197
|
-
xtest,
|
|
198
|
-
ytest,
|
|
199
|
-
normalize="all",
|
|
200
|
-
values_format=".2g",
|
|
201
|
-
cmap=plt.get_cmap(cmap),
|
|
202
|
-
)
|
|
203
|
-
# for now only 1, add different views to this array for display in UI
|
|
204
|
-
cmd.plot()
|
|
205
|
-
return PlotArtifact(
|
|
206
|
-
"confusion-matrix-normalized",
|
|
207
|
-
body=cmd.figure_,
|
|
208
|
-
title="Confusion Matrix - Normalized Plot",
|
|
209
|
-
)
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
@deprecated(
|
|
213
|
-
version="1.3.0",
|
|
214
|
-
reason="'mlrun.mlutils' will be removed in 1.5.0, use 'mlrun.framework' instead",
|
|
215
|
-
category=FutureWarning,
|
|
216
|
-
)
|
|
217
|
-
def precision_recall_multi(ytest_b, yprob, labels, scoring="micro"):
|
|
218
|
-
""""""
|
|
219
|
-
n_classes = len(labels)
|
|
220
|
-
|
|
221
|
-
precision = dict()
|
|
222
|
-
recall = dict()
|
|
223
|
-
avg_prec = dict()
|
|
224
|
-
for i in range(n_classes):
|
|
225
|
-
precision[i], recall[i], _ = metrics.precision_recall_curve(
|
|
226
|
-
ytest_b[:, i], yprob[:, i]
|
|
227
|
-
)
|
|
228
|
-
avg_prec[i] = metrics.average_precision_score(ytest_b[:, i], yprob[:, i])
|
|
229
|
-
precision["micro"], recall["micro"], _ = metrics.precision_recall_curve(
|
|
230
|
-
ytest_b.ravel(), yprob.ravel()
|
|
231
|
-
)
|
|
232
|
-
avg_prec["micro"] = metrics.average_precision_score(ytest_b, yprob, average="micro")
|
|
233
|
-
ap_micro = avg_prec["micro"]
|
|
234
|
-
# model_metrics.update({'precision-micro-avg-classes': ap_micro})
|
|
235
|
-
|
|
236
|
-
# gcf_clear(plt)
|
|
237
|
-
colors = cycle(["navy", "turquoise", "darkorange", "cornflowerblue", "teal"])
|
|
238
|
-
plt.figure()
|
|
239
|
-
f_scores = np.linspace(0.2, 0.8, num=4)
|
|
240
|
-
lines = []
|
|
241
|
-
labels = []
|
|
242
|
-
for f_score in f_scores:
|
|
243
|
-
x = np.linspace(0.01, 1)
|
|
244
|
-
y = f_score * x / (2 * x - f_score)
|
|
245
|
-
(l,) = plt.plot(x[y >= 0], y[y >= 0], color="gray", alpha=0.2)
|
|
246
|
-
plt.annotate(f"f1={f_score:0.1f}", xy=(0.9, y[45] + 0.02))
|
|
247
|
-
|
|
248
|
-
lines.append(l)
|
|
249
|
-
labels.append("iso-f1 curves")
|
|
250
|
-
(l,) = plt.plot(recall["micro"], precision["micro"], color="gold", lw=10)
|
|
251
|
-
lines.append(l)
|
|
252
|
-
labels.append(f"micro-average precision-recall (area = {ap_micro:0.2f})")
|
|
253
|
-
|
|
254
|
-
for i, color in zip(range(n_classes), colors):
|
|
255
|
-
(l,) = plt.plot(recall[i], precision[i], color=color, lw=2)
|
|
256
|
-
lines.append(l)
|
|
257
|
-
labels.append(f"precision-recall for class {i} (area = {avg_prec[i]:0.2f})")
|
|
258
|
-
|
|
259
|
-
# fig = plt.gcf()
|
|
260
|
-
# fig.subplots_adjust(bottom=0.25)
|
|
261
|
-
plt.xlim([0.0, 1.0])
|
|
262
|
-
plt.ylim([0.0, 1.05])
|
|
263
|
-
plt.xlabel("recall")
|
|
264
|
-
plt.ylabel("precision")
|
|
265
|
-
plt.title("precision recall - multiclass")
|
|
266
|
-
plt.legend(lines, labels, loc=(0, -0.41), prop=dict(size=10))
|
|
267
|
-
|
|
268
|
-
return PlotArtifact(
|
|
269
|
-
"precision-recall-multiclass",
|
|
270
|
-
body=plt.gcf(),
|
|
271
|
-
title="Multiclass Precision Recall",
|
|
272
|
-
)
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
@deprecated(
|
|
276
|
-
version="1.3.0",
|
|
277
|
-
reason="'mlrun.mlutils' will be removed in 1.5.0, use 'mlrun.framework' instead",
|
|
278
|
-
category=FutureWarning,
|
|
279
|
-
)
|
|
280
|
-
def roc_multi(ytest_b, yprob, labels):
|
|
281
|
-
""""""
|
|
282
|
-
n_classes = len(labels)
|
|
283
|
-
|
|
284
|
-
# Compute ROC curve and ROC area for each class
|
|
285
|
-
fpr = dict()
|
|
286
|
-
tpr = dict()
|
|
287
|
-
roc_auc = dict()
|
|
288
|
-
for i in range(n_classes):
|
|
289
|
-
fpr[i], tpr[i], _ = metrics.roc_curve(ytest_b[:, i], yprob[:, i])
|
|
290
|
-
roc_auc[i] = metrics.auc(fpr[i], tpr[i])
|
|
291
|
-
|
|
292
|
-
# Compute micro-average ROC curve and ROC area
|
|
293
|
-
fpr["micro"], tpr["micro"], _ = metrics.roc_curve(ytest_b.ravel(), yprob.ravel())
|
|
294
|
-
roc_auc["micro"] = metrics.auc(fpr["micro"], tpr["micro"])
|
|
295
|
-
|
|
296
|
-
# First aggregate all false positive rates
|
|
297
|
-
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
|
|
298
|
-
|
|
299
|
-
# Then interpolate all ROC curves at this points
|
|
300
|
-
mean_tpr = np.zeros_like(all_fpr)
|
|
301
|
-
for i in range(n_classes):
|
|
302
|
-
mean_tpr += interp(all_fpr, fpr[i], tpr[i])
|
|
303
|
-
|
|
304
|
-
# Finally average it and compute AUC
|
|
305
|
-
mean_tpr /= n_classes
|
|
306
|
-
|
|
307
|
-
fpr["macro"] = all_fpr
|
|
308
|
-
tpr["macro"] = mean_tpr
|
|
309
|
-
roc_auc["macro"] = metrics.auc(fpr["macro"], tpr["macro"])
|
|
310
|
-
|
|
311
|
-
# Plot all ROC curves
|
|
312
|
-
gcf_clear(plt)
|
|
313
|
-
plt.figure()
|
|
314
|
-
plt.plot(
|
|
315
|
-
fpr["micro"],
|
|
316
|
-
tpr["micro"],
|
|
317
|
-
label=f"micro-average ROC curve (area = {roc_auc['micro']:0.2f})",
|
|
318
|
-
color="deeppink",
|
|
319
|
-
linestyle=":",
|
|
320
|
-
linewidth=4,
|
|
321
|
-
)
|
|
322
|
-
|
|
323
|
-
plt.plot(
|
|
324
|
-
fpr["macro"],
|
|
325
|
-
tpr["macro"],
|
|
326
|
-
label=f"macro-average ROC curve (area = {roc_auc['macro']:0.2f})",
|
|
327
|
-
color="navy",
|
|
328
|
-
linestyle=":",
|
|
329
|
-
linewidth=4,
|
|
330
|
-
)
|
|
331
|
-
|
|
332
|
-
colors = cycle(["aqua", "darkorange", "cornflowerblue"])
|
|
333
|
-
for i, color in zip(range(n_classes), colors):
|
|
334
|
-
plt.plot(
|
|
335
|
-
fpr[i],
|
|
336
|
-
tpr[i],
|
|
337
|
-
color=color,
|
|
338
|
-
lw=2,
|
|
339
|
-
label=f"ROC curve of class {i} (area = {roc_auc[i]:0.2f})",
|
|
340
|
-
)
|
|
341
|
-
|
|
342
|
-
plt.plot([0, 1], [0, 1], "k--", lw=2)
|
|
343
|
-
plt.xlim([0.0, 1.0])
|
|
344
|
-
plt.ylim([0.0, 1.05])
|
|
345
|
-
plt.xlabel("False Positive Rate")
|
|
346
|
-
plt.ylabel("True Positive Rate")
|
|
347
|
-
plt.title("receiver operating characteristic - multiclass")
|
|
348
|
-
plt.legend(loc=(0, -0.68), prop=dict(size=10))
|
|
349
|
-
|
|
350
|
-
return PlotArtifact("roc-multiclass", body=plt.gcf(), title="Multiclass ROC Curve")
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
@deprecated(
|
|
354
|
-
version="1.3.0",
|
|
355
|
-
reason="'mlrun.mlutils' will be removed in 1.5.0, use 'mlrun.framework' instead",
|
|
356
|
-
category=FutureWarning,
|
|
357
|
-
)
|
|
358
|
-
def roc_bin(ytest, yprob, clear: bool = False):
|
|
359
|
-
""""""
|
|
360
|
-
# ROC plot
|
|
361
|
-
if clear:
|
|
362
|
-
gcf_clear(plt)
|
|
363
|
-
fpr, tpr, _ = metrics.roc_curve(ytest, yprob)
|
|
364
|
-
plt.figure()
|
|
365
|
-
plt.plot([0, 1], [0, 1], "k--")
|
|
366
|
-
plt.plot(fpr, tpr, label="a label")
|
|
367
|
-
plt.xlabel("false positive rate")
|
|
368
|
-
plt.ylabel("true positive rate")
|
|
369
|
-
plt.title("roc curve")
|
|
370
|
-
plt.legend(loc="best")
|
|
371
|
-
|
|
372
|
-
return PlotArtifact("roc-binary", body=plt.gcf(), title="Binary ROC Curve")
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
@deprecated(
|
|
376
|
-
version="1.3.0",
|
|
377
|
-
reason="'mlrun.mlutils' will be removed in 1.5.0, use 'mlrun.framework' instead",
|
|
378
|
-
category=FutureWarning,
|
|
379
|
-
)
|
|
380
|
-
def precision_recall_bin(model, xtest, ytest, yprob, clear=False):
|
|
381
|
-
""""""
|
|
382
|
-
if clear:
|
|
383
|
-
gcf_clear(plt)
|
|
384
|
-
disp = metrics.plot_precision_recall_curve(model, xtest, ytest)
|
|
385
|
-
disp.ax_.set_title(
|
|
386
|
-
f"precision recall: AP={metrics.average_precision_score(ytest, yprob):0.2f}"
|
|
387
|
-
)
|
|
388
|
-
|
|
389
|
-
return PlotArtifact(
|
|
390
|
-
"precision-recall-binary", body=disp.figure_, title="Binary Precision Recall"
|
|
391
|
-
)
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
@deprecated(
|
|
395
|
-
version="1.3.0",
|
|
396
|
-
reason="'mlrun.mlutils' will be removed in 1.5.0, use 'mlrun.framework' instead",
|
|
397
|
-
category=FutureWarning,
|
|
398
|
-
)
|
|
399
|
-
def plot_roc(
|
|
400
|
-
context,
|
|
401
|
-
y_labels,
|
|
402
|
-
y_probs,
|
|
403
|
-
key="roc",
|
|
404
|
-
plots_dir: str = "plots",
|
|
405
|
-
fmt="png",
|
|
406
|
-
fpr_label: str = "false positive rate",
|
|
407
|
-
tpr_label: str = "true positive rate",
|
|
408
|
-
title: str = "roc curve",
|
|
409
|
-
legend_loc: str = "best",
|
|
410
|
-
clear: bool = True,
|
|
411
|
-
):
|
|
412
|
-
"""plot roc curves
|
|
413
|
-
|
|
414
|
-
**legacy version please deprecate in functions and demos**
|
|
415
|
-
|
|
416
|
-
:param context: the function context
|
|
417
|
-
:param y_labels: ground truth labels, hot encoded for multiclass
|
|
418
|
-
:param y_probs: model prediction probabilities
|
|
419
|
-
:param key: ("roc") key of plot in artifact store
|
|
420
|
-
:param plots_dir: ("plots") destination folder relative path to artifact path
|
|
421
|
-
:param fmt: ("png") plot format
|
|
422
|
-
:param fpr_label: ("false positive rate") x-axis labels
|
|
423
|
-
:param tpr_label: ("true positive rate") y-axis labels
|
|
424
|
-
:param title: ("roc curve") title of plot
|
|
425
|
-
:param legend_loc: ("best") location of plot legend
|
|
426
|
-
:param clear: (True) clear the matplotlib figure before drawing
|
|
427
|
-
"""
|
|
428
|
-
# clear matplotlib current figure
|
|
429
|
-
if clear:
|
|
430
|
-
gcf_clear(plt)
|
|
431
|
-
|
|
432
|
-
# draw 45 degree line
|
|
433
|
-
plt.plot([0, 1], [0, 1], "k--")
|
|
434
|
-
|
|
435
|
-
# labelling
|
|
436
|
-
plt.xlabel(fpr_label)
|
|
437
|
-
plt.ylabel(tpr_label)
|
|
438
|
-
plt.title(title)
|
|
439
|
-
plt.legend(loc=legend_loc)
|
|
440
|
-
|
|
441
|
-
# single ROC or multiple
|
|
442
|
-
if y_labels.shape[1] > 1:
|
|
443
|
-
|
|
444
|
-
# data accumulators by class
|
|
445
|
-
fpr = dict()
|
|
446
|
-
tpr = dict()
|
|
447
|
-
roc_auc = dict()
|
|
448
|
-
for i in range(y_labels[:, :-1].shape[1]):
|
|
449
|
-
fpr[i], tpr[i], _ = metrics.roc_curve(
|
|
450
|
-
y_labels[:, i], y_probs[:, i], pos_label=1
|
|
451
|
-
)
|
|
452
|
-
roc_auc[i] = metrics.auc(fpr[i], tpr[i])
|
|
453
|
-
plt.plot(fpr[i], tpr[i], label=f"class {i}")
|
|
454
|
-
else:
|
|
455
|
-
fpr, tpr, _ = metrics.roc_curve(y_labels, y_probs[:, 1], pos_label=1)
|
|
456
|
-
plt.plot(fpr, tpr, label="positive class")
|
|
457
|
-
|
|
458
|
-
fname = f"{plots_dir}/{key}.html"
|
|
459
|
-
return context.log_artifact(PlotArtifact(key, body=plt.gcf()), local_path=fname)
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
@deprecated(
|
|
463
|
-
version="1.3.0",
|
|
464
|
-
reason="'mlrun.mlutils' will be removed in 1.5.0, use 'mlrun.framework' instead",
|
|
465
|
-
category=FutureWarning,
|
|
466
|
-
)
|
|
467
|
-
def eval_class_model(
|
|
468
|
-
xtest, ytest, model, labels: str = "labels", pred_params: dict = {}
|
|
469
|
-
):
|
|
470
|
-
"""generate predictions and validation stats
|
|
471
|
-
|
|
472
|
-
pred_params are non-default, scikit-learn api prediction-function parameters.
|
|
473
|
-
For example, a tree-type of model may have a tree depth limit for its prediction
|
|
474
|
-
function.
|
|
475
|
-
|
|
476
|
-
:param xtest: features array type Union(DataItem, DataFrame, np. Array)
|
|
477
|
-
:param ytest: ground-truth labels Union(DataItem, DataFrame, Series, np. Array, List)
|
|
478
|
-
:param model: estimated model
|
|
479
|
-
:param labels: ('labels') labels in ytest is a pd.DataFrame or Series
|
|
480
|
-
:param pred_params: (None) dict of predict function parameters
|
|
481
|
-
"""
|
|
482
|
-
if isinstance(ytest, (pd.DataFrame, pd.Series)):
|
|
483
|
-
unique_labels = ytest[labels].unique()
|
|
484
|
-
ytest = ytest.values
|
|
485
|
-
elif isinstance(ytest, np.ndarray):
|
|
486
|
-
unique_labels = np.unique(ytest)
|
|
487
|
-
elif isinstance(ytest, list):
|
|
488
|
-
unique_labels = set(ytest)
|
|
489
|
-
|
|
490
|
-
n_classes = len(unique_labels)
|
|
491
|
-
is_multiclass = True if n_classes > 2 else False
|
|
492
|
-
|
|
493
|
-
# PROBS
|
|
494
|
-
ypred = model.predict(xtest, **pred_params)
|
|
495
|
-
if hasattr(model, "predict_proba"):
|
|
496
|
-
yprob = model.predict_proba(xtest, **pred_params)
|
|
497
|
-
else:
|
|
498
|
-
# todo if decision fn...
|
|
499
|
-
raise Exception("not implemented for this classifier")
|
|
500
|
-
|
|
501
|
-
# todo - calibrate
|
|
502
|
-
# outputs are some stats and some plots and...
|
|
503
|
-
# should be option, some classifiers don't need, some do it already, many don't
|
|
504
|
-
|
|
505
|
-
model_metrics = {
|
|
506
|
-
"plots": [], # placeholder for plots
|
|
507
|
-
"accuracy": float(metrics.accuracy_score(ytest, ypred)),
|
|
508
|
-
"test-error-rate": np.sum(ytest != ypred) / ytest.shape[0],
|
|
509
|
-
}
|
|
510
|
-
|
|
511
|
-
# CONFUSION MATRIX
|
|
512
|
-
gcf_clear(plt)
|
|
513
|
-
cmd = metrics.plot_confusion_matrix(
|
|
514
|
-
model, xtest, ytest, normalize="all", cmap=plt.cm.Blues
|
|
515
|
-
)
|
|
516
|
-
model_metrics["plots"].append(PlotArtifact("confusion-matrix", body=cmd.figure_))
|
|
517
|
-
|
|
518
|
-
if is_multiclass:
|
|
519
|
-
# PRECISION-RECALL CURVES MICRO AVGED
|
|
520
|
-
# binarize/hot-encode here since we look at each class
|
|
521
|
-
lb = LabelBinarizer()
|
|
522
|
-
ytest_b = lb.fit_transform(ytest)
|
|
523
|
-
|
|
524
|
-
precision = dict()
|
|
525
|
-
recall = dict()
|
|
526
|
-
avg_prec = dict()
|
|
527
|
-
for i in range(n_classes):
|
|
528
|
-
precision[i], recall[i], _ = metrics.precision_recall_curve(
|
|
529
|
-
ytest_b[:, i], yprob[:, i]
|
|
530
|
-
)
|
|
531
|
-
avg_prec[i] = metrics.average_precision_score(ytest_b[:, i], yprob[:, i])
|
|
532
|
-
precision["micro"], recall["micro"], _ = metrics.precision_recall_curve(
|
|
533
|
-
ytest_b.ravel(), yprob.ravel()
|
|
534
|
-
)
|
|
535
|
-
avg_prec["micro"] = metrics.average_precision_score(
|
|
536
|
-
ytest_b, yprob, average="micro"
|
|
537
|
-
)
|
|
538
|
-
ap_micro = avg_prec["micro"]
|
|
539
|
-
model_metrics.update({"precision-micro-avg-classes": ap_micro})
|
|
540
|
-
|
|
541
|
-
gcf_clear(plt)
|
|
542
|
-
colors = cycle(["navy", "turquoise", "darkorange", "cornflowerblue", "teal"])
|
|
543
|
-
plt.figure(figsize=(7, 8))
|
|
544
|
-
f_scores = np.linspace(0.2, 0.8, num=4)
|
|
545
|
-
lines = []
|
|
546
|
-
labels = []
|
|
547
|
-
for f_score in f_scores:
|
|
548
|
-
x = np.linspace(0.01, 1)
|
|
549
|
-
y = f_score * x / (2 * x - f_score)
|
|
550
|
-
(l,) = plt.plot(x[y >= 0], y[y >= 0], color="gray", alpha=0.2)
|
|
551
|
-
plt.annotate(f"f1={f_score:0.1f}", xy=(0.9, y[45] + 0.02))
|
|
552
|
-
|
|
553
|
-
lines.append(l)
|
|
554
|
-
labels.append("iso-f1 curves")
|
|
555
|
-
(l,) = plt.plot(recall["micro"], precision["micro"], color="gold", lw=10)
|
|
556
|
-
lines.append(l)
|
|
557
|
-
labels.append(f"micro-average precision-recall (area = {ap_micro:0.2f})")
|
|
558
|
-
|
|
559
|
-
for i, color in zip(range(n_classes), colors):
|
|
560
|
-
(l,) = plt.plot(recall[i], precision[i], color=color, lw=2)
|
|
561
|
-
lines.append(l)
|
|
562
|
-
labels.append(f"precision-recall for class {i} (area = {avg_prec[i]:0.2f})")
|
|
563
|
-
|
|
564
|
-
fig = plt.gcf()
|
|
565
|
-
fig.subplots_adjust(bottom=0.25)
|
|
566
|
-
plt.xlim([0.0, 1.0])
|
|
567
|
-
plt.ylim([0.0, 1.05])
|
|
568
|
-
plt.xlabel("recall")
|
|
569
|
-
plt.ylabel("precision")
|
|
570
|
-
plt.title("precision recall - multiclass")
|
|
571
|
-
plt.legend(lines, labels, loc=(0, -0.38), prop=dict(size=10))
|
|
572
|
-
model_metrics["plots"].append(
|
|
573
|
-
PlotArtifact("precision-recall-multiclass", body=plt.gcf())
|
|
574
|
-
)
|
|
575
|
-
|
|
576
|
-
# ROC CURVES
|
|
577
|
-
# Compute ROC curve and ROC area for each class
|
|
578
|
-
fpr = dict()
|
|
579
|
-
tpr = dict()
|
|
580
|
-
roc_auc = dict()
|
|
581
|
-
for i in range(n_classes):
|
|
582
|
-
fpr[i], tpr[i], _ = metrics.roc_curve(ytest_b[:, i], yprob[:, i])
|
|
583
|
-
roc_auc[i] = metrics.auc(fpr[i], tpr[i])
|
|
584
|
-
|
|
585
|
-
# Compute micro-average ROC curve and ROC area
|
|
586
|
-
fpr["micro"], tpr["micro"], _ = metrics.roc_curve(
|
|
587
|
-
ytest_b.ravel(), yprob.ravel()
|
|
588
|
-
)
|
|
589
|
-
roc_auc["micro"] = metrics.auc(fpr["micro"], tpr["micro"])
|
|
590
|
-
|
|
591
|
-
# First aggregate all false positive rates
|
|
592
|
-
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
|
|
593
|
-
|
|
594
|
-
# Then interpolate all ROC curves at this points
|
|
595
|
-
mean_tpr = np.zeros_like(all_fpr)
|
|
596
|
-
for i in range(n_classes):
|
|
597
|
-
mean_tpr += interp(all_fpr, fpr[i], tpr[i])
|
|
598
|
-
|
|
599
|
-
# Finally average it and compute AUC
|
|
600
|
-
mean_tpr /= n_classes
|
|
601
|
-
|
|
602
|
-
fpr["macro"] = all_fpr
|
|
603
|
-
tpr["macro"] = mean_tpr
|
|
604
|
-
roc_auc["macro"] = metrics.auc(fpr["macro"], tpr["macro"])
|
|
605
|
-
|
|
606
|
-
# Plot all ROC curves
|
|
607
|
-
gcf_clear(plt)
|
|
608
|
-
plt.figure()
|
|
609
|
-
plt.plot(
|
|
610
|
-
fpr["micro"],
|
|
611
|
-
tpr["micro"],
|
|
612
|
-
label=f"micro-average ROC curve (area = {roc_auc['micro']:0.2f})",
|
|
613
|
-
color="deeppink",
|
|
614
|
-
linestyle=":",
|
|
615
|
-
linewidth=4,
|
|
616
|
-
)
|
|
617
|
-
|
|
618
|
-
plt.plot(
|
|
619
|
-
fpr["macro"],
|
|
620
|
-
tpr["macro"],
|
|
621
|
-
label=f"macro-average ROC curve (area = {roc_auc['macro']:0.2f})",
|
|
622
|
-
color="navy",
|
|
623
|
-
linestyle=":",
|
|
624
|
-
linewidth=4,
|
|
625
|
-
)
|
|
626
|
-
|
|
627
|
-
colors = cycle(["aqua", "darkorange", "cornflowerblue"])
|
|
628
|
-
for i, color in zip(range(n_classes), colors):
|
|
629
|
-
plt.plot(
|
|
630
|
-
fpr[i],
|
|
631
|
-
tpr[i],
|
|
632
|
-
color=color,
|
|
633
|
-
lw=2,
|
|
634
|
-
label=f"ROC curve of class {i} (area = {roc_auc[i]:0.2f})",
|
|
635
|
-
)
|
|
636
|
-
|
|
637
|
-
plt.plot([0, 1], [0, 1], "k--", lw=2)
|
|
638
|
-
plt.xlim([0.0, 1.0])
|
|
639
|
-
plt.ylim([0.0, 1.05])
|
|
640
|
-
plt.xlabel("False Positive Rate")
|
|
641
|
-
plt.ylabel("True Positive Rate")
|
|
642
|
-
plt.title("receiver operating characteristic - multiclass")
|
|
643
|
-
plt.legend(loc="lower right")
|
|
644
|
-
model_metrics["plots"].append(PlotArtifact("roc-multiclass", body=plt.gcf()))
|
|
645
|
-
# AUC multiclass
|
|
646
|
-
model_metrics.update(
|
|
647
|
-
{
|
|
648
|
-
"auc-macro": metrics.roc_auc_score(
|
|
649
|
-
ytest_b, yprob, multi_class="ovo", average="macro"
|
|
650
|
-
),
|
|
651
|
-
"auc-weighted": metrics.roc_auc_score(
|
|
652
|
-
ytest_b, yprob, multi_class="ovo", average="weighted"
|
|
653
|
-
),
|
|
654
|
-
}
|
|
655
|
-
)
|
|
656
|
-
|
|
657
|
-
# others (todo - macro, micro...)
|
|
658
|
-
model_metrics.update(
|
|
659
|
-
{
|
|
660
|
-
"f1-score": metrics.f1_score(ytest, ypred, average="macro"),
|
|
661
|
-
"recall_score": metrics.recall_score(ytest, ypred, average="macro"),
|
|
662
|
-
}
|
|
663
|
-
)
|
|
664
|
-
else:
|
|
665
|
-
# binary
|
|
666
|
-
yprob_pos = yprob[:, 1]
|
|
667
|
-
|
|
668
|
-
model_metrics.update(
|
|
669
|
-
{
|
|
670
|
-
"rocauc": metrics.roc_auc_score(ytest, yprob_pos),
|
|
671
|
-
"brier_score": metrics.brier_score_loss(
|
|
672
|
-
ytest, yprob_pos, pos_label=ytest.max()
|
|
673
|
-
),
|
|
674
|
-
}
|
|
675
|
-
)
|
|
676
|
-
|
|
677
|
-
# precision-recall
|
|
678
|
-
|
|
679
|
-
# ROC plot
|
|
680
|
-
|
|
681
|
-
return model_metrics
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
@deprecated(
|
|
685
|
-
version="1.3.0",
|
|
686
|
-
reason="'mlrun.mlutils' will be removed in 1.5.0, use 'mlrun.framework' instead",
|
|
687
|
-
category=FutureWarning,
|
|
688
|
-
)
|
|
689
|
-
def eval_model_v2(
|
|
690
|
-
context,
|
|
691
|
-
xtest,
|
|
692
|
-
ytest,
|
|
693
|
-
model,
|
|
694
|
-
pcurve_bins: int = 10,
|
|
695
|
-
pcurve_names: List[str] = ["my classifier"],
|
|
696
|
-
plots_artifact_path: str = "",
|
|
697
|
-
pred_params: dict = {},
|
|
698
|
-
cmap="Blues",
|
|
699
|
-
):
|
|
700
|
-
"""generate predictions and validation stats
|
|
701
|
-
|
|
702
|
-
pred_params are non-default, scikit-learn api prediction-function
|
|
703
|
-
parameters. For example, a tree-type of model may have a tree depth
|
|
704
|
-
limit for its prediction function.
|
|
705
|
-
|
|
706
|
-
:param xtest: features array type Union(DataItem, DataFrame,
|
|
707
|
-
numpy array)
|
|
708
|
-
:param ytest: ground-truth labels Union(DataItem, DataFrame,
|
|
709
|
-
Series, numpy array, List)
|
|
710
|
-
:param model: estimated model
|
|
711
|
-
:param pcurve_bins: (10) subdivide [0,1] interval into n bins, x-axis
|
|
712
|
-
:param pcurve_names: label for each calibration curve
|
|
713
|
-
:param pred_params: (None) dict of predict function parameters
|
|
714
|
-
:param cmap: ('Blues') matplotlib color map
|
|
715
|
-
"""
|
|
716
|
-
|
|
717
|
-
import numpy as np
|
|
718
|
-
|
|
719
|
-
def df_blob(df):
|
|
720
|
-
return bytes(df.to_csv(index=False), encoding="utf-8")
|
|
721
|
-
|
|
722
|
-
if isinstance(ytest, np.ndarray):
|
|
723
|
-
unique_labels = np.unique(ytest)
|
|
724
|
-
elif isinstance(ytest, list):
|
|
725
|
-
unique_labels = set(ytest)
|
|
726
|
-
else:
|
|
727
|
-
try:
|
|
728
|
-
ytest = ytest.values
|
|
729
|
-
unique_labels = np.unique(ytest)
|
|
730
|
-
except Exception as exc:
|
|
731
|
-
raise Exception("unrecognized data type for ytest") from exc
|
|
732
|
-
|
|
733
|
-
n_classes = len(unique_labels)
|
|
734
|
-
is_multiclass = True if n_classes > 2 else False
|
|
735
|
-
|
|
736
|
-
# INIT DICT...OR SOME OTHER COLLECTOR THAT CAN BE ACCESSED
|
|
737
|
-
plots_path = plots_artifact_path or context.artifact_subpath("plots")
|
|
738
|
-
extra_data = {}
|
|
739
|
-
|
|
740
|
-
ypred = model.predict(xtest, **pred_params)
|
|
741
|
-
context.log_results(
|
|
742
|
-
{
|
|
743
|
-
"accuracy": float(metrics.accuracy_score(ytest, ypred)),
|
|
744
|
-
"test-error": np.sum(ytest != ypred) / ytest.shape[0],
|
|
745
|
-
}
|
|
746
|
-
)
|
|
747
|
-
|
|
748
|
-
# PROBABILITIES
|
|
749
|
-
if hasattr(model, "predict_proba"):
|
|
750
|
-
yprob = model.predict_proba(xtest, **pred_params)
|
|
751
|
-
if not is_multiclass:
|
|
752
|
-
fraction_of_positives, mean_predicted_value = calibration_curve(
|
|
753
|
-
ytest, yprob[:, -1], n_bins=pcurve_bins, strategy="uniform"
|
|
754
|
-
)
|
|
755
|
-
cmd = plot_calibration_curve(ytest, [yprob], pcurve_names)
|
|
756
|
-
calibration = context.log_artifact(
|
|
757
|
-
PlotArtifact(
|
|
758
|
-
"probability-calibration",
|
|
759
|
-
body=cmd.get_figure(),
|
|
760
|
-
title="probability calibration plot",
|
|
761
|
-
),
|
|
762
|
-
artifact_path=plots_path,
|
|
763
|
-
db_key=False,
|
|
764
|
-
)
|
|
765
|
-
extra_data["probability calibration"] = calibration
|
|
766
|
-
|
|
767
|
-
# CONFUSION MATRIX
|
|
768
|
-
cm = sklearn_confusion_matrix(ytest, ypred, normalize="all")
|
|
769
|
-
df = pd.DataFrame(data=cm)
|
|
770
|
-
extra_data["confusion matrix table.csv"] = df_blob(df)
|
|
771
|
-
|
|
772
|
-
cmd = metrics.plot_confusion_matrix(
|
|
773
|
-
model,
|
|
774
|
-
xtest,
|
|
775
|
-
ytest,
|
|
776
|
-
normalize="all",
|
|
777
|
-
values_format=".2g",
|
|
778
|
-
cmap=plt.get_cmap(cmap),
|
|
779
|
-
)
|
|
780
|
-
confusion = context.log_artifact(
|
|
781
|
-
PlotArtifact(
|
|
782
|
-
"confusion-matrix",
|
|
783
|
-
body=cmd.figure_,
|
|
784
|
-
title="Confusion Matrix - Normalized Plot",
|
|
785
|
-
),
|
|
786
|
-
artifact_path=plots_path,
|
|
787
|
-
db_key=False,
|
|
788
|
-
)
|
|
789
|
-
extra_data["confusion matrix"] = confusion
|
|
790
|
-
|
|
791
|
-
# LEARNING CURVES
|
|
792
|
-
if hasattr(model, "evals_result"):
|
|
793
|
-
results = model.evals_result()
|
|
794
|
-
train_set = list(results.items())[0]
|
|
795
|
-
valid_set = list(results.items())[1]
|
|
796
|
-
|
|
797
|
-
learning_curves_df = None
|
|
798
|
-
if is_multiclass:
|
|
799
|
-
if hasattr(train_set[1], "merror"):
|
|
800
|
-
learning_curves_df = pd.DataFrame(
|
|
801
|
-
{
|
|
802
|
-
"train_error": train_set[1]["merror"],
|
|
803
|
-
"valid_error": valid_set[1]["merror"],
|
|
804
|
-
}
|
|
805
|
-
)
|
|
806
|
-
else:
|
|
807
|
-
if hasattr(train_set[1], "error"):
|
|
808
|
-
learning_curves_df = pd.DataFrame(
|
|
809
|
-
{
|
|
810
|
-
"train_error": train_set[1]["error"],
|
|
811
|
-
"valid_error": valid_set[1]["error"],
|
|
812
|
-
}
|
|
813
|
-
)
|
|
814
|
-
|
|
815
|
-
if learning_curves_df:
|
|
816
|
-
extra_data["learning curve table.csv"] = df_blob(learning_curves_df)
|
|
817
|
-
|
|
818
|
-
_, ax = plt.subplots()
|
|
819
|
-
plt.xlabel("# training examples")
|
|
820
|
-
plt.ylabel("error rate")
|
|
821
|
-
plt.title("learning curve - error")
|
|
822
|
-
ax.plot(learning_curves_df["train_error"], label="train")
|
|
823
|
-
ax.plot(learning_curves_df["valid_error"], label="valid")
|
|
824
|
-
learning = context.log_artifact(
|
|
825
|
-
PlotArtifact(
|
|
826
|
-
"learning-curve", body=plt.gcf(), title="Learning Curve - erreur"
|
|
827
|
-
),
|
|
828
|
-
artifact_path=plots_path,
|
|
829
|
-
db_key=False,
|
|
830
|
-
)
|
|
831
|
-
extra_data["learning curve"] = learning
|
|
832
|
-
|
|
833
|
-
# FEATURE IMPORTANCES
|
|
834
|
-
if hasattr(model, "feature_importances_"):
|
|
835
|
-
(fi_plot, fi_tbl) = feature_importances(model, xtest.columns)
|
|
836
|
-
extra_data["feature importances"] = context.log_artifact(
|
|
837
|
-
fi_plot, db_key=False, artifact_path=plots_path
|
|
838
|
-
)
|
|
839
|
-
extra_data["feature importances table.csv"] = df_blob(fi_tbl)
|
|
840
|
-
|
|
841
|
-
# AUC - ROC - PR CURVES
|
|
842
|
-
if is_multiclass:
|
|
843
|
-
lb = LabelBinarizer()
|
|
844
|
-
ytest_b = lb.fit_transform(ytest)
|
|
845
|
-
|
|
846
|
-
extra_data["precision_recall_multi"] = context.log_artifact(
|
|
847
|
-
precision_recall_multi(ytest_b, yprob, unique_labels),
|
|
848
|
-
artifact_path=plots_path,
|
|
849
|
-
db_key=False,
|
|
850
|
-
)
|
|
851
|
-
extra_data["roc_multi"] = context.log_artifact(
|
|
852
|
-
roc_multi(ytest_b, yprob, unique_labels),
|
|
853
|
-
artifact_path=plots_path,
|
|
854
|
-
db_key=False,
|
|
855
|
-
)
|
|
856
|
-
|
|
857
|
-
# AUC multiclass
|
|
858
|
-
aucmicro = metrics.roc_auc_score(
|
|
859
|
-
ytest_b, yprob, multi_class="ovo", average="micro"
|
|
860
|
-
)
|
|
861
|
-
aucweighted = metrics.roc_auc_score(
|
|
862
|
-
ytest_b, yprob, multi_class="ovo", average="weighted"
|
|
863
|
-
)
|
|
864
|
-
|
|
865
|
-
context.log_results({"auc-micro": aucmicro, "auc-weighted": aucweighted})
|
|
866
|
-
|
|
867
|
-
# others (todo - macro, micro...)
|
|
868
|
-
f1 = metrics.f1_score(ytest, ypred, average="macro")
|
|
869
|
-
ps = metrics.precision_score(ytest, ypred, average="macro")
|
|
870
|
-
rs = metrics.recall_score(ytest, ypred, average="macro")
|
|
871
|
-
context.log_results({"f1-score": f1, "precision_score": ps, "recall_score": rs})
|
|
872
|
-
|
|
873
|
-
else:
|
|
874
|
-
yprob_pos = yprob[:, 1]
|
|
875
|
-
extra_data["precision_recall_bin"] = context.log_artifact(
|
|
876
|
-
precision_recall_bin(model, xtest, ytest, yprob_pos),
|
|
877
|
-
artifact_path=plots_path,
|
|
878
|
-
db_key=False,
|
|
879
|
-
)
|
|
880
|
-
extra_data["roc_bin"] = context.log_artifact(
|
|
881
|
-
roc_bin(ytest, yprob_pos, clear=True),
|
|
882
|
-
artifact_path=plots_path,
|
|
883
|
-
db_key=False,
|
|
884
|
-
)
|
|
885
|
-
|
|
886
|
-
rocauc = metrics.roc_auc_score(ytest, yprob_pos)
|
|
887
|
-
brier_score = metrics.brier_score_loss(ytest, yprob_pos, pos_label=ytest.max())
|
|
888
|
-
f1 = metrics.f1_score(ytest, ypred)
|
|
889
|
-
ps = metrics.precision_score(ytest, ypred)
|
|
890
|
-
rs = metrics.recall_score(ytest, ypred)
|
|
891
|
-
context.log_results(
|
|
892
|
-
{
|
|
893
|
-
"rocauc": rocauc,
|
|
894
|
-
"brier_score": brier_score,
|
|
895
|
-
"f1-score": f1,
|
|
896
|
-
"precision_score": ps,
|
|
897
|
-
"recall_score": rs,
|
|
898
|
-
}
|
|
899
|
-
)
|
|
900
|
-
|
|
901
|
-
# return all model metrics and plots
|
|
902
|
-
return extra_data
|