caskade 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
caskade/__init__.py ADDED
@@ -0,0 +1,48 @@
1
+ from ._version import version as VERSION # noqa
2
+
3
+ from .base import Node
4
+ from .context import ActiveContext, ValidContext, OverrideParam
5
+ from .decorators import forward
6
+ from .module import Module
7
+ from .param import Param
8
+ from .tests import test
9
+ from .errors import (
10
+ CaskadeException,
11
+ GraphError,
12
+ NodeConfigurationError,
13
+ ParamConfigurationError,
14
+ ParamTypeError,
15
+ ActiveStateError,
16
+ FillDynamicParamsError,
17
+ FillDynamicParamsTensorError,
18
+ FillDynamicParamsSequenceError,
19
+ FillDynamicParamsMappingError,
20
+ )
21
+ from .warnings import CaskadeWarning, InvalidValueWarning
22
+
23
+
24
+ __version__ = VERSION
25
+ __author__ = "Connor Stone and Alexandre Adam"
26
+
27
+ __all__ = (
28
+ "Node",
29
+ "Module",
30
+ "Param",
31
+ "ActiveContext",
32
+ "ValidContext",
33
+ "OverrideParam",
34
+ "forward",
35
+ "test",
36
+ "CaskadeException",
37
+ "GraphError",
38
+ "NodeConfigurationError",
39
+ "ParamConfigurationError",
40
+ "ParamTypeError",
41
+ "ActiveStateError",
42
+ "FillDynamicParamsError",
43
+ "FillDynamicParamsTensorError",
44
+ "FillDynamicParamsSequenceError",
45
+ "FillDynamicParamsMappingError",
46
+ "CaskadeWarning",
47
+ "InvalidValueWarning",
48
+ )
caskade/_version.py ADDED
@@ -0,0 +1,16 @@
1
+ # file generated by setuptools_scm
2
+ # don't change, don't track in version control
3
+ TYPE_CHECKING = False
4
+ if TYPE_CHECKING:
5
+ from typing import Tuple, Union
6
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
7
+ else:
8
+ VERSION_TUPLE = object
9
+
10
+ version: str
11
+ __version__: str
12
+ __version_tuple__: VERSION_TUPLE
13
+ version_tuple: VERSION_TUPLE
14
+
15
+ __version__ = version = '0.6.1'
16
+ __version_tuple__ = version_tuple = (0, 6, 1)
caskade/base.py ADDED
@@ -0,0 +1,224 @@
1
+ from typing import Optional, Union
2
+
3
+ from .errors import GraphError, NodeConfigurationError
4
+
5
+
6
+ class Node(object):
7
+ """
8
+ Base graph node class for ``caskade`` objects.
9
+
10
+ The ``Node`` object is the base class for all ``caskade`` objects. It is used to
11
+ construct the directed acyclic graph (DAG). The primary function of the
12
+ ``Node`` object is to manage the parent-child relationships between nodes in
13
+ the graph. There is limited functionality for the ``Node`` object, though it
14
+ implements the base versions of the ``active`` state and ``to`` /
15
+ ``update_graph`` methods. The ``active`` state is used to communicate
16
+ through the graph that the simulator is currently running. The ``to`` method
17
+ is used to move and/or cast the values of the parameter. The ``update_graph``
18
+ method is used signal all parents that the graph below them has changed.
19
+
20
+ Examples
21
+ --------
22
+
23
+ Example making some ``Node`` objects and then linking/unlinking them::
24
+
25
+ n1 = Node()
26
+ n2 = Node()
27
+ n1.link("subnode", n2) # link n2 as a child of n1, may use any str as the key
28
+ n1.unlink("subnode") # alternately n1.unlink(n2) to unlink by object
29
+ """
30
+
31
+ graphviz_types = {"node": {"style": "solid", "color": "black", "shape": "circle"}}
32
+
33
+ def __init__(self, name: Optional[str] = None):
34
+ if name is None:
35
+ name = self.__class__.__name__
36
+ if not isinstance(name, str):
37
+ raise NodeConfigurationError(f"{self.__class__.__name__} name must be a string")
38
+ if "|" in name:
39
+ raise NodeConfigurationError(f"{self.__class__.__name__} cannot contain '|'")
40
+ self._name = name
41
+ self._children = {}
42
+ self._parents = set()
43
+ self._active = False
44
+ self._type = "node"
45
+
46
+ @property
47
+ def name(self) -> str:
48
+ return self._name
49
+
50
+ @property
51
+ def children(self) -> dict:
52
+ return self._children
53
+
54
+ @property
55
+ def parents(self) -> set:
56
+ return self._parents
57
+
58
+ def link(self, key: Union[str, "Node"], child: Optional["Node"] = None):
59
+ """Link the current ``Node`` object to another ``Node`` object as a child.
60
+
61
+ Parameters
62
+ ----------
63
+ key: (Union[str, Node])
64
+ The key to link the child node with.
65
+ child: (Optional[Node], optional)
66
+ The child ``Node`` object to link to. Defaults to None in which
67
+ case the key is used as the child.
68
+
69
+ Examples
70
+ --------
71
+
72
+ Example making some ``Node`` objects and then linking/unlinking them. demonstrating multiple ways to link/unlink::
73
+
74
+ n1 = Node()
75
+ n2 = Node()
76
+
77
+ n1.link("subnode", n2) # may use any str as the key
78
+ n1.unlink("subnode")
79
+
80
+ # Alternately, link by object
81
+ n1.link(n2)
82
+ n1.unlink(n2)
83
+ """
84
+ if child is None:
85
+ child = key
86
+ key = child.name
87
+ # Avoid double linking to the same object
88
+ if key in self.children:
89
+ raise GraphError(f"Child key {key} already linked to parent {self.name}")
90
+ if child in self.children.values():
91
+ raise GraphError(f"Child {child.name} already linked to parent {self.name}")
92
+ # avoid cycles
93
+ if self in child.topological_ordering():
94
+ raise GraphError(
95
+ f"Linking {child.name} to {self.name} would create a cycle in the graph"
96
+ )
97
+
98
+ self._children[key] = child
99
+ child._parents.add(self)
100
+ self.update_graph()
101
+
102
+ def unlink(self, key: Union[str, "Node"]):
103
+ """Unlink the current ``Node`` object from another ``Node`` object which is a child."""
104
+ if isinstance(key, Node):
105
+ for node in self.children:
106
+ if self.children[node] == key:
107
+ key = node
108
+ break
109
+ self._children[key]._parents.remove(self)
110
+ self._children[key].update_graph()
111
+ del self._children[key]
112
+ self.update_graph()
113
+
114
+ def topological_ordering(self, with_type: Optional[str] = None) -> tuple["Node"]:
115
+ """Return a topological ordering of the graph below the current node."""
116
+ ordering = [self]
117
+ for node in self.children.values():
118
+ for subnode in node.topological_ordering():
119
+ if subnode not in ordering:
120
+ ordering.append(subnode)
121
+ if with_type is None:
122
+ return tuple(ordering)
123
+ return tuple(filter(lambda n: n._type == with_type, ordering))
124
+
125
+ def update_graph(self):
126
+ """Triggers a call to all parents that the graph below them has been
127
+ updated. The base ``Node`` object does nothing with this information, but
128
+ other node types may use this to update internal state."""
129
+ for parent in self.parents:
130
+ parent.update_graph()
131
+
132
+ @property
133
+ def active(self) -> bool:
134
+ return self._active
135
+
136
+ @active.setter
137
+ def active(self, value: bool):
138
+ # Avoid unnecessary updates
139
+ if self._active == value:
140
+ return
141
+
142
+ # Set self active level
143
+ self._active = value
144
+
145
+ # Propagate active level to children
146
+ for child in self._children.values():
147
+ child.active = value
148
+
149
+ def to(self, device=None, dtype=None):
150
+ """
151
+ Moves and/or casts the PyTorch values of the ``Node``.
152
+
153
+ Parameters
154
+ ----------
155
+ device: (Optional[torch.device], optional)
156
+ The device to move the values to. Defaults to None.
157
+ dtype: (Optional[torch.dtype], optional)
158
+ The desired data type. Defaults to None.
159
+ """
160
+
161
+ for child in self.children.values():
162
+ child.to(device=device, dtype=dtype)
163
+
164
+ return self
165
+
166
+ def graphviz(self, top_down=True) -> "graphviz.Digraph":
167
+ """Return a graphviz object representing the graph below the current
168
+ node in the DAG.
169
+
170
+ Parameters
171
+ ----------
172
+ top_down: (bool, optional)
173
+ Whether to draw the graph top-down (current node at top) or
174
+ bottom-up (current node at bottom). Defaults to True.
175
+ """
176
+ import graphviz
177
+
178
+ components = set()
179
+
180
+ def add_node(node, dot):
181
+ if node in components:
182
+ return
183
+ dot.attr("node", **node.graphviz_types[node._type])
184
+ dot.node(str(id(node)), repr(node))
185
+ components.add(node)
186
+
187
+ for child in node.children.values():
188
+ add_node(child, dot)
189
+ if top_down:
190
+ dot.edge(str(id(node)), str(id(child)))
191
+ else:
192
+ dot.edge(str(id(child)), str(id(node)))
193
+
194
+ dot = graphviz.Digraph(strict=True)
195
+ add_node(self, dot)
196
+ return dot
197
+
198
+ def graph_dict(self) -> dict[str, dict]:
199
+ """Return a dictionary representation of the graph below the current
200
+ node."""
201
+ graph = {
202
+ f"{self.name}|{self._type}": {},
203
+ }
204
+ for node in self.children.values():
205
+ graph[f"{self.name}|{self._type}"].update(node.graph_dict())
206
+ return graph
207
+
208
+ def graph_print(self, dag: dict, depth: int = 0, indent: int = 4, result: str = "") -> str:
209
+ """Print the graph dictionary in a human-readable format."""
210
+ for key in dag:
211
+ result = f"{result}{' ' * indent * depth}{key}\n"
212
+ result = self.graph_print(dag[key], depth + 1, indent, result) + "\n"
213
+ if result: # remove trailing newline
214
+ result = result[:-1]
215
+ return result
216
+
217
+ def __str__(self) -> str:
218
+ return self.graph_print(self.graph_dict())
219
+
220
+ def __repr__(self) -> str:
221
+ return f"{self.__class__.__name__}({self.name})"
222
+
223
+ def __getitem__(self, key: str) -> "Node":
224
+ return self.children[key]
caskade/context.py ADDED
@@ -0,0 +1,73 @@
1
+ from .module import Module
2
+ from .param import Param
3
+
4
+
5
+ class ActiveContext:
6
+ """
7
+ Context manager to activate a module for a simulation. Only inside an
8
+ ActiveContext is it possible to fill/clear the dynamic and live parameters.
9
+ """
10
+
11
+ def __init__(self, module: Module, active: bool = True):
12
+ self.module = module
13
+ self.active = active
14
+
15
+ def __enter__(self):
16
+ self.outer_active = self.module.active
17
+ if self.outer_active and not self.active:
18
+ self.outer_params = list(p.value for p in self.module.dynamic_params)
19
+ self.module.clear_params()
20
+ self.module.active = self.active
21
+
22
+ def __exit__(self, exc_type, exc_value, traceback):
23
+ if not self.outer_active and self.active:
24
+ self.module.clear_params()
25
+ self.module.active = self.outer_active
26
+ if self.outer_active and not self.active:
27
+ self.module.fill_params(self.outer_params)
28
+
29
+
30
+ class ValidContext:
31
+ """
32
+ Context manager to set valid values for parameters. Only inside a
33
+ ValidContext will parameters automatically be assumed valid.
34
+ """
35
+
36
+ def __init__(self, module: Module):
37
+ self.module = module
38
+
39
+ def __enter__(self):
40
+ self.module.valid_context = True
41
+
42
+ def __exit__(self, exc_type, exc_value, traceback):
43
+ self.module.valid_context = False
44
+
45
+
46
+ class OverrideParam:
47
+ """
48
+ Context manager to override a parameter value. Only inside an
49
+ OverrideParam will the parameter be set to the new value.
50
+ """
51
+
52
+ def __init__(self, param, value):
53
+ self.param = param
54
+ self.value = value
55
+
56
+ def __enter__(self):
57
+ # Store the old value
58
+ self.old_values = {str(id(self.param)): self.param._value}
59
+ # Set the new value
60
+ self.param._value = self.value
61
+ # Clear the pointer values as they may have updated
62
+ for node in self.param.parents:
63
+ if isinstance(node, Param) and node.pointer:
64
+ self.old_values[str(id(node))] = node._value
65
+ node._value = None
66
+
67
+ def __exit__(self, exc_type, exc_value, traceback):
68
+ # Reset the old value
69
+ self.param._value = self.old_values[str(id(self.param))]
70
+ # Clear the pointer values as they may have updated
71
+ for node in self.param.parents:
72
+ if isinstance(node, Param) and node.pointer:
73
+ node._value = self.old_values[str(id(node))]
caskade/decorators.py ADDED
@@ -0,0 +1,75 @@
1
+ import inspect
2
+ import functools
3
+
4
+ from .context import ActiveContext
5
+
6
+ __all__ = ("forward",)
7
+
8
+
9
+ def _get_arguments(method):
10
+ sig = inspect.signature(method)
11
+ return tuple(sig.parameters.keys())
12
+
13
+
14
+ def forward(method):
15
+ """
16
+ Decorator to define a forward method for a module.
17
+
18
+ Parameters
19
+ ----------
20
+ method: (Callable)
21
+ The forward method to be decorated.
22
+
23
+ Examples
24
+ --------
25
+ Standard usage of the forward decorator::
26
+
27
+ class ExampleSim(Module):
28
+ def __init__(self, a, b, c):
29
+ super().__init__("example_sim")
30
+ self.a = a
31
+ self.b = Param("b", b)
32
+ self.c = Param("c", c)
33
+
34
+ @forward
35
+ def example_func(self, x, b=None):
36
+ return x + self.a + b
37
+
38
+ E = ExampleSim(a=1, b=None, c=3)
39
+ print(E.example_func(4, params=[5]))
40
+ # Output: 10
41
+
42
+ Returns
43
+ -------
44
+ Callable
45
+ The decorated forward method.
46
+ """
47
+
48
+ # Get arguments from function signature
49
+ method_params = _get_arguments(method)
50
+
51
+ @functools.wraps(method)
52
+ def wrapped(self, *args, **kwargs):
53
+ if self.active:
54
+ kwargs = {**self.fill_kwargs(method_params), **kwargs}
55
+ return method(self, *args, **kwargs)
56
+
57
+ # Extract params from the arguments
58
+ if len(self.dynamic_params) == 0:
59
+ params = {}
60
+ elif "params" in kwargs:
61
+ params = kwargs.pop("params")
62
+ elif args:
63
+ params = args[-1]
64
+ args = args[:-1]
65
+ else:
66
+ raise ValueError(
67
+ f"Params must be provided for a top level @forward method. Either by keyword 'method(params=params)' or as the last positional argument 'method(a, b, c, params)'"
68
+ )
69
+
70
+ with ActiveContext(self):
71
+ self.fill_params(params)
72
+ kwargs = {**self.fill_kwargs(method_params), **kwargs}
73
+ return method(self, *args, **kwargs)
74
+
75
+ return wrapped
caskade/errors.py ADDED
@@ -0,0 +1,99 @@
1
+ from math import prod
2
+ from textwrap import dedent
3
+
4
+
5
+ class CaskadeException(Exception):
6
+ """Base class for all exceptions in ``caskade``."""
7
+
8
+
9
+ class GraphError(CaskadeException):
10
+ """Class for graph exceptions in ``caskade``."""
11
+
12
+
13
+ class NodeConfigurationError(CaskadeException):
14
+ """Class for node configuration exceptions in ``caskade``."""
15
+
16
+
17
+ class ParamConfigurationError(NodeConfigurationError):
18
+ """Class for parameter configuration exceptions in ``caskade``."""
19
+
20
+
21
+ class ParamTypeError(CaskadeException):
22
+ """Class for exceptions related to the type of a parameter in ``caskade``."""
23
+
24
+
25
+ class ActiveStateError(CaskadeException):
26
+ """Class for exceptions related to the active state of a node in ``caskade``."""
27
+
28
+
29
+ class FillDynamicParamsError(CaskadeException):
30
+ """Class for exceptions related to filling dynamic parameters in ``caskade``."""
31
+
32
+
33
+ class FillDynamicParamsTensorError(FillDynamicParamsError):
34
+ """Class for exceptions related to filling dynamic parameters with a tensor in ``caskade``."""
35
+
36
+ def __init__(self, name, input_params, dynamic_params):
37
+ fullnumel = sum(max(1, prod(p.shape)) for p in dynamic_params)
38
+ message = dedent(
39
+ f"""
40
+ For flattened Tensor input, the (last) dim of the Tensor should
41
+ equal the sum of all flattened dynamic params ({fullnumel}).
42
+ Input params shape {input_params.shape} does not match dynamic
43
+ params shape of: {name}.
44
+
45
+ Registered dynamic params (name: shape):
46
+ {', '.join(f"{repr(p)}: {str(p.shape)}" for p in dynamic_params)}"""
47
+ )
48
+ super().__init__(message)
49
+
50
+
51
+ class FillDynamicParamsSequenceError(FillDynamicParamsError):
52
+ """Class for exceptions related to filling dynamic parameters with a sequence (list, tuple, etc.) in ``caskade``."""
53
+
54
+ def __init__(self, name, input_params, dynamic_params, dynamic_modules):
55
+ message = dedent(
56
+ f"""
57
+ Input params length ({len(input_params)}) does not match dynamic
58
+ params length ({len(dynamic_params)}) or number of dynamic
59
+ modules ({len(dynamic_modules)}) of: {name}.
60
+
61
+ Registered dynamic modules:
62
+ {', '.join(repr(m) for m in dynamic_modules)}
63
+
64
+ Registered dynamic params:
65
+ {', '.join(repr(p) for p in dynamic_params)}"""
66
+ )
67
+ super().__init__(message)
68
+
69
+
70
+ class FillDynamicParamsMappingError(FillDynamicParamsError):
71
+ """Class for exceptions related to filling dynamic parameters with a mapping (dict) in ``caskade``."""
72
+
73
+ def __init__(self, name, children, dynamic_modules, missing_key=None, missing_param=None):
74
+ if missing_key is not None:
75
+ message = dedent(
76
+ f"""
77
+ Input params key "{missing_key}" not found in dynamic modules or children of: {name}.
78
+
79
+ Registered dynamic modules:
80
+ {', '.join(repr(m) for m in dynamic_modules)}
81
+
82
+ Registered dynamic children:
83
+ {', '.join(repr(c) for c in children.values() if c.dynamic)}"""
84
+ )
85
+ else:
86
+ message = dedent(
87
+ f"""
88
+ Dynamic param "{missing_param.name}" not filled with given input params dict passed to {name}.
89
+
90
+ Dynamic param parent(s):
91
+ {', '.join(repr(p) for p in missing_param.parents)}
92
+
93
+ Registered dynamic modules:
94
+ {', '.join(repr(m) for m in dynamic_modules)}
95
+
96
+ Registered dynamic children:
97
+ {', '.join(repr(c) for c in children.values() if c.dynamic)}"""
98
+ )
99
+ super().__init__(message)