validmind 2.5.18__py3-none-any.whl → 2.5.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. validmind/__init__.py +7 -46
  2. validmind/__version__.py +1 -1
  3. validmind/ai/test_result_description/context.py +2 -2
  4. validmind/api_client.py +131 -266
  5. validmind/client_config.py +1 -3
  6. validmind/datasets/__init__.py +1 -1
  7. validmind/datasets/nlp/__init__.py +1 -1
  8. validmind/errors.py +20 -30
  9. validmind/tests/data_validation/ProtectedClassesCombination.py +17 -9
  10. validmind/tests/data_validation/ProtectedClassesDisparity.py +12 -4
  11. validmind/tests/data_validation/ProtectedClassesThresholdOptimizer.py +18 -10
  12. validmind/tests/load.py +25 -5
  13. validmind/tests/model_validation/ragas/AnswerCorrectness.py +12 -6
  14. validmind/tests/model_validation/ragas/AnswerRelevance.py +12 -6
  15. validmind/tests/model_validation/ragas/AnswerSimilarity.py +12 -6
  16. validmind/tests/model_validation/ragas/AspectCritique.py +19 -13
  17. validmind/tests/model_validation/ragas/ContextEntityRecall.py +12 -6
  18. validmind/tests/model_validation/ragas/ContextPrecision.py +12 -6
  19. validmind/tests/model_validation/ragas/ContextRecall.py +12 -6
  20. validmind/tests/model_validation/ragas/ContextUtilization.py +12 -6
  21. validmind/tests/model_validation/ragas/Faithfulness.py +12 -6
  22. validmind/tests/model_validation/ragas/NoiseSensitivity.py +12 -6
  23. validmind/tests/model_validation/sklearn/ClassifierPerformance.py +5 -2
  24. validmind/tests/run.py +219 -116
  25. validmind/vm_models/test/result_wrapper.py +4 -4
  26. {validmind-2.5.18.dist-info → validmind-2.5.23.dist-info}/METADATA +12 -12
  27. {validmind-2.5.18.dist-info → validmind-2.5.23.dist-info}/RECORD +30 -30
  28. {validmind-2.5.18.dist-info → validmind-2.5.23.dist-info}/WHEEL +1 -1
  29. {validmind-2.5.18.dist-info → validmind-2.5.23.dist-info}/LICENSE +0 -0
  30. {validmind-2.5.18.dist-info → validmind-2.5.23.dist-info}/entry_points.txt +0 -0
@@ -3,7 +3,7 @@
3
3
  # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
4
 
5
5
  """
6
- Example datasets that can be used with the developer framework.
6
+ Example datasets that can be used with the library.
7
7
  """
8
8
 
9
9
  __all__ = [
validmind/errors.py CHANGED
@@ -3,7 +3,7 @@
3
3
  # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
4
 
5
5
  """
6
- This module contains all the custom errors that are used in the developer framework.
6
+ This module contains all the custom errors that are used in the library.
7
7
 
8
8
  The following base errors are defined for others:
9
9
  - BaseError
@@ -207,6 +207,23 @@ class MissingRequiredTestInputError(BaseError):
207
207
  pass
208
208
 
209
209
 
210
+ class MissingDependencyError(BaseError):
211
+ """
212
+ When a required dependency is missing.
213
+ """
214
+
215
+ def __init__(self, message="", required_dependencies=None, extra=None):
216
+ """
217
+ Args:
218
+ message (str): The error message.
219
+ required_dependencies (list): A list of required dependencies.
220
+ extra (str): The particular validmind `extra` that will install the missing dependencies.
221
+ """
222
+ super().__init__(message)
223
+ self.required_dependencies = required_dependencies or []
224
+ self.extra = extra
225
+
226
+
210
227
  class MissingRExtrasError(BaseError):
211
228
  """
212
229
  When the R extras have not been installed.
@@ -219,14 +236,6 @@ class MissingRExtrasError(BaseError):
219
236
  )
220
237
 
221
238
 
222
- class MissingRunCUIDError(APIRequestError):
223
- """
224
- When data is being sent to the API but the run_cuid is missing.
225
- """
226
-
227
- pass
228
-
229
-
230
239
  class MissingTextContentIdError(APIRequestError):
