nd2py 3.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. nd2py-3.0.0/LICENSE +21 -0
  2. nd2py-3.0.0/PKG-INFO +35 -0
  3. nd2py-3.0.0/README.md +6 -0
  4. nd2py-3.0.0/nd2py/__init__.py +19 -0
  5. nd2py-3.0.0/nd2py/core/__init__.py +5 -0
  6. nd2py-3.0.0/nd2py/core/base_visitor.py +64 -0
  7. nd2py-3.0.0/nd2py/core/basic/__init__.py +9 -0
  8. nd2py-3.0.0/nd2py/core/basic/get_copy.py +37 -0
  9. nd2py-3.0.0/nd2py/core/basic/get_length.py +16 -0
  10. nd2py-3.0.0/nd2py/core/calc/__init__.py +9 -0
  11. nd2py-3.0.0/nd2py/core/calc/numpy_calc.py +391 -0
  12. nd2py-3.0.0/nd2py/core/calc/torch_calc.py +346 -0
  13. nd2py-3.0.0/nd2py/core/context/__init__.py +5 -0
  14. nd2py-3.0.0/nd2py/core/context/copy_value.py +23 -0
  15. nd2py-3.0.0/nd2py/core/context/nettype_inference.py +23 -0
  16. nd2py-3.0.0/nd2py/core/context/set_fitable.py +20 -0
  17. nd2py-3.0.0/nd2py/core/context/warn_once.py +33 -0
  18. nd2py-3.0.0/nd2py/core/converter/__init__.py +11 -0
  19. nd2py-3.0.0/nd2py/core/converter/from_postorder.py +27 -0
  20. nd2py-3.0.0/nd2py/core/converter/from_preorder.py +35 -0
  21. nd2py-3.0.0/nd2py/core/converter/parser.py +74 -0
  22. nd2py-3.0.0/nd2py/core/converter/string_printer.py +255 -0
  23. nd2py-3.0.0/nd2py/core/converter/tree_printer.py +101 -0
  24. nd2py-3.0.0/nd2py/core/nettype/__init__.py +2 -0
  25. nd2py-3.0.0/nd2py/core/nettype/inter_nettype.py +83 -0
  26. nd2py-3.0.0/nd2py/core/nettype/nettype_mixin.py +99 -0
  27. nd2py-3.0.0/nd2py/core/symbol_api.py +243 -0
  28. nd2py-3.0.0/nd2py/core/symbols/__init__.py +7 -0
  29. nd2py-3.0.0/nd2py/core/symbols/empty.py +19 -0
  30. nd2py-3.0.0/nd2py/core/symbols/functions.py +84 -0
  31. nd2py-3.0.0/nd2py/core/symbols/number.py +40 -0
  32. nd2py-3.0.0/nd2py/core/symbols/operands.py +204 -0
  33. nd2py-3.0.0/nd2py/core/symbols/symbols.py +288 -0
  34. nd2py-3.0.0/nd2py/core/symbols/variable.py +26 -0
  35. nd2py-3.0.0/nd2py/core/transform/__init__.py +7 -0
  36. nd2py-3.0.0/nd2py/core/transform/bfgs_fit.py +101 -0
  37. nd2py-3.0.0/nd2py/core/transform/fix_nettype.py +275 -0
  38. nd2py-3.0.0/nd2py/core/transform/fold_constant.py +65 -0
  39. nd2py-3.0.0/nd2py/core/transform/simplify.py +245 -0
  40. nd2py-3.0.0/nd2py/core/transform/split_by_add.py +233 -0
  41. nd2py-3.0.0/nd2py/core/transform/split_by_mul.py +79 -0
  42. nd2py-3.0.0/nd2py/core/tree/__init__.py +1 -0
  43. nd2py-3.0.0/nd2py/core/tree/iter_postorder.py +14 -0
  44. nd2py-3.0.0/nd2py/core/tree/iter_preorder.py +62 -0
  45. nd2py-3.0.0/nd2py/core/tree/tree_mixin.py +88 -0
  46. nd2py-3.0.0/nd2py/dataset/__init__.py +3 -0
  47. nd2py-3.0.0/nd2py/dataset/tokenizer.py +63 -0
  48. nd2py-3.0.0/nd2py/generator/__init__.py +2 -0
  49. nd2py-3.0.0/nd2py/generator/data/__init__.py +2 -0
  50. nd2py-3.0.0/nd2py/generator/data/gmm_generator.py +50 -0
  51. nd2py-3.0.0/nd2py/generator/data/subeq_generator.py +41 -0
  52. nd2py-3.0.0/nd2py/generator/eq/__init__.py +3 -0
  53. nd2py-3.0.0/nd2py/generator/eq/gplearn_generator.py +101 -0
  54. nd2py-3.0.0/nd2py/generator/eq/metaai_generator.py +178 -0
  55. nd2py-3.0.0/nd2py/generator/eq/snip_generator.py +145 -0
  56. nd2py-3.0.0/nd2py/search/__init__.py +1 -0
  57. nd2py-3.0.0/nd2py/search/gp/__init__.py +2 -0
  58. nd2py-3.0.0/nd2py/search/gp/gp.py +482 -0
  59. nd2py-3.0.0/nd2py/search/llmsr/__init__.py +1 -0
  60. nd2py-3.0.0/nd2py/search/llmsr/llmsr.py +451 -0
  61. nd2py-3.0.0/nd2py/search/mcts/__init__.py +1 -0
  62. nd2py-3.0.0/nd2py/search/mcts/mcts.py +455 -0
  63. nd2py-3.0.0/nd2py/search/mcts/mcts_forest.py +9 -0
  64. nd2py-3.0.0/nd2py/search/mcts/utils.py +38 -0
  65. nd2py-3.0.0/nd2py/search/ndformer/__init__.py +6 -0
  66. nd2py-3.0.0/nd2py/search/ndformer/ndformer_config.py +51 -0
  67. nd2py-3.0.0/nd2py/search/ndformer/ndformer_dataset.py +138 -0
  68. nd2py-3.0.0/nd2py/search/ndformer/ndformer_generator.py +318 -0
  69. nd2py-3.0.0/nd2py/search/ndformer/ndformer_model.py +185 -0
  70. nd2py-3.0.0/nd2py/search/ndformer/ndformer_tokenizer.py +243 -0
  71. nd2py-3.0.0/nd2py/search/ndformer/ndformer_train.py +457 -0
  72. nd2py-3.0.0/nd2py/utils/__init__.py +13 -0
  73. nd2py-3.0.0/nd2py/utils/attr_dict.py +48 -0
  74. nd2py-3.0.0/nd2py/utils/auto_gpu.py +80 -0
  75. nd2py-3.0.0/nd2py/utils/classproperty.py +12 -0
  76. nd2py-3.0.0/nd2py/utils/fix_parser.py +67 -0
  77. nd2py-3.0.0/nd2py/utils/logger.py +137 -0
  78. nd2py-3.0.0/nd2py/utils/metrics.py +29 -0
  79. nd2py-3.0.0/nd2py/utils/nn/__init__.py +2 -0
  80. nd2py-3.0.0/nd2py/utils/nn/gnn.py +112 -0
  81. nd2py-3.0.0/nd2py/utils/nn/positional_encoding.py +21 -0
  82. nd2py-3.0.0/nd2py/utils/plot.py +368 -0
  83. nd2py-3.0.0/nd2py/utils/render_markdown.py +22 -0
  84. nd2py-3.0.0/nd2py/utils/render_python.py +17 -0
  85. nd2py-3.0.0/nd2py/utils/softmax.py +0 -0
  86. nd2py-3.0.0/nd2py/utils/tag2ansi.py +169 -0
  87. nd2py-3.0.0/nd2py/utils/timing.py +268 -0
  88. nd2py-3.0.0/nd2py/utils/utils.py +25 -0
  89. nd2py-3.0.0/nd2py.egg-info/PKG-INFO +35 -0
  90. nd2py-3.0.0/nd2py.egg-info/SOURCES.txt +93 -0
  91. nd2py-3.0.0/nd2py.egg-info/dependency_links.txt +1 -0
  92. nd2py-3.0.0/nd2py.egg-info/requires.txt +21 -0
  93. nd2py-3.0.0/nd2py.egg-info/top_level.txt +1 -0
  94. nd2py-3.0.0/pyproject.toml +51 -0
  95. nd2py-3.0.0/setup.cfg +4 -0
