caskade 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caskade/__init__.py +48 -0
- caskade/_version.py +16 -0
- caskade/base.py +224 -0
- caskade/context.py +73 -0
- caskade/decorators.py +75 -0
- caskade/errors.py +99 -0
- caskade/module.py +322 -0
- caskade/param.py +288 -0
- caskade/tests.py +48 -0
- caskade/warnings.py +17 -0
- caskade-0.6.1.dist-info/METADATA +113 -0
- caskade-0.6.1.dist-info/RECORD +14 -0
- caskade-0.6.1.dist-info/WHEEL +4 -0
- caskade-0.6.1.dist-info/licenses/LICENSE +21 -0
caskade/__init__.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from ._version import version as VERSION # noqa
|
|
2
|
+
|
|
3
|
+
from .base import Node
|
|
4
|
+
from .context import ActiveContext, ValidContext, OverrideParam
|
|
5
|
+
from .decorators import forward
|
|
6
|
+
from .module import Module
|
|
7
|
+
from .param import Param
|
|
8
|
+
from .tests import test
|
|
9
|
+
from .errors import (
|
|
10
|
+
CaskadeException,
|
|
11
|
+
GraphError,
|
|
12
|
+
NodeConfigurationError,
|
|
13
|
+
ParamConfigurationError,
|
|
14
|
+
ParamTypeError,
|
|
15
|
+
ActiveStateError,
|
|
16
|
+
FillDynamicParamsError,
|
|
17
|
+
FillDynamicParamsTensorError,
|
|
18
|
+
FillDynamicParamsSequenceError,
|
|
19
|
+
FillDynamicParamsMappingError,
|
|
20
|
+
)
|
|
21
|
+
from .warnings import CaskadeWarning, InvalidValueWarning
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
__version__ = VERSION
|
|
25
|
+
__author__ = "Connor Stone and Alexandre Adam"
|
|
26
|
+
|
|
27
|
+
__all__ = (
|
|
28
|
+
"Node",
|
|
29
|
+
"Module",
|
|
30
|
+
"Param",
|
|
31
|
+
"ActiveContext",
|
|
32
|
+
"ValidContext",
|
|
33
|
+
"OverrideParam",
|
|
34
|
+
"forward",
|
|
35
|
+
"test",
|
|
36
|
+
"CaskadeException",
|
|
37
|
+
"GraphError",
|
|
38
|
+
"NodeConfigurationError",
|
|
39
|
+
"ParamConfigurationError",
|
|
40
|
+
"ParamTypeError",
|
|
41
|
+
"ActiveStateError",
|
|
42
|
+
"FillDynamicParamsError",
|
|
43
|
+
"FillDynamicParamsTensorError",
|
|
44
|
+
"FillDynamicParamsSequenceError",
|
|
45
|
+
"FillDynamicParamsMappingError",
|
|
46
|
+
"CaskadeWarning",
|
|
47
|
+
"InvalidValueWarning",
|
|
48
|
+
)
|
caskade/_version.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# file generated by setuptools_scm
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
TYPE_CHECKING = False
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
from typing import Tuple, Union
|
|
6
|
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
|
7
|
+
else:
|
|
8
|
+
VERSION_TUPLE = object
|
|
9
|
+
|
|
10
|
+
version: str
|
|
11
|
+
__version__: str
|
|
12
|
+
__version_tuple__: VERSION_TUPLE
|
|
13
|
+
version_tuple: VERSION_TUPLE
|
|
14
|
+
|
|
15
|
+
__version__ = version = '0.6.1'
|
|
16
|
+
__version_tuple__ = version_tuple = (0, 6, 1)
|
caskade/base.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
from typing import Optional, Union
|
|
2
|
+
|
|
3
|
+
from .errors import GraphError, NodeConfigurationError
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Node(object):
|
|
7
|
+
"""
|
|
8
|
+
Base graph node class for ``caskade`` objects.
|
|
9
|
+
|
|
10
|
+
The ``Node`` object is the base class for all ``caskade`` objects. It is used to
|
|
11
|
+
construct the directed acyclic graph (DAG). The primary function of the
|
|
12
|
+
``Node`` object is to manage the parent-child relationships between nodes in
|
|
13
|
+
the graph. There is limited functionality for the ``Node`` object, though it
|
|
14
|
+
implements the base versions of the ``active`` state and ``to`` /
|
|
15
|
+
``update_graph`` methods. The ``active`` state is used to communicate
|
|
16
|
+
through the graph that the simulator is currently running. The ``to`` method
|
|
17
|
+
is used to move and/or cast the values of the parameter. The ``update_graph``
|
|
18
|
+
method is used signal all parents that the graph below them has changed.
|
|
19
|
+
|
|
20
|
+
Examples
|
|
21
|
+
--------
|
|
22
|
+
|
|
23
|
+
Example making some ``Node`` objects and then linking/unlinking them::
|
|
24
|
+
|
|
25
|
+
n1 = Node()
|
|
26
|
+
n2 = Node()
|
|
27
|
+
n1.link("subnode", n2) # link n2 as a child of n1, may use any str as the key
|
|
28
|
+
n1.unlink("subnode") # alternately n1.unlink(n2) to unlink by object
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
graphviz_types = {"node": {"style": "solid", "color": "black", "shape": "circle"}}
|
|
32
|
+
|
|
33
|
+
def __init__(self, name: Optional[str] = None):
|
|
34
|
+
if name is None:
|
|
35
|
+
name = self.__class__.__name__
|
|
36
|
+
if not isinstance(name, str):
|
|
37
|
+
raise NodeConfigurationError(f"{self.__class__.__name__} name must be a string")
|
|
38
|
+
if "|" in name:
|
|
39
|
+
raise NodeConfigurationError(f"{self.__class__.__name__} cannot contain '|'")
|
|
40
|
+
self._name = name
|
|
41
|
+
self._children = {}
|
|
42
|
+
self._parents = set()
|
|
43
|
+
self._active = False
|
|
44
|
+
self._type = "node"
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def name(self) -> str:
|
|
48
|
+
return self._name
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def children(self) -> dict:
|
|
52
|
+
return self._children
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def parents(self) -> set:
|
|
56
|
+
return self._parents
|
|
57
|
+
|
|
58
|
+
def link(self, key: Union[str, "Node"], child: Optional["Node"] = None):
|
|
59
|
+
"""Link the current ``Node`` object to another ``Node`` object as a child.
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
key: (Union[str, Node])
|
|
64
|
+
The key to link the child node with.
|
|
65
|
+
child: (Optional[Node], optional)
|
|
66
|
+
The child ``Node`` object to link to. Defaults to None in which
|
|
67
|
+
case the key is used as the child.
|
|
68
|
+
|
|
69
|
+
Examples
|
|
70
|
+
--------
|
|
71
|
+
|
|
72
|
+
Example making some ``Node`` objects and then linking/unlinking them. demonstrating multiple ways to link/unlink::
|
|
73
|
+
|
|
74
|
+
n1 = Node()
|
|
75
|
+
n2 = Node()
|
|
76
|
+
|
|
77
|
+
n1.link("subnode", n2) # may use any str as the key
|
|
78
|
+
n1.unlink("subnode")
|
|
79
|
+
|
|
80
|
+
# Alternately, link by object
|
|
81
|
+
n1.link(n2)
|
|
82
|
+
n1.unlink(n2)
|
|
83
|
+
"""
|
|
84
|
+
if child is None:
|
|
85
|
+
child = key
|
|
86
|
+
key = child.name
|
|
87
|
+
# Avoid double linking to the same object
|
|
88
|
+
if key in self.children:
|
|
89
|
+
raise GraphError(f"Child key {key} already linked to parent {self.name}")
|
|
90
|
+
if child in self.children.values():
|
|
91
|
+
raise GraphError(f"Child {child.name} already linked to parent {self.name}")
|
|
92
|
+
# avoid cycles
|
|
93
|
+
if self in child.topological_ordering():
|
|
94
|
+
raise GraphError(
|
|
95
|
+
f"Linking {child.name} to {self.name} would create a cycle in the graph"
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
self._children[key] = child
|
|
99
|
+
child._parents.add(self)
|
|
100
|
+
self.update_graph()
|
|
101
|
+
|
|
102
|
+
def unlink(self, key: Union[str, "Node"]):
|
|
103
|
+
"""Unlink the current ``Node`` object from another ``Node`` object which is a child."""
|
|
104
|
+
if isinstance(key, Node):
|
|
105
|
+
for node in self.children:
|
|
106
|
+
if self.children[node] == key:
|
|
107
|
+
key = node
|
|
108
|
+
break
|
|
109
|
+
self._children[key]._parents.remove(self)
|
|
110
|
+
self._children[key].update_graph()
|
|
111
|
+
del self._children[key]
|
|
112
|
+
self.update_graph()
|
|
113
|
+
|
|
114
|
+
def topological_ordering(self, with_type: Optional[str] = None) -> tuple["Node"]:
|
|
115
|
+
"""Return a topological ordering of the graph below the current node."""
|
|
116
|
+
ordering = [self]
|
|
117
|
+
for node in self.children.values():
|
|
118
|
+
for subnode in node.topological_ordering():
|
|
119
|
+
if subnode not in ordering:
|
|
120
|
+
ordering.append(subnode)
|
|
121
|
+
if with_type is None:
|
|
122
|
+
return tuple(ordering)
|
|
123
|
+
return tuple(filter(lambda n: n._type == with_type, ordering))
|
|
124
|
+
|
|
125
|
+
def update_graph(self):
|
|
126
|
+
"""Triggers a call to all parents that the graph below them has been
|
|
127
|
+
updated. The base ``Node`` object does nothing with this information, but
|
|
128
|
+
other node types may use this to update internal state."""
|
|
129
|
+
for parent in self.parents:
|
|
130
|
+
parent.update_graph()
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def active(self) -> bool:
|
|
134
|
+
return self._active
|
|
135
|
+
|
|
136
|
+
@active.setter
|
|
137
|
+
def active(self, value: bool):
|
|
138
|
+
# Avoid unnecessary updates
|
|
139
|
+
if self._active == value:
|
|
140
|
+
return
|
|
141
|
+
|
|
142
|
+
# Set self active level
|
|
143
|
+
self._active = value
|
|
144
|
+
|
|
145
|
+
# Propagate active level to children
|
|
146
|
+
for child in self._children.values():
|
|
147
|
+
child.active = value
|
|
148
|
+
|
|
149
|
+
def to(self, device=None, dtype=None):
|
|
150
|
+
"""
|
|
151
|
+
Moves and/or casts the PyTorch values of the ``Node``.
|
|
152
|
+
|
|
153
|
+
Parameters
|
|
154
|
+
----------
|
|
155
|
+
device: (Optional[torch.device], optional)
|
|
156
|
+
The device to move the values to. Defaults to None.
|
|
157
|
+
dtype: (Optional[torch.dtype], optional)
|
|
158
|
+
The desired data type. Defaults to None.
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
for child in self.children.values():
|
|
162
|
+
child.to(device=device, dtype=dtype)
|
|
163
|
+
|
|
164
|
+
return self
|
|
165
|
+
|
|
166
|
+
def graphviz(self, top_down=True) -> "graphviz.Digraph":
|
|
167
|
+
"""Return a graphviz object representing the graph below the current
|
|
168
|
+
node in the DAG.
|
|
169
|
+
|
|
170
|
+
Parameters
|
|
171
|
+
----------
|
|
172
|
+
top_down: (bool, optional)
|
|
173
|
+
Whether to draw the graph top-down (current node at top) or
|
|
174
|
+
bottom-up (current node at bottom). Defaults to True.
|
|
175
|
+
"""
|
|
176
|
+
import graphviz
|
|
177
|
+
|
|
178
|
+
components = set()
|
|
179
|
+
|
|
180
|
+
def add_node(node, dot):
|
|
181
|
+
if node in components:
|
|
182
|
+
return
|
|
183
|
+
dot.attr("node", **node.graphviz_types[node._type])
|
|
184
|
+
dot.node(str(id(node)), repr(node))
|
|
185
|
+
components.add(node)
|
|
186
|
+
|
|
187
|
+
for child in node.children.values():
|
|
188
|
+
add_node(child, dot)
|
|
189
|
+
if top_down:
|
|
190
|
+
dot.edge(str(id(node)), str(id(child)))
|
|
191
|
+
else:
|
|
192
|
+
dot.edge(str(id(child)), str(id(node)))
|
|
193
|
+
|
|
194
|
+
dot = graphviz.Digraph(strict=True)
|
|
195
|
+
add_node(self, dot)
|
|
196
|
+
return dot
|
|
197
|
+
|
|
198
|
+
def graph_dict(self) -> dict[str, dict]:
|
|
199
|
+
"""Return a dictionary representation of the graph below the current
|
|
200
|
+
node."""
|
|
201
|
+
graph = {
|
|
202
|
+
f"{self.name}|{self._type}": {},
|
|
203
|
+
}
|
|
204
|
+
for node in self.children.values():
|
|
205
|
+
graph[f"{self.name}|{self._type}"].update(node.graph_dict())
|
|
206
|
+
return graph
|
|
207
|
+
|
|
208
|
+
def graph_print(self, dag: dict, depth: int = 0, indent: int = 4, result: str = "") -> str:
|
|
209
|
+
"""Print the graph dictionary in a human-readable format."""
|
|
210
|
+
for key in dag:
|
|
211
|
+
result = f"{result}{' ' * indent * depth}{key}\n"
|
|
212
|
+
result = self.graph_print(dag[key], depth + 1, indent, result) + "\n"
|
|
213
|
+
if result: # remove trailing newline
|
|
214
|
+
result = result[:-1]
|
|
215
|
+
return result
|
|
216
|
+
|
|
217
|
+
def __str__(self) -> str:
|
|
218
|
+
return self.graph_print(self.graph_dict())
|
|
219
|
+
|
|
220
|
+
def __repr__(self) -> str:
|
|
221
|
+
return f"{self.__class__.__name__}({self.name})"
|
|
222
|
+
|
|
223
|
+
def __getitem__(self, key: str) -> "Node":
|
|
224
|
+
return self.children[key]
|
caskade/context.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from .module import Module
|
|
2
|
+
from .param import Param
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class ActiveContext:
|
|
6
|
+
"""
|
|
7
|
+
Context manager to activate a module for a simulation. Only inside an
|
|
8
|
+
ActiveContext is it possible to fill/clear the dynamic and live parameters.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
def __init__(self, module: Module, active: bool = True):
|
|
12
|
+
self.module = module
|
|
13
|
+
self.active = active
|
|
14
|
+
|
|
15
|
+
def __enter__(self):
|
|
16
|
+
self.outer_active = self.module.active
|
|
17
|
+
if self.outer_active and not self.active:
|
|
18
|
+
self.outer_params = list(p.value for p in self.module.dynamic_params)
|
|
19
|
+
self.module.clear_params()
|
|
20
|
+
self.module.active = self.active
|
|
21
|
+
|
|
22
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
23
|
+
if not self.outer_active and self.active:
|
|
24
|
+
self.module.clear_params()
|
|
25
|
+
self.module.active = self.outer_active
|
|
26
|
+
if self.outer_active and not self.active:
|
|
27
|
+
self.module.fill_params(self.outer_params)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ValidContext:
|
|
31
|
+
"""
|
|
32
|
+
Context manager to set valid values for parameters. Only inside a
|
|
33
|
+
ValidContext will parameters automatically be assumed valid.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, module: Module):
|
|
37
|
+
self.module = module
|
|
38
|
+
|
|
39
|
+
def __enter__(self):
|
|
40
|
+
self.module.valid_context = True
|
|
41
|
+
|
|
42
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
43
|
+
self.module.valid_context = False
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class OverrideParam:
|
|
47
|
+
"""
|
|
48
|
+
Context manager to override a parameter value. Only inside an
|
|
49
|
+
OverrideParam will the parameter be set to the new value.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(self, param, value):
|
|
53
|
+
self.param = param
|
|
54
|
+
self.value = value
|
|
55
|
+
|
|
56
|
+
def __enter__(self):
|
|
57
|
+
# Store the old value
|
|
58
|
+
self.old_values = {str(id(self.param)): self.param._value}
|
|
59
|
+
# Set the new value
|
|
60
|
+
self.param._value = self.value
|
|
61
|
+
# Clear the pointer values as they may have updated
|
|
62
|
+
for node in self.param.parents:
|
|
63
|
+
if isinstance(node, Param) and node.pointer:
|
|
64
|
+
self.old_values[str(id(node))] = node._value
|
|
65
|
+
node._value = None
|
|
66
|
+
|
|
67
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
68
|
+
# Reset the old value
|
|
69
|
+
self.param._value = self.old_values[str(id(self.param))]
|
|
70
|
+
# Clear the pointer values as they may have updated
|
|
71
|
+
for node in self.param.parents:
|
|
72
|
+
if isinstance(node, Param) and node.pointer:
|
|
73
|
+
node._value = self.old_values[str(id(node))]
|
caskade/decorators.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import functools
|
|
3
|
+
|
|
4
|
+
from .context import ActiveContext
|
|
5
|
+
|
|
6
|
+
__all__ = ("forward",)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _get_arguments(method):
|
|
10
|
+
sig = inspect.signature(method)
|
|
11
|
+
return tuple(sig.parameters.keys())
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def forward(method):
|
|
15
|
+
"""
|
|
16
|
+
Decorator to define a forward method for a module.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
method: (Callable)
|
|
21
|
+
The forward method to be decorated.
|
|
22
|
+
|
|
23
|
+
Examples
|
|
24
|
+
--------
|
|
25
|
+
Standard usage of the forward decorator::
|
|
26
|
+
|
|
27
|
+
class ExampleSim(Module):
|
|
28
|
+
def __init__(self, a, b, c):
|
|
29
|
+
super().__init__("example_sim")
|
|
30
|
+
self.a = a
|
|
31
|
+
self.b = Param("b", b)
|
|
32
|
+
self.c = Param("c", c)
|
|
33
|
+
|
|
34
|
+
@forward
|
|
35
|
+
def example_func(self, x, b=None):
|
|
36
|
+
return x + self.a + b
|
|
37
|
+
|
|
38
|
+
E = ExampleSim(a=1, b=None, c=3)
|
|
39
|
+
print(E.example_func(4, params=[5]))
|
|
40
|
+
# Output: 10
|
|
41
|
+
|
|
42
|
+
Returns
|
|
43
|
+
-------
|
|
44
|
+
Callable
|
|
45
|
+
The decorated forward method.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
# Get arguments from function signature
|
|
49
|
+
method_params = _get_arguments(method)
|
|
50
|
+
|
|
51
|
+
@functools.wraps(method)
|
|
52
|
+
def wrapped(self, *args, **kwargs):
|
|
53
|
+
if self.active:
|
|
54
|
+
kwargs = {**self.fill_kwargs(method_params), **kwargs}
|
|
55
|
+
return method(self, *args, **kwargs)
|
|
56
|
+
|
|
57
|
+
# Extract params from the arguments
|
|
58
|
+
if len(self.dynamic_params) == 0:
|
|
59
|
+
params = {}
|
|
60
|
+
elif "params" in kwargs:
|
|
61
|
+
params = kwargs.pop("params")
|
|
62
|
+
elif args:
|
|
63
|
+
params = args[-1]
|
|
64
|
+
args = args[:-1]
|
|
65
|
+
else:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"Params must be provided for a top level @forward method. Either by keyword 'method(params=params)' or as the last positional argument 'method(a, b, c, params)'"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
with ActiveContext(self):
|
|
71
|
+
self.fill_params(params)
|
|
72
|
+
kwargs = {**self.fill_kwargs(method_params), **kwargs}
|
|
73
|
+
return method(self, *args, **kwargs)
|
|
74
|
+
|
|
75
|
+
return wrapped
|
caskade/errors.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from math import prod
|
|
2
|
+
from textwrap import dedent
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class CaskadeException(Exception):
|
|
6
|
+
"""Base class for all exceptions in ``caskade``."""
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class GraphError(CaskadeException):
|
|
10
|
+
"""Class for graph exceptions in ``caskade``."""
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class NodeConfigurationError(CaskadeException):
|
|
14
|
+
"""Class for node configuration exceptions in ``caskade``."""
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ParamConfigurationError(NodeConfigurationError):
|
|
18
|
+
"""Class for parameter configuration exceptions in ``caskade``."""
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ParamTypeError(CaskadeException):
|
|
22
|
+
"""Class for exceptions related to the type of a parameter in ``caskade``."""
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ActiveStateError(CaskadeException):
|
|
26
|
+
"""Class for exceptions related to the active state of a node in ``caskade``."""
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class FillDynamicParamsError(CaskadeException):
|
|
30
|
+
"""Class for exceptions related to filling dynamic parameters in ``caskade``."""
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class FillDynamicParamsTensorError(FillDynamicParamsError):
|
|
34
|
+
"""Class for exceptions related to filling dynamic parameters with a tensor in ``caskade``."""
|
|
35
|
+
|
|
36
|
+
def __init__(self, name, input_params, dynamic_params):
|
|
37
|
+
fullnumel = sum(max(1, prod(p.shape)) for p in dynamic_params)
|
|
38
|
+
message = dedent(
|
|
39
|
+
f"""
|
|
40
|
+
For flattened Tensor input, the (last) dim of the Tensor should
|
|
41
|
+
equal the sum of all flattened dynamic params ({fullnumel}).
|
|
42
|
+
Input params shape {input_params.shape} does not match dynamic
|
|
43
|
+
params shape of: {name}.
|
|
44
|
+
|
|
45
|
+
Registered dynamic params (name: shape):
|
|
46
|
+
{', '.join(f"{repr(p)}: {str(p.shape)}" for p in dynamic_params)}"""
|
|
47
|
+
)
|
|
48
|
+
super().__init__(message)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class FillDynamicParamsSequenceError(FillDynamicParamsError):
|
|
52
|
+
"""Class for exceptions related to filling dynamic parameters with a sequence (list, tuple, etc.) in ``caskade``."""
|
|
53
|
+
|
|
54
|
+
def __init__(self, name, input_params, dynamic_params, dynamic_modules):
|
|
55
|
+
message = dedent(
|
|
56
|
+
f"""
|
|
57
|
+
Input params length ({len(input_params)}) does not match dynamic
|
|
58
|
+
params length ({len(dynamic_params)}) or number of dynamic
|
|
59
|
+
modules ({len(dynamic_modules)}) of: {name}.
|
|
60
|
+
|
|
61
|
+
Registered dynamic modules:
|
|
62
|
+
{', '.join(repr(m) for m in dynamic_modules)}
|
|
63
|
+
|
|
64
|
+
Registered dynamic params:
|
|
65
|
+
{', '.join(repr(p) for p in dynamic_params)}"""
|
|
66
|
+
)
|
|
67
|
+
super().__init__(message)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class FillDynamicParamsMappingError(FillDynamicParamsError):
|
|
71
|
+
"""Class for exceptions related to filling dynamic parameters with a mapping (dict) in ``caskade``."""
|
|
72
|
+
|
|
73
|
+
def __init__(self, name, children, dynamic_modules, missing_key=None, missing_param=None):
|
|
74
|
+
if missing_key is not None:
|
|
75
|
+
message = dedent(
|
|
76
|
+
f"""
|
|
77
|
+
Input params key "{missing_key}" not found in dynamic modules or children of: {name}.
|
|
78
|
+
|
|
79
|
+
Registered dynamic modules:
|
|
80
|
+
{', '.join(repr(m) for m in dynamic_modules)}
|
|
81
|
+
|
|
82
|
+
Registered dynamic children:
|
|
83
|
+
{', '.join(repr(c) for c in children.values() if c.dynamic)}"""
|
|
84
|
+
)
|
|
85
|
+
else:
|
|
86
|
+
message = dedent(
|
|
87
|
+
f"""
|
|
88
|
+
Dynamic param "{missing_param.name}" not filled with given input params dict passed to {name}.
|
|
89
|
+
|
|
90
|
+
Dynamic param parent(s):
|
|
91
|
+
{', '.join(repr(p) for p in missing_param.parents)}
|
|
92
|
+
|
|
93
|
+
Registered dynamic modules:
|
|
94
|
+
{', '.join(repr(m) for m in dynamic_modules)}
|
|
95
|
+
|
|
96
|
+
Registered dynamic children:
|
|
97
|
+
{', '.join(repr(c) for c in children.values() if c.dynamic)}"""
|
|
98
|
+
)
|
|
99
|
+
super().__init__(message)
|