231
240
  """
232
241
  When a Text object is sent to the API without a content_id.
@@ -243,30 +252,14 @@ class MissingTextContentsError(APIRequestError):
243
252
  pass
244
253
 
245
254
 
246
- class MissingProjectIdError(BaseError):
255
+ class MissingModelIdError(BaseError):
247
256
  def description(self, *args, **kwargs):
248
257
  return (
249
258
  self.message
250
- or "Project ID must be provided either as an environment variable or as an argument to init."
259
+ or "Model ID must be provided either as an environment variable or as an argument to init."
251
260
  )
252
261
 
253
262
 
254
- class StartTestRunFailedError(APIRequestError):
255
- """
256
- When the API was not able to start a test run.
257
- """
258
-
259
- pass
260
-
261
-
262
- class TestRunNotFoundError(APIRequestError):
263
- """
264
- When a test run is not found in the API.
265
- """
266
-
267
- pass
268
-
269
-
270
263
  class TestInputInvalidDatasetError(BaseError):
271
264
  """
272
265
  When an invalid dataset is used in a test context.
@@ -352,11 +345,8 @@ def raise_api_error(error_string):
352
345
  "missing_text": MissingTextContentsError,
353
346
  "invalid_text_object": InvalidTextObjectError,
354
347
  "invalid_content_id_prefix": InvalidContentIdPrefixError,
355
- "missing_run_cuid": MissingRunCUIDError,
356
- "test_run_not_found": TestRunNotFoundError,
357
348
  "invalid_metric_results": InvalidMetricResultsError,
358
349
  "invalid_test_results": InvalidTestResultsError,
359
- "start_test_run_failed": StartTestRunFailedError,
360
350
  }
361
351
 
362
352
  error_class = error_map.get(api_code, APIRequestError)
@@ -7,19 +7,27 @@ import sys
7
7
  import pandas as pd
8
8
  import plotly.graph_objects as go
9
9
  import plotly.subplots as sp
10
- from fairlearn.metrics import (
11
- MetricFrame,
12
- count,
13
- demographic_parity_ratio,
14
- equalized_odds_ratio,
15
- false_positive_rate,
16
- selection_rate,
17
- true_positive_rate,
18
- )
19
10
 
20
11
  from validmind import tags, tasks
12
+ from validmind.errors import MissingDependencyError
21
13
  from validmind.logging import get_logger
22
14
 
15
+ try:
16
+ from fairlearn.metrics import (
17
+ MetricFrame,
18
+ count,
19
+ demographic_parity_ratio,
20
+ equalized_odds_ratio,
21
+ false_positive_rate,
22
+ selection_rate,
23
+ true_positive_rate,
24
+ )
25
+ except ImportError as e:
26
+ raise MissingDependencyError(
27
+ "Missing required package `fairlearn` for ProtectedClassesCombination.",
28
+ required_dependencies=["fairlearn"],
29
+ ) from e
30
+
23
31
  logger = get_logger(__name__)
24
32
 
25
33
 
@@ -5,15 +5,23 @@
5
5
  import io
6
6
  import sys
7
7
 
8
- import aequitas.plot as ap
9
8
  import pandas as pd
10
- from aequitas.bias import Bias
11
- from aequitas.group import Group
12
- from aequitas.plotting import Plot
13
9
 
14
10
  from validmind import tags, tasks
11
+ from validmind.errors import MissingDependencyError
15
12
  from validmind.logging import get_logger
16
13
 
14
+ try:
15
+ import aequitas.plot as ap
16
+ from aequitas.bias import Bias
17
+ from aequitas.group import Group
18
+ from aequitas.plotting import Plot
19
+ except ImportError as e:
20
+ raise MissingDependencyError(
21
+ "Missing required package `aequitas` for ProtectedClassesDisparity.",
22
+ required_dependencies=["aequitas"],
23
+ ) from e
24
+
17
25
  logger = get_logger(__name__)
18
26
 
19
27
 
@@ -7,20 +7,28 @@ import sys
7
7
 
8
8
  import matplotlib.pyplot as plt
9
9
  import pandas as pd
10
- from fairlearn.metrics import (
11
- MetricFrame,
12
- count,
13
- demographic_parity_ratio,
14
- equalized_odds_ratio,
15
- false_negative_rate,
16
- false_positive_rate,
17
- true_positive_rate,
18
- )
19
- from fairlearn.postprocessing import ThresholdOptimizer, plot_threshold_optimizer
20
10
 
21
11
  from validmind import tags, tasks
12
+ from validmind.errors import MissingDependencyError
22
13
  from validmind.logging import get_logger
23
14
 
