mlrun 1.5.0rc1__py3-none-any.whl → 1.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (119) hide show
  1. mlrun/__init__.py +2 -35
  2. mlrun/__main__.py +1 -40
  3. mlrun/api/api/api.py +6 -0
  4. mlrun/api/api/endpoints/feature_store.py +0 -4
  5. mlrun/api/api/endpoints/files.py +14 -2
  6. mlrun/api/api/endpoints/functions.py +6 -1
  7. mlrun/api/api/endpoints/logs.py +17 -3
  8. mlrun/api/api/endpoints/pipelines.py +1 -5
  9. mlrun/api/api/endpoints/projects.py +88 -0
  10. mlrun/api/api/endpoints/runs.py +48 -6
  11. mlrun/api/api/endpoints/workflows.py +355 -0
  12. mlrun/api/api/utils.py +1 -1
  13. mlrun/api/crud/__init__.py +1 -0
  14. mlrun/api/crud/client_spec.py +3 -0
  15. mlrun/api/crud/model_monitoring/deployment.py +36 -7
  16. mlrun/api/crud/model_monitoring/grafana.py +1 -1
  17. mlrun/api/crud/model_monitoring/helpers.py +32 -2
  18. mlrun/api/crud/model_monitoring/model_endpoints.py +27 -5
  19. mlrun/api/crud/notifications.py +9 -4
  20. mlrun/api/crud/pipelines.py +4 -9
  21. mlrun/api/crud/runtime_resources.py +4 -3
  22. mlrun/api/crud/secrets.py +21 -0
  23. mlrun/api/crud/workflows.py +352 -0
  24. mlrun/api/db/base.py +16 -1
  25. mlrun/api/db/sqldb/db.py +97 -16
  26. mlrun/api/launcher.py +26 -7
  27. mlrun/api/main.py +3 -4
  28. mlrun/{mlutils → api/rundb}/__init__.py +2 -6
  29. mlrun/{db → api/rundb}/sqldb.py +35 -83
  30. mlrun/api/runtime_handlers/__init__.py +56 -0
  31. mlrun/api/runtime_handlers/base.py +1247 -0
  32. mlrun/api/runtime_handlers/daskjob.py +209 -0
  33. mlrun/api/runtime_handlers/kubejob.py +37 -0
  34. mlrun/api/runtime_handlers/mpijob.py +147 -0
  35. mlrun/api/runtime_handlers/remotesparkjob.py +29 -0
  36. mlrun/api/runtime_handlers/sparkjob.py +148 -0
  37. mlrun/api/utils/builder.py +1 -4
  38. mlrun/api/utils/clients/chief.py +14 -0
  39. mlrun/api/utils/scheduler.py +98 -15
  40. mlrun/api/utils/singletons/db.py +4 -0
  41. mlrun/artifacts/manager.py +1 -2
  42. mlrun/common/schemas/__init__.py +6 -0
  43. mlrun/common/schemas/auth.py +4 -1
  44. mlrun/common/schemas/client_spec.py +1 -1
  45. mlrun/common/schemas/model_monitoring/__init__.py +1 -0
  46. mlrun/common/schemas/model_monitoring/constants.py +11 -0
  47. mlrun/common/schemas/project.py +1 -0
  48. mlrun/common/schemas/runs.py +1 -8
  49. mlrun/common/schemas/schedule.py +1 -8
  50. mlrun/common/schemas/workflow.py +54 -0
  51. mlrun/config.py +42 -40
  52. mlrun/datastore/sources.py +1 -1
  53. mlrun/db/__init__.py +4 -68
  54. mlrun/db/base.py +12 -0
  55. mlrun/db/factory.py +65 -0
  56. mlrun/db/httpdb.py +175 -19
  57. mlrun/db/nopdb.py +4 -2
  58. mlrun/execution.py +4 -2
  59. mlrun/feature_store/__init__.py +1 -0
  60. mlrun/feature_store/api.py +1 -2
  61. mlrun/feature_store/feature_set.py +0 -10
  62. mlrun/feature_store/feature_vector.py +340 -2
  63. mlrun/feature_store/ingestion.py +5 -10
  64. mlrun/feature_store/retrieval/base.py +118 -104
  65. mlrun/feature_store/retrieval/dask_merger.py +17 -10
  66. mlrun/feature_store/retrieval/job.py +4 -1
  67. mlrun/feature_store/retrieval/local_merger.py +18 -18
  68. mlrun/feature_store/retrieval/spark_merger.py +21 -14
  69. mlrun/feature_store/retrieval/storey_merger.py +21 -15
  70. mlrun/kfpops.py +3 -9
  71. mlrun/launcher/base.py +3 -3
  72. mlrun/launcher/client.py +3 -2
  73. mlrun/launcher/factory.py +16 -13
  74. mlrun/lists.py +0 -11
  75. mlrun/model.py +9 -15
  76. mlrun/model_monitoring/helpers.py +15 -25
  77. mlrun/model_monitoring/model_monitoring_batch.py +72 -4
  78. mlrun/model_monitoring/prometheus.py +219 -0
  79. mlrun/model_monitoring/stores/__init__.py +15 -9
  80. mlrun/model_monitoring/stores/sql_model_endpoint_store.py +3 -1
  81. mlrun/model_monitoring/stream_processing.py +181 -29
  82. mlrun/package/packager.py +6 -8
  83. mlrun/package/packagers/default_packager.py +121 -10
  84. mlrun/platforms/__init__.py +0 -2
  85. mlrun/platforms/iguazio.py +0 -56
  86. mlrun/projects/pipelines.py +57 -158
  87. mlrun/projects/project.py +6 -32
  88. mlrun/render.py +1 -1
  89. mlrun/run.py +2 -124
  90. mlrun/runtimes/__init__.py +6 -42
  91. mlrun/runtimes/base.py +26 -1241
  92. mlrun/runtimes/daskjob.py +2 -198
  93. mlrun/runtimes/function.py +16 -5
  94. mlrun/runtimes/kubejob.py +5 -29
  95. mlrun/runtimes/mpijob/__init__.py +2 -2
  96. mlrun/runtimes/mpijob/abstract.py +10 -1
  97. mlrun/runtimes/mpijob/v1.py +0 -76
  98. mlrun/runtimes/mpijob/v1alpha1.py +1 -74
  99. mlrun/runtimes/nuclio.py +3 -2
  100. mlrun/runtimes/pod.py +0 -10
  101. mlrun/runtimes/remotesparkjob.py +1 -15
  102. mlrun/runtimes/serving.py +1 -1
  103. mlrun/runtimes/sparkjob/__init__.py +0 -1
  104. mlrun/runtimes/sparkjob/abstract.py +4 -131
  105. mlrun/serving/states.py +1 -1
  106. mlrun/utils/db.py +0 -2
  107. mlrun/utils/helpers.py +19 -13
  108. mlrun/utils/notifications/notification_pusher.py +5 -25
  109. mlrun/utils/regex.py +7 -2
  110. mlrun/utils/version/version.json +2 -2
  111. {mlrun-1.5.0rc1.dist-info → mlrun-1.5.0rc2.dist-info}/METADATA +24 -23
  112. {mlrun-1.5.0rc1.dist-info → mlrun-1.5.0rc2.dist-info}/RECORD +116 -107
  113. {mlrun-1.5.0rc1.dist-info → mlrun-1.5.0rc2.dist-info}/WHEEL +1 -1
  114. mlrun/mlutils/data.py +0 -160
  115. mlrun/mlutils/models.py +0 -78
  116. mlrun/mlutils/plots.py +0 -902
  117. {mlrun-1.5.0rc1.dist-info → mlrun-1.5.0rc2.dist-info}/LICENSE +0 -0
  118. {mlrun-1.5.0rc1.dist-info → mlrun-1.5.0rc2.dist-info}/entry_points.txt +0 -0
  119. {mlrun-1.5.0rc1.dist-info → mlrun-1.5.0rc2.dist-info}/top_level.txt +0 -0