nd2py-3.0.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 YuMeow
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
nd2py-3.0.0/PKG-INFO ADDED
@@ -0,0 +1,35 @@
1
+ Metadata-Version: 2.4
2
+ Name: nd2py
3
+ Version: 3.0.0
4
+ Summary: nd2py (Neural Discovery of Network Dynamics) symbolic regression
5
+ Author-email: YuMeow <yuzh19@tsinghua.org.cn>
6
+ Requires-Python: >=3.8
7
+ Description-Content-Type: text/markdown
8
+ License-File: LICENSE
9
+ Requires-Dist: tqdm>=4.66
10
+ Requires-Dist: pytest>=8.3
11
+ Requires-Dist: numpy>=1.20
12
+ Requires-Dist: scipy>=1.14
13
+ Requires-Dist: pandas>=2.2
14
+ Requires-Dist: scikit-learn>=1.5
15
+ Requires-Dist: setproctitle
16
+ Requires-Dist: matplotlib
17
+ Requires-Dist: seaborn
18
+ Requires-Dist: pyyaml
19
+ Requires-Dist: rich
20
+ Requires-Dist: dotenv
21
+ Requires-Dist: requests
22
+ Requires-Dist: pyperclip
23
+ Provides-Extra: nn
24
+ Requires-Dist: torch>=1.12; extra == "nn"
25
+ Requires-Dist: torch_geometric>=2.0; extra == "nn"
26
+ Provides-Extra: all
27
+ Requires-Dist: nd2py[nn]; extra == "all"
28
+ Dynamic: license-file
29
+
30
+
31
+
32
+ ```shell
33
+ conda create --prefix ./venv python=3.12
34
+
35
+ ```
nd2py-3.0.0/README.md ADDED
@@ -0,0 +1,6 @@
1
+
2
+
3
+ ```shell
4
+ conda create --prefix ./venv python=3.12
5
+
6
+ ```
@@ -0,0 +1,19 @@
1
+ from .core import *
2
+ from . import dataset, generator, search, utils
3
+
4
+ # 定义模块级别的 __getattr__ 拦截
5
+ def __getattr__(name):
6
+ if name == '__all__':
7
+ # Triggered when 'from nd2py import *' is used
8
+ import warnings
9
+ warnings.warn(
10
+ "Detected 'from nd2py import *'. WARNING: This will shadow standard Python built-in functions (such as 'sum').\n"
11
+ "It is strongly recommended to use explicit imports, e.g., 'from nd2py import sum' or 'import nd2py as nd'.",
12
+ category=UserWarning,
13
+ stacklevel=2
14
+ )
15
+ # 动态获取当前模块(__init__.py)中所有不以 '_' 开头的全局变量和模块
16
+ return [n for n in globals() if not n.startswith('_')]
17
+
18
+ # 对于其他不存在的属性,保持默认的报错行为
19
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
@@ -0,0 +1,5 @@
1
+ from .calc import *
2
+ from .symbols import *
3
+ from .context import *
4
+ from .converter import *
5
+ from .transform import *
@@ -0,0 +1,64 @@
1
+ # Copyright (c) 2024-present, Yumeow. Licensed under the MIT License.
2
+ from __future__ import annotations
3
+ from abc import ABC, abstractmethod
4
+ from typing import TYPE_CHECKING
5
+
6
+ if TYPE_CHECKING: # 避免循环引用,仅用于类型检查
7
+ from .symbols import Symbol
8
+
9
+
10
+ def yield_nothing():
11
+ """A generator that yields nothing, used as a placeholder for methods that do not yield."""
12
+ if False:
13
+ yield
14
+
15
+
16
+ class Visitor(ABC):
17
+ def __call__(self, node: Symbol, *args, **kwargs):
18
+ """
19
+ 1) call Visitor.visit_<ClassName> based on type(node)
20
+ 2) call Visitor.generic_visit if no Visitor.visit_<ClassName> defined
21
+ """
22
+ stack = [("start", node, args, kwargs, None)]
23
+ result = None
24
+ while stack:
25
+ state, node, args, kwargs, gen = stack.pop()
26
+ if state == "start":
27
+ method = getattr(
28
+ self, "visit_" + type(node).__name__, self.generic_visit
29
+ )
30
+ gen = method(node, *args, **kwargs)
31
+ if not hasattr(gen, "__next__"):
32
+ raise TypeError(
33
+ f"Expected a generator but got {type(gen).__name__}, please add `yield from yield_nothing()` in {type(self).__name__}.{method.__name__}."
34
+ )
35
+ try:
36
+ child, args, kwargs = next(gen)
37
+ stack.append(("resume", node, None, None, gen))
38
+ stack.append(("start", child, args, kwargs, None))
39
+ except StopIteration as e:
40
+ result = e.value
41
+ elif state == "resume":
42
+ try:
43
+ child, args, kwargs = gen.send(result)
44
+ stack.append(("resume", node, None, None, gen))
45
+ stack.append(("start", child, args, kwargs, None))
46
+ except StopIteration as e:
47
+ result = e.value
48
+ else:
49
+ raise ValueError(f"Unknown state: {state}")
50
+ return result
51
+
52
+ @abstractmethod
53
+ def generic_visit(self, node: Symbol, *args, **kwargs):
54
+ msg = f"generic_visit not implemented for {type(self).__name__}"
55
+ raise NotImplementedError(msg)
56
+
57
+ _SYMBOL_CACHE = {}
58
+
59
+ @classmethod
60
+ def _get_symbol(cls, name):
61
+ if name not in cls._SYMBOL_CACHE:
62
+ from . import symbols
63
+ cls._SYMBOL_CACHE[name] = getattr(symbols, name)
64
+ return cls._SYMBOL_CACHE[name]
@@ -0,0 +1,9 @@
1
+ # Copyright (c) 2024-present, Yumeow. Licensed under the MIT License.
2
+ """
3
+ The implementation of basic functionality of Symbol as an object, including:
4
+ - GetCopy: Deep copy a Symbol tree.
5
+ - GetLength: Count the number of nodes in a Symbol tree.
6
+ """
7
+
8
+ from .get_copy import GetCopy
9
+ from .get_length import GetLength
@@ -0,0 +1,37 @@
1
+ # Copyright (c) 2024-present, Yumeow. Licensed under the MIT License.
2
+ from copy import deepcopy
3
+ from ..base_visitor import Visitor, yield_nothing
4
+ from ..context.copy_value import get_copy_value
5
+
6
+
7
+ class GetCopy(Visitor):
8
+ def __call__(self, node, *args, **kwargs):
9
+ """Create a copy of the node."""
10
+ return super().__call__(node, *args, **kwargs)
11
+
12
+ def generic_visit(self, node, *args, **kwargs):
13
+ yield from yield_nothing()
14
+ children = []
15
+ for child in node.operands:
16
+ child = yield child, args, kwargs
17
+ children.append(child)
18
+ return node.__class__(*children, nettype=node._assigned_nettypes)
19
+
20
+ def visit_Number(self, node, *args, **kwargs):
21
+ yield from yield_nothing()
22
+ if get_copy_value():
23
+ return node.__class__(
24
+ deepcopy(node.value), nettype=node.nettype, fitable=node.fitable
25
+ )
26
+ else:
27
+ return node.__class__(
28
+ node.value, nettype=node._assigned_nettypes, fitable=node.fitable
29
+ )
30
+
31
+ def visit_Variable(self, node, *args, **kwargs):
32
+ yield from yield_nothing()
33
+ return node.__class__(node.name, nettype=node._assigned_nettypes)
34
+
35
+ def visit_Empty(self, node, *args, **kwargs):
36
+ yield from yield_nothing()
37
+ return node.__class__(nettype=node._assigned_nettypes)
@@ -0,0 +1,16 @@
1
+ # Copyright (c) 2024-present, Yumeow. Licensed under the MIT License.
2
+ from ..base_visitor import Visitor, yield_nothing
3
+
4
+
5
+ class GetLength(Visitor):
6
+ def __call__(self, node, *args, **kwargs):
7
+ """Count the number of nodes in the tree."""
8
+ return super().__call__(node, *args, **kwargs)
9
+
10
+ def generic_visit(self, node, *args, **kwargs):
11
+ yield from yield_nothing()
12
+ children = []
13
+ for child in node.operands:
14
+ child = yield child, args, kwargs
15
+ children.append(child)
16
+ return 1 + sum(children)
@@ -0,0 +1,9 @@
1
+ # Copyright (c) 2024-present, Yumeow. Licensed under the MIT License.
2
+ """
3
+ The implementation of calculation functionality of Symbol, including:
4
+ - NumpyCalc: Calculate the value of a Symbol tree using numpy.
5
+ - TorchCalc: Calculate the value of a Symbol tree using torch.
6
+ """
7
+
8
+ from .numpy_calc import NumpyCalc
9
+ from .torch_calc import TorchCalc
@@ -0,0 +1,391 @@
1
+ # Copyright (c) 2024-present, Yumeow. Licensed under the MIT License.
2
+ from __future__ import annotations
3
+ import numbers
4
+ import warnings
5
+ import functools
6
+ import traceback
7
+ import numpy as np
8
+ from typing import List, Tuple, TYPE_CHECKING
9
+ from ..base_visitor import Visitor, yield_nothing
10
+ if TYPE_CHECKING:
11
+ from ..symbols import *
12
+
13
+
14
+ # Decorator to unpack operands for operations
15
+ # This allows us to handle operations with multiple operands in a clean way
16
+ # We can also use this decorator to suppress numpy errors
17
+ def unpack_operands():
18
+ def decorator(func):
19
+ @functools.wraps(func)
20
+ def wrapper(self, node, *args, **kwargs):
21
+ # Calculate the values of the operands
22
+ yield from yield_nothing()
23
+ X = []
24
+ for op in node.operands:
25
+ x = yield (op, args, kwargs)
26
+ X.append(x)
27
+ # Use the defined 'visit_<Operation>' as 'func' to process the operands
28
+ with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
29
+ return func(self, node, *X, *args, **kwargs)
30
+
31
+ return wrapper
32
+
33
+ return decorator
34
+
35
+
36
+ class NumpyCalc(Visitor):
37
+ def __call__(
38
+ self,
39
+ node: Symbol,
40
+ vars: dict = {},
41
+ edge_list: Tuple[List[int], List[int]] = None,
42
+ num_nodes: int = None,
43
+ use_eps: float = 0.0,
44
+ ):
45
+ """
46
+ Args:
47
+ - vars: a dictionary of variable names and their values
48
+ - edge_list: edges (i,j) in the graph
49
+ i and j are the indices of the nodes (starting from 0)
50
+ - num_nodes: the number of nodes in the graph
51
+ if not provided, it will be inferred from edge_list
52
+ - use_eps: a small value to avoid division by zero
53
+ """
54
+ if num_nodes is None and edge_list is not None:
55
+ nodes = np.unique(np.array(edge_list).reshape(-1))
56
+ num_nodes = max(nodes) + 1
57
+
58
+ return super().__call__(
59
+ node,
60
+ vars=vars,
61
+ edge_list=edge_list,
62
+ num_nodes=num_nodes,
63
+ use_eps=use_eps,
64
+ )
65
+
66
+ def generic_visit(self, node: Symbol, *args, **kwargs):
67
+ raise NotImplementedError(
68
+ f"{type(self).__name__}.visit_{type(node).__name__} not implemented"
69
+ )
70
+
71
+ def visit_Empty(self, node: Empty, *args, **kwargs):
72
+ raise ValueError(
73
+ f"Incomplete expression with Empty node is not allowed to evaluate: {node}"
74
+ )
75
+
76
+ def visit_Number(self, node: Number, *args, **kwargs):
77
+ yield from yield_nothing()
78
+ return np.asarray(node.value)
79
+
80
+ def visit_Variable(self, node: Variable, *args, **kwargs):
81
+ yield from yield_nothing()
82
+ return np.asarray(kwargs["vars"][node.name])
83
+
84
+ @unpack_operands()
85
+ def visit_Add(self, node: Add, x1, x2, *args, **kwargs):
86
+ return x1 + x2
87
+
88
+ @unpack_operands()
89
+ def visit_Sub(self, node: Sub, x1, x2, *args, **kwargs):
90
+ return x1 - x2
91
+
92
+ @unpack_operands()
93
+ def visit_Mul(self, node: Mul, x1, x2, *args, **kwargs):
94
+ return x1 * x2
95
+
96
+ @unpack_operands()
97
+ def visit_Div(self, node: Div, x1, x2, *args, **kwargs):
98
+ eps = kwargs.get("use_eps")
99
+ return x1 / (x2 + eps * (x2 == 0))
100
+
101
+ @unpack_operands()
102
+ def visit_Pow(self, node: Pow, x1, x2, *args, **kwargs):
103
+ return x1**x2
104
+
105
+ @unpack_operands()
106
+ def visit_Max(self, node: Max, x1, x2, *args, **kwargs):
107
+ return np.maximum(x1, x2)
108
+
109
+ @unpack_operands()
110
+ def visit_Min(self, node: Min, x1, x2, *args, **kwargs):
111
+ return np.minimum(x1, x2)
112
+
113
+ @unpack_operands()
114
+ def visit_Identity(self, node: Identity, x, *args, **kwargs):
115
+ return x
116
+
117
+ @unpack_operands()
118
+ def visit_Sin(self, node: Sin, x, *args, **kwargs):
119
+ return np.sin(x)
120
+
121
+ @unpack_operands()
122
+ def visit_Cos(self, node: Cos, x, *args, **kwargs):
123
+ return np.cos(x)
124
+
125
+ @unpack_operands()
126
+ def visit_Tan(self, node: Tan, x, *args, **kwargs):
127
+ return np.tan(x)
128
+
129
+ @unpack_operands()
130
+ def visit_Sec(self, node: Sec, x, *args, **kwargs):
131
+ eps = kwargs.get("use_eps")
132
+ return 1 / (np.cos(x) + eps * (np.cos(x) == 0))
133
+
134
+ @unpack_operands()
135
+ def visit_Csc(self, node: Csc, x, *args, **kwargs):
136
+ eps = kwargs.get("use_eps")
137
+ return 1 / (np.sin(x) + eps * (np.sin(x) == 0))
138
+
139
+ @unpack_operands()
140
+ def visit_Cot(self, node: Cot, x, *args, **kwargs):
141
+ eps = kwargs.get("use_eps")
142
+ return 1 / (np.tan(x) + eps * (np.tan(x) == 0))
143
+
144
+ @unpack_operands()
145
+ def visit_Log(self, node: Log, x, *args, **kwargs):
146
+ eps = kwargs.get("use_eps")
147
+ return np.log(x + eps * (x == 0))
148
+
149
+ @unpack_operands()
150
+ def visit_LogAbs(self, node: LogAbs, x, *args, **kwargs):
151
+ eps = kwargs.get("use_eps")
152
+ return np.log(np.abs(x + eps * (x == 0)))
153
+
154
+ @unpack_operands()
155
+ def visit_Exp(self, node: Exp, x, *args, **kwargs):
156
+ return np.exp(x)
157
+
158
+ @unpack_operands()
159
+ def visit_Abs(self, node: Abs, x, *args, **kwargs):
160
+ return np.abs(x)
161
+
162
+ @unpack_operands()
163
+ def visit_Neg(self, node: Neg, x, *args, **kwargs):
164
+ return -x
165
+
166
+ @unpack_operands()
167
+ def visit_Inv(self, node: Inv, x, *args, **kwargs):
168
+ eps = kwargs.get("use_eps")
169
+ return 1 / (x + eps * (x == 0))
170
+
171
+ @unpack_operands()
172
+ def visit_Sqrt(self, node: Sqrt, x, *args, **kwargs):
173
+ return np.sqrt(x)
174
+
175
+ @unpack_operands()
176
+ def visit_SqrtAbs(self, node: SqrtAbs, x, *args, **kwargs):
177
+ return np.sqrt(np.abs(x))
178
+
179
+ @unpack_operands()
180
+ def visit_Pow2(self, node: Pow2, x, *args, **kwargs):
181
+ return x**2
182
+
183
+ @unpack_operands()
184
+ def visit_Pow3(self, node: Pow3, x, *args, **kwargs):
185
+ return x**3
186
+
187
+ @unpack_operands()
188
+ def visit_Arcsin(self, node: Arcsin, x, *args, **kwargs):
189
+ return np.arcsin(x)
190
+
191
+ @unpack_operands()
192
+ def visit_Arccos(self, node: Arccos, x, *args, **kwargs):
193
+ return np.arccos(x)
194
+
195
+ @unpack_operands()
196
+ def visit_Arctan(self, node: Arctan, x, *args, **kwargs):
197
+ return np.arctan(x)
198
+
199
+ @unpack_operands()
200
+ def visit_Sinh(self, node: Sinh, x, *args, **kwargs):
201
+ return np.sinh(x)
202
+
203
+ @unpack_operands()
204
+ def visit_Cosh(self, node: Cosh, x, *args, **kwargs):
205
+ return np.cosh(x)
206
+
207
+ @unpack_operands()
208
+ def visit_Tanh(self, node: Tanh, x, *args, **kwargs):
209
+ return np.tanh(x)
210
+
211
+ @unpack_operands()
212
+ def visit_Sech(self, node: Sech, x, *args, **kwargs):
213
+ eps = kwargs.get("use_eps")
214
+ return 1 / (np.cosh(x) + eps * (np.cosh(x) == 0))
215
+
216
+ @unpack_operands()
217
+ def visit_Csch(self, node: Csch, x, *args, **kwargs):
218
+ eps = kwargs.get("use_eps")
219
+ return 1 / (np.sinh(x) + eps * (np.sinh(x) == 0))
220
+
221
+ @unpack_operands()
222
+ def visit_Coth(self, node: Coth, x, *args, **kwargs):
223
+ eps = kwargs.get("use_eps")
224
+ return 1 / (np.tanh(x) + eps * (np.tanh(x) == 0))
225
+
226
+ @unpack_operands()
227
+ def visit_Sigmoid(self, node: Sigmoid, x, *args, **kwargs):
228
+ return 1 / (1 + np.exp(-x))
229
+
230
+ @unpack_operands()
231
+ def visit_Regular(self, node: Regular, x1, x2, *args, **kwargs):
232
+ eps = kwargs.get("use_eps")
233
+ return 1 / (1 + (np.abs(x1) + eps * (x1 == 0)) ** (-x2))
234
+
235
+ @unpack_operands()
236
+ def visit_Sour(self, node: Sour, x, *args, **kwargs):
237
+ """(*, n_nodes or 1) -> (*, n_edges or 1)"""
238
+ edge_list = kwargs.get("edge_list", ([], []))
239
+
240
+ if isinstance(x, numbers.Number) or x.size == 1:
241
+ return x # (1,) -> (1,)
242
+ elif node.operands[0].nettype == "scalar" or x.shape[-1] == 1:
243
+ if x.shape[-1] != 1:
244
+ x = x[..., np.newaxis]
245
+ return x # (*, 1) -> (*, 1)
246
+ else:
247
+ return x[..., edge_list[0]] # (*, V) -> (*, E)
248
+
249
+ @unpack_operands()
250
+ def visit_Targ(self, node: Targ, x, *args, **kwargs):
251
+ """(*, n_nodes or 1) -> (*, n_edges or 1)"""
252
+ edge_list = kwargs.get("edge_list", ([], []))
253
+
254
+ if isinstance(x, numbers.Number) or x.size == 1:
255
+ return x # (1,) -> (1,)
256
+ elif node.operands[0].nettype == "scalar" or x.shape[-1] == 1:
257
+ if x.shape[-1] != 1:
258
+ x = x[..., np.newaxis]
259
+ return x # (*, 1) -> (*, 1)
260
+ else:
261
+ return x[..., edge_list[1]] # (*, V) -> (*, E)
262
+
263
+ @unpack_operands()
264
+ def visit_Aggr(self, node: Aggr, x, *args, **kwargs):
265
+ """(*, n_edges or 1) -> (*, n_nodes)"""
266
+ edge_list = kwargs.get("edge_list", ([], []))
267
+ num_nodes = kwargs.get("num_nodes")
268
+
269
+ if isinstance(x, numbers.Number) or x.size == 1:
270
+ y = np.zeros((num_nodes,))
271
+ np.add.at(y, edge_list[1], float(x))
272
+ return y
273
+ elif node.operands[0].nettype == "scalar" or x.shape[-1] == 1:
274
+ if x.shape[-1] != 1:
275
+ x = x[..., np.newaxis]
276
+ y = np.zeros((num_nodes,))
277
+ np.add.at(y, edge_list[1], 1)
278
+ y = y * x
279
+ return y
280
+ else:
281
+ y = np.zeros((*x.shape[:-1], num_nodes))
282
+ for k, j in enumerate(edge_list[1]):
283
+ y[..., j] += x[..., k]
284
+ return y
285
+
286
+ @unpack_operands()
287
+ def visit_Rgga(self, node: Rgga, x, *args, **kwargs):
288
+ """(*, n_edges or 1) -> (*, n_nodes)"""
289
+ edge_list = kwargs.get("edge_list", ([], []))
290
+ num_nodes = kwargs.get("num_nodes")
291
+
292
+ if isinstance(x, numbers.Number) or x.size == 1:
293
+ y = np.zeros((num_nodes,))
294
+ np.add.at(y, edge_list[1], x)
295
+ return y
296
+ elif node.operands[0].nettype == "scalar" or x.shape[-1] == 1:
297
+ if x.shape[-1] != 1:
298
+ x = x[..., np.newaxis]
299
+ y = np.zeros((num_nodes,))
300
+ np.add.at(y, edge_list[0], 1)
301
+ y = y * x
302
+ return y
303
+ else:
304
+ y = np.zeros((*x.shape[:-1], num_nodes))
305
+ for k, i in enumerate(edge_list[0]):
306
+ y[..., i] += x[..., k]
307
+ return y
308
+
309
+ @unpack_operands()
310
+ def visit_Readout(self, node: Readout, x, *args, **kwargs):
311
+ """(*, n_nodes or n_edges or 1) -> (*, 1)"""
312
+ return np.sum(x, axis=-1, keepdims=True)
313
+
314
+
315
+ """
316
+ # 比较 aggr 不同实现方式的性能
317
+
318
+ import numpy as np
319
+
320
+ T = 100
321
+ V = 10
322
+ E = 10
323
+ x = np.random.rand(T, E)
324
+ index = np.random.randint(0, V, size=(E,))
325
+
326
+ def aggr0(x, index, V):
327
+ y = np.zeros((*x.shape[:-1], V), dtype=x.dtype)
328
+ for k, j in enumerate(index):
329
+ y[..., j] += x[..., k]
330
+ return y
331
+
332
+
333
+ def aggr1(x, index, V):
334
+ x = np.asarray(x)
335
+ orig_shape = x.shape
336
+ x_flat = x.reshape(-1, x.shape[-1]) # shape: (B, N), B = product of leading dims
337
+
338
+ T = x_flat.shape[0]
339
+ y = np.zeros((T, V), dtype=x.dtype)
340
+ for i in range(T):
341
+ np.add.at(y[i, :], index, x_flat[i, :])
342
+ y = y.reshape(*orig_shape[:-1], V)
343
+ return y
344
+
345
+ def aggr2(x, index, V):
346
+ x = np.asarray(x)
347
+ orig_shape = x.shape
348
+ x_flat = x.reshape(-1, x.shape[-1]) # shape: (B, N), B = product of leading dims
349
+
350
+ T = x_flat.shape[0]
351
+ y = np.zeros((T, V), dtype=x.dtype)
352
+ for i in range(T):
353
+ y[i, :] = np.bincount(index, weights=x[i, :], minlength=V)
354
+ y = y.reshape(*orig_shape[:-1], V)
355
+ return y
356
+
357
+ def aggr3(x, index, V):
358
+ x = np.asarray(x)
359
+ orig_shape = x.shape
360
+ x_flat = x.reshape(-1, x.shape[-1]) # shape: (B, E)
361
+
362
+ T = x_flat.shape[0]
363
+ # 扩展 index 为 shape (T, E) 以对齐每个 batch
364
+ index_broadcast = np.broadcast_to(index, (T, E))
365
+
366
+ # 准备目标数组
367
+ y_flat = np.zeros((T, V), dtype=x.dtype)
368
+
369
+ # 扁平化批次索引:行号 (0,0,...,1,1,1,...T-1) 与 index 构成 2D 索引
370
+ row_idx = np.repeat(np.arange(T), E)
371
+ col_idx = index_broadcast.ravel()
372
+ values = x_flat.ravel()
373
+
374
+ # 聚合加和
375
+ np.add.at(y_flat, (row_idx, col_idx), values)
376
+
377
+ # reshape 回原来的批次形状
378
+ return y_flat.reshape(*orig_shape[:-1], V)
379
+
380
+
381
+ y0 = aggr0(x, index, V)
382
+
383
+ y1 = aggr1(x, index, V)
384
+ assert (y0 == y1).all()
385
+
386
+ y2 = aggr2(x, index, V)
387
+ assert (y0 == y2).all()
388
+
389
+ y3 = aggr3(x, index, V)
390
+ assert (y0 == y3).all()
391
+ """