tgraphx 0.2.0__tar.gz → 0.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {tgraphx-0.2.0 → tgraphx-0.2.2}/PKG-INFO +131 -16
- {tgraphx-0.2.0 → tgraphx-0.2.2}/README.md +130 -15
- {tgraphx-0.2.0 → tgraphx-0.2.2}/pyproject.toml +1 -1
- tgraphx-0.2.2/tests/test_amp_compile.py +710 -0
- tgraphx-0.2.2/tests/test_documentation_claims.py +411 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/__init__.py +1 -1
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/graph_builders.py +33 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/layers/_scatter.py +39 -3
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/layers/gat.py +10 -1
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx.egg-info/PKG-INFO +131 -16
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx.egg-info/SOURCES.txt +2 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/LICENSE +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/setup.cfg +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_3d_support.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_dashboard.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_devices.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_edge_features.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_edge_weight.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_factories.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_gnn_families.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_gradients.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_graph.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_graph_api.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_graph_builders.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_imports.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_layers.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_math.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_models.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_packaging.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_performance_smoke.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_tracking.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tests/test_training.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/core/__init__.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/core/dataloader.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/core/graph.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/core/graph_utils.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/core/utils.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/dashboard/__init__.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/dashboard/__main__.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/dashboard/app.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/dashboard/static/dashboard.css +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/dashboard/static/dashboard.js +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/layers/__init__.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/layers/_dim.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/layers/aggregator.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/layers/attention_message.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/layers/base.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/layers/conv_message.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/layers/factory.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/layers/gin.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/layers/safe_pool.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/layers/sage.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/models/__init__.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/models/cnn_encoder.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/models/cnn_gnn_model.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/models/edge_predictor.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/models/factory.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/models/graph_classifier.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/models/node_classifier.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/models/pre_encoder.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/models/regressors.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/performance.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/tracking.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx/training.py +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx.egg-info/dependency_links.txt +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx.egg-info/entry_points.txt +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx.egg-info/requires.txt +0 -0
- {tgraphx-0.2.0 → tgraphx-0.2.2}/tgraphx.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tgraphx
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.2
|
|
4
4
|
Summary: Tensor-aware graph neural networks preserving spatial node feature layouts
|
|
5
5
|
Author-email: Arash Sajjadi <arash.sajjadi@usask.ca>
|
|
6
6
|
Maintainer-email: Arash Sajjadi <arash.sajjadi@usask.ca>
|
|
@@ -186,10 +186,23 @@ python benchmarks/benchmark_graph_builders.py # full
|
|
|
186
186
|
|
|
187
187
|
```bash
|
|
188
188
|
python examples/torch_compile_benchmark.py # eager vs compiled, correctness check
|
|
189
|
-
python examples/mixed_precision_inference.py # autocast forward demo
|
|
189
|
+
python examples/mixed_precision_inference.py # autocast forward demo (finite-output check)
|
|
190
190
|
python examples/memory_report.py # env report + memory estimates
|
|
191
191
|
```
|
|
192
192
|
|
|
193
|
+
**AMP policy (v0.2.2):**
|
|
194
|
+
|
|
195
|
+
| Backend | Recommended dtype | Status | Notes |
|
|
196
|
+
|---------|:-----------------:|:------:|-------|
|
|
197
|
+
| CPU | bfloat16 | ⚠️ Best-effort | Tested in CI |
|
|
198
|
+
| CUDA | float16 / bfloat16 | ⚠️ Best-effort | bfloat16 needs Ampere+ |
|
|
199
|
+
| MPS | — | ❌ Not tested | PyTorch operator coverage varies |
|
|
200
|
+
|
|
201
|
+
v0.2.2 fixes: `broadcast_edge_weight` casts edge weights to activation dtype;
|
|
202
|
+
`TensorGATLayer` casts attention weights before `index_add_`; `edge_softmax`
|
|
203
|
+
upcasts to fp32 for numerical stability and casts back. See
|
|
204
|
+
[docs/performance.md](docs/performance.md#amp-policy) for full details.
|
|
205
|
+
|
|
193
206
|
### Optional chunked forward (ConvMessagePassing)
|
|
194
207
|
|
|
195
208
|
Reduce peak edge-buffer memory by processing edges in chunks:
|
|
@@ -208,14 +221,14 @@ standard path with a warning. GAT, SAGE, and GIN chunking are deferred
|
|
|
208
221
|
|
|
209
222
|
### Hardware compatibility
|
|
210
223
|
|
|
211
|
-
| Platform | Forward | AMP | torch.compile | Notes |
|
|
212
|
-
|
|
213
|
-
| CPU | ✅ | ⚠️ bfloat16 only | ✅ | Compile overhead may dominate small graphs |
|
|
214
|
-
| CUDA | ✅ | ⚠️ float16 (op-dependent) | ✅ | index_add_ ops require dtype match |
|
|
215
|
-
| MPS (Apple Silicon) | ✅ | limited | ⚠️ |
|
|
216
|
-
| Linux | ✅ | ✅ | ✅ |
|
|
217
|
-
| Windows | ✅ | ✅ | ✅ |
|
|
218
|
-
| macOS | ✅ | limited | ⚠️ | MPS support best-effort |
|
|
224
|
+
| Platform | Forward | AMP | torch.compile | CI coverage | Notes |
|
|
225
|
+
|----------|:-------:|:---:|:-------------:|:-----------:|-------|
|
|
226
|
+
| CPU | ✅ | ⚠️ bfloat16 only | ✅ | Full CI | Compile overhead may dominate small graphs |
|
|
227
|
+
| CUDA | ✅ | ⚠️ float16 (op-dependent) | ✅ | Full CI | `index_add_` ops require dtype match |
|
|
228
|
+
| MPS (Apple Silicon) | ✅ | ⚠️ limited | ⚠️ | No CI | Best-effort; some ops may not compile |
|
|
229
|
+
| Linux | ✅ | ✅ | ✅ | Full CI (ubuntu-latest) | Primary CI platform |
|
|
230
|
+
| Windows | ✅ | ✅ | ✅ | No CI | Best-effort; no automated tests |
|
|
231
|
+
| macOS | ✅ | ⚠️ limited | ⚠️ | No CI | MPS support best-effort |
|
|
219
232
|
|
|
220
233
|
---
|
|
221
234
|
|
|
@@ -593,11 +606,17 @@ The graph structure — which nodes are connected — is **not learned by the mo
|
|
|
593
606
|
|
|
594
607
|
| Layer | Vector `[E, D_e]` | Spatial `[E, C_e, H, W]` |
|
|
595
608
|
|-------|:----:|:----:|
|
|
596
|
-
| `ConvMessagePassing` | ✗ | ✓ (concatenated along channels) |
|
|
597
|
-
| `TensorGATLayer` | ✓ (additive attention bias on logits) |
|
|
609
|
+
| `ConvMessagePassing` | ✗ | ✓ (concatenated along channels; channel count must equal node channel count) |
|
|
610
|
+
| `TensorGATLayer` | ✓ (additive attention bias on logits) | ⚠️ accepted; mean-pooled to scalar attention bias (no per-pixel attention) |
|
|
598
611
|
| `TensorGraphSAGELayer` | ✓ (additive channel bias post-`W_neigh`) | ✓ (concatenated to source) |
|
|
599
612
|
| `TensorGINLayer` | ✓ (broadcast bias before ReLU) | ✓ (1×1 Conv2d projection) |
|
|
600
613
|
|
|
614
|
+
> **TensorGATLayer spatial edge features:** `[E, C_e, H, W]` (or `[E, C_e, D, H, W]` for 3-D nodes) are
|
|
615
|
+
> accepted and mean-pooled over spatial dims before the attention bias projection. Spatial dims do
|
|
616
|
+
> **not** need to match the node spatial dims. Mismatched rank (e.g. 5-D edges into a 2-D-configured
|
|
617
|
+
> GAT) raises `NotImplementedError`. Use `TensorGraphSAGELayer` or `TensorGINLayer` for full
|
|
618
|
+
> spatial edge-feature processing (no pooling).
|
|
619
|
+
|
|
601
620
|
---
|
|
602
621
|
|
|
603
622
|
## Factory API
|
|
@@ -825,7 +844,7 @@ from tgraphx.layers import ConvMessagePassing
|
|
|
825
844
|
layer = ConvMessagePassing(
|
|
826
845
|
in_shape=(C, H, W), # tuple: per-node input shape (spatial only)
|
|
827
846
|
out_shape=(C_out, H, W), # H and W must stay equal to in_shape's H, W
|
|
828
|
-
aggr="sum", # "sum" (default) | "mean"
|
|
847
|
+
aggr="sum", # "sum" (default) | "mean" | "max"
|
|
829
848
|
use_edge_features=False, # set True to concatenate edge tensors into messages
|
|
830
849
|
aggregator_params=None, # dict forwarded to DeepCNNAggregator; e.g.
|
|
831
850
|
# {"num_layers": 2, "dropout_prob": 0.1}
|
|
@@ -835,7 +854,9 @@ out = layer(node_features, edge_index) # [N, C_out, H, W]
|
|
|
835
854
|
out = layer(node_features, edge_index, edge_features) # with edge features
|
|
836
855
|
```
|
|
837
856
|
|
|
838
|
-
>
|
|
857
|
+
> **`aggr="max"`** is supported via `scatter_reduce_(reduce='amax')`. When `chunk_size` is also
|
|
858
|
+
> set, `aggr="max"` falls back to the unchunked path with a `warnings.warn`.
|
|
859
|
+
> Use `GraphClassifier(pooling="max")` for graph-level max readout.
|
|
839
860
|
|
|
840
861
|
### `AttentionMessagePassing`
|
|
841
862
|
|
|
@@ -894,8 +915,14 @@ Attention is **scalar per `(edge, head)`** in this implementation: the
|
|
|
894
915
|
projected query and key feature maps are mean-pooled over `H × W` before
|
|
895
916
|
being scored, while the value tensors keep their full spatial layout
|
|
896
917
|
during aggregation. Per-pixel and per-channel attention modes are not yet
|
|
897
|
-
supported.
|
|
898
|
-
|
|
918
|
+
supported.
|
|
919
|
+
|
|
920
|
+
**Spatial edge features** (`[E, C_e, H, W]` for `spatial_rank=2`;
|
|
921
|
+
`[E, C_e, D, H, W]` for `spatial_rank=3`) are accepted: spatial dims are
|
|
922
|
+
mean-pooled to a channel vector before the per-`(edge, head)` attention bias
|
|
923
|
+
projection (spatial dims need not match node spatial dims). Use
|
|
924
|
+
`TensorGraphSAGELayer` or `TensorGINLayer` for full spatial edge-feature
|
|
925
|
+
processing without pooling.
|
|
899
926
|
|
|
900
927
|
### `TensorGraphSAGELayer`
|
|
901
928
|
|
|
@@ -1146,6 +1173,94 @@ batch.to(device)
|
|
|
1146
1173
|
|
|
1147
1174
|
---
|
|
1148
1175
|
|
|
1176
|
+
## Support status
|
|
1177
|
+
|
|
1178
|
+
### Legend
|
|
1179
|
+
|
|
1180
|
+
| Label | Meaning |
|
|
1181
|
+
|-------|---------|
|
|
1182
|
+
| ✅ Stable | Tested in CI; API is stable |
|
|
1183
|
+
| 🧪 Experimental | Available but not yet guaranteed-stable |
|
|
1184
|
+
| ⚠️ Best-effort | Works in practice; known constraints documented |
|
|
1185
|
+
| ⏳ Planned | On roadmap; not yet implemented |
|
|
1186
|
+
| ❌ Not supported | Out of scope for the current release |
|
|
1187
|
+
| 🔒 Opt-in | Disabled by default; explicitly enabled by the user |
|
|
1188
|
+
|
|
1189
|
+
### Backend support
|
|
1190
|
+
|
|
1191
|
+
| Backend | Forward | AMP | torch.compile | CI coverage | Status | Notes |
|
|
1192
|
+
|---------|:-------:|:---:|:-------------:|:-----------:|:------:|-------|
|
|
1193
|
+
| CPU | ✅ | ⚠️ bfloat16 | ✅ | Full CI | ✅ Stable | Compile overhead for small graphs |
|
|
1194
|
+
| CUDA | ✅ | ⚠️ op-dependent | ✅ | Full CI | ✅ Stable | `index_add_` requires dtype match under float16 |
|
|
1195
|
+
| MPS (Apple Silicon) | ✅ | ⚠️ limited | ⚠️ partial | No CI | ⚠️ Best-effort | PyTorch operator coverage varies |
|
|
1196
|
+
| Linux | ✅ | ✅ | ✅ | Full CI (ubuntu-latest) | ✅ Stable | Primary CI platform |
|
|
1197
|
+
| Windows | ✅ | ✅ | ✅ | No CI | ⚠️ Best-effort | Not in CI; known to install correctly |
|
|
1198
|
+
| macOS | ✅ | ⚠️ limited | ⚠️ | No CI | ⚠️ Best-effort | MPS path; no CI coverage |
|
|
1199
|
+
| Multi-GPU | ❌ | ❌ | ❌ | No CI | ❌ Not supported | — |
|
|
1200
|
+
|
|
1201
|
+
> ⚠️ **Best-effort backend:** MPS support depends on PyTorch operator coverage per release.
|
|
1202
|
+
> CPU workflows are tested; MPS-specific AMP/compile paths may fall back or be skipped.
|
|
1203
|
+
>
|
|
1204
|
+
> ⚠️ **Windows/macOS:** The package installs and runs on Windows and macOS, but automated tests
|
|
1205
|
+
> run on Ubuntu only. Regressions on those platforms may not be caught until user reports.
|
|
1206
|
+
|
|
1207
|
+
### Feature support
|
|
1208
|
+
|
|
1209
|
+
| Feature | Status | Notes |
|
|
1210
|
+
|---------|:------:|-------|
|
|
1211
|
+
| Vector node features `[N, D]` | ✅ Stable | `LinearMessagePassing`, `"linear"` factory |
|
|
1212
|
+
| 2-D spatial node features `[N, C, H, W]` | ✅ Stable | All four spatial layers |
|
|
1213
|
+
| 3-D volumetric node features `[N, C, D, H, W]` | ✅ Stable | `spatial_rank=3` |
|
|
1214
|
+
| Arbitrary-rank tensors (rank ≥ 4) | ❌ Not supported | Only vector, 2-D, 3-D |
|
|
1215
|
+
| Edge weights `[E]` | ✅ Stable | All layers |
|
|
1216
|
+
| Vector edge features `[E, D_e]` | ✅ Stable | GAT, SAGE, GIN |
|
|
1217
|
+
| Spatial edge features `[E, C_e, H, W]` | ⚠️ Best-effort | ConvMP (concat); GAT (mean-pooled); SAGE/GIN (full) |
|
|
1218
|
+
| Volumetric edge features `[E, C_e, D, H, W]` | ⚠️ Best-effort | Same as spatial; `spatial_rank=3` |
|
|
1219
|
+
| Graph Transformer | ❌ Not supported | ⏳ Planned v0.2.5 feasibility study |
|
|
1220
|
+
| Heterogeneous graphs | ❌ Not supported | ⏳ Planned v0.2.5+ |
|
|
1221
|
+
| Temporal graphs | ❌ Not supported | ⏳ Planned v0.2.5+ |
|
|
1222
|
+
| Learned graph construction | ❌ Not supported | `edge_index` is always user-supplied |
|
|
1223
|
+
| PyG/DGL converters | ❌ Not supported | ⏳ Planned v0.2.5 |
|
|
1224
|
+
| MLflowLogger | ❌ Not supported | Use `mlflow` client directly |
|
|
1225
|
+
| Dashboard | 🔒 Opt-in | Launch explicitly; zero overhead when off |
|
|
1226
|
+
| Offline dashboard export | ✅ Stable | `--export-html` or `export_dashboard_html()` |
|
|
1227
|
+
| Multi-run dashboard | ✅ Stable | Point `--logdir` at parent directory |
|
|
1228
|
+
| Hardware monitoring | 🔒 Opt-in | `pip install "tgraphx[monitoring]"` |
|
|
1229
|
+
| TensorBoard logging | 🔒 Opt-in | `pip install "tgraphx[tracking]"`; `TensorBoardLogger` |
|
|
1230
|
+
|
|
1231
|
+
### Scalability support
|
|
1232
|
+
|
|
1233
|
+
| Feature | Status | Notes |
|
|
1234
|
+
|---------|:------:|-------|
|
|
1235
|
+
| `ConvMessagePassing` chunked forward | ✅ Stable | `aggr="sum"` / `"mean"`; max falls back with warning |
|
|
1236
|
+
| `TensorGraphSAGELayer` chunked forward | ⏳ Planned v0.2.3 | Deferred |
|
|
1237
|
+
| `TensorGINLayer` chunked forward | ⏳ Planned v0.2.3 | Deferred |
|
|
1238
|
+
| `TensorGATLayer` chunked forward | ⏳ Planned v0.2.3+ | Destination-wise softmax makes chunking complex |
|
|
1239
|
+
| `build_grid_graph` / `build_grid_graph_3d` | ✅ Stable | O(E) — scales well |
|
|
1240
|
+
| `build_random_graph` | ✅ Stable | O(E) — scales well |
|
|
1241
|
+
| `build_knn_graph` / `build_radius_graph` | ⚠️ Best-effort | O(N²) via `torch.cdist`; N > 10 000 emits a warning |
|
|
1242
|
+
| `build_fully_connected_graph` / `build_iou_graph` | ⚠️ Best-effort | O(N²) edges; N > 5 000 emits a warning |
|
|
1243
|
+
| Dashboard metrics API | ✅ Stable | Incremental `?since_row=N`; `--max-metric-rows` cap |
|
|
1244
|
+
| Large `metrics.csv` tail-read | ⏳ Planned v0.2.3 | Current: mtime cache + full re-parse on miss |
|
|
1245
|
+
|
|
1246
|
+
> ⚠️ **Scalability warning:** `build_knn_graph`, `build_radius_graph`, `build_fully_connected_graph`,
|
|
1247
|
+
> and `build_iou_graph` use pairwise `torch.cdist` or enumerate all pairs. Memory and time grow as
|
|
1248
|
+
> **O(N²)**. A `warnings.warn` is emitted when node count exceeds the threshold (10 000 for kNN/radius,
|
|
1249
|
+
> 5 000 for fully-connected/IoU). For large graphs use an approximate-NN library instead.
|
|
1250
|
+
|
|
1251
|
+
### Attention support
|
|
1252
|
+
|
|
1253
|
+
| Feature | Status | Notes |
|
|
1254
|
+
|---------|:------:|-------|
|
|
1255
|
+
| Scalar attention per `(edge, head)` | ✅ Stable | Default in `TensorGATLayer` |
|
|
1256
|
+
| Vector edge attention bias | ✅ Stable | `use_edge_features=True, edge_dim=D` |
|
|
1257
|
+
| Spatial edge attention bias (2-D/3-D) | ⚠️ Best-effort | Accepted; mean-pooled to scalar before projection |
|
|
1258
|
+
| Per-channel attention | ❌ Not supported | ⏳ Planned v0.2.4 |
|
|
1259
|
+
| Per-pixel attention | ❌ Not supported | ⏳ Planned v0.2.4 |
|
|
1260
|
+
| Per-voxel attention | ❌ Not supported | ⏳ Planned v0.2.4 |
|
|
1261
|
+
|
|
1262
|
+
---
|
|
1263
|
+
|
|
1149
1264
|
## Limitations
|
|
1150
1265
|
|
|
1151
1266
|
- **Scope:** TGraphX provides tensor-aware adaptations of GCN-style, GAT, GraphSAGE, and GIN. It is **not** a drop-in PyTorch Geometric replacement: heterogeneous graphs, temporal graphs, graph transformers, and learned graph construction are all out of scope for the current release.
|
|
@@ -141,10 +141,23 @@ python benchmarks/benchmark_graph_builders.py # full
|
|
|
141
141
|
|
|
142
142
|
```bash
|
|
143
143
|
python examples/torch_compile_benchmark.py # eager vs compiled, correctness check
|
|
144
|
-
python examples/mixed_precision_inference.py # autocast forward demo
|
|
144
|
+
python examples/mixed_precision_inference.py # autocast forward demo (finite-output check)
|
|
145
145
|
python examples/memory_report.py # env report + memory estimates
|
|
146
146
|
```
|
|
147
147
|
|
|
148
|
+
**AMP policy (v0.2.2):**
|
|
149
|
+
|
|
150
|
+
| Backend | Recommended dtype | Status | Notes |
|
|
151
|
+
|---------|:-----------------:|:------:|-------|
|
|
152
|
+
| CPU | bfloat16 | ⚠️ Best-effort | Tested in CI |
|
|
153
|
+
| CUDA | float16 / bfloat16 | ⚠️ Best-effort | bfloat16 needs Ampere+ |
|
|
154
|
+
| MPS | — | ❌ Not tested | PyTorch operator coverage varies |
|
|
155
|
+
|
|
156
|
+
v0.2.2 fixes: `broadcast_edge_weight` casts edge weights to activation dtype;
|
|
157
|
+
`TensorGATLayer` casts attention weights before `index_add_`; `edge_softmax`
|
|
158
|
+
upcasts to fp32 for numerical stability and casts back. See
|
|
159
|
+
[docs/performance.md](docs/performance.md#amp-policy) for full details.
|
|
160
|
+
|
|
148
161
|
### Optional chunked forward (ConvMessagePassing)
|
|
149
162
|
|
|
150
163
|
Reduce peak edge-buffer memory by processing edges in chunks:
|
|
@@ -163,14 +176,14 @@ standard path with a warning. GAT, SAGE, and GIN chunking are deferred
|
|
|
163
176
|
|
|
164
177
|
### Hardware compatibility
|
|
165
178
|
|
|
166
|
-
| Platform | Forward | AMP | torch.compile | Notes |
|
|
167
|
-
|
|
168
|
-
| CPU | ✅ | ⚠️ bfloat16 only | ✅ | Compile overhead may dominate small graphs |
|
|
169
|
-
| CUDA | ✅ | ⚠️ float16 (op-dependent) | ✅ | index_add_ ops require dtype match |
|
|
170
|
-
| MPS (Apple Silicon) | ✅ | limited | ⚠️ |
|
|
171
|
-
| Linux | ✅ | ✅ | ✅ |
|
|
172
|
-
| Windows | ✅ | ✅ | ✅ |
|
|
173
|
-
| macOS | ✅ | limited | ⚠️ | MPS support best-effort |
|
|
179
|
+
| Platform | Forward | AMP | torch.compile | CI coverage | Notes |
|
|
180
|
+
|----------|:-------:|:---:|:-------------:|:-----------:|-------|
|
|
181
|
+
| CPU | ✅ | ⚠️ bfloat16 only | ✅ | Full CI | Compile overhead may dominate small graphs |
|
|
182
|
+
| CUDA | ✅ | ⚠️ float16 (op-dependent) | ✅ | Full CI | `index_add_` ops require dtype match |
|
|
183
|
+
| MPS (Apple Silicon) | ✅ | ⚠️ limited | ⚠️ | No CI | Best-effort; some ops may not compile |
|
|
184
|
+
| Linux | ✅ | ✅ | ✅ | Full CI (ubuntu-latest) | Primary CI platform |
|
|
185
|
+
| Windows | ✅ | ✅ | ✅ | No CI | Best-effort; no automated tests |
|
|
186
|
+
| macOS | ✅ | ⚠️ limited | ⚠️ | No CI | MPS support best-effort |
|
|
174
187
|
|
|
175
188
|
---
|
|
176
189
|
|
|
@@ -548,11 +561,17 @@ The graph structure — which nodes are connected — is **not learned by the mo
|
|
|
548
561
|
|
|
549
562
|
| Layer | Vector `[E, D_e]` | Spatial `[E, C_e, H, W]` |
|
|
550
563
|
|-------|:----:|:----:|
|
|
551
|
-
| `ConvMessagePassing` | ✗ | ✓ (concatenated along channels) |
|
|
552
|
-
| `TensorGATLayer` | ✓ (additive attention bias on logits) |
|
|
564
|
+
| `ConvMessagePassing` | ✗ | ✓ (concatenated along channels; channel count must equal node channel count) |
|
|
565
|
+
| `TensorGATLayer` | ✓ (additive attention bias on logits) | ⚠️ accepted; mean-pooled to scalar attention bias (no per-pixel attention) |
|
|
553
566
|
| `TensorGraphSAGELayer` | ✓ (additive channel bias post-`W_neigh`) | ✓ (concatenated to source) |
|
|
554
567
|
| `TensorGINLayer` | ✓ (broadcast bias before ReLU) | ✓ (1×1 Conv2d projection) |
|
|
555
568
|
|
|
569
|
+
> **TensorGATLayer spatial edge features:** `[E, C_e, H, W]` (or `[E, C_e, D, H, W]` for 3-D nodes) are
|
|
570
|
+
> accepted and mean-pooled over spatial dims before the attention bias projection. Spatial dims do
|
|
571
|
+
> **not** need to match the node spatial dims. Mismatched rank (e.g. 5-D edges into a 2-D-configured
|
|
572
|
+
> GAT) raises `NotImplementedError`. Use `TensorGraphSAGELayer` or `TensorGINLayer` for full
|
|
573
|
+
> spatial edge-feature processing (no pooling).
|
|
574
|
+
|
|
556
575
|
---
|
|
557
576
|
|
|
558
577
|
## Factory API
|
|
@@ -780,7 +799,7 @@ from tgraphx.layers import ConvMessagePassing
|
|
|
780
799
|
layer = ConvMessagePassing(
|
|
781
800
|
in_shape=(C, H, W), # tuple: per-node input shape (spatial only)
|
|
782
801
|
out_shape=(C_out, H, W), # H and W must stay equal to in_shape's H, W
|
|
783
|
-
aggr="sum", # "sum" (default) | "mean"
|
|
802
|
+
aggr="sum", # "sum" (default) | "mean" | "max"
|
|
784
803
|
use_edge_features=False, # set True to concatenate edge tensors into messages
|
|
785
804
|
aggregator_params=None, # dict forwarded to DeepCNNAggregator; e.g.
|
|
786
805
|
# {"num_layers": 2, "dropout_prob": 0.1}
|
|
@@ -790,7 +809,9 @@ out = layer(node_features, edge_index) # [N, C_out, H, W]
|
|
|
790
809
|
out = layer(node_features, edge_index, edge_features) # with edge features
|
|
791
810
|
```
|
|
792
811
|
|
|
793
|
-
>
|
|
812
|
+
> **`aggr="max"`** is supported via `scatter_reduce_(reduce='amax')`. When `chunk_size` is also
|
|
813
|
+
> set, `aggr="max"` falls back to the unchunked path with a `warnings.warn`.
|
|
814
|
+
> Use `GraphClassifier(pooling="max")` for graph-level max readout.
|
|
794
815
|
|
|
795
816
|
### `AttentionMessagePassing`
|
|
796
817
|
|
|
@@ -849,8 +870,14 @@ Attention is **scalar per `(edge, head)`** in this implementation: the
|
|
|
849
870
|
projected query and key feature maps are mean-pooled over `H × W` before
|
|
850
871
|
being scored, while the value tensors keep their full spatial layout
|
|
851
872
|
during aggregation. Per-pixel and per-channel attention modes are not yet
|
|
852
|
-
supported.
|
|
853
|
-
|
|
873
|
+
supported.
|
|
874
|
+
|
|
875
|
+
**Spatial edge features** (`[E, C_e, H, W]` for `spatial_rank=2`;
|
|
876
|
+
`[E, C_e, D, H, W]` for `spatial_rank=3`) are accepted: spatial dims are
|
|
877
|
+
mean-pooled to a channel vector before the per-`(edge, head)` attention bias
|
|
878
|
+
projection (spatial dims need not match node spatial dims). Use
|
|
879
|
+
`TensorGraphSAGELayer` or `TensorGINLayer` for full spatial edge-feature
|
|
880
|
+
processing without pooling.
|
|
854
881
|
|
|
855
882
|
### `TensorGraphSAGELayer`
|
|
856
883
|
|
|
@@ -1101,6 +1128,94 @@ batch.to(device)
|
|
|
1101
1128
|
|
|
1102
1129
|
---
|
|
1103
1130
|
|
|
1131
|
+
## Support status
|
|
1132
|
+
|
|
1133
|
+
### Legend
|
|
1134
|
+
|
|
1135
|
+
| Label | Meaning |
|
|
1136
|
+
|-------|---------|
|
|
1137
|
+
| ✅ Stable | Tested in CI; API is stable |
|
|
1138
|
+
| 🧪 Experimental | Available but not yet guaranteed-stable |
|
|
1139
|
+
| ⚠️ Best-effort | Works in practice; known constraints documented |
|
|
1140
|
+
| ⏳ Planned | On roadmap; not yet implemented |
|
|
1141
|
+
| ❌ Not supported | Out of scope for the current release |
|
|
1142
|
+
| 🔒 Opt-in | Disabled by default; explicitly enabled by the user |
|
|
1143
|
+
|
|
1144
|
+
### Backend support
|
|
1145
|
+
|
|
1146
|
+
| Backend | Forward | AMP | torch.compile | CI coverage | Status | Notes |
|
|
1147
|
+
|---------|:-------:|:---:|:-------------:|:-----------:|:------:|-------|
|
|
1148
|
+
| CPU | ✅ | ⚠️ bfloat16 | ✅ | Full CI | ✅ Stable | Compile overhead for small graphs |
|
|
1149
|
+
| CUDA | ✅ | ⚠️ op-dependent | ✅ | Full CI | ✅ Stable | `index_add_` requires dtype match under float16 |
|
|
1150
|
+
| MPS (Apple Silicon) | ✅ | ⚠️ limited | ⚠️ partial | No CI | ⚠️ Best-effort | PyTorch operator coverage varies |
|
|
1151
|
+
| Linux | ✅ | ✅ | ✅ | Full CI (ubuntu-latest) | ✅ Stable | Primary CI platform |
|
|
1152
|
+
| Windows | ✅ | ✅ | ✅ | No CI | ⚠️ Best-effort | Not in CI; known to install correctly |
|
|
1153
|
+
| macOS | ✅ | ⚠️ limited | ⚠️ | No CI | ⚠️ Best-effort | MPS path; no CI coverage |
|
|
1154
|
+
| Multi-GPU | ❌ | ❌ | ❌ | No CI | ❌ Not supported | — |
|
|
1155
|
+
|
|
1156
|
+
> ⚠️ **Best-effort backend:** MPS support depends on PyTorch operator coverage per release.
|
|
1157
|
+
> CPU workflows are tested; MPS-specific AMP/compile paths may fall back or be skipped.
|
|
1158
|
+
>
|
|
1159
|
+
> ⚠️ **Windows/macOS:** The package installs and runs on Windows and macOS, but automated tests
|
|
1160
|
+
> run on Ubuntu only. Regressions on those platforms may not be caught until user reports.
|
|
1161
|
+
|
|
1162
|
+
### Feature support
|
|
1163
|
+
|
|
1164
|
+
| Feature | Status | Notes |
|
|
1165
|
+
|---------|:------:|-------|
|
|
1166
|
+
| Vector node features `[N, D]` | ✅ Stable | `LinearMessagePassing`, `"linear"` factory |
|
|
1167
|
+
| 2-D spatial node features `[N, C, H, W]` | ✅ Stable | All four spatial layers |
|
|
1168
|
+
| 3-D volumetric node features `[N, C, D, H, W]` | ✅ Stable | `spatial_rank=3` |
|
|
1169
|
+
| Arbitrary-rank tensors (rank ≥ 4) | ❌ Not supported | Only vector, 2-D, 3-D |
|
|
1170
|
+
| Edge weights `[E]` | ✅ Stable | All layers |
|
|
1171
|
+
| Vector edge features `[E, D_e]` | ✅ Stable | GAT, SAGE, GIN |
|
|
1172
|
+
| Spatial edge features `[E, C_e, H, W]` | ⚠️ Best-effort | ConvMP (concat); GAT (mean-pooled); SAGE/GIN (full) |
|
|
1173
|
+
| Volumetric edge features `[E, C_e, D, H, W]` | ⚠️ Best-effort | Same as spatial; `spatial_rank=3` |
|
|
1174
|
+
| Graph Transformer | ❌ Not supported | ⏳ Planned v0.2.5 feasibility study |
|
|
1175
|
+
| Heterogeneous graphs | ❌ Not supported | ⏳ Planned v0.2.5+ |
|
|
1176
|
+
| Temporal graphs | ❌ Not supported | ⏳ Planned v0.2.5+ |
|
|
1177
|
+
| Learned graph construction | ❌ Not supported | `edge_index` is always user-supplied |
|
|
1178
|
+
| PyG/DGL converters | ❌ Not supported | ⏳ Planned v0.2.5 |
|
|
1179
|
+
| MLflowLogger | ❌ Not supported | Use `mlflow` client directly |
|
|
1180
|
+
| Dashboard | 🔒 Opt-in | Launch explicitly; zero overhead when off |
|
|
1181
|
+
| Offline dashboard export | ✅ Stable | `--export-html` or `export_dashboard_html()` |
|
|
1182
|
+
| Multi-run dashboard | ✅ Stable | Point `--logdir` at parent directory |
|
|
1183
|
+
| Hardware monitoring | 🔒 Opt-in | `pip install "tgraphx[monitoring]"` |
|
|
1184
|
+
| TensorBoard logging | 🔒 Opt-in | `pip install "tgraphx[tracking]"`; `TensorBoardLogger` |
|
|
1185
|
+
|
|
1186
|
+
### Scalability support
|
|
1187
|
+
|
|
1188
|
+
| Feature | Status | Notes |
|
|
1189
|
+
|---------|:------:|-------|
|
|
1190
|
+
| `ConvMessagePassing` chunked forward | ✅ Stable | `aggr="sum"` / `"mean"`; max falls back with warning |
|
|
1191
|
+
| `TensorGraphSAGELayer` chunked forward | ⏳ Planned v0.2.3 | Deferred |
|
|
1192
|
+
| `TensorGINLayer` chunked forward | ⏳ Planned v0.2.3 | Deferred |
|
|
1193
|
+
| `TensorGATLayer` chunked forward | ⏳ Planned v0.2.3+ | Destination-wise softmax makes chunking complex |
|
|
1194
|
+
| `build_grid_graph` / `build_grid_graph_3d` | ✅ Stable | O(E) — scales well |
|
|
1195
|
+
| `build_random_graph` | ✅ Stable | O(E) — scales well |
|
|
1196
|
+
| `build_knn_graph` / `build_radius_graph` | ⚠️ Best-effort | O(N²) via `torch.cdist`; N > 10 000 emits a warning |
|
|
1197
|
+
| `build_fully_connected_graph` / `build_iou_graph` | ⚠️ Best-effort | O(N²) edges; N > 5 000 emits a warning |
|
|
1198
|
+
| Dashboard metrics API | ✅ Stable | Incremental `?since_row=N`; `--max-metric-rows` cap |
|
|
1199
|
+
| Large `metrics.csv` tail-read | ⏳ Planned v0.2.3 | Current: mtime cache + full re-parse on miss |
|
|
1200
|
+
|
|
1201
|
+
> ⚠️ **Scalability warning:** `build_knn_graph`, `build_radius_graph`, `build_fully_connected_graph`,
|
|
1202
|
+
> and `build_iou_graph` use pairwise `torch.cdist` or enumerate all pairs. Memory and time grow as
|
|
1203
|
+
> **O(N²)**. A `warnings.warn` is emitted when node count exceeds the threshold (10 000 for kNN/radius,
|
|
1204
|
+
> 5 000 for fully-connected/IoU). For large graphs use an approximate-NN library instead.
|
|
1205
|
+
|
|
1206
|
+
### Attention support
|
|
1207
|
+
|
|
1208
|
+
| Feature | Status | Notes |
|
|
1209
|
+
|---------|:------:|-------|
|
|
1210
|
+
| Scalar attention per `(edge, head)` | ✅ Stable | Default in `TensorGATLayer` |
|
|
1211
|
+
| Vector edge attention bias | ✅ Stable | `use_edge_features=True, edge_dim=D` |
|
|
1212
|
+
| Spatial edge attention bias (2-D/3-D) | ⚠️ Best-effort | Accepted; mean-pooled to scalar before projection |
|
|
1213
|
+
| Per-channel attention | ❌ Not supported | ⏳ Planned v0.2.4 |
|
|
1214
|
+
| Per-pixel attention | ❌ Not supported | ⏳ Planned v0.2.4 |
|
|
1215
|
+
| Per-voxel attention | ❌ Not supported | ⏳ Planned v0.2.4 |
|
|
1216
|
+
|
|
1217
|
+
---
|
|
1218
|
+
|
|
1104
1219
|
## Limitations
|
|
1105
1220
|
|
|
1106
1221
|
- **Scope:** TGraphX provides tensor-aware adaptations of GCN-style, GAT, GraphSAGE, and GIN. It is **not** a drop-in PyTorch Geometric replacement: heterogeneous graphs, temporal graphs, graph transformers, and learned graph construction are all out of scope for the current release.
|
|
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "tgraphx"
|
|
7
7
|
# Keep this in sync with tgraphx/__init__.py::__version__
|
|
8
|
-
version = "0.2.
|
|
8
|
+
version = "0.2.2"
|
|
9
9
|
description = "Tensor-aware graph neural networks preserving spatial node feature layouts"
|
|
10
10
|
readme = "README.md"
|
|
11
11
|
requires-python = ">=3.9"
|