tgraphx 0.2.0__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. {tgraphx-0.2.0 → tgraphx-0.2.2}/PKG-INFO +131 -16
  2. {tgraphx-0.2.0 → tgraphx-0.2.2}/README.md +130 -15
  3. {tgraphx-0.2.0 → tgraphx-0.2.2}/pyproject.toml +1 -1
  4. tgraphx-0.2.2/tests/test_amp_compile.py +710 -0
  5. tgraphx-0.2.2/tests/test_documentation_claims.py +411 -0
  6. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/__init__.py +1 -1
  7. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/graph_builders.py +33 -0
  8. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/layers/_scatter.py +39 -3
  9. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/layers/gat.py +10 -1
  10. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx.egg-info/PKG-INFO +131 -16
  11. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx.egg-info/SOURCES.txt +2 -0
  12. {tgraphx-0.2.0 → tgraphx-0.2.2}/LICENSE +0 -0
  13. {tgraphx-0.2.0 → tgraphx-0.2.2}/setup.cfg +0 -0
  14. {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_3d_support.py +0 -0
  15. {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_dashboard.py +0 -0
  16. {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_devices.py +0 -0
  17. {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_edge_features.py +0 -0
  18. {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_edge_weight.py +0 -0
  19. {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_factories.py +0 -0
  20. {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_gnn_families.py +0 -0
  21. {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_gradients.py +0 -0
  22. {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_graph.py +0 -0
  23. {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_graph_api.py +0 -0
  24. {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_graph_builders.py +0 -0
  25. {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_imports.py +0 -0
  26. {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_layers.py +0 -0
  27. {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_math.py +0 -0
  28. {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_models.py +0 -0
  29. {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_packaging.py +0 -0
  30. {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_performance_smoke.py +0 -0
  31. {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_tracking.py +0 -0
  32. {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_training.py +0 -0
  33. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/core/__init__.py +0 -0
  34. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/core/dataloader.py +0 -0
  35. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/core/graph.py +0 -0
  36. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/core/graph_utils.py +0 -0
  37. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/core/utils.py +0 -0
  38. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/dashboard/__init__.py +0 -0
  39. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/dashboard/__main__.py +0 -0
  40. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/dashboard/app.py +0 -0
  41. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/dashboard/static/dashboard.css +0 -0
  42. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/dashboard/static/dashboard.js +0 -0
  43. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/layers/__init__.py +0 -0
  44. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/layers/_dim.py +0 -0
  45. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/layers/aggregator.py +0 -0
  46. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/layers/attention_message.py +0 -0
  47. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/layers/base.py +0 -0
  48. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/layers/conv_message.py +0 -0
  49. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/layers/factory.py +0 -0
  50. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/layers/gin.py +0 -0
  51. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/layers/safe_pool.py +0 -0
  52. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/layers/sage.py +0 -0
  53. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/models/__init__.py +0 -0
  54. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/models/cnn_encoder.py +0 -0
  55. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/models/cnn_gnn_model.py +0 -0
  56. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/models/edge_predictor.py +0 -0
  57. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/models/factory.py +0 -0
  58. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/models/graph_classifier.py +0 -0
  59. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/models/node_classifier.py +0 -0
  60. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/models/pre_encoder.py +0 -0
  61. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/models/regressors.py +0 -0
  62. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/performance.py +0 -0
  63. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/tracking.py +0 -0
  64. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/training.py +0 -0
  65. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx.egg-info/dependency_links.txt +0 -0
  66. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx.egg-info/entry_points.txt +0 -0
  67. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx.egg-info/requires.txt +0 -0
  68. {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tgraphx
3
- Version: 0.2.0
3
+ Version: 0.2.2
4
4
  Summary: Tensor-aware graph neural networks preserving spatial node feature layouts
5
5
  Author-email: Arash Sajjadi <arash.sajjadi@usask.ca>
6
6
  Maintainer-email: Arash Sajjadi <arash.sajjadi@usask.ca>
@@ -186,10 +186,23 @@ python benchmarks/benchmark_graph_builders.py # full
186
186
 
187
187
  ```bash
188
188
  python examples/torch_compile_benchmark.py # eager vs compiled, correctness check
189
- python examples/mixed_precision_inference.py # autocast forward demo
189
+ python examples/mixed_precision_inference.py # autocast forward demo (finite-output check)
190
190
  python examples/memory_report.py # env report + memory estimates
191
191
  ```
192
192
 
193
+ **AMP policy (v0.2.2):**
194
+
195
+ | Backend | Recommended dtype | Status | Notes |
196
+ |---------|:-----------------:|:------:|-------|
197
+ | CPU | bfloat16 | ⚠️ Best-effort | Tested in CI |
198
+ | CUDA | float16 / bfloat16 | ⚠️ Best-effort | bfloat16 needs Ampere+ |
199
+ | MPS | — | ❌ Not tested | PyTorch operator coverage varies |
200
+
201
+ v0.2.2 fixes: `broadcast_edge_weight` casts edge weights to activation dtype;
202
+ `TensorGATLayer` casts attention weights before `index_add_`; `edge_softmax`
203
+ upcasts to fp32 for numerical stability and casts back. See
204
+ [docs/performance.md](docs/performance.md#amp-policy) for full details.
205
+
193
206
  ### Optional chunked forward (ConvMessagePassing)
194
207
 
195
208
  Reduce peak edge-buffer memory by processing edges in chunks:
@@ -208,14 +221,14 @@ standard path with a warning. GAT, SAGE, and GIN chunking are deferred
208
221
 
209
222
  ### Hardware compatibility
210
223
 
211
- | Platform | Forward | AMP | torch.compile | Notes |
212
- |----------|:-------:|:---:|:-------------:|-------|
213
- | CPU | ✅ | ⚠️ bfloat16 only | ✅ | Compile overhead may dominate small graphs |
214
- | CUDA | ✅ | ⚠️ float16 (op-dependent) | ✅ | index_add_ ops require dtype match |
215
- | MPS (Apple Silicon) | ✅ | limited | ⚠️ | Some ops may not be compiled |
216
- | Linux | ✅ | ✅ | ✅ | Fully supported |
217
- | Windows | ✅ | ✅ | ✅ | Fully supported |
218
- | macOS | ✅ | limited | ⚠️ | MPS support best-effort |
224
+ | Platform | Forward | AMP | torch.compile | CI coverage | Notes |
225
+ |----------|:-------:|:---:|:-------------:|:-----------:|-------|
226
+ | CPU | ✅ | ⚠️ bfloat16 only | ✅ | Full CI | Compile overhead may dominate small graphs |
227
+ | CUDA | ✅ | ⚠️ float16 (op-dependent) | ✅ | Full CI | `index_add_` ops require dtype match |
228
+ | MPS (Apple Silicon) | ✅ | ⚠️ limited | ⚠️ | No CI | Best-effort; some ops may not compile |
229
+ | Linux | ✅ | ✅ | ✅ | Full CI (ubuntu-latest) | Primary CI platform |
230
+ | Windows | ✅ | ✅ | ✅ | No CI | Best-effort; no automated tests |
231
+ | macOS | ✅ | ⚠️ limited | ⚠️ | No CI | MPS support best-effort |
219
232
 
220
233
  ---
221
234
 
@@ -593,11 +606,17 @@ The graph structure — which nodes are connected — is **not learned by the mo
593
606
 
594
607
  | Layer | Vector `[E, D_e]` | Spatial `[E, C_e, H, W]` |
595
608
  |-------|:----:|:----:|
596
- | `ConvMessagePassing` | ✗ | ✓ (concatenated along channels) |
597
- | `TensorGATLayer` | ✓ (additive attention bias on logits) | |
609
+ | `ConvMessagePassing` | ✗ | ✓ (concatenated along channels; channel count must equal node channel count) |
610
+ | `TensorGATLayer` | ✓ (additive attention bias on logits) | ⚠️ accepted; mean-pooled to scalar attention bias (no per-pixel attention) |
598
611
  | `TensorGraphSAGELayer` | ✓ (additive channel bias post-`W_neigh`) | ✓ (concatenated to source) |
599
612
  | `TensorGINLayer` | ✓ (broadcast bias before ReLU) | ✓ (1×1 Conv2d projection) |
600
613
 
614
+ > **TensorGATLayer spatial edge features:** `[E, C_e, H, W]` (or `[E, C_e, D, H, W]` for 3-D nodes) are
615
+ > accepted and mean-pooled over spatial dims before the attention bias projection. Spatial dims do
616
+ > **not** need to match the node spatial dims. Mismatched rank (e.g. 5-D edges into a 2-D-configured
617
+ > GAT) raises `NotImplementedError`. Use `TensorGraphSAGELayer` or `TensorGINLayer` for full
618
+ > spatial edge-feature processing (no pooling).
619
+
601
620
  ---
602
621
 
603
622
  ## Factory API
@@ -825,7 +844,7 @@ from tgraphx.layers import ConvMessagePassing
825
844
  layer = ConvMessagePassing(
826
845
  in_shape=(C, H, W), # tuple: per-node input shape (spatial only)
827
846
  out_shape=(C_out, H, W), # H and W must stay equal to in_shape's H, W
828
- aggr="sum", # "sum" (default) | "mean"
847
+ aggr="sum", # "sum" (default) | "mean" | "max"
829
848
  use_edge_features=False, # set True to concatenate edge tensors into messages
830
849
  aggregator_params=None, # dict forwarded to DeepCNNAggregator; e.g.
831
850
  # {"num_layers": 2, "dropout_prob": 0.1}
@@ -835,7 +854,9 @@ out = layer(node_features, edge_index) # [N, C_out, H, W]
835
854
  out = layer(node_features, edge_index, edge_features) # with edge features
836
855
  ```
837
856
 
838
- > `aggr="max"` raises `NotImplementedError`. Use `GraphClassifier(pooling="max")` for graph-level max readout.
857
+ > **`aggr="max"`** is supported via `scatter_reduce_(reduce='amax')`. When `chunk_size` is also
858
+ > set, `aggr="max"` falls back to the unchunked path with a `warnings.warn`.
859
+ > Use `GraphClassifier(pooling="max")` for graph-level max readout.
839
860
 
840
861
  ### `AttentionMessagePassing`
841
862
 
@@ -894,8 +915,14 @@ Attention is **scalar per `(edge, head)`** in this implementation: the
894
915
  projected query and key feature maps are mean-pooled over `H × W` before
895
916
  being scored, while the value tensors keep their full spatial layout
896
917
  during aggregation. Per-pixel and per-channel attention modes are not yet
897
- supported. **Spatial** edge feature tensors are not supported by this
898
- layer — use `TensorGraphSAGELayer` or `TensorGINLayer` for those.
918
+ supported.
919
+
920
+ **Spatial edge features** (`[E, C_e, H, W]` for `spatial_rank=2`;
921
+ `[E, C_e, D, H, W]` for `spatial_rank=3`) are accepted: spatial dims are
922
+ mean-pooled to a channel vector before the per-`(edge, head)` attention bias
923
+ projection (spatial dims need not match node spatial dims). Use
924
+ `TensorGraphSAGELayer` or `TensorGINLayer` for full spatial edge-feature
925
+ processing without pooling.
899
926
 
900
927
  ### `TensorGraphSAGELayer`
901
928
 
@@ -1146,6 +1173,94 @@ batch.to(device)
1146
1173
 
1147
1174
  ---
1148
1175
 
1176
+ ## Support status
1177
+
1178
+ ### Legend
1179
+
1180
+ | Label | Meaning |
1181
+ |-------|---------|
1182
+ | ✅ Stable | Tested in CI; API is stable |
1183
+ | 🧪 Experimental | Available but not yet guaranteed-stable |
1184
+ | ⚠️ Best-effort | Works in practice; known constraints documented |
1185
+ | ⏳ Planned | On roadmap; not yet implemented |
1186
+ | ❌ Not supported | Out of scope for the current release |
1187
+ | 🔒 Opt-in | Disabled by default; explicitly enabled by the user |
1188
+
1189
+ ### Backend support
1190
+
1191
+ | Backend | Forward | AMP | torch.compile | CI coverage | Status | Notes |
1192
+ |---------|:-------:|:---:|:-------------:|:-----------:|:------:|-------|
1193
+ | CPU | ✅ | ⚠️ bfloat16 | ✅ | Full CI | ✅ Stable | Compile overhead for small graphs |
1194
+ | CUDA | ✅ | ⚠️ op-dependent | ✅ | Full CI | ✅ Stable | `index_add_` requires dtype match under float16 |
1195
+ | MPS (Apple Silicon) | ✅ | ⚠️ limited | ⚠️ partial | No CI | ⚠️ Best-effort | PyTorch operator coverage varies |
1196
+ | Linux | ✅ | ✅ | ✅ | Full CI (ubuntu-latest) | ✅ Stable | Primary CI platform |
1197
+ | Windows | ✅ | ✅ | ✅ | No CI | ⚠️ Best-effort | Not in CI; known to install correctly |
1198
+ | macOS | ✅ | ⚠️ limited | ⚠️ | No CI | ⚠️ Best-effort | MPS path; no CI coverage |
1199
+ | Multi-GPU | ❌ | ❌ | ❌ | No CI | ❌ Not supported | — |
1200
+
1201
+ > ⚠️ **Best-effort backend:** MPS support depends on PyTorch operator coverage per release.
1202
+ > CPU workflows are tested; MPS-specific AMP/compile paths may fall back or be skipped.
1203
+ >
1204
+ > ⚠️ **Windows/macOS:** The package installs and runs on Windows and macOS, but automated tests
1205
+ > run on Ubuntu only. Regressions on those platforms may not be caught until user reports.
1206
+
1207
+ ### Feature support
1208
+
1209
+ | Feature | Status | Notes |
1210
+ |---------|:------:|-------|
1211
+ | Vector node features `[N, D]` | ✅ Stable | `LinearMessagePassing`, `"linear"` factory |
1212
+ | 2-D spatial node features `[N, C, H, W]` | ✅ Stable | All four spatial layers |
1213
+ | 3-D volumetric node features `[N, C, D, H, W]` | ✅ Stable | `spatial_rank=3` |
1214
+ | Arbitrary-rank tensors (rank ≥ 4) | ❌ Not supported | Only vector, 2-D, 3-D |
1215
+ | Edge weights `[E]` | ✅ Stable | All layers |
1216
+ | Vector edge features `[E, D_e]` | ✅ Stable | GAT, SAGE, GIN |
1217
+ | Spatial edge features `[E, C_e, H, W]` | ⚠️ Best-effort | ConvMP (concat); GAT (mean-pooled); SAGE/GIN (full) |
1218
+ | Volumetric edge features `[E, C_e, D, H, W]` | ⚠️ Best-effort | Same as spatial; `spatial_rank=3` |
1219
+ | Graph Transformer | ❌ Not supported | ⏳ Planned v0.2.5 feasibility study |
1220
+ | Heterogeneous graphs | ❌ Not supported | ⏳ Planned v0.2.5+ |
1221
+ | Temporal graphs | ❌ Not supported | ⏳ Planned v0.2.5+ |
1222
+ | Learned graph construction | ❌ Not supported | `edge_index` is always user-supplied |
1223
+ | PyG/DGL converters | ❌ Not supported | ⏳ Planned v0.2.5 |
1224
+ | MLflowLogger | ❌ Not supported | Use `mlflow` client directly |
1225
+ | Dashboard | 🔒 Opt-in | Launch explicitly; zero overhead when off |
1226
+ | Offline dashboard export | ✅ Stable | `--export-html` or `export_dashboard_html()` |
1227
+ | Multi-run dashboard | ✅ Stable | Point `--logdir` at parent directory |
1228
+ | Hardware monitoring | 🔒 Opt-in | `pip install "tgraphx[monitoring]"` |
1229
+ | TensorBoard logging | 🔒 Opt-in | `pip install "tgraphx[tracking]"`; `TensorBoardLogger` |
1230
+
1231
+ ### Scalability support
1232
+
1233
+ | Feature | Status | Notes |
1234
+ |---------|:------:|-------|
1235
+ | `ConvMessagePassing` chunked forward | ✅ Stable | `aggr="sum"` / `"mean"`; max falls back with warning |
1236
+ | `TensorGraphSAGELayer` chunked forward | ⏳ Planned v0.2.3 | Deferred |
1237
+ | `TensorGINLayer` chunked forward | ⏳ Planned v0.2.3 | Deferred |
1238
+ | `TensorGATLayer` chunked forward | ⏳ Planned v0.2.3+ | Destination-wise softmax makes chunking complex |
1239
+ | `build_grid_graph` / `build_grid_graph_3d` | ✅ Stable | O(E) — scales well |
1240
+ | `build_random_graph` | ✅ Stable | O(E) — scales well |
1241
+ | `build_knn_graph` / `build_radius_graph` | ⚠️ Best-effort | O(N²) via `torch.cdist`; N > 10 000 emits a warning |
1242
+ | `build_fully_connected_graph` / `build_iou_graph` | ⚠️ Best-effort | O(N²) edges; N > 5 000 emits a warning |
1243
+ | Dashboard metrics API | ✅ Stable | Incremental `?since_row=N`; `--max-metric-rows` cap |
1244
+ | Large `metrics.csv` tail-read | ⏳ Planned v0.2.3 | Current: mtime cache + full re-parse on miss |
1245
+
1246
+ > ⚠️ **Scalability warning:** `build_knn_graph`, `build_radius_graph`, `build_fully_connected_graph`,
1247
+ > and `build_iou_graph` use pairwise `torch.cdist` or enumerate all pairs. Memory and time grow as
1248
+ > **O(N²)**. A `warnings.warn` is emitted when node count exceeds the threshold (10 000 for kNN/radius,
1249
+ > 5 000 for fully-connected/IoU). For large graphs use an approximate-NN library instead.
1250
+
1251
+ ### Attention support
1252
+
1253
+ | Feature | Status | Notes |
1254
+ |---------|:------:|-------|
1255
+ | Scalar attention per `(edge, head)` | ✅ Stable | Default in `TensorGATLayer` |
1256
+ | Vector edge attention bias | ✅ Stable | `use_edge_features=True, edge_dim=D` |
1257
+ | Spatial edge attention bias (2-D/3-D) | ⚠️ Best-effort | Accepted; mean-pooled to scalar before projection |
1258
+ | Per-channel attention | ❌ Not supported | ⏳ Planned v0.2.4 |
1259
+ | Per-pixel attention | ❌ Not supported | ⏳ Planned v0.2.4 |
1260
+ | Per-voxel attention | ❌ Not supported | ⏳ Planned v0.2.4 |
1261
+
1262
+ ---
1263
+
1149
1264
  ## Limitations
1150
1265
 
1151
1266
  - **Scope:** TGraphX provides tensor-aware adaptations of GCN-style, GAT, GraphSAGE, and GIN. It is **not** a drop-in PyTorch Geometric replacement: heterogeneous graphs, temporal graphs, graph transformers, and learned graph construction are all out of scope for the current release.
@@ -141,10 +141,23 @@ python benchmarks/benchmark_graph_builders.py # full
141
141
 
142
142
  ```bash
143
143
  python examples/torch_compile_benchmark.py # eager vs compiled, correctness check
144
- python examples/mixed_precision_inference.py # autocast forward demo
144
+ python examples/mixed_precision_inference.py # autocast forward demo (finite-output check)
145
145
  python examples/memory_report.py # env report + memory estimates
146
146
  ```
147
147
 
148
+ **AMP policy (v0.2.2):**
149
+
150
+ | Backend | Recommended dtype | Status | Notes |
151
+ |---------|:-----------------:|:------:|-------|
152
+ | CPU | bfloat16 | ⚠️ Best-effort | Tested in CI |
153
+ | CUDA | float16 / bfloat16 | ⚠️ Best-effort | bfloat16 needs Ampere+ |
154
+ | MPS | — | ❌ Not tested | PyTorch operator coverage varies |
155
+
156
+ v0.2.2 fixes: `broadcast_edge_weight` casts edge weights to activation dtype;
157
+ `TensorGATLayer` casts attention weights before `index_add_`; `edge_softmax`
158
+ upcasts to fp32 for numerical stability and casts back. See
159
+ [docs/performance.md](docs/performance.md#amp-policy) for full details.
160
+
148
161
  ### Optional chunked forward (ConvMessagePassing)
149
162
 
150
163
  Reduce peak edge-buffer memory by processing edges in chunks:
@@ -163,14 +176,14 @@ standard path with a warning. GAT, SAGE, and GIN chunking are deferred
163
176
 
164
177
  ### Hardware compatibility
165
178
 
166
- | Platform | Forward | AMP | torch.compile | Notes |
167
- |----------|:-------:|:---:|:-------------:|-------|
168
- | CPU | ✅ | ⚠️ bfloat16 only | ✅ | Compile overhead may dominate small graphs |
169
- | CUDA | ✅ | ⚠️ float16 (op-dependent) | ✅ | index_add_ ops require dtype match |
170
- | MPS (Apple Silicon) | ✅ | limited | ⚠️ | Some ops may not be compiled |
171
- | Linux | ✅ | ✅ | ✅ | Fully supported |
172
- | Windows | ✅ | ✅ | ✅ | Fully supported |
173
- | macOS | ✅ | limited | ⚠️ | MPS support best-effort |
179
+ | Platform | Forward | AMP | torch.compile | CI coverage | Notes |
180
+ |----------|:-------:|:---:|:-------------:|:-----------:|-------|
181
+ | CPU | ✅ | ⚠️ bfloat16 only | ✅ | Full CI | Compile overhead may dominate small graphs |
182
+ | CUDA | ✅ | ⚠️ float16 (op-dependent) | ✅ | Full CI | `index_add_` ops require dtype match |
183
+ | MPS (Apple Silicon) | ✅ | ⚠️ limited | ⚠️ | No CI | Best-effort; some ops may not compile |
184
+ | Linux | ✅ | ✅ | ✅ | Full CI (ubuntu-latest) | Primary CI platform |
185
+ | Windows | ✅ | ✅ | ✅ | No CI | Best-effort; no automated tests |
186
+ | macOS | ✅ | ⚠️ limited | ⚠️ | No CI | MPS support best-effort |
174
187
 
175
188
  ---
176
189
 
@@ -548,11 +561,17 @@ The graph structure — which nodes are connected — is **not learned by the mo
548
561
 
549
562
  | Layer | Vector `[E, D_e]` | Spatial `[E, C_e, H, W]` |
550
563
  |-------|:----:|:----:|
551
- | `ConvMessagePassing` | ✗ | ✓ (concatenated along channels) |
552
- | `TensorGATLayer` | ✓ (additive attention bias on logits) | |
564
+ | `ConvMessagePassing` | ✗ | ✓ (concatenated along channels; channel count must equal node channel count) |
565
+ | `TensorGATLayer` | ✓ (additive attention bias on logits) | ⚠️ accepted; mean-pooled to scalar attention bias (no per-pixel attention) |
553
566
  | `TensorGraphSAGELayer` | ✓ (additive channel bias post-`W_neigh`) | ✓ (concatenated to source) |
554
567
  | `TensorGINLayer` | ✓ (broadcast bias before ReLU) | ✓ (1×1 Conv2d projection) |
555
568
 
569
+ > **TensorGATLayer spatial edge features:** `[E, C_e, H, W]` (or `[E, C_e, D, H, W]` for 3-D nodes) are
570
+ > accepted and mean-pooled over spatial dims before the attention bias projection. Spatial dims do
571
+ > **not** need to match the node spatial dims. Mismatched rank (e.g. 5-D edges into a 2-D-configured
572
+ > GAT) raises `NotImplementedError`. Use `TensorGraphSAGELayer` or `TensorGINLayer` for full
573
+ > spatial edge-feature processing (no pooling).
574
+
556
575
  ---
557
576
 
558
577
  ## Factory API
@@ -780,7 +799,7 @@ from tgraphx.layers import ConvMessagePassing
780
799
  layer = ConvMessagePassing(
781
800
  in_shape=(C, H, W), # tuple: per-node input shape (spatial only)
782
801
  out_shape=(C_out, H, W), # H and W must stay equal to in_shape's H, W
783
- aggr="sum", # "sum" (default) | "mean"
802
+ aggr="sum", # "sum" (default) | "mean" | "max"
784
803
  use_edge_features=False, # set True to concatenate edge tensors into messages
785
804
  aggregator_params=None, # dict forwarded to DeepCNNAggregator; e.g.
786
805
  # {"num_layers": 2, "dropout_prob": 0.1}
@@ -790,7 +809,9 @@ out = layer(node_features, edge_index) # [N, C_out, H, W]
790
809
  out = layer(node_features, edge_index, edge_features) # with edge features
791
810
  ```
792
811
 
793
- > `aggr="max"` raises `NotImplementedError`. Use `GraphClassifier(pooling="max")` for graph-level max readout.
812
+ > **`aggr="max"`** is supported via `scatter_reduce_(reduce='amax')`. When `chunk_size` is also
813
+ > set, `aggr="max"` falls back to the unchunked path with a `warnings.warn`.
814
+ > Use `GraphClassifier(pooling="max")` for graph-level max readout.
794
815
 
795
816
  ### `AttentionMessagePassing`
796
817
 
@@ -849,8 +870,14 @@ Attention is **scalar per `(edge, head)`** in this implementation: the
849
870
  projected query and key feature maps are mean-pooled over `H × W` before
850
871
  being scored, while the value tensors keep their full spatial layout
851
872
  during aggregation. Per-pixel and per-channel attention modes are not yet
852
- supported. **Spatial** edge feature tensors are not supported by this
853
- layer — use `TensorGraphSAGELayer` or `TensorGINLayer` for those.
873
+ supported.
874
+
875
+ **Spatial edge features** (`[E, C_e, H, W]` for `spatial_rank=2`;
876
+ `[E, C_e, D, H, W]` for `spatial_rank=3`) are accepted: spatial dims are
877
+ mean-pooled to a channel vector before the per-`(edge, head)` attention bias
878
+ projection (spatial dims need not match node spatial dims). Use
879
+ `TensorGraphSAGELayer` or `TensorGINLayer` for full spatial edge-feature
880
+ processing without pooling.
854
881
 
855
882
  ### `TensorGraphSAGELayer`
856
883
 
@@ -1101,6 +1128,94 @@ batch.to(device)
1101
1128
 
1102
1129
  ---
1103
1130
 
1131
+ ## Support status
1132
+
1133
+ ### Legend
1134
+
1135
+ | Label | Meaning |
1136
+ |-------|---------|
1137
+ | ✅ Stable | Tested in CI; API is stable |
1138
+ | 🧪 Experimental | Available but not yet guaranteed-stable |
1139
+ | ⚠️ Best-effort | Works in practice; known constraints documented |
1140
+ | ⏳ Planned | On roadmap; not yet implemented |
1141
+ | ❌ Not supported | Out of scope for the current release |
1142
+ | 🔒 Opt-in | Disabled by default; explicitly enabled by the user |
1143
+
1144
+ ### Backend support
1145
+
1146
+ | Backend | Forward | AMP | torch.compile | CI coverage | Status | Notes |
1147
+ |---------|:-------:|:---:|:-------------:|:-----------:|:------:|-------|
1148
+ | CPU | ✅ | ⚠️ bfloat16 | ✅ | Full CI | ✅ Stable | Compile overhead for small graphs |
1149
+ | CUDA | ✅ | ⚠️ op-dependent | ✅ | Full CI | ✅ Stable | `index_add_` requires dtype match under float16 |
1150
+ | MPS (Apple Silicon) | ✅ | ⚠️ limited | ⚠️ partial | No CI | ⚠️ Best-effort | PyTorch operator coverage varies |
1151
+ | Linux | ✅ | ✅ | ✅ | Full CI (ubuntu-latest) | ✅ Stable | Primary CI platform |
1152
+ | Windows | ✅ | ✅ | ✅ | No CI | ⚠️ Best-effort | Not in CI; known to install correctly |
1153
+ | macOS | ✅ | ⚠️ limited | ⚠️ | No CI | ⚠️ Best-effort | MPS path; no CI coverage |
1154
+ | Multi-GPU | ❌ | ❌ | ❌ | No CI | ❌ Not supported | — |
1155
+
1156
+ > ⚠️ **Best-effort backend:** MPS support depends on PyTorch operator coverage per release.
1157
+ > CPU workflows are tested; MPS-specific AMP/compile paths may fall back or be skipped.
1158
+ >
1159
+ > ⚠️ **Windows/macOS:** The package installs and runs on Windows and macOS, but automated tests
1160
+ > run on Ubuntu only. Regressions on those platforms may not be caught until user reports.
1161
+
1162
+ ### Feature support
1163
+
1164
+ | Feature | Status | Notes |
1165
+ |---------|:------:|-------|
1166
+ | Vector node features `[N, D]` | ✅ Stable | `LinearMessagePassing`, `"linear"` factory |
1167
+ | 2-D spatial node features `[N, C, H, W]` | ✅ Stable | All four spatial layers |
1168
+ | 3-D volumetric node features `[N, C, D, H, W]` | ✅ Stable | `spatial_rank=3` |
1169
+ | Arbitrary-rank tensors (rank ≥ 4) | ❌ Not supported | Only vector, 2-D, 3-D |
1170
+ | Edge weights `[E]` | ✅ Stable | All layers |
1171
+ | Vector edge features `[E, D_e]` | ✅ Stable | GAT, SAGE, GIN |
1172
+ | Spatial edge features `[E, C_e, H, W]` | ⚠️ Best-effort | ConvMP (concat); GAT (mean-pooled); SAGE/GIN (full) |
1173
+ | Volumetric edge features `[E, C_e, D, H, W]` | ⚠️ Best-effort | Same as spatial; `spatial_rank=3` |
1174
+ | Graph Transformer | ❌ Not supported | ⏳ Planned v0.2.5 feasibility study |
1175
+ | Heterogeneous graphs | ❌ Not supported | ⏳ Planned v0.2.5+ |
1176
+ | Temporal graphs | ❌ Not supported | ⏳ Planned v0.2.5+ |
1177
+ | Learned graph construction | ❌ Not supported | `edge_index` is always user-supplied |
1178
+ | PyG/DGL converters | ❌ Not supported | ⏳ Planned v0.2.5 |
1179
+ | MLflowLogger | ❌ Not supported | Use `mlflow` client directly |
1180
+ | Dashboard | 🔒 Opt-in | Launch explicitly; zero overhead when off |
1181
+ | Offline dashboard export | ✅ Stable | `--export-html` or `export_dashboard_html()` |
1182
+ | Multi-run dashboard | ✅ Stable | Point `--logdir` at parent directory |
1183
+ | Hardware monitoring | 🔒 Opt-in | `pip install "tgraphx[monitoring]"` |
1184
+ | TensorBoard logging | 🔒 Opt-in | `pip install "tgraphx[tracking]"`; `TensorBoardLogger` |
1185
+
1186
+ ### Scalability support
1187
+
1188
+ | Feature | Status | Notes |
1189
+ |---------|:------:|-------|
1190
+ | `ConvMessagePassing` chunked forward | ✅ Stable | `aggr="sum"` / `"mean"`; max falls back with warning |
1191
+ | `TensorGraphSAGELayer` chunked forward | ⏳ Planned v0.2.3 | Deferred |
1192
+ | `TensorGINLayer` chunked forward | ⏳ Planned v0.2.3 | Deferred |
1193
+ | `TensorGATLayer` chunked forward | ⏳ Planned v0.2.3+ | Destination-wise softmax makes chunking complex |
1194
+ | `build_grid_graph` / `build_grid_graph_3d` | ✅ Stable | O(E) — scales well |
1195
+ | `build_random_graph` | ✅ Stable | O(E) — scales well |
1196
+ | `build_knn_graph` / `build_radius_graph` | ⚠️ Best-effort | O(N²) via `torch.cdist`; N > 10 000 emits a warning |
1197
+ | `build_fully_connected_graph` / `build_iou_graph` | ⚠️ Best-effort | O(N²) edges; N > 5 000 emits a warning |
1198
+ | Dashboard metrics API | ✅ Stable | Incremental `?since_row=N`; `--max-metric-rows` cap |
1199
+ | Large `metrics.csv` tail-read | ⏳ Planned v0.2.3 | Current: mtime cache + full re-parse on miss |
1200
+
1201
+ > ⚠️ **Scalability warning:** `build_knn_graph`, `build_radius_graph`, `build_fully_connected_graph`,
1202
+ > and `build_iou_graph` use pairwise `torch.cdist` or enumerate all pairs. Memory and time grow as
1203
+ > **O(N²)**. A `warnings.warn` is emitted when node count exceeds the threshold (10 000 for kNN/radius,
1204
+ > 5 000 for fully-connected/IoU). For large graphs use an approximate-NN library instead.
1205
+
1206
+ ### Attention support
1207
+
1208
+ | Feature | Status | Notes |
1209
+ |---------|:------:|-------|
1210
+ | Scalar attention per `(edge, head)` | ✅ Stable | Default in `TensorGATLayer` |
1211
+ | Vector edge attention bias | ✅ Stable | `use_edge_features=True, edge_dim=D` |
1212
+ | Spatial edge attention bias (2-D/3-D) | ⚠️ Best-effort | Accepted; mean-pooled to scalar before projection |
1213
+ | Per-channel attention | ❌ Not supported | ⏳ Planned v0.2.4 |
1214
+ | Per-pixel attention | ❌ Not supported | ⏳ Planned v0.2.4 |
1215
+ | Per-voxel attention | ❌ Not supported | ⏳ Planned v0.2.4 |
1216
+
1217
+ ---
1218
+
1104
1219
  ## Limitations
1105
1220
 
1106
1221
  - **Scope:** TGraphX provides tensor-aware adaptations of GCN-style, GAT, GraphSAGE, and GIN. It is **not** a drop-in PyTorch Geometric replacement: heterogeneous graphs, temporal graphs, graph transformers, and learned graph construction are all out of scope for the current release.
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
5
5
  [project]
6
6
  name = "tgraphx"
7
7
  # Keep this in sync with tgraphx/__init__.py::__version__
8
- version = "0.2.0"
8
+ version = "0.2.2"
9
9
  description = "Tensor-aware graph neural networks preserving spatial node feature layouts"
10
10
  readme = "README.md"
11
11
  requires-python = ">=3.9"