tgraphx 0.2.2__tar.gz → 0.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. {tgraphx-0.2.2 → tgraphx-0.2.3}/PKG-INFO +14 -11
  2. {tgraphx-0.2.2 → tgraphx-0.2.3}/README.md +13 -10
  3. {tgraphx-0.2.2 → tgraphx-0.2.3}/pyproject.toml +1 -1
  4. tgraphx-0.2.3/tests/test_chunking.py +500 -0
  5. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/__init__.py +1 -1
  6. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/dashboard/app.py +113 -17
  7. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/graph_builders.py +246 -48
  8. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/layers/gin.py +119 -10
  9. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/layers/sage.py +160 -23
  10. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx.egg-info/PKG-INFO +14 -11
  11. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx.egg-info/SOURCES.txt +1 -0
  12. {tgraphx-0.2.2 → tgraphx-0.2.3}/LICENSE +0 -0
  13. {tgraphx-0.2.2 → tgraphx-0.2.3}/setup.cfg +0 -0
  14. {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_3d_support.py +0 -0
  15. {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_amp_compile.py +0 -0
  16. {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_dashboard.py +0 -0
  17. {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_devices.py +0 -0
  18. {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_documentation_claims.py +0 -0
  19. {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_edge_features.py +0 -0
  20. {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_edge_weight.py +0 -0
  21. {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_factories.py +0 -0
  22. {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_gnn_families.py +0 -0
  23. {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_gradients.py +0 -0
  24. {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_graph.py +0 -0
  25. {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_graph_api.py +0 -0
  26. {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_graph_builders.py +0 -0
  27. {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_imports.py +0 -0
  28. {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_layers.py +0 -0
  29. {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_math.py +0 -0
  30. {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_models.py +0 -0
  31. {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_packaging.py +0 -0
  32. {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_performance_smoke.py +0 -0
  33. {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_tracking.py +0 -0
  34. {tgraphx-0.2.2 → tgraphx-0.2.3}/tests/test_training.py +0 -0
  35. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/core/__init__.py +0 -0
  36. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/core/dataloader.py +0 -0
  37. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/core/graph.py +0 -0
  38. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/core/graph_utils.py +0 -0
  39. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/core/utils.py +0 -0
  40. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/dashboard/__init__.py +0 -0
  41. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/dashboard/__main__.py +0 -0
  42. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/dashboard/static/dashboard.css +0 -0
  43. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/dashboard/static/dashboard.js +0 -0
  44. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/layers/__init__.py +0 -0
  45. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/layers/_dim.py +0 -0
  46. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/layers/_scatter.py +0 -0
  47. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/layers/aggregator.py +0 -0
  48. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/layers/attention_message.py +0 -0
  49. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/layers/base.py +0 -0
  50. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/layers/conv_message.py +0 -0
  51. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/layers/factory.py +0 -0
  52. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/layers/gat.py +0 -0
  53. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/layers/safe_pool.py +0 -0
  54. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/models/__init__.py +0 -0
  55. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/models/cnn_encoder.py +0 -0
  56. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/models/cnn_gnn_model.py +0 -0
  57. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/models/edge_predictor.py +0 -0
  58. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/models/factory.py +0 -0
  59. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/models/graph_classifier.py +0 -0
  60. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/models/node_classifier.py +0 -0
  61. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/models/pre_encoder.py +0 -0
  62. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/models/regressors.py +0 -0
  63. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/performance.py +0 -0
  64. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/tracking.py +0 -0
  65. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx/training.py +0 -0
  66. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx.egg-info/dependency_links.txt +0 -0
  67. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx.egg-info/entry_points.txt +0 -0
  68. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx.egg-info/requires.txt +0 -0
  69. {tgraphx-0.2.2 → tgraphx-0.2.3}/tgraphx.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tgraphx
3
- Version: 0.2.2
3
+ Version: 0.2.3
4
4
  Summary: Tensor-aware graph neural networks preserving spatial node feature layouts
5
5
  Author-email: Arash Sajjadi <arash.sajjadi@usask.ca>
6
6
  Maintainer-email: Arash Sajjadi <arash.sajjadi@usask.ca>
@@ -129,9 +129,10 @@ drop-in clones of PyTorch Geometric's vector-feature implementations.
129
129
  - **Heterogeneous and temporal graphs.**
130
130
  - **MLflowLogger.** Not implemented. Use the `mlflow` client directly.
131
131
  `pip install mlflow`.
132
- - **GAT / SAGE / GIN chunked forward.** Deferred (GAT's destination-wise
133
- softmax requires all edge scores; SAGE/GIN chunking deferred for scope).
134
- `ConvMessagePassing` supports `chunk_size` for `sum` / `mean` aggregation.
132
+ - **GAT chunked forward.** Deferred to v0.2.4. GAT's destination-wise
133
+ softmax requires all incoming edge scores simultaneously; a correct two-pass
134
+ implementation is needed. `ConvMessagePassing`, `TensorGraphSAGELayer`, and
135
+ `TensorGINLayer` all support `chunk_size` in their `forward()` call.
135
136
  - **Hardware-monitoring extras.** CPU/RAM/GPU metrics in the dashboard
136
137
  require optional packages: `pip install tgraphx[monitoring]`.
137
138
  - **torch.compile / AMP.** `torch.compile` is available in PyTorch ≥ 2.0
@@ -1233,15 +1234,17 @@ batch.to(device)
1233
1234
  | Feature | Status | Notes |
1234
1235
  |---------|:------:|-------|
1235
1236
  | `ConvMessagePassing` chunked forward | ✅ Stable | `aggr="sum"` / `"mean"`; max falls back with warning |
1236
- | `TensorGraphSAGELayer` chunked forward | Planned v0.2.3 | Deferred |
1237
- | `TensorGINLayer` chunked forward | Planned v0.2.3 | Deferred |
1238
- | `TensorGATLayer` chunked forward | ⏳ Planned v0.2.3+ | Destination-wise softmax makes chunking complex |
1237
+ | `TensorGraphSAGELayer` chunked forward | Stable | v0.2.3; mean and max; pass `chunk_size=K` to `forward()` |
1238
+ | `TensorGINLayer` chunked forward | Stable | v0.2.3; sum aggregation; pass `chunk_size=K` to `forward()` |
1239
+ | `TensorGATLayer` chunked forward | ⏳ Planned v0.2.4 | Requires two-pass algorithm for destination-wise softmax |
1239
1240
  | `build_grid_graph` / `build_grid_graph_3d` | ✅ Stable | O(E) — scales well |
1240
1241
  | `build_random_graph` | ✅ Stable | O(E) — scales well |
1241
- | `build_knn_graph` / `build_radius_graph` | ⚠️ Best-effort | O(N²) via `torch.cdist`; N > 10 000 emits a warning |
1242
- | `build_fully_connected_graph` / `build_iou_graph` | ⚠️ Best-effort | O(N²) edges; N > 5 000 emits a warning |
1243
- | Dashboard metrics API | Stable | Incremental `?since_row=N`; `--max-metric-rows` cap |
1244
- | Large `metrics.csv` tail-read | Planned v0.2.3 | Current: mtime cache + full re-parse on miss |
1242
+ | `build_knn_graph` / `build_radius_graph` | ⚠️ Best-effort | O(N²) time; `chunk_size=K` reduces peak memory to O(K×N) |
1243
+ | `build_fully_connected_graph` | ⚠️ Best-effort | O(N²) edges; N > 5 000 emits warning |
1244
+ | `build_iou_graph` | ⚠️ Best-effort | O(N²) IoU; `chunk_size=K` reduces peak memory to O(K×N) |
1245
+ | `build_random_graph` | Stable | `algorithm="sample"` uses O(num_edges) memory for large N |
1246
+ | Dashboard metrics API | ✅ Stable | Incremental `?since_row=N`; `--max-metric-rows` cap; byte-seek tail-read (v0.2.3) |
1247
+ | Large `metrics.csv` tail-read | ✅ Stable | v0.2.3: byte-seek on append; full reparse on rotation/truncation |
1245
1248
 
1246
1249
  > ⚠️ **Scalability warning:** `build_knn_graph`, `build_radius_graph`, `build_fully_connected_graph`,
1247
1250
  > and `build_iou_graph` use pairwise `torch.cdist` or enumerate all pairs. Memory and time grow as
@@ -84,9 +84,10 @@ drop-in clones of PyTorch Geometric's vector-feature implementations.
84
84
  - **Heterogeneous and temporal graphs.**
85
85
  - **MLflowLogger.** Not implemented. Use the `mlflow` client directly.
86
86
  `pip install mlflow`.
87
- - **GAT / SAGE / GIN chunked forward.** Deferred (GAT's destination-wise
88
- softmax requires all edge scores; SAGE/GIN chunking deferred for scope).
89
- `ConvMessagePassing` supports `chunk_size` for `sum` / `mean` aggregation.
87
+ - **GAT chunked forward.** Deferred to v0.2.4. GAT's destination-wise
88
+ softmax requires all incoming edge scores simultaneously; a correct two-pass
89
+ implementation is needed. `ConvMessagePassing`, `TensorGraphSAGELayer`, and
90
+ `TensorGINLayer` all support `chunk_size` in their `forward()` call.
90
91
  - **Hardware-monitoring extras.** CPU/RAM/GPU metrics in the dashboard
91
92
  require optional packages: `pip install tgraphx[monitoring]`.
92
93
  - **torch.compile / AMP.** `torch.compile` is available in PyTorch ≥ 2.0
@@ -1188,15 +1189,17 @@ batch.to(device)
1188
1189
  | Feature | Status | Notes |
1189
1190
  |---------|:------:|-------|
1190
1191
  | `ConvMessagePassing` chunked forward | ✅ Stable | `aggr="sum"` / `"mean"`; max falls back with warning |
1191
- | `TensorGraphSAGELayer` chunked forward | Planned v0.2.3 | Deferred |
1192
- | `TensorGINLayer` chunked forward | Planned v0.2.3 | Deferred |
1193
- | `TensorGATLayer` chunked forward | ⏳ Planned v0.2.3+ | Destination-wise softmax makes chunking complex |
1192
+ | `TensorGraphSAGELayer` chunked forward | Stable | v0.2.3; mean and max; pass `chunk_size=K` to `forward()` |
1193
+ | `TensorGINLayer` chunked forward | Stable | v0.2.3; sum aggregation; pass `chunk_size=K` to `forward()` |
1194
+ | `TensorGATLayer` chunked forward | ⏳ Planned v0.2.4 | Requires two-pass algorithm for destination-wise softmax |
1194
1195
  | `build_grid_graph` / `build_grid_graph_3d` | ✅ Stable | O(E) — scales well |
1195
1196
  | `build_random_graph` | ✅ Stable | O(E) — scales well |
1196
- | `build_knn_graph` / `build_radius_graph` | ⚠️ Best-effort | O(N²) via `torch.cdist`; N > 10 000 emits a warning |
1197
- | `build_fully_connected_graph` / `build_iou_graph` | ⚠️ Best-effort | O(N²) edges; N > 5 000 emits a warning |
1198
- | Dashboard metrics API | Stable | Incremental `?since_row=N`; `--max-metric-rows` cap |
1199
- | Large `metrics.csv` tail-read | Planned v0.2.3 | Current: mtime cache + full re-parse on miss |
1197
+ | `build_knn_graph` / `build_radius_graph` | ⚠️ Best-effort | O(N²) time; `chunk_size=K` reduces peak memory to O(K×N) |
1198
+ | `build_fully_connected_graph` | ⚠️ Best-effort | O(N²) edges; N > 5 000 emits warning |
1199
+ | `build_iou_graph` | ⚠️ Best-effort | O(N²) IoU; `chunk_size=K` reduces peak memory to O(K×N) |
1200
+ | `build_random_graph` | Stable | `algorithm="sample"` uses O(num_edges) memory for large N |
1201
+ | Dashboard metrics API | ✅ Stable | Incremental `?since_row=N`; `--max-metric-rows` cap; byte-seek tail-read (v0.2.3) |
1202
+ | Large `metrics.csv` tail-read | ✅ Stable | v0.2.3: byte-seek on append; full reparse on rotation/truncation |
1200
1203
 
1201
1204
  > ⚠️ **Scalability warning:** `build_knn_graph`, `build_radius_graph`, `build_fully_connected_graph`,
1202
1205
  > and `build_iou_graph` use pairwise `torch.cdist` or enumerate all pairs. Memory and time grow as
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
5
5
  [project]
6
6
  name = "tgraphx"
7
7
  # Keep this in sync with tgraphx/__init__.py::__version__
8
- version = "0.2.2"
8
+ version = "0.2.3"
9
9
  description = "Tensor-aware graph neural networks preserving spatial node feature layouts"
10
10
  readme = "README.md"
11
11
  requires-python = ">=3.9"
@@ -0,0 +1,500 @@
1
+ """Chunked forward tests for TensorGraphSAGELayer and TensorGINLayer (v0.2.3).
2
+
3
+ Verifies that chunked forward produces output identical to unchunked within
4
+ floating-point tolerance for all supported configurations:
5
+ - 2-D spatial (spatial_rank=2)
6
+ - 3-D volumetric (spatial_rank=3)
7
+ - aggr="mean" and aggr="max" (SAGE)
8
+ - edge_weight
9
+ - vector edge_features
10
+ - spatial/volumetric edge_features
11
+ - isolated nodes (no incoming edges)
12
+ - chunk_size=None unchanged (identity)
13
+ - empty edge_index
14
+ - gradients finite
15
+ - AMP/bfloat16 smoke (skipped when unsupported)
16
+
17
+ Also tests graph builder chunked paths (kNN, radius, IoU) and random graph
18
+ algorithm="sample".
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import warnings
24
+
25
+ import pytest
26
+ import torch
27
+
28
+ from tgraphx.graph_builders import (
29
+ build_grid_graph,
30
+ build_grid_graph_3d,
31
+ build_iou_graph,
32
+ build_knn_graph,
33
+ build_radius_graph,
34
+ build_random_graph,
35
+ )
36
+ from tgraphx.layers import TensorGINLayer, TensorGraphSAGELayer
37
+
38
+
39
+ # ── Helpers ───────────────────────────────────────────────────────────────────
40
+
41
+ ATOL = 1e-5 # float32 tolerance for chunk vs unchunked comparison
42
+
43
+
44
+ def _small_2d(N=9, C=4, H=4, W=4, seed=0):
45
+ torch.manual_seed(seed)
46
+ x = torch.randn(N, C, H, W)
47
+ ei = build_grid_graph(3, 3, directed=False, self_loops=True)
48
+ return x, ei, N, C, H, W
49
+
50
+
51
+ def _small_3d(N=8, C=4, D=4, H=4, W=4, seed=0):
52
+ torch.manual_seed(seed)
53
+ x = torch.randn(N, C, D, H, W)
54
+ ei = build_grid_graph_3d(2, 2, 2, directed=False, self_loops=True)
55
+ return x, ei, N, C, D, H, W
56
+
57
+
58
+ def _cpu_bf16_ok():
59
+ try:
60
+ with torch.autocast("cpu", dtype=torch.bfloat16):
61
+ t = torch.tensor([1.0])
62
+ _ = t + t
63
+ return True
64
+ except Exception:
65
+ return False
66
+
67
+
68
+ skip_bf16 = pytest.mark.skipif(
69
+ not _cpu_bf16_ok(), reason="CPU bfloat16 autocast not available"
70
+ )
71
+
72
+
73
+ # ── SAGE chunking — mean aggregation ─────────────────────────────────────────
74
+
75
+ class TestSAGEChunkedMean:
76
+ def _layer(self, **kw):
77
+ return TensorGraphSAGELayer(4, 4, aggr="mean", **kw).eval()
78
+
79
+ def test_2d_parity(self):
80
+ x, ei, N, C, H, W = _small_2d()
81
+ layer = self._layer()
82
+ with torch.no_grad():
83
+ out_full = layer(x, ei)
84
+ out_chunk = layer(x, ei, chunk_size=5)
85
+ assert out_chunk.shape == out_full.shape
86
+ assert torch.allclose(out_full, out_chunk, atol=ATOL), \
87
+ f"max diff = {(out_full - out_chunk).abs().max():.2e}"
88
+
89
+ def test_3d_parity(self):
90
+ x, ei, N, C, D, H, W = _small_3d()
91
+ layer = TensorGraphSAGELayer(C, C, aggr="mean", spatial_rank=3).eval()
92
+ with torch.no_grad():
93
+ out_full = layer(x, ei)
94
+ out_chunk = layer(x, ei, chunk_size=5)
95
+ assert torch.allclose(out_full, out_chunk, atol=ATOL)
96
+
97
+ def test_chunk_size_1(self):
98
+ x, ei, N, C, H, W = _small_2d()
99
+ layer = self._layer()
100
+ with torch.no_grad():
101
+ out_full = layer(x, ei)
102
+ out_chunk = layer(x, ei, chunk_size=1)
103
+ assert torch.allclose(out_full, out_chunk, atol=ATOL)
104
+
105
+ def test_chunk_size_exceeds_edges(self):
106
+ x, ei, N, C, H, W = _small_2d()
107
+ layer = self._layer()
108
+ with torch.no_grad():
109
+ out_full = layer(x, ei)
110
+ out_chunk = layer(x, ei, chunk_size=10_000)
111
+ assert torch.allclose(out_full, out_chunk, atol=ATOL)
112
+
113
+ def test_chunk_size_none_unchanged(self):
114
+ x, ei, N, C, H, W = _small_2d()
115
+ layer = self._layer()
116
+ with torch.no_grad():
117
+ out1 = layer(x, ei, chunk_size=None)
118
+ out2 = layer(x, ei)
119
+ assert torch.equal(out1, out2)
120
+
121
+ def test_edge_weight_parity(self):
122
+ x, ei, N, C, H, W = _small_2d()
123
+ layer = self._layer()
124
+ ew = torch.rand(ei.size(1))
125
+ with torch.no_grad():
126
+ out_full = layer(x, ei, edge_weight=ew)
127
+ out_chunk = layer(x, ei, edge_weight=ew, chunk_size=5)
128
+ assert torch.allclose(out_full, out_chunk, atol=ATOL)
129
+
130
+ def test_vector_edge_features_parity(self):
131
+ x, ei, N, C, H, W = _small_2d()
132
+ layer = TensorGraphSAGELayer(
133
+ C, C, aggr="mean", use_edge_features=True,
134
+ edge_dim=3, edge_features_kind="vector",
135
+ ).eval()
136
+ ef = torch.randn(ei.size(1), 3)
137
+ with torch.no_grad():
138
+ out_full = layer(x, ei, edge_features=ef)
139
+ out_chunk = layer(x, ei, edge_features=ef, chunk_size=5)
140
+ assert torch.allclose(out_full, out_chunk, atol=ATOL)
141
+
142
+ def test_spatial_edge_features_parity(self):
143
+ x, ei, N, C, H, W = _small_2d()
144
+ layer = TensorGraphSAGELayer(
145
+ C, C, aggr="mean", use_edge_features=True,
146
+ edge_dim=3, edge_features_kind="spatial",
147
+ ).eval()
148
+ ef = torch.randn(ei.size(1), 3, H, W)
149
+ with torch.no_grad():
150
+ out_full = layer(x, ei, edge_features=ef)
151
+ out_chunk = layer(x, ei, edge_features=ef, chunk_size=5)
152
+ assert torch.allclose(out_full, out_chunk, atol=ATOL)
153
+
154
+ def test_isolated_nodes(self):
155
+ """Nodes with no incoming edges must produce zeros in the agg term."""
156
+ torch.manual_seed(1)
157
+ N, C, H, W = 5, 4, 4, 4
158
+ x = torch.randn(N, C, H, W)
159
+ # Node 4 has no incoming edges.
160
+ ei = torch.tensor([[0, 1, 2], [1, 2, 3]], dtype=torch.long)
161
+ layer = self._layer()
162
+ with torch.no_grad():
163
+ out_full = layer(x, ei)
164
+ out_chunk = layer(x, ei, chunk_size=2)
165
+ assert torch.allclose(out_full, out_chunk, atol=ATOL)
166
+
167
+ def test_gradient_finite(self):
168
+ x, ei, N, C, H, W = _small_2d()
169
+ x = x.requires_grad_(True)
170
+ layer = TensorGraphSAGELayer(C, C, aggr="mean").train()
171
+ out = layer(x, ei, chunk_size=5)
172
+ out.sum().backward()
173
+ assert x.grad is not None
174
+ assert torch.isfinite(x.grad).all()
175
+
176
+ @skip_bf16
177
+ def test_amp_bfloat16_smoke(self):
178
+ x, ei, N, C, H, W = _small_2d()
179
+ layer = self._layer()
180
+ with torch.no_grad(), torch.autocast("cpu", dtype=torch.bfloat16):
181
+ out = layer(x, ei, chunk_size=5)
182
+ assert torch.isfinite(out).all()
183
+
184
+
185
+ # ── SAGE chunking — max aggregation ──────────────────────────────────────────
186
+
187
+ class TestSAGEChunkedMax:
188
+ def _layer(self, **kw):
189
+ return TensorGraphSAGELayer(4, 4, aggr="max", **kw).eval()
190
+
191
+ def test_2d_parity(self):
192
+ x, ei, N, C, H, W = _small_2d()
193
+ layer = self._layer()
194
+ with torch.no_grad():
195
+ out_full = layer(x, ei)
196
+ out_chunk = layer(x, ei, chunk_size=5)
197
+ assert torch.allclose(out_full, out_chunk, atol=ATOL)
198
+
199
+ def test_3d_parity(self):
200
+ x, ei, N, C, D, H, W = _small_3d()
201
+ layer = TensorGraphSAGELayer(C, C, aggr="max", spatial_rank=3).eval()
202
+ with torch.no_grad():
203
+ out_full = layer(x, ei)
204
+ out_chunk = layer(x, ei, chunk_size=5)
205
+ assert torch.allclose(out_full, out_chunk, atol=ATOL)
206
+
207
+ def test_edge_weight_parity(self):
208
+ x, ei, N, C, H, W = _small_2d()
209
+ layer = self._layer()
210
+ ew = torch.rand(ei.size(1))
211
+ with torch.no_grad():
212
+ out_full = layer(x, ei, edge_weight=ew)
213
+ out_chunk = layer(x, ei, edge_weight=ew, chunk_size=5)
214
+ assert torch.allclose(out_full, out_chunk, atol=ATOL)
215
+
216
+ def test_isolated_nodes_max(self):
217
+ """Isolated nodes must be zero in max-aggregated output (not -inf)."""
218
+ torch.manual_seed(2)
219
+ N, C, H, W = 4, 4, 4, 4
220
+ x = torch.randn(N, C, H, W)
221
+ ei = torch.tensor([[0, 1], [1, 2]], dtype=torch.long)
222
+ layer = self._layer()
223
+ with torch.no_grad():
224
+ out_full = layer(x, ei)
225
+ out_chunk = layer(x, ei, chunk_size=1)
226
+ assert torch.allclose(out_full, out_chunk, atol=ATOL)
227
+ # Node 3 has no incoming edges → agg must be 0 (not -inf)
228
+ assert torch.isfinite(out_full[3]).all()
229
+ assert torch.isfinite(out_chunk[3]).all()
230
+
231
+ def test_gradient_finite(self):
232
+ x, ei, N, C, H, W = _small_2d()
233
+ x = x.requires_grad_(True)
234
+ layer = TensorGraphSAGELayer(C, C, aggr="max").train()
235
+ out = layer(x, ei, chunk_size=5)
236
+ out.sum().backward()
237
+ assert x.grad is not None
238
+ assert torch.isfinite(x.grad).all()
239
+
240
+
241
+ # ── GIN chunking ──────────────────────────────────────────────────────────────
242
+
243
+ class TestGINChunked:
244
+ def _layer(self, **kw):
245
+ return TensorGINLayer(4, 4, **kw).eval()
246
+
247
+ def test_2d_parity(self):
248
+ x, ei, N, C, H, W = _small_2d()
249
+ layer = self._layer()
250
+ with torch.no_grad():
251
+ out_full = layer(x, ei)
252
+ out_chunk = layer(x, ei, chunk_size=5)
253
+ assert torch.allclose(out_full, out_chunk, atol=ATOL)
254
+
255
+ def test_3d_parity(self):
256
+ x, ei, N, C, D, H, W = _small_3d()
257
+ layer = TensorGINLayer(C, C, spatial_rank=3).eval()
258
+ with torch.no_grad():
259
+ out_full = layer(x, ei)
260
+ out_chunk = layer(x, ei, chunk_size=5)
261
+ assert torch.allclose(out_full, out_chunk, atol=ATOL)
262
+
263
+ def test_chunk_size_1(self):
264
+ x, ei, N, C, H, W = _small_2d()
265
+ layer = self._layer()
266
+ with torch.no_grad():
267
+ out_full = layer(x, ei)
268
+ out_chunk = layer(x, ei, chunk_size=1)
269
+ assert torch.allclose(out_full, out_chunk, atol=ATOL)
270
+
271
+ def test_chunk_size_none_unchanged(self):
272
+ x, ei, N, C, H, W = _small_2d()
273
+ layer = self._layer()
274
+ with torch.no_grad():
275
+ out1 = layer(x, ei)
276
+ out2 = layer(x, ei, chunk_size=None)
277
+ assert torch.equal(out1, out2)
278
+
279
+ def test_edge_weight_parity(self):
280
+ x, ei, N, C, H, W = _small_2d()
281
+ layer = self._layer()
282
+ ew = torch.rand(ei.size(1))
283
+ with torch.no_grad():
284
+ out_full = layer(x, ei, edge_weight=ew)
285
+ out_chunk = layer(x, ei, edge_weight=ew, chunk_size=5)
286
+ assert torch.allclose(out_full, out_chunk, atol=ATOL)
287
+
288
+ def test_vector_edge_features_parity(self):
289
+ x, ei, N, C, H, W = _small_2d()
290
+ layer = TensorGINLayer(
291
+ C, C, use_edge_features=True, edge_dim=3,
292
+ edge_features_kind="vector",
293
+ ).eval()
294
+ ef = torch.randn(ei.size(1), 3)
295
+ with torch.no_grad():
296
+ out_full = layer(x, ei, edge_features=ef)
297
+ out_chunk = layer(x, ei, edge_features=ef, chunk_size=5)
298
+ assert torch.allclose(out_full, out_chunk, atol=ATOL)
299
+
300
+ def test_spatial_edge_features_parity(self):
301
+ x, ei, N, C, H, W = _small_2d()
302
+ layer = TensorGINLayer(
303
+ C, C, use_edge_features=True, edge_dim=3,
304
+ edge_features_kind="spatial",
305
+ ).eval()
306
+ ef = torch.randn(ei.size(1), 3, H, W)
307
+ with torch.no_grad():
308
+ out_full = layer(x, ei, edge_features=ef)
309
+ out_chunk = layer(x, ei, edge_features=ef, chunk_size=5)
310
+ assert torch.allclose(out_full, out_chunk, atol=ATOL)
311
+
312
+ def test_train_eps_parity(self):
313
+ x, ei, N, C, H, W = _small_2d()
314
+ layer = TensorGINLayer(C, C, train_eps=True, eps=0.5).eval()
315
+ with torch.no_grad():
316
+ out_full = layer(x, ei)
317
+ out_chunk = layer(x, ei, chunk_size=5)
318
+ assert torch.allclose(out_full, out_chunk, atol=ATOL)
319
+
320
+ def test_custom_mlp(self):
321
+ import torch.nn as nn
322
+ x, ei, N, C, H, W = _small_2d()
323
+ mlp = nn.Sequential(
324
+ nn.Conv2d(C, C, 1), nn.ReLU(), nn.Conv2d(C, C, 1)
325
+ )
326
+ layer = TensorGINLayer(C, C, mlp=mlp).eval()
327
+ with torch.no_grad():
328
+ out_full = layer(x, ei)
329
+ out_chunk = layer(x, ei, chunk_size=5)
330
+ assert torch.allclose(out_full, out_chunk, atol=ATOL)
331
+
332
+ def test_isolated_nodes(self):
333
+ torch.manual_seed(3)
334
+ N, C, H, W = 5, 4, 4, 4
335
+ x = torch.randn(N, C, H, W)
336
+ ei = torch.tensor([[0, 1], [1, 2]], dtype=torch.long)
337
+ layer = self._layer()
338
+ with torch.no_grad():
339
+ out_full = layer(x, ei)
340
+ out_chunk = layer(x, ei, chunk_size=1)
341
+ assert torch.allclose(out_full, out_chunk, atol=ATOL)
342
+
343
+ def test_gradient_finite(self):
344
+ x, ei, N, C, H, W = _small_2d()
345
+ x = x.requires_grad_(True)
346
+ layer = TensorGINLayer(C, C).train()
347
+ out = layer(x, ei, chunk_size=5)
348
+ out.sum().backward()
349
+ assert x.grad is not None
350
+ assert torch.isfinite(x.grad).all()
351
+
352
+ @skip_bf16
353
+ def test_amp_bfloat16_smoke(self):
354
+ x, ei, N, C, H, W = _small_2d()
355
+ layer = self._layer()
356
+ with torch.no_grad(), torch.autocast("cpu", dtype=torch.bfloat16):
357
+ out = layer(x, ei, chunk_size=5)
358
+ assert torch.isfinite(out).all()
359
+
360
+
361
+ # ── Graph builder chunked paths ───────────────────────────────────────────────
362
+
363
+ class TestKNNChunked:
364
+ def _sort(self, ei):
365
+ keys = ei[0] * 10000 + ei[1]
366
+ return ei[:, keys.sort().indices]
367
+
368
+ def test_parity_small(self):
369
+ torch.manual_seed(0)
370
+ coords = torch.randn(20, 2)
371
+ ei_full = build_knn_graph(coords, k=3)
372
+ ei_chunk = build_knn_graph(coords, k=3, chunk_size=5)
373
+ assert self._sort(ei_full).shape == self._sort(ei_chunk).shape
374
+ assert torch.all(self._sort(ei_full) == self._sort(ei_chunk))
375
+
376
+ def test_parity_directed(self):
377
+ torch.manual_seed(1)
378
+ coords = torch.randn(12, 3)
379
+ ei_full = build_knn_graph(coords, k=2, directed=True)
380
+ ei_chunk = build_knn_graph(coords, k=2, directed=True, chunk_size=4)
381
+ assert self._sort(ei_full).shape == self._sort(ei_chunk).shape
382
+ assert torch.all(self._sort(ei_full) == self._sort(ei_chunk))
383
+
384
+ def test_no_warning_with_chunk_size(self):
385
+ coords = torch.randn(10_001, 2)
386
+ with warnings.catch_warnings(record=True) as w:
387
+ warnings.simplefilter("always")
388
+ ei = build_knn_graph(coords, k=2, chunk_size=100)
389
+ # chunk_size suppresses the O(N²) memory warning
390
+ mem_warns = [x for x in w if "O(N²)" in str(x.message) or "O.N.2" in str(x.message)]
391
+ assert len(mem_warns) == 0, "unexpected O(N²) warning with chunk_size set"
392
+
393
+ def test_warning_without_chunk_size(self):
394
+ coords = torch.randn(10_001, 2)
395
+ with pytest.warns(UserWarning, match="num_nodes"):
396
+ build_knn_graph(coords, k=2)
397
+
398
+
399
+ class TestRadiusChunked:
400
+ def _sort(self, ei):
401
+ keys = ei[0] * 10000 + ei[1]
402
+ return ei[:, keys.sort().indices]
403
+
404
+ def test_parity_undirected(self):
405
+ torch.manual_seed(2)
406
+ coords = torch.randn(20, 2)
407
+ ei_full = build_radius_graph(coords, radius=0.8)
408
+ ei_chunk = build_radius_graph(coords, radius=0.8, chunk_size=5)
409
+ assert self._sort(ei_full).shape == self._sort(ei_chunk).shape
410
+ assert torch.all(self._sort(ei_full) == self._sort(ei_chunk))
411
+
412
+ def test_parity_directed(self):
413
+ torch.manual_seed(3)
414
+ coords = torch.randn(16, 2)
415
+ ei_full = build_radius_graph(coords, radius=0.5, directed=True)
416
+ ei_chunk = build_radius_graph(coords, radius=0.5, directed=True, chunk_size=4)
417
+ assert self._sort(ei_full).shape == self._sort(ei_chunk).shape
418
+ assert torch.all(self._sort(ei_full) == self._sort(ei_chunk))
419
+
420
+ def test_no_warning_with_chunk_size(self):
421
+ coords = torch.randn(10_001, 2)
422
+ with warnings.catch_warnings(record=True) as w:
423
+ warnings.simplefilter("always")
424
+ build_radius_graph(coords, radius=0.01, chunk_size=100)
425
+ mem_warns = [x for x in w if "num_nodes" in str(x.message)]
426
+ assert len(mem_warns) == 0
427
+
428
+
429
+ class TestIoUChunked:
430
+ def _sort(self, ei):
431
+ keys = ei[0] * 10000 + ei[1]
432
+ return ei[:, keys.sort().indices]
433
+
434
+ def _boxes(self, n=10, seed=4):
435
+ torch.manual_seed(seed)
436
+ x1 = torch.rand(n)
437
+ y1 = torch.rand(n)
438
+ x2 = x1 + torch.rand(n) * 0.5 + 0.1
439
+ y2 = y1 + torch.rand(n) * 0.5 + 0.1
440
+ return torch.stack([x1, y1, x2, y2], dim=1)
441
+
442
+ def test_parity(self):
443
+ boxes = self._boxes(16)
444
+ ei_full = build_iou_graph(boxes, threshold=0.1)
445
+ ei_chunk = build_iou_graph(boxes, threshold=0.1, chunk_size=4)
446
+ assert self._sort(ei_full).shape == self._sort(ei_chunk).shape
447
+ assert torch.all(self._sort(ei_full) == self._sort(ei_chunk))
448
+
449
+ def test_parity_directed(self):
450
+ boxes = self._boxes(12, seed=5)
451
+ ei_full = build_iou_graph(boxes, threshold=0.05, directed=True)
452
+ ei_chunk = build_iou_graph(boxes, threshold=0.05, directed=True, chunk_size=3)
453
+ assert self._sort(ei_full).shape == self._sort(ei_chunk).shape
454
+ assert torch.all(self._sort(ei_full) == self._sort(ei_chunk))
455
+
456
+ def test_no_warning_with_chunk_size(self):
457
+ boxes = self._boxes(5_001)
458
+ with warnings.catch_warnings(record=True) as w:
459
+ warnings.simplefilter("always")
460
+ build_iou_graph(boxes, threshold=0.5, chunk_size=100)
461
+ mem_warns = [x for x in w if "num_nodes" in str(x.message)]
462
+ assert len(mem_warns) == 0
463
+
464
+
465
+ class TestRandomGraphSample:
466
+ def test_correct_edge_count(self):
467
+ ei = build_random_graph(100, 200, algorithm="sample", seed=42)
468
+ assert ei.shape == (2, 200)
469
+
470
+ def test_no_self_loops(self):
471
+ ei = build_random_graph(100, 200, algorithm="sample", seed=42)
472
+ assert (ei[0] != ei[1]).all()
473
+
474
+ def test_no_duplicates(self):
475
+ ei = build_random_graph(100, 200, algorithm="sample", seed=42)
476
+ keys = ei[0] * 100 + ei[1]
477
+ assert keys.unique().numel() == 200
478
+
479
+ def test_deterministic(self):
480
+ ei1 = build_random_graph(100, 200, algorithm="sample", seed=99)
481
+ ei2 = build_random_graph(100, 200, algorithm="sample", seed=99)
482
+ assert torch.equal(ei1, ei2)
483
+
484
+ def test_zero_edges(self):
485
+ ei = build_random_graph(10, 0, algorithm="sample", seed=0)
486
+ assert ei.shape == (2, 0)
487
+
488
+ def test_too_many_edges_raises(self):
489
+ with pytest.raises(ValueError, match="Cannot sample"):
490
+ build_random_graph(5, 100, algorithm="sample", seed=0)
491
+
492
+ def test_unsupported_undirected_raises(self):
493
+ with pytest.raises(ValueError, match="directed=True"):
494
+ build_random_graph(10, 5, directed=False, algorithm="sample")
495
+
496
+ def test_exact_is_default_and_unchanged(self):
497
+ """Default algorithm='exact' must match previous behavior."""
498
+ ei_default = build_random_graph(10, 20, seed=0)
499
+ ei_exact = build_random_graph(10, 20, seed=0, algorithm="exact")
500
+ assert torch.equal(ei_default, ei_exact)
@@ -10,7 +10,7 @@ Common one-liner imports::
10
10
  """
11
11
 
12
12
  # Keep this in sync with [project].version in pyproject.toml.
13
- __version__ = "0.2.2"
13
+ __version__ = "0.2.3"
14
14
 
15
15
  # ── Core data structures ──────────────────────────────────────────────────────
16
16
  from .core.graph import Graph, GraphBatch