mlrun/mlutils/plots.py DELETED
@@ -1,902 +0,0 @@
1
- # Copyright 2023 Iguazio
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- #
15
- from itertools import cycle
16
- from typing import List
17
-
18
- import matplotlib.pyplot as plt
19
- import numpy as np
20
- import pandas as pd
21
- import seaborn as sns
22
- from deprecated import deprecated
23
- from scikitplot.metrics import plot_calibration_curve
24
- from scipy import interp
25
- from sklearn import metrics
26
- from sklearn.calibration import calibration_curve
27
- from sklearn.metrics import confusion_matrix as sklearn_confusion_matrix
28
- from sklearn.preprocessing import LabelBinarizer
29
-
30
- from ..artifacts import PlotArtifact
31
-
32
- # TODO: remove mlutils in 1.5.0
33
-
34
-
35
- @deprecated(
36
- version="1.3.0",
37
- reason="'mlrun.mlutils' will be removed in 1.5.0, use 'mlrun.framework' instead",
38
- category=FutureWarning,
39
- )
40
- def gcf_clear(plt):
41
- """Utility to clear matplotlib figure
42
- Run this inside every plot method before calling any matplotlib
43
- methods
44
- :param plot: matloblib figure object
45
- """
46
- plt.cla()
47
- plt.clf()
48
- plt.close()
49
-
50
-
51
- @deprecated(
52
- version="1.3.0",
53
- reason="'mlrun.mlutils' will be removed in 1.5.0, use 'mlrun.framework' instead",
54
- category=FutureWarning,
55
- )
56
- def feature_importances(model, header):
57
- """Display estimated feature importances
58
- Only works for models with attribute 'feature_importances_`
59
- :param model: fitted model
60
- :param header: feature labels
61
- """
62
- if not hasattr(model, "feature_importances_"):
63
- raise Exception(
64
- "feature importances are only available for some models, if you got "
65
- "here then please make sure to check your estimated model for a "
66
- "`feature_importances_` attribute before calling this method"
67
- )
68
-
69
- # create a feature importance table with desired labels
70
- zipped = zip(model.feature_importances_, header)
71
- feature_imp = pd.DataFrame(sorted(zipped), columns=["freq", "feature"]).sort_values(
72
- by="freq", ascending=False
73
- )
74
-
75
- plt.clf() # gcf_clear(plt)
76
- plt.figure()
77
- sns.barplot(x="freq", y="feature", data=feature_imp)
78
- plt.title("features")
79
- plt.tight_layout()
80
-
81
- return (
82
- PlotArtifact(
83
- "feature-importances", body=plt.gcf(), title="Feature Importances"
84
- ),
85
- feature_imp,
86
- )
87
-
88
-
89
- @deprecated(
90
- version="1.3.0",
91
- reason="'mlrun.mlutils' will be removed in 1.5.0, use 'mlrun.framework' instead",
92
- category=FutureWarning,
93
- )
94
- def plot_importance(
95
- context, model, key: str = "feature-importances", plots_dest: str = "plots"
96
- ):
97
- """Display estimated feature importances
98
- Only works for models with attribute 'feature_importances_`
99
-
100
- **legacy version please deprecate in functions and demos**
101
-
102
- :param context: function context
103
- :param model: fitted model
104
- :param key: key of feature importances plot and table in artifact
105
- store
106
- :param plots_dest: subfolder in artifact store
107
- """
108
- if not hasattr(model, "feature_importances_"):
109
- raise Exception("feature importaces are only available for some models")
110
-
111
- # create a feature importance table with desired labels
112
- zipped = zip(model.feature_importances_, context.header)
113
- feature_imp = pd.DataFrame(sorted(zipped), columns=["freq", "feature"]).sort_values(
114
- by="freq", ascending=False
115
- )
116
-
117
- gcf_clear(plt)
118
- plt.figure(figsize=(20, 10))
119
- sns.barplot(x="freq", y="feature", data=feature_imp)
120
- plt.title("features")
121
- plt.tight_layout()
122
-
123
- fname = f"{plots_dest}/{key}.html"
124
- context.log_artifact(PlotArtifact(key, body=plt.gcf()), local_path=fname)
125
-
126
- # feature importances are also saved as a csv table (generally small):
127
- fname = key + "-tbl.csv"
128
- return context.log_dataset(key + "-tbl", df=feature_imp, local_path=fname)
129
-
130
-
131
- @deprecated(
132
- version="1.3.0",
133
- reason="'mlrun.mlutils' will be removed in 1.5.0, use 'mlrun.framework' instead",
134
- category=FutureWarning,
135
- )
136
- def learning_curves(model):
137
- """model class dependent
138
-
139
- WIP
140
-
141
- get training history plots for xgboost, lightgbm
142
-
143
- returns list of PlotArtifacts, can be empty if no history
144
- is found
145
- """
146
- plots = []
147
-
148
- # do this here and not in the call to learning_curve plots,
149
- # this is default approach for xgboost and lightgbm
150
- if hasattr(model, "evals_result"):
151
- results = model.evals_result()
152
- train_set = list(results.items())[0]
153
- valid_set = list(results.items())[1]
154
-
155
- learning_curves = pd.DataFrame(
156
- {
157
- "train_error": train_set[1]["error"],
158
- "train_auc": train_set[1]["auc"],
159
- "valid_error": valid_set[1]["error"],
160
- "valid_auc": valid_set[1]["auc"],
161
- }
162
- )
163
-
164
- plt.clf() # gcf_clear(plt)
165
- fig, ax = plt.subplots()
166
- plt.xlabel("# training examples")
167
- plt.ylabel("auc")
168
- plt.title("learning curve - auc")
169
- ax.plot(learning_curves.train_auc, label="train")
170
- ax.plot(learning_curves.valid_auc, label="valid")
171
- ax.legend(loc="lower left")
172
- plots.append(PlotArtifact("learning curve - auc", body=plt.gcf()))
173
-
174
- plt.clf() # gcf_clear(plt)
175
- fig, ax = plt.subplots()
176
- plt.xlabel("# training examples")
177
- plt.ylabel("error rate")
178
- plt.title("learning curve - error")
179
- ax.plot(learning_curves.train_error, label="train")
180
- ax.plot(learning_curves.valid_error, label="valid")
181
- ax.legend(loc="lower left")
182
- plots.append(PlotArtifact("learning curve - taoot", body=plt.gcf()))
183
-
184
- # elif some other model history api...
185
-
186
- return plots
187
-
188
-
189
- @deprecated(
190
- version="1.3.0",
191
- reason="'mlrun.mlutils' will be removed in 1.5.0, use 'mlrun.framework' instead",
192
- category=FutureWarning,
193
- )
194
- def confusion_matrix(model, xtest, ytest, cmap="Blues"):
195
- cmd = metrics.plot_confusion_matrix(
196
- model,
197
- xtest,
198
- ytest,
199
- normalize="all",
200
- values_format=".2g",
201
- cmap=plt.get_cmap(cmap),
202
- )
203
- # for now only 1, add different views to this array for display in UI
204
- cmd.plot()
205
- return PlotArtifact(
206
- "confusion-matrix-normalized",
207
- body=cmd.figure_,
208
- title="Confusion Matrix - Normalized Plot",
209
- )
210
-
211
-
212
- @deprecated(
213
- version="1.3.0",
214
- reason="'mlrun.mlutils' will be removed in 1.5.0, use 'mlrun.framework' instead",
215
- category=FutureWarning,
216
- )
217
- def precision_recall_multi(ytest_b, yprob, labels, scoring="micro"):
218
- """"""
219
- n_classes = len(labels)
220
-
221
- precision = dict()
222
- recall = dict()
223
- avg_prec = dict()
224
- for i in range(n_classes):
225
- precision[i], recall[i], _ = metrics.precision_recall_curve(
226
- ytest_b[:, i], yprob[:, i]
227
- )
228
- avg_prec[i] = metrics.average_precision_score(ytest_b[:, i], yprob[:, i])
229
- precision["micro"], recall["micro"], _ = metrics.precision_recall_curve(
230
- ytest_b.ravel(), yprob.ravel()
231
- )
232
- avg_prec["micro"] = metrics.average_precision_score(ytest_b, yprob, average="micro")
233
- ap_micro = avg_prec["micro"]
234
- # model_metrics.update({'precision-micro-avg-classes': ap_micro})
235
-
236
- # gcf_clear(plt)
237
- colors = cycle(["navy", "turquoise", "darkorange", "cornflowerblue", "teal"])
238
- plt.figure()
239
- f_scores = np.linspace(0.2, 0.8, num=4)
240
- lines = []
241
- labels = []
242
- for f_score in f_scores:
243
- x = np.linspace(0.01, 1)
244
- y = f_score * x / (2 * x - f_score)
245
- (l,) = plt.plot(x[y >= 0], y[y >= 0], color="gray", alpha=0.2)
246
- plt.annotate(f"f1={f_score:0.1f}", xy=(0.9, y[45] + 0.02))
247
-
248
- lines.append(l)
249
- labels.append("iso-f1 curves")
250
- (l,) = plt.plot(recall["micro"], precision["micro"], color="gold", lw=10)
251
- lines.append(l)
252
- labels.append(f"micro-average precision-recall (area = {ap_micro:0.2f})")
253
-
254
- for i, color in zip(range(n_classes), colors):
255
- (l,) = plt.plot(recall[i], precision[i], color=color, lw=2)
256
- lines.append(l)
257
- labels.append(f"precision-recall for class {i} (area = {avg_prec[i]:0.2f})")
258
-
259
- # fig = plt.gcf()
260
- # fig.subplots_adjust(bottom=0.25)
261
- plt.xlim([0.0, 1.0])
262
- plt.ylim([0.0, 1.05])
263
- plt.xlabel("recall")
264
- plt.ylabel("precision")
265
- plt.title("precision recall - multiclass")
266
- plt.legend(lines, labels, loc=(0, -0.41), prop=dict(size=10))
267
-
268
- return PlotArtifact(
269
- "precision-recall-multiclass",
270
- body=plt.gcf(),
271
- title="Multiclass Precision Recall",
272
- )
273
-
274
-
275
- @deprecated(
276
- version="1.3.0",
277
- reason="'mlrun.mlutils' will be removed in 1.5.0, use 'mlrun.framework' instead",
278
- category=FutureWarning,
279
- )
280
- def roc_multi(ytest_b, yprob, labels):
281
- """"""
282
- n_classes = len(labels)
283
-
284
- # Compute ROC curve and ROC area for each class
285
- fpr = dict()
286
- tpr = dict()
287
- roc_auc = dict()
288
- for i in range(n_classes):
289
- fpr[i], tpr[i], _ = metrics.roc_curve(ytest_b[:, i], yprob[:, i])
290
- roc_auc[i] = metrics.auc(fpr[i], tpr[i])
291
-
292
- # Compute micro-average ROC curve and ROC area
293
- fpr["micro"], tpr["micro"], _ = metrics.roc_curve(ytest_b.ravel(), yprob.ravel())
294
- roc_auc["micro"] = metrics.auc(fpr["micro"], tpr["micro"])
295
-
296
- # First aggregate all false positive rates
297
- all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
298
-
299
- # Then interpolate all ROC curves at this points
300
- mean_tpr = np.zeros_like(all_fpr)
301
- for i in range(n_classes):
302
- mean_tpr += interp(all_fpr, fpr[i], tpr[i])
303
-
304
- # Finally average it and compute AUC
305
- mean_tpr /= n_classes
306
-
307
- fpr["macro"] = all_fpr
308
- tpr["macro"] = mean_tpr
309
- roc_auc["macro"] = metrics.auc(fpr["macro"], tpr["macro"])
310
-
311
- # Plot all ROC curves
312
- gcf_clear(plt)
313
- plt.figure()
314
- plt.plot(
315
- fpr["micro"],
316
- tpr["micro"],
317
- label=f"micro-average ROC curve (area = {roc_auc['micro']:0.2f})",
318
- color="deeppink",
319
- linestyle=":",
320
- linewidth=4,
321
- )
322
-
323
- plt.plot(
324
- fpr["macro"],
325
- tpr["macro"],
326
- label=f"macro-average ROC curve (area = {roc_auc['macro']:0.2f})",
327
- color="navy",
328
- linestyle=":",
329
- linewidth=4,
330
- )
331
-
332
- colors = cycle(["aqua", "darkorange", "cornflowerblue"])
333
- for i, color in zip(range(n_classes), colors):
334
- plt.plot(
335
- fpr[i],
336
- tpr[i],
337
- color=color,
338
- lw=2,
339
- label=f"ROC curve of class {i} (area = {roc_auc[i]:0.2f})",
340
- )
341
-
342
- plt.plot([0, 1], [0, 1], "k--", lw=2)
343
- plt.xlim([0.0, 1.0])
344
- plt.ylim([0.0, 1.05])
345
- plt.xlabel("False Positive Rate")
346
- plt.ylabel("True Positive Rate")
347
- plt.title("receiver operating characteristic - multiclass")
348
- plt.legend(loc=(0, -0.68), prop=dict(size=10))
349
-
350
- return PlotArtifact("roc-multiclass", body=plt.gcf(), title="Multiclass ROC Curve")
351
-
352
-
353
- @deprecated(
354
- version="1.3.0",
355
- reason="'mlrun.mlutils' will be removed in 1.5.0, use 'mlrun.framework' instead",
356
- category=FutureWarning,
357
- )
358
- def roc_bin(ytest, yprob, clear: bool = False):
359
- """"""
360
- # ROC plot
361
- if clear:
362
- gcf_clear(plt)
363
- fpr, tpr, _ = metrics.roc_curve(ytest, yprob)
364
- plt.figure()
365
- plt.plot([0, 1], [0, 1], "k--")
366
- plt.plot(fpr, tpr, label="a label")
367
- plt.xlabel("false positive rate")
368
- plt.ylabel("true positive rate")
369
- plt.title("roc curve")
370
- plt.legend(loc="best")
371
-
372
- return PlotArtifact("roc-binary", body=plt.gcf(), title="Binary ROC Curve")
373
-
374
-
375
- @deprecated(
376
- version="1.3.0",
377
- reason="'mlrun.mlutils' will be removed in 1.5.0, use 'mlrun.framework' instead",
378
- category=FutureWarning,
379
- )
380
- def precision_recall_bin(model, xtest, ytest, yprob, clear=False):
381
- """"""
382
- if clear:
383
- gcf_clear(plt)
384
- disp = metrics.plot_precision_recall_curve(model, xtest, ytest)
385
- disp.ax_.set_title(
386
- f"precision recall: AP={metrics.average_precision_score(ytest, yprob):0.2f}"
387
- )
388
-
389
- return PlotArtifact(
390
- "precision-recall-binary", body=disp.figure_, title="Binary Precision Recall"
391
- )
392
-
393
-
394
- @deprecated(
395
- version="1.3.0",
396
- reason="'mlrun.mlutils' will be removed in 1.5.0, use 'mlrun.framework' instead",
397
- category=FutureWarning,
398
- )
399
- def plot_roc(
400
- context,
401
- y_labels,
402
- y_probs,
403
- key="roc",
404
- plots_dir: str = "plots",
405
- fmt="png",
406
- fpr_label: str = "false positive rate",
407
- tpr_label: str = "true positive rate",
408
- title: str = "roc curve",
409
- legend_loc: str = "best",
410
- clear: bool = True,
411
- ):
412
- """plot roc curves
413
-
414
- **legacy version please deprecate in functions and demos**
415
-
416
- :param context: the function context
417
- :param y_labels: ground truth labels, hot encoded for multiclass
418
- :param y_probs: model prediction probabilities
419
- :param key: ("roc") key of plot in artifact store
420
- :param plots_dir: ("plots") destination folder relative path to artifact path
421
- :param fmt: ("png") plot format
422
- :param fpr_label: ("false positive rate") x-axis labels
423
- :param tpr_label: ("true positive rate") y-axis labels
424
- :param title: ("roc curve") title of plot
425
- :param legend_loc: ("best") location of plot legend
426
- :param clear: (True) clear the matplotlib figure before drawing
427
- """
428
- # clear matplotlib current figure
429
- if clear:
430
- gcf_clear(plt)
431
-
432
- # draw 45 degree line
433
- plt.plot([0, 1], [0, 1], "k--")
434
-
435
- # labelling
436
- plt.xlabel(fpr_label)
437
- plt.ylabel(tpr_label)
438
- plt.title(title)
439
- plt.legend(loc=legend_loc)
440
-
441
- # single ROC or multiple
442
- if y_labels.shape[1] > 1:
443
-
444
- # data accumulators by class
445
- fpr = dict()
446
- tpr = dict()
447
- roc_auc = dict()
448
- for i in range(y_labels[:, :-1].shape[1]):
449
- fpr[i], tpr[i], _ = metrics.roc_curve(
450
- y_labels[:, i], y_probs[:, i], pos_label=1
451
- )
452
- roc_auc[i] = metrics.auc(fpr[i], tpr[i])
453
- plt.plot(fpr[i], tpr[i], label=f"class {i}")
454
- else:
455
- fpr, tpr, _ = metrics.roc_curve(y_labels, y_probs[:, 1], pos_label=1)
456
- plt.plot(fpr, tpr, label="positive class")
457
-
458
- fname = f"{plots_dir}/{key}.html"
459
- return context.log_artifact(PlotArtifact(key, body=plt.gcf()), local_path=fname)
460
-
461
-
462
- @deprecated(
463
- version="1.3.0",
464
- reason="'mlrun.mlutils' will be removed in 1.5.0, use 'mlrun.framework' instead",
465
- category=FutureWarning,
466
- )
467
- def eval_class_model(
468
- xtest, ytest, model, labels: str = "labels", pred_params: dict = {}
469
- ):
470
- """generate predictions and validation stats
471
-
472
- pred_params are non-default, scikit-learn api prediction-function parameters.
473
- For example, a tree-type of model may have a tree depth limit for its prediction
474
- function.
475
-
476
- :param xtest: features array type Union(DataItem, DataFrame, np. Array)
477
- :param ytest: ground-truth labels Union(DataItem, DataFrame, Series, np. Array, List)
478
- :param model: estimated model
479
- :param labels: ('labels') labels in ytest is a pd.DataFrame or Series
480
- :param pred_params: (None) dict of predict function parameters
481
- """
482
- if isinstance(ytest, (pd.DataFrame, pd.Series)):
483
- unique_labels = ytest[labels].unique()
484
- ytest = ytest.values
485
- elif isinstance(ytest, np.ndarray):
486
- unique_labels = np.unique(ytest)
487
- elif isinstance(ytest, list):
488
- unique_labels = set(ytest)
489
-
490
- n_classes = len(unique_labels)
491
- is_multiclass = True if n_classes > 2 else False
492
-
493
- # PROBS
494
- ypred = model.predict(xtest, **pred_params)
495
- if hasattr(model, "predict_proba"):
496
- yprob = model.predict_proba(xtest, **pred_params)
497
- else:
498
- # todo if decision fn...
499
- raise Exception("not implemented for this classifier")
500
-
501
- # todo - calibrate
502
- # outputs are some stats and some plots and...
503
- # should be option, some classifiers don't need, some do it already, many don't
504
-
505
- model_metrics = {
506
- "plots": [], # placeholder for plots
507
- "accuracy": float(metrics.accuracy_score(ytest, ypred)),
508
- "test-error-rate": np.sum(ytest != ypred) / ytest.shape[0],
509
- }
510
-
511
- # CONFUSION MATRIX
512
- gcf_clear(plt)
513
- cmd = metrics.plot_confusion_matrix(
514
- model, xtest, ytest, normalize="all", cmap=plt.cm.Blues
515
- )
516
- model_metrics["plots"].append(PlotArtifact("confusion-matrix", body=cmd.figure_))
517
-
518
- if is_multiclass:
519
- # PRECISION-RECALL CURVES MICRO AVGED
520
- # binarize/hot-encode here since we look at each class
521
- lb = LabelBinarizer()
522
- ytest_b = lb.fit_transform(ytest)
523
-
524
- precision = dict()
525
- recall = dict()
526
- avg_prec = dict()
527
- for i in range(n_classes):
528
- precision[i], recall[i], _ = metrics.precision_recall_curve(
529
- ytest_b[:, i], yprob[:, i]
530
- )
531
- avg_prec[i] = metrics.average_precision_score(ytest_b[:, i], yprob[:, i])
532
- precision["micro"], recall["micro"], _ = metrics.precision_recall_curve(
533
- ytest_b.ravel(), yprob.ravel()
534
- )
535
- avg_prec["micro"] = metrics.average_precision_score(
536
- ytest_b, yprob, average="micro"
537
- )
538
- ap_micro = avg_prec["micro"]
539
- model_metrics.update({"precision-micro-avg-classes": ap_micro})
540
-
541
- gcf_clear(plt)
542
- colors = cycle(["navy", "turquoise", "darkorange", "cornflowerblue", "teal"])
543
- plt.figure(figsize=(7, 8))
544
- f_scores = np.linspace(0.2, 0.8, num=4)
545
- lines = []
546
- labels = []
547
- for f_score in f_scores:
548
- x = np.linspace(0.01, 1)
549
- y = f_score * x / (2 * x - f_score)
550
- (l,) = plt.plot(x[y >= 0], y[y >= 0], color="gray", alpha=0.2)
551
- plt.annotate(f"f1={f_score:0.1f}", xy=(0.9, y[45] + 0.02))
552
-
553
- lines.append(l)
554
- labels.append("iso-f1 curves")
555
- (l,) = plt.plot(recall["micro"], precision["micro"], color="gold", lw=10)
556
- lines.append(l)
557
- labels.append(f"micro-average precision-recall (area = {ap_micro:0.2f})")
558
-
559
- for i, color in zip(range(n_classes), colors):
560
- (l,) = plt.plot(recall[i], precision[i], color=color, lw=2)
561
- lines.append(l)
562
- labels.append(f"precision-recall for class {i} (area = {avg_prec[i]:0.2f})")
563
-
564
- fig = plt.gcf()
565
- fig.subplots_adjust(bottom=0.25)
566
- plt.xlim([0.0, 1.0])
567
- plt.ylim([0.0, 1.05])
568
- plt.xlabel("recall")
569
- plt.ylabel("precision")
570
- plt.title("precision recall - multiclass")
571
- plt.legend(lines, labels, loc=(0, -0.38), prop=dict(size=10))
572
- model_metrics["plots"].append(
573
- PlotArtifact("precision-recall-multiclass", body=plt.gcf())
574
- )
575
-
576
- # ROC CURVES
577
- # Compute ROC curve and ROC area for each class
578
- fpr = dict()
579
- tpr = dict()
580
- roc_auc = dict()
581
- for i in range(n_classes):
582
- fpr[i], tpr[i], _ = metrics.roc_curve(ytest_b[:, i], yprob[:, i])
583
- roc_auc[i] = metrics.auc(fpr[i], tpr[i])
584
-
585
- # Compute micro-average ROC curve and ROC area
586
- fpr["micro"], tpr["micro"], _ = metrics.roc_curve(
587
- ytest_b.ravel(), yprob.ravel()
588
- )
589
- roc_auc["micro"] = metrics.auc(fpr["micro"], tpr["micro"])
590
-
591
- # First aggregate all false positive rates
592
- all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
593
-
594
- # Then interpolate all ROC curves at this points
595
- mean_tpr = np.zeros_like(all_fpr)
596
- for i in range(n_classes):
597
- mean_tpr += interp(all_fpr, fpr[i], tpr[i])
598
-
599
- # Finally average it and compute AUC
600
- mean_tpr /= n_classes
601
-
602
- fpr["macro"] = all_fpr
603
- tpr["macro"] = mean_tpr
604
- roc_auc["macro"] = metrics.auc(fpr["macro"], tpr["macro"])
605
-
606
- # Plot all ROC curves
607
- gcf_clear(plt)
608
- plt.figure()
609
- plt.plot(
610
- fpr["micro"],
611
- tpr["micro"],
612
- label=f"micro-average ROC curve (area = {roc_auc['micro']:0.2f})",
613
- color="deeppink",
614
- linestyle=":",
615
- linewidth=4,
616
- )
617
-
618
- plt.plot(
619
- fpr["macro"],
620
- tpr["macro"],
621
- label=f"macro-average ROC curve (area = {roc_auc['macro']:0.2f})",
622
- color="navy",
623
- linestyle=":",
624
- linewidth=4,
625
- )
626
-
627
- colors = cycle(["aqua", "darkorange", "cornflowerblue"])
628
- for i, color in zip(range(n_classes), colors):
629
- plt.plot(
630
- fpr[i],
631
- tpr[i],
632
- color=color,
633
- lw=2,
634
- label=f"ROC curve of class {i} (area = {roc_auc[i]:0.2f})",
635
- )
636
-
637
- plt.plot([0, 1], [0, 1], "k--", lw=2)
638
- plt.xlim([0.0, 1.0])
639
- plt.ylim([0.0, 1.05])
640
- plt.xlabel("False Positive Rate")
641
- plt.ylabel("True Positive Rate")
642
- plt.title("receiver operating characteristic - multiclass")
643
- plt.legend(loc="lower right")
644
- model_metrics["plots"].append(PlotArtifact("roc-multiclass", body=plt.gcf()))
645
- # AUC multiclass
646
- model_metrics.update(
647
- {
648
- "auc-macro": metrics.roc_auc_score(
649
- ytest_b, yprob, multi_class="ovo", average="macro"
650
- ),
651
- "auc-weighted": metrics.roc_auc_score(
652
- ytest_b, yprob, multi_class="ovo", average="weighted"
653
- ),
654
- }
655
- )
656
-
657
- # others (todo - macro, micro...)
658
- model_metrics.update(
659
- {
660
- "f1-score": metrics.f1_score(ytest, ypred, average="macro"),
661
- "recall_score": metrics.recall_score(ytest, ypred, average="macro"),
662
- }
663
- )
664
- else:
665
- # binary
666
- yprob_pos = yprob[:, 1]
667
-
668
- model_metrics.update(
669
- {
670
- "rocauc": metrics.roc_auc_score(ytest, yprob_pos),
671
- "brier_score": metrics.brier_score_loss(
672
- ytest, yprob_pos, pos_label=ytest.max()
673
- ),
674
- }
675
- )
676
-
677
- # precision-recall
678
-
679
- # ROC plot
680
-
681
- return model_metrics
682
-
683
-
684
- @deprecated(
685
- version="1.3.0",
686
- reason="'mlrun.mlutils' will be removed in 1.5.0, use 'mlrun.framework' instead",
687
- category=FutureWarning,
688
- )
689
- def eval_model_v2(
690
- context,
691
- xtest,
692
- ytest,
693
- model,
694
- pcurve_bins: int = 10,
695
- pcurve_names: List[str] = ["my classifier"],
696
- plots_artifact_path: str = "",
697
- pred_params: dict = {},
698
- cmap="Blues",
699
- ):
700
- """generate predictions and validation stats
701
-
702
- pred_params are non-default, scikit-learn api prediction-function
703
- parameters. For example, a tree-type of model may have a tree depth
704
- limit for its prediction function.
705
-
706
- :param xtest: features array type Union(DataItem, DataFrame,
707
- numpy array)
708
- :param ytest: ground-truth labels Union(DataItem, DataFrame,
709
- Series, numpy array, List)
710
- :param model: estimated model
711
- :param pcurve_bins: (10) subdivide [0,1] interval into n bins, x-axis
712
- :param pcurve_names: label for each calibration curve
713
- :param pred_params: (None) dict of predict function parameters
714
- :param cmap: ('Blues') matplotlib color map
715
- """
716
-
717
- import numpy as np
718
-
719
- def df_blob(df):
720
- return bytes(df.to_csv(index=False), encoding="utf-8")
721
-
722
- if isinstance(ytest, np.ndarray):
723
- unique_labels = np.unique(ytest)
724
- elif isinstance(ytest, list):
725
- unique_labels = set(ytest)
726
- else:
727
- try:
728
- ytest = ytest.values
729
- unique_labels = np.unique(ytest)
730
- except Exception as exc:
731
- raise Exception("unrecognized data type for ytest") from exc
732
-
733
- n_classes = len(unique_labels)
734
- is_multiclass = True if n_classes > 2 else False
735
-
736
- # INIT DICT...OR SOME OTHER COLLECTOR THAT CAN BE ACCESSED
737
- plots_path = plots_artifact_path or context.artifact_subpath("plots")
738
- extra_data = {}
739
-
740
- ypred = model.predict(xtest, **pred_params)
741
- context.log_results(
742
- {
743
- "accuracy": float(metrics.accuracy_score(ytest, ypred)),
744
- "test-error": np.sum(ytest != ypred) / ytest.shape[0],
745
- }
746
- )
747
-
748
- # PROBABILITIES
749
- if hasattr(model, "predict_proba"):
750
- yprob = model.predict_proba(xtest, **pred_params)
751
- if not is_multiclass:
752
- fraction_of_positives, mean_predicted_value = calibration_curve(
753
- ytest, yprob[:, -1], n_bins=pcurve_bins, strategy="uniform"
754
- )
755
- cmd = plot_calibration_curve(ytest, [yprob], pcurve_names)
756
- calibration = context.log_artifact(
757
- PlotArtifact(
758
- "probability-calibration",
759
- body=cmd.get_figure(),
760
- title="probability calibration plot",
761
- ),
762
- artifact_path=plots_path,
763
- db_key=False,
764
- )
765
- extra_data["probability calibration"] = calibration
766
-
767
- # CONFUSION MATRIX
768
- cm = sklearn_confusion_matrix(ytest, ypred, normalize="all")
769
- df = pd.DataFrame(data=cm)
770
- extra_data["confusion matrix table.csv"] = df_blob(df)
771
-
772
- cmd = metrics.plot_confusion_matrix(
773
- model,
774
- xtest,
775
- ytest,
776
- normalize="all",
777
- values_format=".2g",
778
- cmap=plt.get_cmap(cmap),
779
- )
780
- confusion = context.log_artifact(
781
- PlotArtifact(
782
- "confusion-matrix",
783
- body=cmd.figure_,
784
- title="Confusion Matrix - Normalized Plot",
785
- ),
786
- artifact_path=plots_path,
787
- db_key=False,
788
- )
789
- extra_data["confusion matrix"] = confusion
790
-
791
- # LEARNING CURVES
792
- if hasattr(model, "evals_result"):
793
- results = model.evals_result()
794
- train_set = list(results.items())[0]
795
- valid_set = list(results.items())[1]
796
-
797
- learning_curves_df = None
798
- if is_multiclass:
799
- if hasattr(train_set[1], "merror"):
800
- learning_curves_df = pd.DataFrame(
801
- {
802
- "train_error": train_set[1]["merror"],
803
- "valid_error": valid_set[1]["merror"],
804
- }
805
- )
806
- else:
807
- if hasattr(train_set[1], "error"):
808
- learning_curves_df = pd.DataFrame(
809
- {
810
- "train_error": train_set[1]["error"],
811
- "valid_error": valid_set[1]["error"],
812
- }
813
- )
814
-
815
- if learning_curves_df:
816
- extra_data["learning curve table.csv"] = df_blob(learning_curves_df)
817
-
818
- _, ax = plt.subplots()
819
- plt.xlabel("# training examples")
820
- plt.ylabel("error rate")
821
- plt.title("learning curve - error")
822
- ax.plot(learning_curves_df["train_error"], label="train")
823
- ax.plot(learning_curves_df["valid_error"], label="valid")
824
- learning = context.log_artifact(
825
- PlotArtifact(
826
- "learning-curve", body=plt.gcf(), title="Learning Curve - erreur"
827
- ),
828
- artifact_path=plots_path,
829
- db_key=False,
830
- )
831
- extra_data["learning curve"] = learning
832
-
833
- # FEATURE IMPORTANCES
834
- if hasattr(model, "feature_importances_"):
835
- (fi_plot, fi_tbl) = feature_importances(model, xtest.columns)
836
- extra_data["feature importances"] = context.log_artifact(
837
- fi_plot, db_key=False, artifact_path=plots_path
838
- )
839
- extra_data["feature importances table.csv"] = df_blob(fi_tbl)
840
-
841
- # AUC - ROC - PR CURVES
842
- if is_multiclass:
843
- lb = LabelBinarizer()
844
- ytest_b = lb.fit_transform(ytest)
845
-
846
- extra_data["precision_recall_multi"] = context.log_artifact(
847
- precision_recall_multi(ytest_b, yprob, unique_labels),
848
- artifact_path=plots_path,
849
- db_key=False,
850
- )
851
- extra_data["roc_multi"] = context.log_artifact(
852
- roc_multi(ytest_b, yprob, unique_labels),
853
- artifact_path=plots_path,
854
- db_key=False,
855
- )
856
-
857
- # AUC multiclass
858
- aucmicro = metrics.roc_auc_score(
859
- ytest_b, yprob, multi_class="ovo", average="micro"
860
- )
861
- aucweighted = metrics.roc_auc_score(
862
- ytest_b, yprob, multi_class="ovo", average="weighted"
863
- )
864
-
865
- context.log_results({"auc-micro": aucmicro, "auc-weighted": aucweighted})
866
-
867
- # others (todo - macro, micro...)
868
- f1 = metrics.f1_score(ytest, ypred, average="macro")
869
- ps = metrics.precision_score(ytest, ypred, average="macro")
870
- rs = metrics.recall_score(ytest, ypred, average="macro")
871
- context.log_results({"f1-score": f1, "precision_score": ps, "recall_score": rs})
872
-
873
- else:
874
- yprob_pos = yprob[:, 1]
875
- extra_data["precision_recall_bin"] = context.log_artifact(
876
- precision_recall_bin(model, xtest, ytest, yprob_pos),
877
- artifact_path=plots_path,
878
- db_key=False,
879
- )
880
- extra_data["roc_bin"] = context.log_artifact(
881
- roc_bin(ytest, yprob_pos, clear=True),
882
- artifact_path=plots_path,
883
- db_key=False,
884
- )
885
-
886
- rocauc = metrics.roc_auc_score(ytest, yprob_pos)
887
- brier_score = metrics.brier_score_loss(ytest, yprob_pos, pos_label=ytest.max())
888
- f1 = metrics.f1_score(ytest, ypred)
889
- ps = metrics.precision_score(ytest, ypred)
890
- rs = metrics.recall_score(ytest, ypred)
891
- context.log_results(
892
- {
893
- "rocauc": rocauc,
894
- "brier_score": brier_score,
895
- "f1-score": f1,
896
- "precision_score": ps,
897
- "recall_score": rs,
898
- }
899
- )
900
-
901
- # return all model metrics and plots
902
- return extra_data