nd2py 3.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nd2py-3.0.0/LICENSE +21 -0
- nd2py-3.0.0/PKG-INFO +35 -0
- nd2py-3.0.0/README.md +6 -0
- nd2py-3.0.0/nd2py/__init__.py +19 -0
- nd2py-3.0.0/nd2py/core/__init__.py +5 -0
- nd2py-3.0.0/nd2py/core/base_visitor.py +64 -0
- nd2py-3.0.0/nd2py/core/basic/__init__.py +9 -0
- nd2py-3.0.0/nd2py/core/basic/get_copy.py +37 -0
- nd2py-3.0.0/nd2py/core/basic/get_length.py +16 -0
- nd2py-3.0.0/nd2py/core/calc/__init__.py +9 -0
- nd2py-3.0.0/nd2py/core/calc/numpy_calc.py +391 -0
- nd2py-3.0.0/nd2py/core/calc/torch_calc.py +346 -0
- nd2py-3.0.0/nd2py/core/context/__init__.py +5 -0
- nd2py-3.0.0/nd2py/core/context/copy_value.py +23 -0
- nd2py-3.0.0/nd2py/core/context/nettype_inference.py +23 -0
- nd2py-3.0.0/nd2py/core/context/set_fitable.py +20 -0
- nd2py-3.0.0/nd2py/core/context/warn_once.py +33 -0
- nd2py-3.0.0/nd2py/core/converter/__init__.py +11 -0
- nd2py-3.0.0/nd2py/core/converter/from_postorder.py +27 -0
- nd2py-3.0.0/nd2py/core/converter/from_preorder.py +35 -0
- nd2py-3.0.0/nd2py/core/converter/parser.py +74 -0
- nd2py-3.0.0/nd2py/core/converter/string_printer.py +255 -0
- nd2py-3.0.0/nd2py/core/converter/tree_printer.py +101 -0
- nd2py-3.0.0/nd2py/core/nettype/__init__.py +2 -0
- nd2py-3.0.0/nd2py/core/nettype/inter_nettype.py +83 -0
- nd2py-3.0.0/nd2py/core/nettype/nettype_mixin.py +99 -0
- nd2py-3.0.0/nd2py/core/symbol_api.py +243 -0
- nd2py-3.0.0/nd2py/core/symbols/__init__.py +7 -0
- nd2py-3.0.0/nd2py/core/symbols/empty.py +19 -0
- nd2py-3.0.0/nd2py/core/symbols/functions.py +84 -0
- nd2py-3.0.0/nd2py/core/symbols/number.py +40 -0
- nd2py-3.0.0/nd2py/core/symbols/operands.py +204 -0
- nd2py-3.0.0/nd2py/core/symbols/symbols.py +288 -0
- nd2py-3.0.0/nd2py/core/symbols/variable.py +26 -0
- nd2py-3.0.0/nd2py/core/transform/__init__.py +7 -0
- nd2py-3.0.0/nd2py/core/transform/bfgs_fit.py +101 -0
- nd2py-3.0.0/nd2py/core/transform/fix_nettype.py +275 -0
- nd2py-3.0.0/nd2py/core/transform/fold_constant.py +65 -0
- nd2py-3.0.0/nd2py/core/transform/simplify.py +245 -0
- nd2py-3.0.0/nd2py/core/transform/split_by_add.py +233 -0
- nd2py-3.0.0/nd2py/core/transform/split_by_mul.py +79 -0
- nd2py-3.0.0/nd2py/core/tree/__init__.py +1 -0
- nd2py-3.0.0/nd2py/core/tree/iter_postorder.py +14 -0
- nd2py-3.0.0/nd2py/core/tree/iter_preorder.py +62 -0
- nd2py-3.0.0/nd2py/core/tree/tree_mixin.py +88 -0
- nd2py-3.0.0/nd2py/dataset/__init__.py +3 -0
- nd2py-3.0.0/nd2py/dataset/tokenizer.py +63 -0
- nd2py-3.0.0/nd2py/generator/__init__.py +2 -0
- nd2py-3.0.0/nd2py/generator/data/__init__.py +2 -0
- nd2py-3.0.0/nd2py/generator/data/gmm_generator.py +50 -0
- nd2py-3.0.0/nd2py/generator/data/subeq_generator.py +41 -0
- nd2py-3.0.0/nd2py/generator/eq/__init__.py +3 -0
- nd2py-3.0.0/nd2py/generator/eq/gplearn_generator.py +101 -0
- nd2py-3.0.0/nd2py/generator/eq/metaai_generator.py +178 -0
- nd2py-3.0.0/nd2py/generator/eq/snip_generator.py +145 -0
- nd2py-3.0.0/nd2py/search/__init__.py +1 -0
- nd2py-3.0.0/nd2py/search/gp/__init__.py +2 -0
- nd2py-3.0.0/nd2py/search/gp/gp.py +482 -0
- nd2py-3.0.0/nd2py/search/llmsr/__init__.py +1 -0
- nd2py-3.0.0/nd2py/search/llmsr/llmsr.py +451 -0
- nd2py-3.0.0/nd2py/search/mcts/__init__.py +1 -0
- nd2py-3.0.0/nd2py/search/mcts/mcts.py +455 -0
- nd2py-3.0.0/nd2py/search/mcts/mcts_forest.py +9 -0
- nd2py-3.0.0/nd2py/search/mcts/utils.py +38 -0
- nd2py-3.0.0/nd2py/search/ndformer/__init__.py +6 -0
- nd2py-3.0.0/nd2py/search/ndformer/ndformer_config.py +51 -0
- nd2py-3.0.0/nd2py/search/ndformer/ndformer_dataset.py +138 -0
- nd2py-3.0.0/nd2py/search/ndformer/ndformer_generator.py +318 -0
- nd2py-3.0.0/nd2py/search/ndformer/ndformer_model.py +185 -0
- nd2py-3.0.0/nd2py/search/ndformer/ndformer_tokenizer.py +243 -0
- nd2py-3.0.0/nd2py/search/ndformer/ndformer_train.py +457 -0
- nd2py-3.0.0/nd2py/utils/__init__.py +13 -0
- nd2py-3.0.0/nd2py/utils/attr_dict.py +48 -0
- nd2py-3.0.0/nd2py/utils/auto_gpu.py +80 -0
- nd2py-3.0.0/nd2py/utils/classproperty.py +12 -0
- nd2py-3.0.0/nd2py/utils/fix_parser.py +67 -0
- nd2py-3.0.0/nd2py/utils/logger.py +137 -0
- nd2py-3.0.0/nd2py/utils/metrics.py +29 -0
- nd2py-3.0.0/nd2py/utils/nn/__init__.py +2 -0
- nd2py-3.0.0/nd2py/utils/nn/gnn.py +112 -0
- nd2py-3.0.0/nd2py/utils/nn/positional_encoding.py +21 -0
- nd2py-3.0.0/nd2py/utils/plot.py +368 -0
- nd2py-3.0.0/nd2py/utils/render_markdown.py +22 -0
- nd2py-3.0.0/nd2py/utils/render_python.py +17 -0
- nd2py-3.0.0/nd2py/utils/softmax.py +0 -0
- nd2py-3.0.0/nd2py/utils/tag2ansi.py +169 -0
- nd2py-3.0.0/nd2py/utils/timing.py +268 -0
- nd2py-3.0.0/nd2py/utils/utils.py +25 -0
- nd2py-3.0.0/nd2py.egg-info/PKG-INFO +35 -0
- nd2py-3.0.0/nd2py.egg-info/SOURCES.txt +93 -0
- nd2py-3.0.0/nd2py.egg-info/dependency_links.txt +1 -0
- nd2py-3.0.0/nd2py.egg-info/requires.txt +21 -0
- nd2py-3.0.0/nd2py.egg-info/top_level.txt +1 -0
- nd2py-3.0.0/pyproject.toml +51 -0
- nd2py-3.0.0/setup.cfg +4 -0
nd2py-3.0.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024 YuMeow
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
nd2py-3.0.0/PKG-INFO
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: nd2py
|
|
3
|
+
Version: 3.0.0
|
|
4
|
+
Summary: nd2py (Neural Discovery of Network Dynamics) symbolic regression
|
|
5
|
+
Author-email: YuMeow <yuzh19@tsinghua.org.cn>
|
|
6
|
+
Requires-Python: >=3.8
|
|
7
|
+
Description-Content-Type: text/markdown
|
|
8
|
+
License-File: LICENSE
|
|
9
|
+
Requires-Dist: tqdm>=4.66
|
|
10
|
+
Requires-Dist: pytest>=8.3
|
|
11
|
+
Requires-Dist: numpy>=1.20
|
|
12
|
+
Requires-Dist: scipy>=1.14
|
|
13
|
+
Requires-Dist: pandas>=2.2
|
|
14
|
+
Requires-Dist: scikit-learn>=1.5
|
|
15
|
+
Requires-Dist: setproctitle
|
|
16
|
+
Requires-Dist: matplotlib
|
|
17
|
+
Requires-Dist: seaborn
|
|
18
|
+
Requires-Dist: pyyaml
|
|
19
|
+
Requires-Dist: rich
|
|
20
|
+
Requires-Dist: dotenv
|
|
21
|
+
Requires-Dist: requests
|
|
22
|
+
Requires-Dist: pyperclip
|
|
23
|
+
Provides-Extra: nn
|
|
24
|
+
Requires-Dist: torch>=1.12; extra == "nn"
|
|
25
|
+
Requires-Dist: torch_geometric>=2.0; extra == "nn"
|
|
26
|
+
Provides-Extra: all
|
|
27
|
+
Requires-Dist: nd2py[nn]; extra == "all"
|
|
28
|
+
Dynamic: license-file
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
```shell
|
|
33
|
+
conda create --prefix ./venv python=3.12
|
|
34
|
+
|
|
35
|
+
```
|
nd2py-3.0.0/README.md
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from .core import *
|
|
2
|
+
from . import dataset, generator, search, utils
|
|
3
|
+
|
|
4
|
+
# 定义模块级别的 __getattr__ 拦截
|
|
5
|
+
def __getattr__(name):
|
|
6
|
+
if name == '__all__':
|
|
7
|
+
# Triggered when 'from nd2py import *' is used
|
|
8
|
+
import warnings
|
|
9
|
+
warnings.warn(
|
|
10
|
+
"Detected 'from nd2py import *'. WARNING: This will shadow standard Python built-in functions (such as 'sum').\n"
|
|
11
|
+
"It is strongly recommended to use explicit imports, e.g., 'from nd2py import sum' or 'import nd2py as nd'.",
|
|
12
|
+
category=UserWarning,
|
|
13
|
+
stacklevel=2
|
|
14
|
+
)
|
|
15
|
+
# 动态获取当前模块(__init__.py)中所有不以 '_' 开头的全局变量和模块
|
|
16
|
+
return [n for n in globals() if not n.startswith('_')]
|
|
17
|
+
|
|
18
|
+
# 对于其他不存在的属性,保持默认的报错行为
|
|
19
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
# Copyright (c) 2024-present, Yumeow. Licensed under the MIT License.
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING: # 避免循环引用,仅用于类型检查
|
|
7
|
+
from .symbols import Symbol
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def yield_nothing():
|
|
11
|
+
"""A generator that yields nothing, used as a placeholder for methods that do not yield."""
|
|
12
|
+
if False:
|
|
13
|
+
yield
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Visitor(ABC):
|
|
17
|
+
def __call__(self, node: Symbol, *args, **kwargs):
|
|
18
|
+
"""
|
|
19
|
+
1) call Visitor.visit_<ClassName> based on type(node)
|
|
20
|
+
2) call Visitor.generic_visit if no Visitor.visit_<ClassName> defined
|
|
21
|
+
"""
|
|
22
|
+
stack = [("start", node, args, kwargs, None)]
|
|
23
|
+
result = None
|
|
24
|
+
while stack:
|
|
25
|
+
state, node, args, kwargs, gen = stack.pop()
|
|
26
|
+
if state == "start":
|
|
27
|
+
method = getattr(
|
|
28
|
+
self, "visit_" + type(node).__name__, self.generic_visit
|
|
29
|
+
)
|
|
30
|
+
gen = method(node, *args, **kwargs)
|
|
31
|
+
if not hasattr(gen, "__next__"):
|
|
32
|
+
raise TypeError(
|
|
33
|
+
f"Expected a generator but got {type(gen).__name__}, please add `yield from yield_nothing()` in {type(self).__name__}.{method.__name__}."
|
|
34
|
+
)
|
|
35
|
+
try:
|
|
36
|
+
child, args, kwargs = next(gen)
|
|
37
|
+
stack.append(("resume", node, None, None, gen))
|
|
38
|
+
stack.append(("start", child, args, kwargs, None))
|
|
39
|
+
except StopIteration as e:
|
|
40
|
+
result = e.value
|
|
41
|
+
elif state == "resume":
|
|
42
|
+
try:
|
|
43
|
+
child, args, kwargs = gen.send(result)
|
|
44
|
+
stack.append(("resume", node, None, None, gen))
|
|
45
|
+
stack.append(("start", child, args, kwargs, None))
|
|
46
|
+
except StopIteration as e:
|
|
47
|
+
result = e.value
|
|
48
|
+
else:
|
|
49
|
+
raise ValueError(f"Unknown state: {state}")
|
|
50
|
+
return result
|
|
51
|
+
|
|
52
|
+
@abstractmethod
|
|
53
|
+
def generic_visit(self, node: Symbol, *args, **kwargs):
|
|
54
|
+
msg = f"generic_visit not implemented for {type(self).__name__}"
|
|
55
|
+
raise NotImplementedError(msg)
|
|
56
|
+
|
|
57
|
+
_SYMBOL_CACHE = {}
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def _get_symbol(cls, name):
|
|
61
|
+
if name not in cls._SYMBOL_CACHE:
|
|
62
|
+
from . import symbols
|
|
63
|
+
cls._SYMBOL_CACHE[name] = getattr(symbols, name)
|
|
64
|
+
return cls._SYMBOL_CACHE[name]
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
# Copyright (c) 2024-present, Yumeow. Licensed under the MIT License.
|
|
2
|
+
"""
|
|
3
|
+
The implementation of basic functionality of Symbol as an object, including:
|
|
4
|
+
- GetCopy: Deep copy a Symbol tree.
|
|
5
|
+
- GetLength: Count the number of nodes in a Symbol tree.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from .get_copy import GetCopy
|
|
9
|
+
from .get_length import GetLength
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# Copyright (c) 2024-present, Yumeow. Licensed under the MIT License.
|
|
2
|
+
from copy import deepcopy
|
|
3
|
+
from ..base_visitor import Visitor, yield_nothing
|
|
4
|
+
from ..context.copy_value import get_copy_value
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class GetCopy(Visitor):
|
|
8
|
+
def __call__(self, node, *args, **kwargs):
|
|
9
|
+
"""Create a copy of the node."""
|
|
10
|
+
return super().__call__(node, *args, **kwargs)
|
|
11
|
+
|
|
12
|
+
def generic_visit(self, node, *args, **kwargs):
|
|
13
|
+
yield from yield_nothing()
|
|
14
|
+
children = []
|
|
15
|
+
for child in node.operands:
|
|
16
|
+
child = yield child, args, kwargs
|
|
17
|
+
children.append(child)
|
|
18
|
+
return node.__class__(*children, nettype=node._assigned_nettypes)
|
|
19
|
+
|
|
20
|
+
def visit_Number(self, node, *args, **kwargs):
|
|
21
|
+
yield from yield_nothing()
|
|
22
|
+
if get_copy_value():
|
|
23
|
+
return node.__class__(
|
|
24
|
+
deepcopy(node.value), nettype=node.nettype, fitable=node.fitable
|
|
25
|
+
)
|
|
26
|
+
else:
|
|
27
|
+
return node.__class__(
|
|
28
|
+
node.value, nettype=node._assigned_nettypes, fitable=node.fitable
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
def visit_Variable(self, node, *args, **kwargs):
|
|
32
|
+
yield from yield_nothing()
|
|
33
|
+
return node.__class__(node.name, nettype=node._assigned_nettypes)
|
|
34
|
+
|
|
35
|
+
def visit_Empty(self, node, *args, **kwargs):
|
|
36
|
+
yield from yield_nothing()
|
|
37
|
+
return node.__class__(nettype=node._assigned_nettypes)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright (c) 2024-present, Yumeow. Licensed under the MIT License.
|
|
2
|
+
from ..base_visitor import Visitor, yield_nothing
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class GetLength(Visitor):
|
|
6
|
+
def __call__(self, node, *args, **kwargs):
|
|
7
|
+
"""Count the number of nodes in the tree."""
|
|
8
|
+
return super().__call__(node, *args, **kwargs)
|
|
9
|
+
|
|
10
|
+
def generic_visit(self, node, *args, **kwargs):
|
|
11
|
+
yield from yield_nothing()
|
|
12
|
+
children = []
|
|
13
|
+
for child in node.operands:
|
|
14
|
+
child = yield child, args, kwargs
|
|
15
|
+
children.append(child)
|
|
16
|
+
return 1 + sum(children)
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
# Copyright (c) 2024-present, Yumeow. Licensed under the MIT License.
|
|
2
|
+
"""
|
|
3
|
+
The implementation of calculation functionality of Symbol, including:
|
|
4
|
+
- NumpyCalc: Calculate the value of a Symbol tree using numpy.
|
|
5
|
+
- TorchCalc: Calculate the value of a Symbol tree using torch.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from .numpy_calc import NumpyCalc
|
|
9
|
+
from .torch_calc import TorchCalc
|
|
@@ -0,0 +1,391 @@
|
|
|
1
|
+
# Copyright (c) 2024-present, Yumeow. Licensed under the MIT License.
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
import numbers
|
|
4
|
+
import warnings
|
|
5
|
+
import functools
|
|
6
|
+
import traceback
|
|
7
|
+
import numpy as np
|
|
8
|
+
from typing import List, Tuple, TYPE_CHECKING
|
|
9
|
+
from ..base_visitor import Visitor, yield_nothing
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from ..symbols import *
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
# Decorator to unpack operands for operations
|
|
15
|
+
# This allows us to handle operations with multiple operands in a clean way
|
|
16
|
+
# We can also use this decorator to suppress numpy errors
|
|
17
|
+
def unpack_operands():
|
|
18
|
+
def decorator(func):
|
|
19
|
+
@functools.wraps(func)
|
|
20
|
+
def wrapper(self, node, *args, **kwargs):
|
|
21
|
+
# Calculate the values of the operands
|
|
22
|
+
yield from yield_nothing()
|
|
23
|
+
X = []
|
|
24
|
+
for op in node.operands:
|
|
25
|
+
x = yield (op, args, kwargs)
|
|
26
|
+
X.append(x)
|
|
27
|
+
# Use the defined 'visit_<Operation>' as 'func' to process the operands
|
|
28
|
+
with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
|
|
29
|
+
return func(self, node, *X, *args, **kwargs)
|
|
30
|
+
|
|
31
|
+
return wrapper
|
|
32
|
+
|
|
33
|
+
return decorator
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class NumpyCalc(Visitor):
|
|
37
|
+
def __call__(
|
|
38
|
+
self,
|
|
39
|
+
node: Symbol,
|
|
40
|
+
vars: dict = {},
|
|
41
|
+
edge_list: Tuple[List[int], List[int]] = None,
|
|
42
|
+
num_nodes: int = None,
|
|
43
|
+
use_eps: float = 0.0,
|
|
44
|
+
):
|
|
45
|
+
"""
|
|
46
|
+
Args:
|
|
47
|
+
- vars: a dictionary of variable names and their values
|
|
48
|
+
- edge_list: edges (i,j) in the graph
|
|
49
|
+
i and j are the indices of the nodes (starting from 0)
|
|
50
|
+
- num_nodes: the number of nodes in the graph
|
|
51
|
+
if not provided, it will be inferred from edge_list
|
|
52
|
+
- use_eps: a small value to avoid division by zero
|
|
53
|
+
"""
|
|
54
|
+
if num_nodes is None and edge_list is not None:
|
|
55
|
+
nodes = np.unique(np.array(edge_list).reshape(-1))
|
|
56
|
+
num_nodes = max(nodes) + 1
|
|
57
|
+
|
|
58
|
+
return super().__call__(
|
|
59
|
+
node,
|
|
60
|
+
vars=vars,
|
|
61
|
+
edge_list=edge_list,
|
|
62
|
+
num_nodes=num_nodes,
|
|
63
|
+
use_eps=use_eps,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def generic_visit(self, node: Symbol, *args, **kwargs):
|
|
67
|
+
raise NotImplementedError(
|
|
68
|
+
f"{type(self).__name__}.visit_{type(node).__name__} not implemented"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
def visit_Empty(self, node: Empty, *args, **kwargs):
|
|
72
|
+
raise ValueError(
|
|
73
|
+
f"Incomplete expression with Empty node is not allowed to evaluate: {node}"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def visit_Number(self, node: Number, *args, **kwargs):
|
|
77
|
+
yield from yield_nothing()
|
|
78
|
+
return np.asarray(node.value)
|
|
79
|
+
|
|
80
|
+
def visit_Variable(self, node: Variable, *args, **kwargs):
|
|
81
|
+
yield from yield_nothing()
|
|
82
|
+
return np.asarray(kwargs["vars"][node.name])
|
|
83
|
+
|
|
84
|
+
@unpack_operands()
|
|
85
|
+
def visit_Add(self, node: Add, x1, x2, *args, **kwargs):
|
|
86
|
+
return x1 + x2
|
|
87
|
+
|
|
88
|
+
@unpack_operands()
|
|
89
|
+
def visit_Sub(self, node: Sub, x1, x2, *args, **kwargs):
|
|
90
|
+
return x1 - x2
|
|
91
|
+
|
|
92
|
+
@unpack_operands()
|
|
93
|
+
def visit_Mul(self, node: Mul, x1, x2, *args, **kwargs):
|
|
94
|
+
return x1 * x2
|
|
95
|
+
|
|
96
|
+
@unpack_operands()
|
|
97
|
+
def visit_Div(self, node: Div, x1, x2, *args, **kwargs):
|
|
98
|
+
eps = kwargs.get("use_eps")
|
|
99
|
+
return x1 / (x2 + eps * (x2 == 0))
|
|
100
|
+
|
|
101
|
+
@unpack_operands()
|
|
102
|
+
def visit_Pow(self, node: Pow, x1, x2, *args, **kwargs):
|
|
103
|
+
return x1**x2
|
|
104
|
+
|
|
105
|
+
@unpack_operands()
|
|
106
|
+
def visit_Max(self, node: Max, x1, x2, *args, **kwargs):
|
|
107
|
+
return np.maximum(x1, x2)
|
|
108
|
+
|
|
109
|
+
@unpack_operands()
|
|
110
|
+
def visit_Min(self, node: Min, x1, x2, *args, **kwargs):
|
|
111
|
+
return np.minimum(x1, x2)
|
|
112
|
+
|
|
113
|
+
@unpack_operands()
|
|
114
|
+
def visit_Identity(self, node: Identity, x, *args, **kwargs):
|
|
115
|
+
return x
|
|
116
|
+
|
|
117
|
+
@unpack_operands()
|
|
118
|
+
def visit_Sin(self, node: Sin, x, *args, **kwargs):
|
|
119
|
+
return np.sin(x)
|
|
120
|
+
|
|
121
|
+
@unpack_operands()
|
|
122
|
+
def visit_Cos(self, node: Cos, x, *args, **kwargs):
|
|
123
|
+
return np.cos(x)
|
|
124
|
+
|
|
125
|
+
@unpack_operands()
|
|
126
|
+
def visit_Tan(self, node: Tan, x, *args, **kwargs):
|
|
127
|
+
return np.tan(x)
|
|
128
|
+
|
|
129
|
+
@unpack_operands()
|
|
130
|
+
def visit_Sec(self, node: Sec, x, *args, **kwargs):
|
|
131
|
+
eps = kwargs.get("use_eps")
|
|
132
|
+
return 1 / (np.cos(x) + eps * (np.cos(x) == 0))
|
|
133
|
+
|
|
134
|
+
@unpack_operands()
|
|
135
|
+
def visit_Csc(self, node: Csc, x, *args, **kwargs):
|
|
136
|
+
eps = kwargs.get("use_eps")
|
|
137
|
+
return 1 / (np.sin(x) + eps * (np.sin(x) == 0))
|
|
138
|
+
|
|
139
|
+
@unpack_operands()
|
|
140
|
+
def visit_Cot(self, node: Cot, x, *args, **kwargs):
|
|
141
|
+
eps = kwargs.get("use_eps")
|
|
142
|
+
return 1 / (np.tan(x) + eps * (np.tan(x) == 0))
|
|
143
|
+
|
|
144
|
+
@unpack_operands()
|
|
145
|
+
def visit_Log(self, node: Log, x, *args, **kwargs):
|
|
146
|
+
eps = kwargs.get("use_eps")
|
|
147
|
+
return np.log(x + eps * (x == 0))
|
|
148
|
+
|
|
149
|
+
@unpack_operands()
|
|
150
|
+
def visit_LogAbs(self, node: LogAbs, x, *args, **kwargs):
|
|
151
|
+
eps = kwargs.get("use_eps")
|
|
152
|
+
return np.log(np.abs(x + eps * (x == 0)))
|
|
153
|
+
|
|
154
|
+
@unpack_operands()
|
|
155
|
+
def visit_Exp(self, node: Exp, x, *args, **kwargs):
|
|
156
|
+
return np.exp(x)
|
|
157
|
+
|
|
158
|
+
@unpack_operands()
|
|
159
|
+
def visit_Abs(self, node: Abs, x, *args, **kwargs):
|
|
160
|
+
return np.abs(x)
|
|
161
|
+
|
|
162
|
+
@unpack_operands()
|
|
163
|
+
def visit_Neg(self, node: Neg, x, *args, **kwargs):
|
|
164
|
+
return -x
|
|
165
|
+
|
|
166
|
+
@unpack_operands()
|
|
167
|
+
def visit_Inv(self, node: Inv, x, *args, **kwargs):
|
|
168
|
+
eps = kwargs.get("use_eps")
|
|
169
|
+
return 1 / (x + eps * (x == 0))
|
|
170
|
+
|
|
171
|
+
@unpack_operands()
|
|
172
|
+
def visit_Sqrt(self, node: Sqrt, x, *args, **kwargs):
|
|
173
|
+
return np.sqrt(x)
|
|
174
|
+
|
|
175
|
+
@unpack_operands()
|
|
176
|
+
def visit_SqrtAbs(self, node: SqrtAbs, x, *args, **kwargs):
|
|
177
|
+
return np.sqrt(np.abs(x))
|
|
178
|
+
|
|
179
|
+
@unpack_operands()
|
|
180
|
+
def visit_Pow2(self, node: Pow2, x, *args, **kwargs):
|
|
181
|
+
return x**2
|
|
182
|
+
|
|
183
|
+
@unpack_operands()
|
|
184
|
+
def visit_Pow3(self, node: Pow3, x, *args, **kwargs):
|
|
185
|
+
return x**3
|
|
186
|
+
|
|
187
|
+
@unpack_operands()
|
|
188
|
+
def visit_Arcsin(self, node: Arcsin, x, *args, **kwargs):
|
|
189
|
+
return np.arcsin(x)
|
|
190
|
+
|
|
191
|
+
@unpack_operands()
|
|
192
|
+
def visit_Arccos(self, node: Arccos, x, *args, **kwargs):
|
|
193
|
+
return np.arccos(x)
|
|
194
|
+
|
|
195
|
+
@unpack_operands()
|
|
196
|
+
def visit_Arctan(self, node: Arctan, x, *args, **kwargs):
|
|
197
|
+
return np.arctan(x)
|
|
198
|
+
|
|
199
|
+
@unpack_operands()
|
|
200
|
+
def visit_Sinh(self, node: Sinh, x, *args, **kwargs):
|
|
201
|
+
return np.sinh(x)
|
|
202
|
+
|
|
203
|
+
@unpack_operands()
|
|
204
|
+
def visit_Cosh(self, node: Cosh, x, *args, **kwargs):
|
|
205
|
+
return np.cosh(x)
|
|
206
|
+
|
|
207
|
+
@unpack_operands()
|
|
208
|
+
def visit_Tanh(self, node: Tanh, x, *args, **kwargs):
|
|
209
|
+
return np.tanh(x)
|
|
210
|
+
|
|
211
|
+
@unpack_operands()
|
|
212
|
+
def visit_Sech(self, node: Sech, x, *args, **kwargs):
|
|
213
|
+
eps = kwargs.get("use_eps")
|
|
214
|
+
return 1 / (np.cosh(x) + eps * (np.cosh(x) == 0))
|
|
215
|
+
|
|
216
|
+
@unpack_operands()
|
|
217
|
+
def visit_Csch(self, node: Csch, x, *args, **kwargs):
|
|
218
|
+
eps = kwargs.get("use_eps")
|
|
219
|
+
return 1 / (np.sinh(x) + eps * (np.sinh(x) == 0))
|
|
220
|
+
|
|
221
|
+
@unpack_operands()
|
|
222
|
+
def visit_Coth(self, node: Coth, x, *args, **kwargs):
|
|
223
|
+
eps = kwargs.get("use_eps")
|
|
224
|
+
return 1 / (np.tanh(x) + eps * (np.tanh(x) == 0))
|
|
225
|
+
|
|
226
|
+
@unpack_operands()
|
|
227
|
+
def visit_Sigmoid(self, node: Sigmoid, x, *args, **kwargs):
|
|
228
|
+
return 1 / (1 + np.exp(-x))
|
|
229
|
+
|
|
230
|
+
@unpack_operands()
|
|
231
|
+
def visit_Regular(self, node: Regular, x1, x2, *args, **kwargs):
|
|
232
|
+
eps = kwargs.get("use_eps")
|
|
233
|
+
return 1 / (1 + (np.abs(x1) + eps * (x1 == 0)) ** (-x2))
|
|
234
|
+
|
|
235
|
+
@unpack_operands()
|
|
236
|
+
def visit_Sour(self, node: Sour, x, *args, **kwargs):
|
|
237
|
+
"""(*, n_nodes or 1) -> (*, n_edges or 1)"""
|
|
238
|
+
edge_list = kwargs.get("edge_list", ([], []))
|
|
239
|
+
|
|
240
|
+
if isinstance(x, numbers.Number) or x.size == 1:
|
|
241
|
+
return x # (1,) -> (1,)
|
|
242
|
+
elif node.operands[0].nettype == "scalar" or x.shape[-1] == 1:
|
|
243
|
+
if x.shape[-1] != 1:
|
|
244
|
+
x = x[..., np.newaxis]
|
|
245
|
+
return x # (*, 1) -> (*, 1)
|
|
246
|
+
else:
|
|
247
|
+
return x[..., edge_list[0]] # (*, V) -> (*, E)
|
|
248
|
+
|
|
249
|
+
@unpack_operands()
|
|
250
|
+
def visit_Targ(self, node: Targ, x, *args, **kwargs):
|
|
251
|
+
"""(*, n_nodes or 1) -> (*, n_edges or 1)"""
|
|
252
|
+
edge_list = kwargs.get("edge_list", ([], []))
|
|
253
|
+
|
|
254
|
+
if isinstance(x, numbers.Number) or x.size == 1:
|
|
255
|
+
return x # (1,) -> (1,)
|
|
256
|
+
elif node.operands[0].nettype == "scalar" or x.shape[-1] == 1:
|
|
257
|
+
if x.shape[-1] != 1:
|
|
258
|
+
x = x[..., np.newaxis]
|
|
259
|
+
return x # (*, 1) -> (*, 1)
|
|
260
|
+
else:
|
|
261
|
+
return x[..., edge_list[1]] # (*, V) -> (*, E)
|
|
262
|
+
|
|
263
|
+
@unpack_operands()
|
|
264
|
+
def visit_Aggr(self, node: Aggr, x, *args, **kwargs):
|
|
265
|
+
"""(*, n_edges or 1) -> (*, n_nodes)"""
|
|
266
|
+
edge_list = kwargs.get("edge_list", ([], []))
|
|
267
|
+
num_nodes = kwargs.get("num_nodes")
|
|
268
|
+
|
|
269
|
+
if isinstance(x, numbers.Number) or x.size == 1:
|
|
270
|
+
y = np.zeros((num_nodes,))
|
|
271
|
+
np.add.at(y, edge_list[1], float(x))
|
|
272
|
+
return y
|
|
273
|
+
elif node.operands[0].nettype == "scalar" or x.shape[-1] == 1:
|
|
274
|
+
if x.shape[-1] != 1:
|
|
275
|
+
x = x[..., np.newaxis]
|
|
276
|
+
y = np.zeros((num_nodes,))
|
|
277
|
+
np.add.at(y, edge_list[1], 1)
|
|
278
|
+
y = y * x
|
|
279
|
+
return y
|
|
280
|
+
else:
|
|
281
|
+
y = np.zeros((*x.shape[:-1], num_nodes))
|
|
282
|
+
for k, j in enumerate(edge_list[1]):
|
|
283
|
+
y[..., j] += x[..., k]
|
|
284
|
+
return y
|
|
285
|
+
|
|
286
|
+
@unpack_operands()
|
|
287
|
+
def visit_Rgga(self, node: Rgga, x, *args, **kwargs):
|
|
288
|
+
"""(*, n_edges or 1) -> (*, n_nodes)"""
|
|
289
|
+
edge_list = kwargs.get("edge_list", ([], []))
|
|
290
|
+
num_nodes = kwargs.get("num_nodes")
|
|
291
|
+
|
|
292
|
+
if isinstance(x, numbers.Number) or x.size == 1:
|
|
293
|
+
y = np.zeros((num_nodes,))
|
|
294
|
+
np.add.at(y, edge_list[1], x)
|
|
295
|
+
return y
|
|
296
|
+
elif node.operands[0].nettype == "scalar" or x.shape[-1] == 1:
|
|
297
|
+
if x.shape[-1] != 1:
|
|
298
|
+
x = x[..., np.newaxis]
|
|
299
|
+
y = np.zeros((num_nodes,))
|
|
300
|
+
np.add.at(y, edge_list[0], 1)
|
|
301
|
+
y = y * x
|
|
302
|
+
return y
|
|
303
|
+
else:
|
|
304
|
+
y = np.zeros((*x.shape[:-1], num_nodes))
|
|
305
|
+
for k, i in enumerate(edge_list[0]):
|
|
306
|
+
y[..., i] += x[..., k]
|
|
307
|
+
return y
|
|
308
|
+
|
|
309
|
+
@unpack_operands()
|
|
310
|
+
def visit_Readout(self, node: Readout, x, *args, **kwargs):
|
|
311
|
+
"""(*, n_nodes or n_edges or 1) -> (*, 1)"""
|
|
312
|
+
return np.sum(x, axis=-1, keepdims=True)
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
"""
|
|
316
|
+
# 比较 aggr 不同实现方式的性能
|
|
317
|
+
|
|
318
|
+
import numpy as np
|
|
319
|
+
|
|
320
|
+
T = 100
|
|
321
|
+
V = 10
|
|
322
|
+
E = 10
|
|
323
|
+
x = np.random.rand(T, E)
|
|
324
|
+
index = np.random.randint(0, V, size=(E,))
|
|
325
|
+
|
|
326
|
+
def aggr0(x, index, V):
|
|
327
|
+
y = np.zeros((*x.shape[:-1], V), dtype=x.dtype)
|
|
328
|
+
for k, j in enumerate(index):
|
|
329
|
+
y[..., j] += x[..., k]
|
|
330
|
+
return y
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def aggr1(x, index, V):
|
|
334
|
+
x = np.asarray(x)
|
|
335
|
+
orig_shape = x.shape
|
|
336
|
+
x_flat = x.reshape(-1, x.shape[-1]) # shape: (B, N), B = product of leading dims
|
|
337
|
+
|
|
338
|
+
T = x_flat.shape[0]
|
|
339
|
+
y = np.zeros((T, V), dtype=x.dtype)
|
|
340
|
+
for i in range(T):
|
|
341
|
+
np.add.at(y[i, :], index, x_flat[i, :])
|
|
342
|
+
y = y.reshape(*orig_shape[:-1], V)
|
|
343
|
+
return y
|
|
344
|
+
|
|
345
|
+
def aggr2(x, index, V):
|
|
346
|
+
x = np.asarray(x)
|
|
347
|
+
orig_shape = x.shape
|
|
348
|
+
x_flat = x.reshape(-1, x.shape[-1]) # shape: (B, N), B = product of leading dims
|
|
349
|
+
|
|
350
|
+
T = x_flat.shape[0]
|
|
351
|
+
y = np.zeros((T, V), dtype=x.dtype)
|
|
352
|
+
for i in range(T):
|
|
353
|
+
y[i, :] = np.bincount(index, weights=x[i, :], minlength=V)
|
|
354
|
+
y = y.reshape(*orig_shape[:-1], V)
|
|
355
|
+
return y
|
|
356
|
+
|
|
357
|
+
def aggr3(x, index, V):
|
|
358
|
+
x = np.asarray(x)
|
|
359
|
+
orig_shape = x.shape
|
|
360
|
+
x_flat = x.reshape(-1, x.shape[-1]) # shape: (B, E)
|
|
361
|
+
|
|
362
|
+
T = x_flat.shape[0]
|
|
363
|
+
# 扩展 index 为 shape (T, E) 以对齐每个 batch
|
|
364
|
+
index_broadcast = np.broadcast_to(index, (T, E))
|
|
365
|
+
|
|
366
|
+
# 准备目标数组
|
|
367
|
+
y_flat = np.zeros((T, V), dtype=x.dtype)
|
|
368
|
+
|
|
369
|
+
# 扁平化批次索引:行号 (0,0,...,1,1,1,...T-1) 与 index 构成 2D 索引
|
|
370
|
+
row_idx = np.repeat(np.arange(T), E)
|
|
371
|
+
col_idx = index_broadcast.ravel()
|
|
372
|
+
values = x_flat.ravel()
|
|
373
|
+
|
|
374
|
+
# 聚合加和
|
|
375
|
+
np.add.at(y_flat, (row_idx, col_idx), values)
|
|
376
|
+
|
|
377
|
+
# reshape 回原来的批次形状
|
|
378
|
+
return y_flat.reshape(*orig_shape[:-1], V)
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
y0 = aggr0(x, index, V)
|
|
382
|
+
|
|
383
|
+
y1 = aggr1(x, index, V)
|
|
384
|
+
assert (y0 == y1).all()
|
|
385
|
+
|
|
386
|
+
y2 = aggr2(x, index, V)
|
|
387
|
+
assert (y0 == y2).all()
|
|
388
|
+
|
|
389
|
+
y3 = aggr3(x, index, V)
|
|
390
|
+
assert (y0 == y3).all()
|
|
391
|
+
"""
|