scwidgets 0.1.0.dev0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scwidgets-0.1.0.dev0/LICENSE +28 -0
- scwidgets-0.1.0.dev0/PKG-INFO +45 -0
- scwidgets-0.1.0.dev0/README.rst +19 -0
- scwidgets-0.1.0.dev0/pyproject.toml +47 -0
- scwidgets-0.1.0.dev0/setup.cfg +4 -0
- scwidgets-0.1.0.dev0/src/scwidgets/__init__.py +47 -0
- scwidgets-0.1.0.dev0/src/scwidgets/_css_style.py +24 -0
- scwidgets-0.1.0.dev0/src/scwidgets/_utils.py +46 -0
- scwidgets-0.1.0.dev0/src/scwidgets/check/__init__.py +24 -0
- scwidgets-0.1.0.dev0/src/scwidgets/check/_asserts.py +294 -0
- scwidgets-0.1.0.dev0/src/scwidgets/check/_check.py +501 -0
- scwidgets-0.1.0.dev0/src/scwidgets/check/_widget_check_registry.py +334 -0
- scwidgets-0.1.0.dev0/src/scwidgets/code/__init__.py +7 -0
- scwidgets-0.1.0.dev0/src/scwidgets/code/_widget_code_input.py +250 -0
- scwidgets-0.1.0.dev0/src/scwidgets/code/_widget_parameter_panel.py +100 -0
- scwidgets-0.1.0.dev0/src/scwidgets/css/widgets.css +132 -0
- scwidgets-0.1.0.dev0/src/scwidgets/cue/__init__.py +26 -0
- scwidgets-0.1.0.dev0/src/scwidgets/cue/_widget_cue.py +111 -0
- scwidgets-0.1.0.dev0/src/scwidgets/cue/_widget_cue_box.py +205 -0
- scwidgets-0.1.0.dev0/src/scwidgets/cue/_widget_cue_figure.py +162 -0
- scwidgets-0.1.0.dev0/src/scwidgets/cue/_widget_cue_object.py +76 -0
- scwidgets-0.1.0.dev0/src/scwidgets/cue/_widget_cue_output.py +78 -0
- scwidgets-0.1.0.dev0/src/scwidgets/cue/_widget_reset_cue_button.py +352 -0
- scwidgets-0.1.0.dev0/src/scwidgets/exercise/__init__.py +5 -0
- scwidgets-0.1.0.dev0/src/scwidgets/exercise/_widget_code_exercise.py +774 -0
- scwidgets-0.1.0.dev0/src/scwidgets/exercise/_widget_exercise_registry.py +575 -0
- scwidgets-0.1.0.dev0/src/scwidgets/exercise/_widget_text_exercise.py +217 -0
- scwidgets-0.1.0.dev0/src/scwidgets.egg-info/PKG-INFO +45 -0
- scwidgets-0.1.0.dev0/src/scwidgets.egg-info/SOURCES.txt +34 -0
- scwidgets-0.1.0.dev0/src/scwidgets.egg-info/dependency_links.txt +1 -0
- scwidgets-0.1.0.dev0/src/scwidgets.egg-info/requires.txt +5 -0
- scwidgets-0.1.0.dev0/src/scwidgets.egg-info/top_level.txt +1 -0
- scwidgets-0.1.0.dev0/tests/test_answer.py +289 -0
- scwidgets-0.1.0.dev0/tests/test_check.py +405 -0
- scwidgets-0.1.0.dev0/tests/test_code.py +375 -0
- scwidgets-0.1.0.dev0/tests/test_widgets.py +1433 -0
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
BSD 3-Clause License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2023, scicode-widgets developer team
|
|
4
|
+
|
|
5
|
+
Redistribution and use in source and binary forms, with or without
|
|
6
|
+
modification, are permitted provided that the following conditions are met:
|
|
7
|
+
|
|
8
|
+
1. Redistributions of source code must retain the above copyright notice, this
|
|
9
|
+
list of conditions and the following disclaimer.
|
|
10
|
+
|
|
11
|
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
|
12
|
+
this list of conditions and the following disclaimer in the documentation
|
|
13
|
+
and/or other materials provided with the distribution.
|
|
14
|
+
|
|
15
|
+
3. Neither the name of the copyright holder nor the names of its
|
|
16
|
+
contributors may be used to endorse or promote products derived from
|
|
17
|
+
this software without specific prior written permission.
|
|
18
|
+
|
|
19
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
20
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
21
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
22
|
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
23
|
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
24
|
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
25
|
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
26
|
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
27
|
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
28
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: scwidgets
|
|
3
|
+
Version: 0.1.0.dev0
|
|
4
|
+
Summary: A collection of widgets to prepare interactive scientific visualisations, including user code input and validation
|
|
5
|
+
License: BSD-3-Clause
|
|
6
|
+
Classifier: Intended Audience :: Science/Research
|
|
7
|
+
Classifier: Operating System :: POSIX
|
|
8
|
+
Classifier: License :: OSI Approved :: BSD License
|
|
9
|
+
Classifier: Programming Language :: Python
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.8
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
14
|
+
Classifier: Topic :: Scientific/Engineering
|
|
15
|
+
Classifier: Topic :: Scientific/Engineering :: Chemistry
|
|
16
|
+
Classifier: Topic :: Scientific/Engineering :: Physics
|
|
17
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
18
|
+
Requires-Python: >=3.9
|
|
19
|
+
Description-Content-Type: text/x-rst
|
|
20
|
+
License-File: LICENSE
|
|
21
|
+
Requires-Dist: ipywidgets>=8.0.0
|
|
22
|
+
Requires-Dist: numpy<2.0.0
|
|
23
|
+
Requires-Dist: widget_code_input>=4.0.13
|
|
24
|
+
Requires-Dist: matplotlib
|
|
25
|
+
Requires-Dist: termcolor
|
|
26
|
+
|
|
27
|
+
Important
|
|
28
|
+
=========
|
|
29
|
+
|
|
30
|
+
So far scicode-widget has been created by prototyping without much concern about the code quality. This resulted in faster development time but in cost of readability and maintanability of the code. Since we finished now the prototype phase and have converged on a set of functionalities we are satisfied with, we are in the process of refactoring the resulting code in this branch. While we are refactoring we recommend the usage of the `vertical-slice branch <https://github.com/osscar-org/scicode-widgets/tree/vertical-slice>`_ till all features have been implemented in the refactor.
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
scicode-widgets
|
|
34
|
+
===============
|
|
35
|
+
|
|
36
|
+
.. marker-package-description
|
|
37
|
+
|
|
38
|
+
A collection of ipywidgets for the creation of interactive code demos and educational notebooks with exercises that can be checked and exported.
|
|
39
|
+
|
|
40
|
+
Installation
|
|
41
|
+
------------
|
|
42
|
+
|
|
43
|
+
.. code-block:: bash
|
|
44
|
+
|
|
45
|
+
pip install .
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
Important
|
|
2
|
+
=========
|
|
3
|
+
|
|
4
|
+
So far scicode-widget has been created by prototyping without much concern about the code quality. This resulted in faster development time but in cost of readability and maintanability of the code. Since we finished now the prototype phase and have converged on a set of functionalities we are satisfied with, we are in the process of refactoring the resulting code in this branch. While we are refactoring we recommend the usage of the `vertical-slice branch <https://github.com/osscar-org/scicode-widgets/tree/vertical-slice>`_ till all features have been implemented in the refactor.
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
scicode-widgets
|
|
8
|
+
===============
|
|
9
|
+
|
|
10
|
+
.. marker-package-description
|
|
11
|
+
|
|
12
|
+
A collection of ipywidgets for the creation of interactive code demos and educational notebooks with exercises that can be checked and exported.
|
|
13
|
+
|
|
14
|
+
Installation
|
|
15
|
+
------------
|
|
16
|
+
|
|
17
|
+
.. code-block:: bash
|
|
18
|
+
|
|
19
|
+
pip install .
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
# setuptools requirement due to specifiying the version the in pyproject.toml
|
|
3
|
+
# https://packaging.python.org/en/latest/guides/single-sourcing-package-version/
|
|
4
|
+
requires = [
|
|
5
|
+
"setuptools>=61.0.0",
|
|
6
|
+
"wheel"
|
|
7
|
+
]
|
|
8
|
+
build-backend = "setuptools.build_meta"
|
|
9
|
+
|
|
10
|
+
[project]
|
|
11
|
+
name = "scwidgets"
|
|
12
|
+
description = "A collection of widgets to prepare interactive scientific visualisations, including user code input and validation"
|
|
13
|
+
readme = "README.rst"
|
|
14
|
+
requires-python = ">=3.9"
|
|
15
|
+
license = {text = "BSD-3-Clause"}
|
|
16
|
+
classifiers = [
|
|
17
|
+
"Intended Audience :: Science/Research",
|
|
18
|
+
"Operating System :: POSIX",
|
|
19
|
+
"License :: OSI Approved :: BSD License",
|
|
20
|
+
"Programming Language :: Python",
|
|
21
|
+
"Programming Language :: Python :: 3.8",
|
|
22
|
+
"Programming Language :: Python :: 3.9",
|
|
23
|
+
"Programming Language :: Python :: 3.10",
|
|
24
|
+
"Programming Language :: Python :: 3.11",
|
|
25
|
+
"Topic :: Scientific/Engineering",
|
|
26
|
+
"Topic :: Scientific/Engineering :: Chemistry",
|
|
27
|
+
"Topic :: Scientific/Engineering :: Physics",
|
|
28
|
+
"Topic :: Software Development :: Libraries :: Python Modules",
|
|
29
|
+
]
|
|
30
|
+
dependencies = [
|
|
31
|
+
"ipywidgets>=8.0.0",
|
|
32
|
+
"numpy<2.0.0",
|
|
33
|
+
"widget_code_input>=4.0.13",
|
|
34
|
+
"matplotlib",
|
|
35
|
+
"termcolor"
|
|
36
|
+
]
|
|
37
|
+
dynamic = ["version"]
|
|
38
|
+
|
|
39
|
+
[tool.setuptools.dynamic]
|
|
40
|
+
version = {attr = "scwidgets.__version__"}
|
|
41
|
+
readme = {file = ["README.rst"]}
|
|
42
|
+
|
|
43
|
+
[tool.setuptools.package-data]
|
|
44
|
+
scwidgets = ["css/widgets.css"]
|
|
45
|
+
|
|
46
|
+
[tool.isort]
|
|
47
|
+
profile = "black"
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
__version__ = "0.1.0-dev0"
|
|
2
|
+
__authors__ = "the scicode-widgets developer team"
|
|
3
|
+
|
|
4
|
+
from ._css_style import CssStyle, get_css_style
|
|
5
|
+
from .check import * # noqa: F403
|
|
6
|
+
from .code import * # noqa: F403
|
|
7
|
+
from .cue import * # noqa: F403
|
|
8
|
+
from .exercise import * # noqa: F403
|
|
9
|
+
|
|
10
|
+
__all__ = [ # noqa: F405
|
|
11
|
+
# css_style
|
|
12
|
+
"CssStyle",
|
|
13
|
+
"get_css_style",
|
|
14
|
+
# cue
|
|
15
|
+
"CueWidget",
|
|
16
|
+
"CheckCueBox",
|
|
17
|
+
"CueBox",
|
|
18
|
+
"SaveCueBox",
|
|
19
|
+
"UpdateCueBox",
|
|
20
|
+
"ResetCueButton",
|
|
21
|
+
"SaveResetCueButton",
|
|
22
|
+
"CheckResetCueButton",
|
|
23
|
+
"UpdateResetCueButton",
|
|
24
|
+
"CueOutput",
|
|
25
|
+
"CueObject",
|
|
26
|
+
"CueFigure",
|
|
27
|
+
# code
|
|
28
|
+
"CodeInput",
|
|
29
|
+
"ParameterPanel",
|
|
30
|
+
# check
|
|
31
|
+
"Check",
|
|
32
|
+
"CheckResult",
|
|
33
|
+
"AssertResult",
|
|
34
|
+
"CheckRegistry",
|
|
35
|
+
"CheckableWidget",
|
|
36
|
+
"assert_equal",
|
|
37
|
+
"assert_shape",
|
|
38
|
+
"assert_numpy_allclose",
|
|
39
|
+
"assert_type",
|
|
40
|
+
"assert_numpy_floating_sub_dtype",
|
|
41
|
+
"assert_numpy_sub_dtype",
|
|
42
|
+
# exercise
|
|
43
|
+
"CodeExercise",
|
|
44
|
+
"TextExercise",
|
|
45
|
+
"ExerciseWidget",
|
|
46
|
+
"ExerciseRegistry",
|
|
47
|
+
]
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from ipywidgets import HTML
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CssStyle(HTML):
|
|
7
|
+
"""
|
|
8
|
+
This HTML widget has to be displayed so the css style is loaded in the notebook.
|
|
9
|
+
|
|
10
|
+
:param preamble: Text to appear before the style sheet
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(self, preamble: str = ""):
|
|
14
|
+
with open(os.path.join(os.path.dirname(__file__), "css/widgets.css")) as file:
|
|
15
|
+
style_txt = file.read()
|
|
16
|
+
|
|
17
|
+
HTML.__init__(self, preamble + "<style>" + style_txt + "</style>")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_css_style() -> HTML:
|
|
21
|
+
return CssStyle(
|
|
22
|
+
preamble="HTML with scicode-widget css style sheet. "
|
|
23
|
+
"Please keep this cell output alive."
|
|
24
|
+
)
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
from termcolor import colored
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Formatter:
|
|
7
|
+
LINE_LENGTH = 120
|
|
8
|
+
INFO_COLOR = "blue"
|
|
9
|
+
ERROR_COLOR = "red"
|
|
10
|
+
SUCCESS_COLOR = "green"
|
|
11
|
+
|
|
12
|
+
@staticmethod
|
|
13
|
+
def format_title_message(message: str) -> str:
|
|
14
|
+
return message.center(Formatter.LINE_LENGTH - len(message) // 2, "-")
|
|
15
|
+
|
|
16
|
+
@staticmethod
|
|
17
|
+
def break_lines(message: str) -> str:
|
|
18
|
+
return "\n ".join(
|
|
19
|
+
re.findall(r".{1," + str(Formatter.LINE_LENGTH) + "}", message)
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
@staticmethod
|
|
23
|
+
def color_error_message(message: str) -> str:
|
|
24
|
+
return colored(message, Formatter.ERROR_COLOR, attrs=["bold"])
|
|
25
|
+
|
|
26
|
+
@staticmethod
|
|
27
|
+
def color_success_message(message: str) -> str:
|
|
28
|
+
return colored(message, Formatter.SUCCESS_COLOR, attrs=["bold"])
|
|
29
|
+
print(Formatter.color_success_message(message))
|
|
30
|
+
|
|
31
|
+
@staticmethod
|
|
32
|
+
def color_info_message(message: str):
|
|
33
|
+
return colored(message, Formatter.INFO_COLOR, attrs=["bold"])
|
|
34
|
+
print(Formatter.color_info_message(message))
|
|
35
|
+
|
|
36
|
+
@staticmethod
|
|
37
|
+
def color_assert_failed(message: str) -> str:
|
|
38
|
+
return colored(message, "light_" + Formatter.ERROR_COLOR)
|
|
39
|
+
|
|
40
|
+
@staticmethod
|
|
41
|
+
def color_assert_info(message: str) -> str:
|
|
42
|
+
return colored(message, "light_" + Formatter.INFO_COLOR)
|
|
43
|
+
|
|
44
|
+
@staticmethod
|
|
45
|
+
def color_assert_success(message: str) -> str:
|
|
46
|
+
return colored(message, "light_" + Formatter.SUCCESS_COLOR)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from ._asserts import (
|
|
2
|
+
assert_equal,
|
|
3
|
+
assert_numpy_allclose,
|
|
4
|
+
assert_numpy_floating_sub_dtype,
|
|
5
|
+
assert_numpy_sub_dtype,
|
|
6
|
+
assert_shape,
|
|
7
|
+
assert_type,
|
|
8
|
+
)
|
|
9
|
+
from ._check import AssertResult, Check, CheckResult
|
|
10
|
+
from ._widget_check_registry import CheckableWidget, CheckRegistry
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"Check",
|
|
14
|
+
"CheckResult",
|
|
15
|
+
"AssertResult",
|
|
16
|
+
"CheckRegistry",
|
|
17
|
+
"CheckableWidget",
|
|
18
|
+
"assert_equal",
|
|
19
|
+
"assert_shape",
|
|
20
|
+
"assert_numpy_allclose",
|
|
21
|
+
"assert_type",
|
|
22
|
+
"assert_numpy_floating_sub_dtype",
|
|
23
|
+
"assert_numpy_sub_dtype",
|
|
24
|
+
]
|
|
@@ -0,0 +1,294 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
from collections import abc
|
|
3
|
+
from typing import Iterable, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from ._check import AssertResult, Check
|
|
8
|
+
|
|
9
|
+
AssertFunctionOutputT = Union[str, AssertResult]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def assert_equal(
|
|
13
|
+
output_parameters: Check.FunOutParamsT,
|
|
14
|
+
output_references: Check.FunOutParamsT,
|
|
15
|
+
parameters_to_check: Union[Iterable[int], str] = "all",
|
|
16
|
+
) -> AssertResult:
|
|
17
|
+
assert len(output_parameters) == len(
|
|
18
|
+
output_references
|
|
19
|
+
), "output_parameters and output_references have to have the same length"
|
|
20
|
+
|
|
21
|
+
parameter_indices: Iterable[int]
|
|
22
|
+
if isinstance(parameters_to_check, str):
|
|
23
|
+
if parameters_to_check == "all":
|
|
24
|
+
parameter_indices = range(len(output_parameters))
|
|
25
|
+
else:
|
|
26
|
+
raise ValueError(
|
|
27
|
+
f'Got parameters_to_check="{parameters_to_check}" but only "all" '
|
|
28
|
+
"is accepted as string"
|
|
29
|
+
)
|
|
30
|
+
elif isinstance(parameters_to_check, abc.Iterable):
|
|
31
|
+
parameter_indices = parameters_to_check # type: ignore[assignment]
|
|
32
|
+
else:
|
|
33
|
+
raise TypeError(
|
|
34
|
+
"Only str and Iterable are accepted for parameters_to_check, "
|
|
35
|
+
f"but got type {type(parameters_to_check)}."
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
failed_parameter_indices = []
|
|
39
|
+
failed_parameter_values = []
|
|
40
|
+
messages = []
|
|
41
|
+
for i in parameter_indices:
|
|
42
|
+
if not output_parameters[i] == output_references[i]:
|
|
43
|
+
message = (
|
|
44
|
+
f"Expected {output_references[i]} " f"but got {output_parameters[i]}."
|
|
45
|
+
)
|
|
46
|
+
failed_parameter_indices.append(i)
|
|
47
|
+
failed_parameter_values.append(output_parameters[i])
|
|
48
|
+
messages.append(message)
|
|
49
|
+
return AssertResult(
|
|
50
|
+
assert_name="assert_equal",
|
|
51
|
+
parameter_indices=failed_parameter_indices,
|
|
52
|
+
parameter_values=failed_parameter_values,
|
|
53
|
+
messages=messages,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def assert_shape(
|
|
58
|
+
output_parameters: Check.FunOutParamsT,
|
|
59
|
+
output_references: Check.FunOutParamsT,
|
|
60
|
+
parameters_to_check: Union[Iterable[int], str] = "auto",
|
|
61
|
+
) -> AssertResult:
|
|
62
|
+
assert len(output_parameters) == len(
|
|
63
|
+
output_references
|
|
64
|
+
), "output_parameters and output_references have to have the same length"
|
|
65
|
+
|
|
66
|
+
parameter_indices: Iterable[int]
|
|
67
|
+
if isinstance(parameters_to_check, str):
|
|
68
|
+
if parameters_to_check == "auto":
|
|
69
|
+
parameter_indices = []
|
|
70
|
+
for i in range(len(output_references)):
|
|
71
|
+
if hasattr(output_references[i], "shape"):
|
|
72
|
+
parameter_indices.append(i)
|
|
73
|
+
elif parameters_to_check == "all":
|
|
74
|
+
parameter_indices = range(len(output_parameters))
|
|
75
|
+
else:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f'Got parameters_to_check="{parameters_to_check}" but only "all" '
|
|
78
|
+
' and "auto" are accepted as string'
|
|
79
|
+
)
|
|
80
|
+
elif isinstance(parameters_to_check, abc.Iterable):
|
|
81
|
+
parameter_indices = parameters_to_check # type: ignore[assignment]
|
|
82
|
+
else:
|
|
83
|
+
raise TypeError(
|
|
84
|
+
"Only str and Iterable are accepted for parameters_to_check, "
|
|
85
|
+
f"but got type {type(parameters_to_check)}."
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
failed_parameter_indices = []
|
|
89
|
+
failed_parameter_values = []
|
|
90
|
+
messages = []
|
|
91
|
+
for i in parameter_indices:
|
|
92
|
+
if output_parameters[i].shape != output_references[i].shape:
|
|
93
|
+
message = (
|
|
94
|
+
f"Expected shape {output_references[i].shape} "
|
|
95
|
+
f"but got {output_parameters[i].shape}."
|
|
96
|
+
)
|
|
97
|
+
failed_parameter_indices.append(i)
|
|
98
|
+
failed_parameter_values.append(output_parameters[i])
|
|
99
|
+
messages.append(message)
|
|
100
|
+
|
|
101
|
+
return AssertResult(
|
|
102
|
+
assert_name="assert_shape",
|
|
103
|
+
parameter_indices=failed_parameter_indices,
|
|
104
|
+
parameter_values=failed_parameter_values,
|
|
105
|
+
messages=messages,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def assert_numpy_allclose(
|
|
110
|
+
output_parameters: Check.FunOutParamsT,
|
|
111
|
+
output_references: Check.FunOutParamsT,
|
|
112
|
+
parameters_to_check: Union[Iterable[int], str] = "auto",
|
|
113
|
+
rtol=1e-05,
|
|
114
|
+
atol=1e-08,
|
|
115
|
+
equal_nan=False,
|
|
116
|
+
) -> AssertResult:
|
|
117
|
+
assert len(output_parameters) == len(
|
|
118
|
+
output_references
|
|
119
|
+
), "output_parameters and output_references have to have the same length"
|
|
120
|
+
|
|
121
|
+
parameter_indices: Iterable[int]
|
|
122
|
+
if isinstance(parameters_to_check, str):
|
|
123
|
+
if parameters_to_check == "auto":
|
|
124
|
+
parameter_indices = []
|
|
125
|
+
for i in range(len(output_references)):
|
|
126
|
+
try:
|
|
127
|
+
np.allclose(output_references[i], output_references[i])
|
|
128
|
+
parameter_indices.append(i)
|
|
129
|
+
except Exception:
|
|
130
|
+
pass
|
|
131
|
+
elif parameters_to_check == "all":
|
|
132
|
+
parameter_indices = range(len(output_parameters))
|
|
133
|
+
else:
|
|
134
|
+
raise ValueError(
|
|
135
|
+
f'Got parameters_to_check="{parameters_to_check}" but only "all" '
|
|
136
|
+
' and "auto" are accepted as string'
|
|
137
|
+
)
|
|
138
|
+
elif isinstance(parameters_to_check, abc.Iterable):
|
|
139
|
+
parameter_indices = parameters_to_check # type: ignore[assignment]
|
|
140
|
+
else:
|
|
141
|
+
raise TypeError(
|
|
142
|
+
"Only str and Iterable are accepted for parameters_to_check, "
|
|
143
|
+
f"but got type {type(parameters_to_check)}."
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
failed_parameter_indices = []
|
|
147
|
+
failed_parameter_values = []
|
|
148
|
+
messages = []
|
|
149
|
+
for i in parameter_indices:
|
|
150
|
+
is_allclose = np.allclose(
|
|
151
|
+
output_parameters[i],
|
|
152
|
+
output_references[i],
|
|
153
|
+
atol=atol,
|
|
154
|
+
rtol=rtol,
|
|
155
|
+
equal_nan=equal_nan,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
if not (is_allclose):
|
|
159
|
+
output_parameters_i_arr = np.asarray(output_parameters[i])
|
|
160
|
+
output_references_i_arr = np.asarray(output_references[i])
|
|
161
|
+
|
|
162
|
+
diff = np.abs(output_parameters_i_arr - output_references_i_arr)
|
|
163
|
+
abs_diff = np.sum(diff)
|
|
164
|
+
rel_diff_dividend = np.max(
|
|
165
|
+
np.vstack(
|
|
166
|
+
(
|
|
167
|
+
np.abs(output_parameters_i_arr),
|
|
168
|
+
np.abs(output_references_i_arr),
|
|
169
|
+
)
|
|
170
|
+
),
|
|
171
|
+
axis=0,
|
|
172
|
+
)
|
|
173
|
+
# when both are zero the diff is also zero, so we set it to 1
|
|
174
|
+
# so no division by zero error is raised
|
|
175
|
+
rel_diff_dividend[rel_diff_dividend == 0.0] = 1.0
|
|
176
|
+
rel_diff = np.sum(diff / rel_diff_dividend)
|
|
177
|
+
|
|
178
|
+
message = (
|
|
179
|
+
f"Output is not close to reference absolute difference "
|
|
180
|
+
f"is {abs_diff}, relative difference is {rel_diff}."
|
|
181
|
+
)
|
|
182
|
+
failed_parameter_indices.append(i)
|
|
183
|
+
failed_parameter_values.append(output_parameters[i])
|
|
184
|
+
messages.append(message)
|
|
185
|
+
|
|
186
|
+
return AssertResult(
|
|
187
|
+
assert_name="assert_numpy_allclose",
|
|
188
|
+
parameter_indices=failed_parameter_indices,
|
|
189
|
+
parameter_values=failed_parameter_values,
|
|
190
|
+
messages=messages,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def assert_type(
|
|
195
|
+
output_parameters: Check.FunOutParamsT,
|
|
196
|
+
output_references: Check.FunOutParamsT,
|
|
197
|
+
parameters_to_check: Union[Iterable[int], str] = "all",
|
|
198
|
+
) -> AssertResult:
|
|
199
|
+
assert len(output_parameters) == len(
|
|
200
|
+
output_references
|
|
201
|
+
), "output_parameters and output_references have to have the same length"
|
|
202
|
+
|
|
203
|
+
parameter_indices: Iterable[int]
|
|
204
|
+
if isinstance(parameters_to_check, str):
|
|
205
|
+
if parameters_to_check == "all":
|
|
206
|
+
parameter_indices = range(len(output_parameters))
|
|
207
|
+
else:
|
|
208
|
+
raise ValueError(
|
|
209
|
+
f'Got parameters_to_check="{parameters_to_check}" but only "all" '
|
|
210
|
+
"is accepted as string"
|
|
211
|
+
)
|
|
212
|
+
elif isinstance(parameters_to_check, abc.Iterable):
|
|
213
|
+
parameter_indices = parameters_to_check # type: ignore[assignment]
|
|
214
|
+
else:
|
|
215
|
+
raise TypeError(
|
|
216
|
+
"Only str and Iterable are accepted for parameters_to_check, "
|
|
217
|
+
f"but got type {type(parameters_to_check)}."
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
failed_parameter_indices = []
|
|
221
|
+
failed_parameter_values = []
|
|
222
|
+
messages = []
|
|
223
|
+
for i in parameter_indices:
|
|
224
|
+
if not (isinstance(output_parameters[i], type(output_references[i]))):
|
|
225
|
+
message = (
|
|
226
|
+
f"Expected type {type(output_references[i])} "
|
|
227
|
+
f"but got {type(output_parameters[i])}."
|
|
228
|
+
)
|
|
229
|
+
failed_parameter_indices.append(i)
|
|
230
|
+
failed_parameter_values.append(output_parameters[i])
|
|
231
|
+
messages.append(message)
|
|
232
|
+
return AssertResult(
|
|
233
|
+
assert_name="assert_type",
|
|
234
|
+
parameter_indices=failed_parameter_indices,
|
|
235
|
+
parameter_values=failed_parameter_values,
|
|
236
|
+
messages=messages,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def assert_numpy_sub_dtype(
|
|
241
|
+
output_parameters: Union[Check.FunOutParamsT, tuple[Check.FingerprintT]],
|
|
242
|
+
numpy_type: Union[np.dtype, type],
|
|
243
|
+
parameters_to_check: Union[Iterable[int], str] = "all",
|
|
244
|
+
) -> AssertResult:
|
|
245
|
+
if parameters_to_check == "all":
|
|
246
|
+
parameter_indices = range(len(output_parameters))
|
|
247
|
+
elif isinstance(parameters_to_check, abc.Iterable):
|
|
248
|
+
parameter_indices = parameters_to_check # type: ignore[assignment]
|
|
249
|
+
else:
|
|
250
|
+
raise TypeError(
|
|
251
|
+
"Only str and Iterable are accepted for parameters_to_check, "
|
|
252
|
+
f"but got type {type(parameters_to_check)}."
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
failed_parameter_indices = []
|
|
256
|
+
failed_parameter_values = []
|
|
257
|
+
messages = []
|
|
258
|
+
for i in parameter_indices:
|
|
259
|
+
if not (isinstance(output_parameters[i], np.ndarray)):
|
|
260
|
+
failed_parameter_indices.append(i)
|
|
261
|
+
failed_parameter_values.append(output_parameters[i])
|
|
262
|
+
message = (
|
|
263
|
+
f"Output expected to be numpy array "
|
|
264
|
+
f"but got {type(output_parameters[i])}."
|
|
265
|
+
)
|
|
266
|
+
messages.append(message)
|
|
267
|
+
if not (np.issubdtype(output_parameters[i].dtype, numpy_type)):
|
|
268
|
+
if isinstance(numpy_type, np.dtype):
|
|
269
|
+
type_name = numpy_type.type.__name__
|
|
270
|
+
else:
|
|
271
|
+
type_name = numpy_type.__name__
|
|
272
|
+
failed_parameter_indices.append(i)
|
|
273
|
+
failed_parameter_values.append(output_parameters[i])
|
|
274
|
+
message = (
|
|
275
|
+
f"Output expected to be sub dtype "
|
|
276
|
+
f"numpy.{type_name} but got "
|
|
277
|
+
f"numpy.{output_parameters[i].dtype.type.__name__}."
|
|
278
|
+
)
|
|
279
|
+
messages.append(message)
|
|
280
|
+
if isinstance(numpy_type, np.dtype):
|
|
281
|
+
type_name = numpy_type.type.__name__
|
|
282
|
+
else:
|
|
283
|
+
type_name = numpy_type.__name__
|
|
284
|
+
return AssertResult(
|
|
285
|
+
assert_name=f"assert_numpy_{type_name}_sub_dtype",
|
|
286
|
+
parameter_indices=failed_parameter_indices,
|
|
287
|
+
parameter_values=failed_parameter_values,
|
|
288
|
+
messages=messages,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
assert_numpy_floating_sub_dtype = functools.partial(
|
|
293
|
+
assert_numpy_sub_dtype, numpy_type=np.floating
|
|
294
|
+
)
|