15
+ try:
16
+ from fairlearn.metrics import (
17
+ MetricFrame,
18
+ count,
19
+ demographic_parity_ratio,
20
+ equalized_odds_ratio,
21
+ false_negative_rate,
22
+ false_positive_rate,
23
+ true_positive_rate,
24
+ )
25
+ from fairlearn.postprocessing import ThresholdOptimizer, plot_threshold_optimizer
26
+ except ImportError as e:
27
+ raise MissingDependencyError(
28
+ "Missing required package `fairlearn` for ProtectedClassesThresholdOptimizer.",
29
+ required_dependencies=["fairlearn"],
30
+ ) from e
31
+
24
32
  logger = get_logger(__name__)
25
33
 
26
34
 
validmind/tests/load.py CHANGED
@@ -15,7 +15,7 @@ from uuid import uuid4
15
15
  import pandas as pd
16
16
  from ipywidgets import HTML, Accordion
17
17
 
18
- from ..errors import LoadTestError
18
+ from ..errors import LoadTestError, MissingDependencyError
19
19
  from ..html_templates.content_blocks import test_content_block_html
20
20
  from ..logging import get_logger
21
21
  from ..unit_metrics.composite import load_composite_metric
@@ -88,10 +88,30 @@ def list_tests(
88
88
  Returns:
89
89
  list or pandas.DataFrame: A list of all tests or a formatted table.
90
90
  """
91
- tests = {
92
- test_id: load_test(test_id, reload=True)
93
- for test_id in test_store.get_test_ids()
94
- }
91
+ # tests = {
92
+ # test_id: load_test(test_id, reload=True)
93
+ # for test_id in test_store.get_test_ids()
94
+ # }
95
+ tests = {}
96
+ for test_id in test_store.get_test_ids():
97
+ try:
98
+ tests[test_id] = load_test(test_id, reload=True)
99
+ except MissingDependencyError as e:
100
+ # skip tests that have missing dependencies
101
+ logger.debug(str(e))
102
+
103
+ if e.extra:
104
+ logger.info(
105
+ f"Skipping `{test_id}` as it requires extra dependencies: {e.required_dependencies}."
106
+ f" Please run `pip install validmind[{e.extra}]` to view and run this test."
107
+ )
108
+ else:
109
+ logger.info(
110
+ f"Skipping `{test_id}` as it requires missing dependencies: {e.required_dependencies}."
111
+ " Please install the missing dependencies to view and run this test."
112
+ )
113
+
114
+ continue
95
115
 
96
116
  # first search by the filter string since it's the most general search
97
117
  if filter is not None:
@@ -8,9 +8,21 @@ import plotly.express as px
8
8
  from datasets import Dataset
9
9
 
10
10
  from validmind import tags, tasks
11
+ from validmind.errors import MissingDependencyError
11
12
 
12
13
  from .utils import get_ragas_config, get_renamed_columns
13
14
 
15
+ try:
16
+ from ragas import evaluate
17
+ from ragas.metrics import answer_correctness
18
+ except ImportError as e:
19
+ raise MissingDependencyError(
20
+ "Missing required package `ragas` for AnswerCorrectness. "
21
+ "Please run `pip install validmind[llm]` to use LLM tests",
22
+ required_dependencies=["ragas"],
23
+ extra="llm",
24
+ ) from e
25
+
14
26
 
15
27
  @tags("ragas", "llm")
16
28
  @tasks("text_qa", "text_generation", "text_summarization")
@@ -88,12 +100,6 @@ def AnswerCorrectness(
88
100
  }
89
101
  ```
90
102
  """
91
- try:
92
- from ragas import evaluate
93
- from ragas.metrics import answer_correctness
94
- except ImportError:
95
- raise ImportError("Please run `pip install validmind[llm]` to use LLM tests")
96
-
97
103
  warnings.filterwarnings(
98
104
  "ignore",
99
105
  category=FutureWarning,
@@ -8,9 +8,21 @@ import plotly.express as px
8
8
  from datasets import Dataset
9
9
 
10
10
  from validmind import tags, tasks
11
+ from validmind.errors import MissingDependencyError
11
12
 
12
13
  from .utils import get_ragas_config, get_renamed_columns
13
14
 
15
+ try:
16
+ from ragas import evaluate
17
+ from ragas.metrics import answer_relevancy
18
+ except ImportError as e:
19
+ raise MissingDependencyError(
20
+ "Missing required package `ragas` for AnswerRelevance. "
21
+ "Please run `pip install validmind[llm]` to use LLM tests",
22
+ required_dependencies=["ragas"],
23
+ extra="llm",
24
+ ) from e
25
+
14
26
 
15
27
  @tags("ragas", "llm", "rag_performance")
16
28
  @tasks("text_qa", "text_generation", "text_summarization")
@@ -92,12 +104,6 @@ def AnswerRelevance(
92
104
  }
93
105
  ```
94
106
  """
95
- try:
96
- from ragas import evaluate
97
- from ragas.metrics import answer_relevancy
98
- except ImportError:
99
- raise ImportError("Please run `pip install validmind[llm]` to use LLM tests")
100
-
101
107
  warnings.filterwarnings(
102
108
  "ignore",
103
109
  category=FutureWarning,
@@ -8,9 +8,21 @@ import plotly.express as px
8
8
  from datasets import Dataset
9
9
 
10
10
  from validmind import tags, tasks
11
+ from validmind.errors import MissingDependencyError
11
12
 
12
13
  from .utils import get_ragas_config, get_renamed_columns
13
14
 
15
+ try:
16
+ from ragas import evaluate
17
+ from ragas.metrics import answer_similarity
18
+ except ImportError as e:
19
+ raise MissingDependencyError(
20
+ "Missing required package `ragas` for AnswerSimilarity. "
21
+ "Please run `pip install validmind[llm]` to use LLM tests",
22
+ required_dependencies=["ragas"],
23
+ extra="llm",
24
+ ) from e
25
+
14
26
 
15
27
  @tags("ragas", "llm")
16
28
  @tasks("text_qa", "text_generation", "text_summarization")
@@ -78,12 +90,6 @@ def AnswerSimilarity(
78
90
  }
79
91
  ```
80
92
  """
81
- try:
82
- from ragas import evaluate
83
- from ragas.metrics import answer_similarity
84
- except ImportError:
85
- raise ImportError("Please run `pip install validmind[llm]` to use LLM tests")
86
-
87
93
  warnings.filterwarnings(
88
94
  "ignore",
89
95
  category=FutureWarning,
@@ -8,9 +8,28 @@ import plotly.express as px
8
8
  from datasets import Dataset
9
9
 
10
10
  from validmind import tags, tasks
11
+ from validmind.errors import MissingDependencyError
11
12
 
12
13
  from .utils import get_ragas_config, get_renamed_columns
13
14
 
15
+ try:
16
+ from ragas import evaluate
17
+ from ragas.metrics import AspectCritic
18
+ from ragas.metrics._aspect_critic import (
19
+ coherence,
20
+ conciseness,
21
+ correctness,
22
+ harmfulness,
23
+ maliciousness,
24
+ )
25
+ except ImportError as e:
26
+ raise MissingDependencyError(
27
+ "Missing required package `ragas` for AspectCritique. "
28
+ "Please run `pip install validmind[llm]` to use LLM tests",
29
+ required_dependencies=["ragas"],
30
+ extra="llm",
31
+ ) from e
32
+
14
33
  LOWER_IS_BETTER_ASPECTS = ["harmfulness", "maliciousness"]
15
34
 
16
35
 
@@ -101,19 +120,6 @@ def AspectCritique(
101
120
  )
102
121
  ```
103
122
  """
104
- try:
105
- from ragas import evaluate
106
- from ragas.metrics import AspectCritic
107
- from ragas.metrics._aspect_critic import (
108
- coherence,
109
- conciseness,
110
- correctness,
111
- harmfulness,
112
- maliciousness,
113
- )
114
- except ImportError:
115
- raise ImportError("Please run `pip install validmind[llm]` to use LLM tests")
116
-
117
123
  built_in_aspects = {
118
124
  "coherence": coherence,
119
125
  "conciseness": conciseness,
@@ -8,9 +8,21 @@ import plotly.express as px
8
8
  from datasets import Dataset
9
9
 
10
10
  from validmind import tags, tasks
11
+ from validmind.errors import MissingDependencyError
11
12
 
12
13
  from .utils import get_ragas_config, get_renamed_columns
13
14
 
15
+ try:
16
+ from ragas import evaluate
17
+ from ragas.metrics import context_entity_recall
18
+ except ImportError as e:
19
+ raise MissingDependencyError(
20
+ "Missing required package `ragas` for ContextEntityRecall. "
21
+ "Please run `pip install validmind[llm]` to use LLM tests",
22
+ required_dependencies=["ragas"],
23
+ extra="llm",
24
+ ) from e
25
+
14
26
 
15
27
  @tags("ragas", "llm", "retrieval_performance")
16
28
  @tasks("text_qa", "text_generation", "text_summarization")
@@ -84,12 +96,6 @@ def ContextEntityRecall(
84
96
  }
85
97
  ```
86
98
  """
87
- try:
88
- from ragas import evaluate
89
- from ragas.metrics import context_entity_recall
90
- except ImportError:
91
- raise ImportError("Please run `pip install validmind[llm]` to use LLM tests")
92
-
93
99
  warnings.filterwarnings(
94
100
  "ignore",
95
101
  category=FutureWarning,
@@ -8,9 +8,21 @@ import plotly.express as px
8
8
  from datasets import Dataset
9
9
 
10
10
  from validmind import tags, tasks
11
+ from validmind.errors import MissingDependencyError
11
12
 
12
13
  from .utils import get_ragas_config, get_renamed_columns
13
14
 
15
+ try:
16
+ from ragas import evaluate
17
+ from ragas.metrics import context_precision
18
+ except ImportError as e:
19
+ raise MissingDependencyError(
20
+ "Missing required package `ragas` for ContextPrecision. "
21
+ "Please run `pip install validmind[llm]` to use LLM tests",
22
+ required_dependencies=["ragas"],
23
+ extra="llm",
24
+ ) from e
25
+
14
26
 
15
27
  @tags("ragas", "llm", "retrieval_performance")
16
28
  @tasks("text_qa", "text_generation", "text_summarization", "text_classification")
@@ -79,12 +91,6 @@ def ContextPrecision(
79
91
  }
80
92
  ```
81
93
  """
82
- try:
83
- from ragas import evaluate
84
- from ragas.metrics import context_precision
85
- except ImportError:
86
- raise ImportError("Please run `pip install validmind[llm]` to use LLM tests")
87
-
88
94
  warnings.filterwarnings(
89
95
  "ignore",
90
96
  category=FutureWarning,
@@ -8,9 +8,21 @@ import plotly.express as px
8
8
  from datasets import Dataset
9
9
 
10
10
  from validmind import tags, tasks
11
+ from validmind.errors import MissingDependencyError
11
12
 
12
13
  from .utils import get_ragas_config, get_renamed_columns
13
14
 
15
+ try:
16
+ from ragas import evaluate
17
+ from ragas.metrics import context_recall
18
+ except ImportError as e:
19
+ raise MissingDependencyError(
20
+ "Missing required package `ragas` for ContextRecall. "
21
+ "Please run `pip install validmind[llm]` to use LLM tests",
22
+ required_dependencies=["ragas"],
23
+ extra="llm",
24
+ ) from e
25
+
14
26
 
15
27
  @tags("ragas", "llm", "retrieval_performance")
16
28
  @tasks("text_qa", "text_generation", "text_summarization", "text_classification")
@@ -79,12 +91,6 @@ def ContextRecall(
79
91
  }
80
92
  ```
81
93
  """
82
- try:
83
- from ragas import evaluate
84
- from ragas.metrics import context_recall
85
- except ImportError:
86
- raise ImportError("Please run `pip install validmind[llm]` to use LLM tests")
87
-
88
94
  warnings.filterwarnings(
89
95
  "ignore",
90
96
  category=FutureWarning,
@@ -8,9 +8,21 @@ import plotly.express as px
8
8
  from datasets import Dataset
9
9
 
10
10
  from validmind import tags, tasks
11
+ from validmind.errors import MissingDependencyError
11
12
 
12
13
  from .utils import get_ragas_config, get_renamed_columns
13
14
 
15
+ try:
16
+ from ragas import evaluate
17
+ from ragas.metrics import context_utilization
18
+ except ImportError as e:
19
+ raise MissingDependencyError(
20
+ "Missing required package `ragas` for ContextUtilization. "
21
+ "Please run `pip install validmind[llm]` to use LLM tests",
22
+ required_dependencies=["ragas"],
23
+ extra="llm",
24
+ ) from e
25
+
14
26
 
15
27
  @tags("ragas", "llm", "retrieval_performance")
16
28
  @tasks("text_qa", "text_generation", "text_summarization", "text_classification")
@@ -107,12 +119,6 @@ def ContextUtilization(
107
119
  - Requires proper context retrieval to be effective; irrelevant context chunks can skew the results.
108
120
  - Dependent on large sample sizes to provide stable and reliable estimates of utilization performance.
109
121
  """
110
- try:
111
- from ragas import evaluate
112
- from ragas.metrics import context_utilization
113
- except ImportError:
114
- raise ImportError("Please run `pip install validmind[llm]` to use LLM tests")
115
-
116
122
  warnings.filterwarnings(
117
123
  "ignore",
118
124
  category=FutureWarning,
@@ -8,9 +8,21 @@ import plotly.express as px
8
8
  from datasets import Dataset
9
9
 
10
10
  from validmind import tags, tasks
11
+ from validmind.errors import MissingDependencyError
11
12
 
12
13
  from .utils import get_ragas_config, get_renamed_columns
13
14
 
15
+ try:
16
+ from ragas import evaluate
17
+ from ragas.metrics import faithfulness
18
+ except ImportError as e:
19
+ raise MissingDependencyError(
20
+ "Missing required package `ragas` for Faithfulness. "
21
+ "Please run `pip install validmind[llm]` to use LLM tests",
22
+ required_dependencies=["ragas"],
23
+ extra="llm",
24
+ ) from e
25
+
14
26
 
15
27
  @tags("ragas", "llm", "rag_performance")
16
28
  @tasks("text_qa", "text_generation", "text_summarization")
@@ -78,12 +90,6 @@ def Faithfulness(
78
90
  }
79
91
  ```
80
92
  """
81
- try:
82
- from ragas import evaluate
83
- from ragas.metrics import faithfulness
84
- except ImportError:
85
- raise ImportError("Please run `pip install validmind[llm]` to use LLM tests")
86
-
87
93
  warnings.filterwarnings(
88
94
  "ignore",
89
95
  category=FutureWarning,
@@ -8,9 +8,21 @@ import plotly.express as px
8
8
  from datasets import Dataset
9
9
 
10
10
  from validmind import tags, tasks
11
+ from validmind.errors import MissingDependencyError
11
12
 
12
13
  from .utils import get_ragas_config, get_renamed_columns
13
14
 
15
+ try:
16
+ from ragas import evaluate
17
+ from ragas.metrics import noise_sensitivity_relevant
18
+ except ImportError as e:
19
+ raise MissingDependencyError(
20
+ "Missing required package `ragas` for NoiseSensitivity. "
21
+ "Please run `pip install validmind[llm]` to use LLM tests",
22
+ required_dependencies=["ragas"],
23
+ extra="llm",
24
+ ) from e
25
+
14
26
 
15
27
  @tags("ragas", "llm", "rag_performance")
16
28
  @tasks("text_qa", "text_generation", "text_summarization")
@@ -100,12 +112,6 @@ def NoiseSensitivity(
100
112
  - Primarily applicable to tasks like text QA, text generation, and text summarization where contextual relevance is
101
113
  critical.
102
114
  """
103
- try:
104
- from ragas import evaluate
105
- from ragas.metrics import noise_sensitivity_relevant
106
- except ImportError:
107
- raise ImportError("Please run `pip install validmind[llm]` to use LLM tests")
108
-
109
115
  warnings.filterwarnings(
110
116
  "ignore",
111
117
  category=FutureWarning,
@@ -67,6 +67,7 @@ class ClassifierPerformance(Metric):
67
67
  "multiclass_classification",
68
68
  "model_performance",
69
69
  ]
70
+ default_params = {"average": "macro"}
70
71
 
71
72
  def summary(self, metric_value: dict):
72
73
  """
@@ -134,11 +135,13 @@ class ClassifierPerformance(Metric):
134
135
  if len(np.unique(y_true)) > 2:
135
136
  y_pred = self.inputs.dataset.y_pred(self.inputs.model)
136
137
  y_true = y_true.astype(y_pred.dtype)
137
- roc_auc = multiclass_roc_auc_score(y_true, y_pred)
138
+ roc_auc = multiclass_roc_auc_score(
139
+ y_true, y_pred, average=self.params["average"]
140
+ )
138
141
  else:
139
142
  y_prob = self.inputs.dataset.y_prob(self.inputs.model)
140
143
  y_true = y_true.astype(y_prob.dtype).flatten()
141
- roc_auc = roc_auc_score(y_true, y_prob)
144
+ roc_auc = roc_auc_score(y_true, y_prob, average=self.params["average"])
142
145
 
143
146
  report["roc_auc"] = roc_auc
144
147