tgraphx 0.2.2__tar.gz → 0.2.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {tgraphx-0.2.2 → tgraphx-0.2.3}/PKG-INFO +14 -11
- {tgraphx-0.2.2 → tgraphx-0.2.3}/README.md +13 -10
- {tgraphx-0.2.2 → tgraphx-0.2.3}/pyproject.toml +1 -1
- tgraphx-0.2.3/tests/test_chunking.py +500 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/__init__.py +1 -1
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/dashboard/app.py +113 -17
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/graph_builders.py +246 -48
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/layers/gin.py +119 -10
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/layers/sage.py +160 -23
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx.egg-info/PKG-INFO +14 -11
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx.egg-info/SOURCES.txt +1 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/LICENSE +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/setup.cfg +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_3d_support.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_amp_compile.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_dashboard.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_devices.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_documentation_claims.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_edge_features.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_edge_weight.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_factories.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_gnn_families.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_gradients.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_graph.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_graph_api.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_graph_builders.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_imports.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_layers.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_math.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_models.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_packaging.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_performance_smoke.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_tracking.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_training.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/core/__init__.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/core/dataloader.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/core/graph.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/core/graph_utils.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/core/utils.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/dashboard/__init__.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/dashboard/__main__.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/dashboard/static/dashboard.css +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/dashboard/static/dashboard.js +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/layers/__init__.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/layers/_dim.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/layers/_scatter.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/layers/aggregator.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/layers/attention_message.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/layers/base.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/layers/conv_message.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/layers/factory.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/layers/gat.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/layers/safe_pool.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/models/__init__.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/models/cnn_encoder.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/models/cnn_gnn_model.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/models/edge_predictor.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/models/factory.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/models/graph_classifier.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/models/node_classifier.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/models/pre_encoder.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/models/regressors.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/performance.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/tracking.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/training.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx.egg-info/dependency_links.txt +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx.egg-info/entry_points.txt +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx.egg-info/requires.txt +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tgraphx
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.3
|
|
4
4
|
Summary: Tensor-aware graph neural networks preserving spatial node feature layouts
|
|
5
5
|
Author-email: Arash Sajjadi <arash.sajjadi@usask.ca>
|
|
6
6
|
Maintainer-email: Arash Sajjadi <arash.sajjadi@usask.ca>
|
|
@@ -129,9 +129,10 @@ drop-in clones of PyTorch Geometric's vector-feature implementations.
|
|
|
129
129
|
- **Heterogeneous and temporal graphs.**
|
|
130
130
|
- **MLflowLogger.** Not implemented. Use the `mlflow` client directly.
|
|
131
131
|
`pip install mlflow`.
|
|
132
|
-
- **GAT
|
|
133
|
-
softmax requires all edge scores;
|
|
134
|
-
|
|
132
|
+
- **GAT chunked forward.** Deferred to v0.2.4. GAT's destination-wise
|
|
133
|
+
softmax requires all incoming edge scores simultaneously; a correct two-pass
|
|
134
|
+
implementation is needed. `ConvMessagePassing`, `TensorGraphSAGELayer`, and
|
|
135
|
+
`TensorGINLayer` all support `chunk_size` in their `forward()` call.
|
|
135
136
|
- **Hardware-monitoring extras.** CPU/RAM/GPU metrics in the dashboard
|
|
136
137
|
require optional packages: `pip install tgraphx[monitoring]`.
|
|
137
138
|
- **torch.compile / AMP.** `torch.compile` is available in PyTorch ≥ 2.0
|
|
@@ -1233,15 +1234,17 @@ batch.to(device)
|
|
|
1233
1234
|
| Feature | Status | Notes |
|
|
1234
1235
|
|---------|:------:|-------|
|
|
1235
1236
|
| `ConvMessagePassing` chunked forward | ✅ Stable | `aggr="sum"` / `"mean"`; max falls back with warning |
|
|
1236
|
-
| `TensorGraphSAGELayer` chunked forward |
|
|
1237
|
-
| `TensorGINLayer` chunked forward |
|
|
1238
|
-
| `TensorGATLayer` chunked forward | ⏳ Planned v0.2.
|
|
1237
|
+
| `TensorGraphSAGELayer` chunked forward | ✅ Stable | v0.2.3; mean and max; pass `chunk_size=K` to `forward()` |
|
|
1238
|
+
| `TensorGINLayer` chunked forward | ✅ Stable | v0.2.3; sum aggregation; pass `chunk_size=K` to `forward()` |
|
|
1239
|
+
| `TensorGATLayer` chunked forward | ⏳ Planned v0.2.4 | Requires two-pass algorithm for destination-wise softmax |
|
|
1239
1240
|
| `build_grid_graph` / `build_grid_graph_3d` | ✅ Stable | O(E) — scales well |
|
|
1240
1241
|
| `build_random_graph` | ✅ Stable | O(E) — scales well |
|
|
1241
|
-
| `build_knn_graph` / `build_radius_graph` | ⚠️ Best-effort | O(N²)
|
|
1242
|
-
| `build_fully_connected_graph`
|
|
1243
|
-
|
|
|
1244
|
-
|
|
|
1242
|
+
| `build_knn_graph` / `build_radius_graph` | ⚠️ Best-effort | O(N²) time; `chunk_size=K` reduces peak memory to O(K×N) |
|
|
1243
|
+
| `build_fully_connected_graph` | ⚠️ Best-effort | O(N²) edges; N > 5 000 emits warning |
|
|
1244
|
+
| `build_iou_graph` | ⚠️ Best-effort | O(N²) IoU; `chunk_size=K` reduces peak memory to O(K×N) |
|
|
1245
|
+
| `build_random_graph` | ✅ Stable | `algorithm="sample"` uses O(num_edges) memory for large N |
|
|
1246
|
+
| Dashboard metrics API | ✅ Stable | Incremental `?since_row=N`; `--max-metric-rows` cap; byte-seek tail-read (v0.2.3) |
|
|
1247
|
+
| Large `metrics.csv` tail-read | ✅ Stable | v0.2.3: byte-seek on append; full reparse on rotation/truncation |
|
|
1245
1248
|
|
|
1246
1249
|
> ⚠️ **Scalability warning:** `build_knn_graph`, `build_radius_graph`, `build_fully_connected_graph`,
|
|
1247
1250
|
> and `build_iou_graph` use pairwise `torch.cdist` or enumerate all pairs. Memory and time grow as
|
|
@@ -84,9 +84,10 @@ drop-in clones of PyTorch Geometric's vector-feature implementations.
|
|
|
84
84
|
- **Heterogeneous and temporal graphs.**
|
|
85
85
|
- **MLflowLogger.** Not implemented. Use the `mlflow` client directly.
|
|
86
86
|
`pip install mlflow`.
|
|
87
|
-
- **GAT
|
|
88
|
-
softmax requires all edge scores;
|
|
89
|
-
|
|
87
|
+
- **GAT chunked forward.** Deferred to v0.2.4. GAT's destination-wise
|
|
88
|
+
softmax requires all incoming edge scores simultaneously; a correct two-pass
|
|
89
|
+
implementation is needed. `ConvMessagePassing`, `TensorGraphSAGELayer`, and
|
|
90
|
+
`TensorGINLayer` all support `chunk_size` in their `forward()` call.
|
|
90
91
|
- **Hardware-monitoring extras.** CPU/RAM/GPU metrics in the dashboard
|
|
91
92
|
require optional packages: `pip install tgraphx[monitoring]`.
|
|
92
93
|
- **torch.compile / AMP.** `torch.compile` is available in PyTorch ≥ 2.0
|
|
@@ -1188,15 +1189,17 @@ batch.to(device)
|
|
|
1188
1189
|
| Feature | Status | Notes |
|
|
1189
1190
|
|---------|:------:|-------|
|
|
1190
1191
|
| `ConvMessagePassing` chunked forward | ✅ Stable | `aggr="sum"` / `"mean"`; max falls back with warning |
|
|
1191
|
-
| `TensorGraphSAGELayer` chunked forward |
|
|
1192
|
-
| `TensorGINLayer` chunked forward |
|
|
1193
|
-
| `TensorGATLayer` chunked forward | ⏳ Planned v0.2.
|
|
1192
|
+
| `TensorGraphSAGELayer` chunked forward | ✅ Stable | v0.2.3; mean and max; pass `chunk_size=K` to `forward()` |
|
|
1193
|
+
| `TensorGINLayer` chunked forward | ✅ Stable | v0.2.3; sum aggregation; pass `chunk_size=K` to `forward()` |
|
|
1194
|
+
| `TensorGATLayer` chunked forward | ⏳ Planned v0.2.4 | Requires two-pass algorithm for destination-wise softmax |
|
|
1194
1195
|
| `build_grid_graph` / `build_grid_graph_3d` | ✅ Stable | O(E) — scales well |
|
|
1195
1196
|
| `build_random_graph` | ✅ Stable | O(E) — scales well |
|
|
1196
|
-
| `build_knn_graph` / `build_radius_graph` | ⚠️ Best-effort | O(N²)
|
|
1197
|
-
| `build_fully_connected_graph`
|
|
1198
|
-
|
|
|
1199
|
-
|
|
|
1197
|
+
| `build_knn_graph` / `build_radius_graph` | ⚠️ Best-effort | O(N²) time; `chunk_size=K` reduces peak memory to O(K×N) |
|
|
1198
|
+
| `build_fully_connected_graph` | ⚠️ Best-effort | O(N²) edges; N > 5 000 emits warning |
|
|
1199
|
+
| `build_iou_graph` | ⚠️ Best-effort | O(N²) IoU; `chunk_size=K` reduces peak memory to O(K×N) |
|
|
1200
|
+
| `build_random_graph` | ✅ Stable | `algorithm="sample"` uses O(num_edges) memory for large N |
|
|
1201
|
+
| Dashboard metrics API | ✅ Stable | Incremental `?since_row=N`; `--max-metric-rows` cap; byte-seek tail-read (v0.2.3) |
|
|
1202
|
+
| Large `metrics.csv` tail-read | ✅ Stable | v0.2.3: byte-seek on append; full reparse on rotation/truncation |
|
|
1200
1203
|
|
|
1201
1204
|
> ⚠️ **Scalability warning:** `build_knn_graph`, `build_radius_graph`, `build_fully_connected_graph`,
|
|
1202
1205
|
> and `build_iou_graph` use pairwise `torch.cdist` or enumerate all pairs. Memory and time grow as
|
|
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "tgraphx"
|
|
7
7
|
# Keep this in sync with tgraphx/__init__.py::__version__
|
|
8
|
-
version = "0.2.
|
|
8
|
+
version = "0.2.3"
|
|
9
9
|
description = "Tensor-aware graph neural networks preserving spatial node feature layouts"
|
|
10
10
|
readme = "README.md"
|
|
11
11
|
requires-python = ">=3.9"
|
|
@@ -0,0 +1,500 @@
|
|
|
1
|
+
"""Chunked forward tests for TensorGraphSAGELayer and TensorGINLayer (v0.2.3).
|
|
2
|
+
|
|
3
|
+
Verifies that chunked forward produces output identical to unchunked within
|
|
4
|
+
floating-point tolerance for all supported configurations:
|
|
5
|
+
- 2-D spatial (spatial_rank=2)
|
|
6
|
+
- 3-D volumetric (spatial_rank=3)
|
|
7
|
+
- aggr="mean" and aggr="max" (SAGE)
|
|
8
|
+
- edge_weight
|
|
9
|
+
- vector edge_features
|
|
10
|
+
- spatial/volumetric edge_features
|
|
11
|
+
- isolated nodes (no incoming edges)
|
|
12
|
+
- chunk_size=None unchanged (identity)
|
|
13
|
+
- empty edge_index
|
|
14
|
+
- gradients finite
|
|
15
|
+
- AMP/bfloat16 smoke (skipped when unsupported)
|
|
16
|
+
|
|
17
|
+
Also tests graph builder chunked paths (kNN, radius, IoU) and random graph
|
|
18
|
+
algorithm="sample".
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
import warnings
|
|
24
|
+
|
|
25
|
+
import pytest
|
|
26
|
+
import torch
|
|
27
|
+
|
|
28
|
+
from tgraphx.graph_builders import (
|
|
29
|
+
build_grid_graph,
|
|
30
|
+
build_grid_graph_3d,
|
|
31
|
+
build_iou_graph,
|
|
32
|
+
build_knn_graph,
|
|
33
|
+
build_radius_graph,
|
|
34
|
+
build_random_graph,
|
|
35
|
+
)
|
|
36
|
+
from tgraphx.layers import TensorGINLayer, TensorGraphSAGELayer
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# ── Helpers ───────────────────────────────────────────────────────────────────
|
|
40
|
+
|
|
41
|
+
ATOL = 1e-5 # float32 tolerance for chunk vs unchunked comparison
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _small_2d(N=9, C=4, H=4, W=4, seed=0):
|
|
45
|
+
torch.manual_seed(seed)
|
|
46
|
+
x = torch.randn(N, C, H, W)
|
|
47
|
+
ei = build_grid_graph(3, 3, directed=False, self_loops=True)
|
|
48
|
+
return x, ei, N, C, H, W
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _small_3d(N=8, C=4, D=4, H=4, W=4, seed=0):
|
|
52
|
+
torch.manual_seed(seed)
|
|
53
|
+
x = torch.randn(N, C, D, H, W)
|
|
54
|
+
ei = build_grid_graph_3d(2, 2, 2, directed=False, self_loops=True)
|
|
55
|
+
return x, ei, N, C, D, H, W
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _cpu_bf16_ok():
|
|
59
|
+
try:
|
|
60
|
+
with torch.autocast("cpu", dtype=torch.bfloat16):
|
|
61
|
+
t = torch.tensor([1.0])
|
|
62
|
+
_ = t + t
|
|
63
|
+
return True
|
|
64
|
+
except Exception:
|
|
65
|
+
return False
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
skip_bf16 = pytest.mark.skipif(
|
|
69
|
+
not _cpu_bf16_ok(), reason="CPU bfloat16 autocast not available"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
# ── SAGE chunking — mean aggregation ─────────────────────────────────────────
|
|
74
|
+
|
|
75
|
+
class TestSAGEChunkedMean:
|
|
76
|
+
def _layer(self, **kw):
|
|
77
|
+
return TensorGraphSAGELayer(4, 4, aggr="mean", **kw).eval()
|
|
78
|
+
|
|
79
|
+
def test_2d_parity(self):
|
|
80
|
+
x, ei, N, C, H, W = _small_2d()
|
|
81
|
+
layer = self._layer()
|
|
82
|
+
with torch.no_grad():
|
|
83
|
+
out_full = layer(x, ei)
|
|
84
|
+
out_chunk = layer(x, ei, chunk_size=5)
|
|
85
|
+
assert out_chunk.shape == out_full.shape
|
|
86
|
+
assert torch.allclose(out_full, out_chunk, atol=ATOL), \
|
|
87
|
+
f"max diff = {(out_full - out_chunk).abs().max():.2e}"
|
|
88
|
+
|
|
89
|
+
def test_3d_parity(self):
|
|
90
|
+
x, ei, N, C, D, H, W = _small_3d()
|
|
91
|
+
layer = TensorGraphSAGELayer(C, C, aggr="mean", spatial_rank=3).eval()
|
|
92
|
+
with torch.no_grad():
|
|
93
|
+
out_full = layer(x, ei)
|
|
94
|
+
out_chunk = layer(x, ei, chunk_size=5)
|
|
95
|
+
assert torch.allclose(out_full, out_chunk, atol=ATOL)
|
|
96
|
+
|
|
97
|
+
def test_chunk_size_1(self):
|
|
98
|
+
x, ei, N, C, H, W = _small_2d()
|
|
99
|
+
layer = self._layer()
|
|
100
|
+
with torch.no_grad():
|
|
101
|
+
out_full = layer(x, ei)
|
|
102
|
+
out_chunk = layer(x, ei, chunk_size=1)
|
|
103
|
+
assert torch.allclose(out_full, out_chunk, atol=ATOL)
|
|
104
|
+
|
|
105
|
+
def test_chunk_size_exceeds_edges(self):
|
|
106
|
+
x, ei, N, C, H, W = _small_2d()
|
|
107
|
+
layer = self._layer()
|
|
108
|
+
with torch.no_grad():
|
|
109
|
+
out_full = layer(x, ei)
|
|
110
|
+
out_chunk = layer(x, ei, chunk_size=10_000)
|
|
111
|
+
assert torch.allclose(out_full, out_chunk, atol=ATOL)
|
|
112
|
+
|
|
113
|
+
def test_chunk_size_none_unchanged(self):
|
|
114
|
+
x, ei, N, C, H, W = _small_2d()
|
|
115
|
+
layer = self._layer()
|
|
116
|
+
with torch.no_grad():
|
|
117
|
+
out1 = layer(x, ei, chunk_size=None)
|
|
118
|
+
out2 = layer(x, ei)
|
|
119
|
+
assert torch.equal(out1, out2)
|
|
120
|
+
|
|
121
|
+
def test_edge_weight_parity(self):
|
|
122
|
+
x, ei, N, C, H, W = _small_2d()
|
|
123
|
+
layer = self._layer()
|
|
124
|
+
ew = torch.rand(ei.size(1))
|
|
125
|
+
with torch.no_grad():
|
|
126
|
+
out_full = layer(x, ei, edge_weight=ew)
|
|
127
|
+
out_chunk = layer(x, ei, edge_weight=ew, chunk_size=5)
|
|
128
|
+
assert torch.allclose(out_full, out_chunk, atol=ATOL)
|
|
129
|
+
|
|
130
|
+
def test_vector_edge_features_parity(self):
|
|
131
|
+
x, ei, N, C, H, W = _small_2d()
|
|
132
|
+
layer = TensorGraphSAGELayer(
|
|
133
|
+
C, C, aggr="mean", use_edge_features=True,
|
|
134
|
+
edge_dim=3, edge_features_kind="vector",
|
|
135
|
+
).eval()
|
|
136
|
+
ef = torch.randn(ei.size(1), 3)
|
|
137
|
+
with torch.no_grad():
|
|
138
|
+
out_full = layer(x, ei, edge_features=ef)
|
|
139
|
+
out_chunk = layer(x, ei, edge_features=ef, chunk_size=5)
|
|
140
|
+
assert torch.allclose(out_full, out_chunk, atol=ATOL)
|
|
141
|
+
|
|
142
|
+
def test_spatial_edge_features_parity(self):
|
|
143
|
+
x, ei, N, C, H, W = _small_2d()
|
|
144
|
+
layer = TensorGraphSAGELayer(
|
|
145
|
+
C, C, aggr="mean", use_edge_features=True,
|
|
146
|
+
edge_dim=3, edge_features_kind="spatial",
|
|
147
|
+
).eval()
|
|
148
|
+
ef = torch.randn(ei.size(1), 3, H, W)
|
|
149
|
+
with torch.no_grad():
|
|
150
|
+
out_full = layer(x, ei, edge_features=ef)
|
|
151
|
+
out_chunk = layer(x, ei, edge_features=ef, chunk_size=5)
|
|
152
|
+
assert torch.allclose(out_full, out_chunk, atol=ATOL)
|
|
153
|
+
|
|
154
|
+
def test_isolated_nodes(self):
|
|
155
|
+
"""Nodes with no incoming edges must produce zeros in the agg term."""
|
|
156
|
+
torch.manual_seed(1)
|
|
157
|
+
N, C, H, W = 5, 4, 4, 4
|
|
158
|
+
x = torch.randn(N, C, H, W)
|
|
159
|
+
# Node 4 has no incoming edges.
|
|
160
|
+
ei = torch.tensor([[0, 1, 2], [1, 2, 3]], dtype=torch.long)
|
|
161
|
+
layer = self._layer()
|
|
162
|
+
with torch.no_grad():
|
|
163
|
+
out_full = layer(x, ei)
|
|
164
|
+
out_chunk = layer(x, ei, chunk_size=2)
|
|
165
|
+
assert torch.allclose(out_full, out_chunk, atol=ATOL)
|
|
166
|
+
|
|
167
|
+
def test_gradient_finite(self):
|
|
168
|
+
x, ei, N, C, H, W = _small_2d()
|
|
169
|
+
x = x.requires_grad_(True)
|
|
170
|
+
layer = TensorGraphSAGELayer(C, C, aggr="mean").train()
|
|
171
|
+
out = layer(x, ei, chunk_size=5)
|
|
172
|
+
out.sum().backward()
|
|
173
|
+
assert x.grad is not None
|
|
174
|
+
assert torch.isfinite(x.grad).all()
|
|
175
|
+
|
|
176
|
+
@skip_bf16
|
|
177
|
+
def test_amp_bfloat16_smoke(self):
|
|
178
|
+
x, ei, N, C, H, W = _small_2d()
|
|
179
|
+
layer = self._layer()
|
|
180
|
+
with torch.no_grad(), torch.autocast("cpu", dtype=torch.bfloat16):
|
|
181
|
+
out = layer(x, ei, chunk_size=5)
|
|
182
|
+
assert torch.isfinite(out).all()
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
# ── SAGE chunking — max aggregation ──────────────────────────────────────────
|
|
186
|
+
|
|
187
|
+
class TestSAGEChunkedMax:
|
|
188
|
+
def _layer(self, **kw):
|
|
189
|
+
return TensorGraphSAGELayer(4, 4, aggr="max", **kw).eval()
|
|
190
|
+
|
|
191
|
+
def test_2d_parity(self):
|
|
192
|
+
x, ei, N, C, H, W = _small_2d()
|
|
193
|
+
layer = self._layer()
|
|
194
|
+
with torch.no_grad():
|
|
195
|
+
out_full = layer(x, ei)
|
|
196
|
+
out_chunk = layer(x, ei, chunk_size=5)
|
|
197
|
+
assert torch.allclose(out_full, out_chunk, atol=ATOL)
|
|
198
|
+
|
|
199
|
+
def test_3d_parity(self):
|
|
200
|
+
x, ei, N, C, D, H, W = _small_3d()
|
|
201
|
+
layer = TensorGraphSAGELayer(C, C, aggr="max", spatial_rank=3).eval()
|
|
202
|
+
with torch.no_grad():
|
|
203
|
+
out_full = layer(x, ei)
|
|
204
|
+
out_chunk = layer(x, ei, chunk_size=5)
|
|
205
|
+
assert torch.allclose(out_full, out_chunk, atol=ATOL)
|
|
206
|
+
|
|
207
|
+
def test_edge_weight_parity(self):
|
|
208
|
+
x, ei, N, C, H, W = _small_2d()
|
|
209
|
+
layer = self._layer()
|
|
210
|
+
ew = torch.rand(ei.size(1))
|
|
211
|
+
with torch.no_grad():
|
|
212
|
+
out_full = layer(x, ei, edge_weight=ew)
|
|
213
|
+
out_chunk = layer(x, ei, edge_weight=ew, chunk_size=5)
|
|
214
|
+
assert torch.allclose(out_full, out_chunk, atol=ATOL)
|
|
215
|
+
|
|
216
|
+
def test_isolated_nodes_max(self):
|
|
217
|
+
"""Isolated nodes must be zero in max-aggregated output (not -inf)."""
|
|
218
|
+
torch.manual_seed(2)
|
|
219
|
+
N, C, H, W = 4, 4, 4, 4
|
|
220
|
+
x = torch.randn(N, C, H, W)
|
|
221
|
+
ei = torch.tensor([[0, 1], [1, 2]], dtype=torch.long)
|
|
222
|
+
layer = self._layer()
|
|
223
|
+
with torch.no_grad():
|
|
224
|
+
out_full = layer(x, ei)
|
|
225
|
+
out_chunk = layer(x, ei, chunk_size=1)
|
|
226
|
+
assert torch.allclose(out_full, out_chunk, atol=ATOL)
|
|
227
|
+
# Node 3 has no incoming edges → agg must be 0 (not -inf)
|
|
228
|
+
assert torch.isfinite(out_full[3]).all()
|
|
229
|
+
assert torch.isfinite(out_chunk[3]).all()
|
|
230
|
+
|
|
231
|
+
def test_gradient_finite(self):
|
|
232
|
+
x, ei, N, C, H, W = _small_2d()
|
|
233
|
+
x = x.requires_grad_(True)
|
|
234
|
+
layer = TensorGraphSAGELayer(C, C, aggr="max").train()
|
|
235
|
+
out = layer(x, ei, chunk_size=5)
|
|
236
|
+
out.sum().backward()
|
|
237
|
+
assert x.grad is not None
|
|
238
|
+
assert torch.isfinite(x.grad).all()
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
# ── GIN chunking ──────────────────────────────────────────────────────────────
|
|
242
|
+
|
|
243
|
+
class TestGINChunked:
|
|
244
|
+
def _layer(self, **kw):
|
|
245
|
+
return TensorGINLayer(4, 4, **kw).eval()
|
|
246
|
+
|
|
247
|
+
def test_2d_parity(self):
|
|
248
|
+
x, ei, N, C, H, W = _small_2d()
|
|
249
|
+
layer = self._layer()
|
|
250
|
+
with torch.no_grad():
|
|
251
|
+
out_full = layer(x, ei)
|
|
252
|
+
out_chunk = layer(x, ei, chunk_size=5)
|
|
253
|
+
assert torch.allclose(out_full, out_chunk, atol=ATOL)
|
|
254
|
+
|
|
255
|
+
def test_3d_parity(self):
|
|
256
|
+
x, ei, N, C, D, H, W = _small_3d()
|
|
257
|
+
layer = TensorGINLayer(C, C, spatial_rank=3).eval()
|
|
258
|
+
with torch.no_grad():
|
|
259
|
+
out_full = layer(x, ei)
|
|
260
|
+
out_chunk = layer(x, ei, chunk_size=5)
|
|
261
|
+
assert torch.allclose(out_full, out_chunk, atol=ATOL)
|
|
262
|
+
|
|
263
|
+
def test_chunk_size_1(self):
|
|
264
|
+
x, ei, N, C, H, W = _small_2d()
|
|
265
|
+
layer = self._layer()
|
|
266
|
+
with torch.no_grad():
|
|
267
|
+
out_full = layer(x, ei)
|
|
268
|
+
out_chunk = layer(x, ei, chunk_size=1)
|
|
269
|
+
assert torch.allclose(out_full, out_chunk, atol=ATOL)
|
|
270
|
+
|
|
271
|
+
def test_chunk_size_none_unchanged(self):
|
|
272
|
+
x, ei, N, C, H, W = _small_2d()
|
|
273
|
+
layer = self._layer()
|
|
274
|
+
with torch.no_grad():
|
|
275
|
+
out1 = layer(x, ei)
|
|
276
|
+
out2 = layer(x, ei, chunk_size=None)
|
|
277
|
+
assert torch.equal(out1, out2)
|
|
278
|
+
|
|
279
|
+
def test_edge_weight_parity(self):
|
|
280
|
+
x, ei, N, C, H, W = _small_2d()
|
|
281
|
+
layer = self._layer()
|
|
282
|
+
ew = torch.rand(ei.size(1))
|
|
283
|
+
with torch.no_grad():
|
|
284
|
+
out_full = layer(x, ei, edge_weight=ew)
|
|
285
|
+
out_chunk = layer(x, ei, edge_weight=ew, chunk_size=5)
|
|
286
|
+
assert torch.allclose(out_full, out_chunk, atol=ATOL)
|
|
287
|
+
|
|
288
|
+
def test_vector_edge_features_parity(self):
|
|
289
|
+
x, ei, N, C, H, W = _small_2d()
|
|
290
|
+
layer = TensorGINLayer(
|
|
291
|
+
C, C, use_edge_features=True, edge_dim=3,
|
|
292
|
+
edge_features_kind="vector",
|
|
293
|
+
).eval()
|
|
294
|
+
ef = torch.randn(ei.size(1), 3)
|
|
295
|
+
with torch.no_grad():
|
|
296
|
+
out_full = layer(x, ei, edge_features=ef)
|
|
297
|
+
out_chunk = layer(x, ei, edge_features=ef, chunk_size=5)
|
|
298
|
+
assert torch.allclose(out_full, out_chunk, atol=ATOL)
|
|
299
|
+
|
|
300
|
+
def test_spatial_edge_features_parity(self):
|
|
301
|
+
x, ei, N, C, H, W = _small_2d()
|
|
302
|
+
layer = TensorGINLayer(
|
|
303
|
+
C, C, use_edge_features=True, edge_dim=3,
|
|
304
|
+
edge_features_kind="spatial",
|
|
305
|
+
).eval()
|
|
306
|
+
ef = torch.randn(ei.size(1), 3, H, W)
|
|
307
|
+
with torch.no_grad():
|
|
308
|
+
out_full = layer(x, ei, edge_features=ef)
|
|
309
|
+
out_chunk = layer(x, ei, edge_features=ef, chunk_size=5)
|
|
310
|
+
assert torch.allclose(out_full, out_chunk, atol=ATOL)
|
|
311
|
+
|
|
312
|
+
def test_train_eps_parity(self):
|
|
313
|
+
x, ei, N, C, H, W = _small_2d()
|
|
314
|
+
layer = TensorGINLayer(C, C, train_eps=True, eps=0.5).eval()
|
|
315
|
+
with torch.no_grad():
|
|
316
|
+
out_full = layer(x, ei)
|
|
317
|
+
out_chunk = layer(x, ei, chunk_size=5)
|
|
318
|
+
assert torch.allclose(out_full, out_chunk, atol=ATOL)
|
|
319
|
+
|
|
320
|
+
def test_custom_mlp(self):
|
|
321
|
+
import torch.nn as nn
|
|
322
|
+
x, ei, N, C, H, W = _small_2d()
|
|
323
|
+
mlp = nn.Sequential(
|
|
324
|
+
nn.Conv2d(C, C, 1), nn.ReLU(), nn.Conv2d(C, C, 1)
|
|
325
|
+
)
|
|
326
|
+
layer = TensorGINLayer(C, C, mlp=mlp).eval()
|
|
327
|
+
with torch.no_grad():
|
|
328
|
+
out_full = layer(x, ei)
|
|
329
|
+
out_chunk = layer(x, ei, chunk_size=5)
|
|
330
|
+
assert torch.allclose(out_full, out_chunk, atol=ATOL)
|
|
331
|
+
|
|
332
|
+
def test_isolated_nodes(self):
|
|
333
|
+
torch.manual_seed(3)
|
|
334
|
+
N, C, H, W = 5, 4, 4, 4
|
|
335
|
+
x = torch.randn(N, C, H, W)
|
|
336
|
+
ei = torch.tensor([[0, 1], [1, 2]], dtype=torch.long)
|
|
337
|
+
layer = self._layer()
|
|
338
|
+
with torch.no_grad():
|
|
339
|
+
out_full = layer(x, ei)
|
|
340
|
+
out_chunk = layer(x, ei, chunk_size=1)
|
|
341
|
+
assert torch.allclose(out_full, out_chunk, atol=ATOL)
|
|
342
|
+
|
|
343
|
+
def test_gradient_finite(self):
|
|
344
|
+
x, ei, N, C, H, W = _small_2d()
|
|
345
|
+
x = x.requires_grad_(True)
|
|
346
|
+
layer = TensorGINLayer(C, C).train()
|
|
347
|
+
out = layer(x, ei, chunk_size=5)
|
|
348
|
+
out.sum().backward()
|
|
349
|
+
assert x.grad is not None
|
|
350
|
+
assert torch.isfinite(x.grad).all()
|
|
351
|
+
|
|
352
|
+
@skip_bf16
|
|
353
|
+
def test_amp_bfloat16_smoke(self):
|
|
354
|
+
x, ei, N, C, H, W = _small_2d()
|
|
355
|
+
layer = self._layer()
|
|
356
|
+
with torch.no_grad(), torch.autocast("cpu", dtype=torch.bfloat16):
|
|
357
|
+
out = layer(x, ei, chunk_size=5)
|
|
358
|
+
assert torch.isfinite(out).all()
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
# ── Graph builder chunked paths ───────────────────────────────────────────────
|
|
362
|
+
|
|
363
|
+
class TestKNNChunked:
|
|
364
|
+
def _sort(self, ei):
|
|
365
|
+
keys = ei[0] * 10000 + ei[1]
|
|
366
|
+
return ei[:, keys.sort().indices]
|
|
367
|
+
|
|
368
|
+
def test_parity_small(self):
|
|
369
|
+
torch.manual_seed(0)
|
|
370
|
+
coords = torch.randn(20, 2)
|
|
371
|
+
ei_full = build_knn_graph(coords, k=3)
|
|
372
|
+
ei_chunk = build_knn_graph(coords, k=3, chunk_size=5)
|
|
373
|
+
assert self._sort(ei_full).shape == self._sort(ei_chunk).shape
|
|
374
|
+
assert torch.all(self._sort(ei_full) == self._sort(ei_chunk))
|
|
375
|
+
|
|
376
|
+
def test_parity_directed(self):
|
|
377
|
+
torch.manual_seed(1)
|
|
378
|
+
coords = torch.randn(12, 3)
|
|
379
|
+
ei_full = build_knn_graph(coords, k=2, directed=True)
|
|
380
|
+
ei_chunk = build_knn_graph(coords, k=2, directed=True, chunk_size=4)
|
|
381
|
+
assert self._sort(ei_full).shape == self._sort(ei_chunk).shape
|
|
382
|
+
assert torch.all(self._sort(ei_full) == self._sort(ei_chunk))
|
|
383
|
+
|
|
384
|
+
def test_no_warning_with_chunk_size(self):
|
|
385
|
+
coords = torch.randn(10_001, 2)
|
|
386
|
+
with warnings.catch_warnings(record=True) as w:
|
|
387
|
+
warnings.simplefilter("always")
|
|
388
|
+
ei = build_knn_graph(coords, k=2, chunk_size=100)
|
|
389
|
+
# chunk_size suppresses the O(N²) memory warning
|
|
390
|
+
mem_warns = [x for x in w if "O(N²)" in str(x.message) or "O.N.2" in str(x.message)]
|
|
391
|
+
assert len(mem_warns) == 0, "unexpected O(N²) warning with chunk_size set"
|
|
392
|
+
|
|
393
|
+
def test_warning_without_chunk_size(self):
|
|
394
|
+
coords = torch.randn(10_001, 2)
|
|
395
|
+
with pytest.warns(UserWarning, match="num_nodes"):
|
|
396
|
+
build_knn_graph(coords, k=2)
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
class TestRadiusChunked:
|
|
400
|
+
def _sort(self, ei):
|
|
401
|
+
keys = ei[0] * 10000 + ei[1]
|
|
402
|
+
return ei[:, keys.sort().indices]
|
|
403
|
+
|
|
404
|
+
def test_parity_undirected(self):
|
|
405
|
+
torch.manual_seed(2)
|
|
406
|
+
coords = torch.randn(20, 2)
|
|
407
|
+
ei_full = build_radius_graph(coords, radius=0.8)
|
|
408
|
+
ei_chunk = build_radius_graph(coords, radius=0.8, chunk_size=5)
|
|
409
|
+
assert self._sort(ei_full).shape == self._sort(ei_chunk).shape
|
|
410
|
+
assert torch.all(self._sort(ei_full) == self._sort(ei_chunk))
|
|
411
|
+
|
|
412
|
+
def test_parity_directed(self):
|
|
413
|
+
torch.manual_seed(3)
|
|
414
|
+
coords = torch.randn(16, 2)
|
|
415
|
+
ei_full = build_radius_graph(coords, radius=0.5, directed=True)
|
|
416
|
+
ei_chunk = build_radius_graph(coords, radius=0.5, directed=True, chunk_size=4)
|
|
417
|
+
assert self._sort(ei_full).shape == self._sort(ei_chunk).shape
|
|
418
|
+
assert torch.all(self._sort(ei_full) == self._sort(ei_chunk))
|
|
419
|
+
|
|
420
|
+
def test_no_warning_with_chunk_size(self):
|
|
421
|
+
coords = torch.randn(10_001, 2)
|
|
422
|
+
with warnings.catch_warnings(record=True) as w:
|
|
423
|
+
warnings.simplefilter("always")
|
|
424
|
+
build_radius_graph(coords, radius=0.01, chunk_size=100)
|
|
425
|
+
mem_warns = [x for x in w if "num_nodes" in str(x.message)]
|
|
426
|
+
assert len(mem_warns) == 0
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
class TestIoUChunked:
|
|
430
|
+
def _sort(self, ei):
|
|
431
|
+
keys = ei[0] * 10000 + ei[1]
|
|
432
|
+
return ei[:, keys.sort().indices]
|
|
433
|
+
|
|
434
|
+
def _boxes(self, n=10, seed=4):
|
|
435
|
+
torch.manual_seed(seed)
|
|
436
|
+
x1 = torch.rand(n)
|
|
437
|
+
y1 = torch.rand(n)
|
|
438
|
+
x2 = x1 + torch.rand(n) * 0.5 + 0.1
|
|
439
|
+
y2 = y1 + torch.rand(n) * 0.5 + 0.1
|
|
440
|
+
return torch.stack([x1, y1, x2, y2], dim=1)
|
|
441
|
+
|
|
442
|
+
def test_parity(self):
|
|
443
|
+
boxes = self._boxes(16)
|
|
444
|
+
ei_full = build_iou_graph(boxes, threshold=0.1)
|
|
445
|
+
ei_chunk = build_iou_graph(boxes, threshold=0.1, chunk_size=4)
|
|
446
|
+
assert self._sort(ei_full).shape == self._sort(ei_chunk).shape
|
|
447
|
+
assert torch.all(self._sort(ei_full) == self._sort(ei_chunk))
|
|
448
|
+
|
|
449
|
+
def test_parity_directed(self):
|
|
450
|
+
boxes = self._boxes(12, seed=5)
|
|
451
|
+
ei_full = build_iou_graph(boxes, threshold=0.05, directed=True)
|
|
452
|
+
ei_chunk = build_iou_graph(boxes, threshold=0.05, directed=True, chunk_size=3)
|
|
453
|
+
assert self._sort(ei_full).shape == self._sort(ei_chunk).shape
|
|
454
|
+
assert torch.all(self._sort(ei_full) == self._sort(ei_chunk))
|
|
455
|
+
|
|
456
|
+
def test_no_warning_with_chunk_size(self):
|
|
457
|
+
boxes = self._boxes(5_001)
|
|
458
|
+
with warnings.catch_warnings(record=True) as w:
|
|
459
|
+
warnings.simplefilter("always")
|
|
460
|
+
build_iou_graph(boxes, threshold=0.5, chunk_size=100)
|
|
461
|
+
mem_warns = [x for x in w if "num_nodes" in str(x.message)]
|
|
462
|
+
assert len(mem_warns) == 0
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
class TestRandomGraphSample:
|
|
466
|
+
def test_correct_edge_count(self):
|
|
467
|
+
ei = build_random_graph(100, 200, algorithm="sample", seed=42)
|
|
468
|
+
assert ei.shape == (2, 200)
|
|
469
|
+
|
|
470
|
+
def test_no_self_loops(self):
|
|
471
|
+
ei = build_random_graph(100, 200, algorithm="sample", seed=42)
|
|
472
|
+
assert (ei[0] != ei[1]).all()
|
|
473
|
+
|
|
474
|
+
def test_no_duplicates(self):
|
|
475
|
+
ei = build_random_graph(100, 200, algorithm="sample", seed=42)
|
|
476
|
+
keys = ei[0] * 100 + ei[1]
|
|
477
|
+
assert keys.unique().numel() == 200
|
|
478
|
+
|
|
479
|
+
def test_deterministic(self):
|
|
480
|
+
ei1 = build_random_graph(100, 200, algorithm="sample", seed=99)
|
|
481
|
+
ei2 = build_random_graph(100, 200, algorithm="sample", seed=99)
|
|
482
|
+
assert torch.equal(ei1, ei2)
|
|
483
|
+
|
|
484
|
+
def test_zero_edges(self):
|
|
485
|
+
ei = build_random_graph(10, 0, algorithm="sample", seed=0)
|
|
486
|
+
assert ei.shape == (2, 0)
|
|
487
|
+
|
|
488
|
+
def test_too_many_edges_raises(self):
|
|
489
|
+
with pytest.raises(ValueError, match="Cannot sample"):
|
|
490
|
+
build_random_graph(5, 100, algorithm="sample", seed=0)
|
|
491
|
+
|
|
492
|
+
def test_unsupported_undirected_raises(self):
|
|
493
|
+
with pytest.raises(ValueError, match="directed=True"):
|
|
494
|
+
build_random_graph(10, 5, directed=False, algorithm="sample")
|
|
495
|
+
|
|
496
|
+
def test_exact_is_default_and_unchanged(self):
|
|
497
|
+
"""Default algorithm='exact' must match previous behavior."""
|
|
498
|
+
ei_default = build_random_graph(10, 20, seed=0)
|
|
499
|
+
ei_exact = build_random_graph(10, 20, seed=0, algorithm="exact")
|
|
500
|
+
assert torch.equal(ei_default, ei_exact)
|
|
@@ -10,7 +10,7 @@ Common one-liner imports::
|
|
|
10
10
|
"""
|
|
11
11
|
|
|
12
12
|
# Keep this in sync with [project].version in pyproject.toml.
|
|
13
|
-
__version__ = "0.2.
|
|
13
|
+
__version__ = "0.2.3"
|
|
14
14
|
|
|
15
15
|
# ── Core data structures ──────────────────────────────────────────────────────
|
|
16
16
|
from .core.graph import Graph, GraphBatch
|