tgraphx 0.2.2__tar.gz → 0.2.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {tgraphx-0.2.2 → tgraphx-0.2.7}/PKG-INFO +78 -44
- {tgraphx-0.2.2 → tgraphx-0.2.7}/README.md +75 -43
- {tgraphx-0.2.2 → tgraphx-0.2.7}/pyproject.toml +5 -3
- tgraphx-0.2.7/tests/test_backward_compatibility.py +153 -0
- tgraphx-0.2.7/tests/test_chunking.py +500 -0
- tgraphx-0.2.7/tests/test_distributed_compat.py +57 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tests/test_documentation_claims.py +126 -0
- tgraphx-0.2.7/tests/test_gat_chunking.py +168 -0
- tgraphx-0.2.7/tests/test_graph_transformer_v027.py +130 -0
- tgraphx-0.2.7/tests/test_hetero_batch.py +150 -0
- tgraphx-0.2.7/tests/test_hetero_layers.py +185 -0
- tgraphx-0.2.7/tests/test_sampling.py +237 -0
- tgraphx-0.2.7/tests/test_sampling_loaders.py +107 -0
- tgraphx-0.2.7/tests/test_temporal_v025.py +198 -0
- tgraphx-0.2.7/tests/test_v024_features.py +553 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/__init__.py +36 -5
- tgraphx-0.2.7/tgraphx/core/hetero_batch.py +296 -0
- tgraphx-0.2.7/tgraphx/core/hetero_graph.py +292 -0
- tgraphx-0.2.7/tgraphx/core/temporal.py +142 -0
- tgraphx-0.2.7/tgraphx/core/temporal_batch.py +180 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/dashboard/app.py +113 -17
- tgraphx-0.2.7/tgraphx/distributed.py +109 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/graph_builders.py +309 -54
- tgraphx-0.2.7/tgraphx/interop.py +440 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/layers/factory.py +29 -1
- tgraphx-0.2.7/tgraphx/layers/gat.py +582 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/layers/gin.py +119 -10
- tgraphx-0.2.7/tgraphx/layers/graph_transformer.py +298 -0
- tgraphx-0.2.7/tgraphx/layers/hetero.py +241 -0
- tgraphx-0.2.7/tgraphx/layers/hetero_readout.py +148 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/layers/sage.py +160 -23
- tgraphx-0.2.7/tgraphx/layers/temporal_readout.py +88 -0
- tgraphx-0.2.7/tgraphx/layers/transformer_encodings.py +181 -0
- tgraphx-0.2.7/tgraphx/learned_graph.py +251 -0
- tgraphx-0.2.7/tgraphx/models/hetero_models.py +200 -0
- tgraphx-0.2.7/tgraphx/models/temporal_models.py +149 -0
- tgraphx-0.2.7/tgraphx/sampling.py +482 -0
- tgraphx-0.2.7/tgraphx/sampling_loaders.py +201 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/tracking.py +162 -5
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx.egg-info/PKG-INFO +78 -44
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx.egg-info/SOURCES.txt +28 -1
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx.egg-info/requires.txt +3 -0
- tgraphx-0.2.2/tgraphx/layers/gat.py +0 -344
- {tgraphx-0.2.2 → tgraphx-0.2.7}/LICENSE +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/setup.cfg +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tests/test_3d_support.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tests/test_amp_compile.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tests/test_dashboard.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tests/test_devices.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tests/test_edge_features.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tests/test_edge_weight.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tests/test_factories.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tests/test_gnn_families.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tests/test_gradients.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tests/test_graph.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tests/test_graph_api.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tests/test_graph_builders.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tests/test_imports.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tests/test_layers.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tests/test_math.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tests/test_models.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tests/test_packaging.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tests/test_performance_smoke.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tests/test_tracking.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tests/test_training.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/core/__init__.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/core/dataloader.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/core/graph.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/core/graph_utils.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/core/utils.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/dashboard/__init__.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/dashboard/__main__.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/dashboard/static/dashboard.css +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/dashboard/static/dashboard.js +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/layers/__init__.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/layers/_dim.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/layers/_scatter.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/layers/aggregator.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/layers/attention_message.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/layers/base.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/layers/conv_message.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/layers/safe_pool.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/models/__init__.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/models/cnn_encoder.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/models/cnn_gnn_model.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/models/edge_predictor.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/models/factory.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/models/graph_classifier.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/models/node_classifier.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/models/pre_encoder.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/models/regressors.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/performance.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx/training.py +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx.egg-info/dependency_links.txt +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx.egg-info/entry_points.txt +0 -0
- {tgraphx-0.2.2 → tgraphx-0.2.7}/tgraphx.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tgraphx
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.7
|
|
4
4
|
Summary: Tensor-aware graph neural networks preserving spatial node feature layouts
|
|
5
5
|
Author-email: Arash Sajjadi <arash.sajjadi@usask.ca>
|
|
6
6
|
Maintainer-email: Arash Sajjadi <arash.sajjadi@usask.ca>
|
|
@@ -41,6 +41,8 @@ Requires-Dist: psutil>=5.9; extra == "monitoring"
|
|
|
41
41
|
Requires-Dist: pynvml>=11.0; extra == "monitoring"
|
|
42
42
|
Provides-Extra: tracking
|
|
43
43
|
Requires-Dist: tensorboard>=2.11; extra == "tracking"
|
|
44
|
+
Provides-Extra: mlflow
|
|
45
|
+
Requires-Dist: mlflow>=2.0; extra == "mlflow"
|
|
44
46
|
Dynamic: license-file
|
|
45
47
|
|
|
46
48
|
<p align="center">
|
|
@@ -119,28 +121,46 @@ drop-in clones of PyTorch Geometric's vector-feature implementations.
|
|
|
119
121
|
| Dataset & loader | `GraphDataset`, `GraphDataLoader` | — | Wraps `torch.utils.data` |
|
|
120
122
|
| Utilities | `load_config`, `get_device` | — | YAML/JSON config; CUDA→MPS→CPU |
|
|
121
123
|
|
|
122
|
-
##
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
124
|
+
## Current scope and boundaries
|
|
125
|
+
|
|
126
|
+
TGraphX is a focused library for tensor-aware patch-graph GNNs. Here is
|
|
127
|
+
what is stable, what is experimental, and what is intentionally out of scope.
|
|
128
|
+
Full details are in [docs/limitations.md](docs/limitations.md) and
|
|
129
|
+
[docs/roadmap.md](docs/roadmap.md).
|
|
130
|
+
|
|
131
|
+
### Optional and experimental features (v0.2.4)
|
|
132
|
+
|
|
133
|
+
| Feature | Status | Install / usage |
|
|
134
|
+
|---------|:------:|---------|
|
|
135
|
+
| `TensorGATLayer(attention_mode="channel")` | 🧪 Experimental | Constructor argument |
|
|
136
|
+
| `TensorGATLayer(chunk_size=K)` forward | ✅ Stable | `forward(chunk_size=K)` |
|
|
137
|
+
| `GraphTransformerLayer` (vector features only) | 🧪 Experimental | `from tgraphx.layers.graph_transformer import GraphTransformerLayer` |
|
|
138
|
+
| `HeteroGraph` container | 🧪 Experimental | `from tgraphx.core.hetero_graph import HeteroGraph` |
|
|
139
|
+
| `TemporalGraphSequence` container | 🧪 Experimental | `from tgraphx.core.temporal import TemporalGraphSequence` |
|
|
140
|
+
| `MLflowLogger` | ✅ Opt-in | `pip install mlflow` or `pip install "tgraphx[mlflow]"` |
|
|
141
|
+
| PyG / DGL converters | ✅ Opt-in | `from tgraphx.interop import to_pyg_data, to_dgl_graph, …` |
|
|
142
|
+
| Learned graph helpers | ✅ Stable | `from tgraphx.learned_graph import soft_adjacency_from_embeddings, EdgeScorer, …` |
|
|
143
|
+
| Patch helper `padding="auto"` | ✅ Stable | `image_to_patches(imgs, ps, padding="auto")` |
|
|
144
|
+
| Hardware monitoring dashboard | 🔒 Opt-in | `pip install "tgraphx[monitoring]"` |
|
|
145
|
+
| TensorBoard logging | 🔒 Opt-in | `pip install "tgraphx[tracking]"` |
|
|
146
|
+
|
|
147
|
+
### Scope boundaries (design decisions, not bugs)
|
|
148
|
+
|
|
149
|
+
- **Spatial / volumetric node features:** `[N, D]`, `[N, C, H, W]`, and
|
|
150
|
+
`[N, C, D, H, W]` are the supported shapes. Arbitrary-rank tensors
|
|
151
|
+
(rank ≥ 5 node features) are out of scope.
|
|
152
|
+
- **GAT per-pixel / per-voxel attention:** score tensors would be
|
|
153
|
+
`O(E × K × H × W)` — prohibitive for typical spatial GNN workloads.
|
|
154
|
+
Planned for a future release after memory analysis.
|
|
155
|
+
- **Full hetero/temporal GNN layers:** `HeteroGraph` and
|
|
156
|
+
`TemporalGraphSequence` are *containers*, not GNN implementations.
|
|
157
|
+
- **PyG/DGL drop-in compatibility:** TGraphX is not a replacement.
|
|
158
|
+
The optional converters transfer data only; APIs differ.
|
|
159
|
+
- **Neighbor sampling, distributed training, multi-GPU:** out of scope
|
|
160
|
+
for the current release.
|
|
161
|
+
- **Profiling and file writes:** disabled by default; all are opt-in.
|
|
162
|
+
|
|
163
|
+
See [docs/roadmap.md](docs/roadmap.md) for the v0.2.5+ planned items.
|
|
144
164
|
|
|
145
165
|
---
|
|
146
166
|
|
|
@@ -216,8 +236,8 @@ out = layer(x, edge_index, chunk_size=512)
|
|
|
216
236
|
```
|
|
217
237
|
|
|
218
238
|
Supported aggregations: `"sum"` and `"mean"`. `"max"` falls back to the
|
|
219
|
-
standard path with a warning.
|
|
220
|
-
|
|
239
|
+
standard path with a warning. SAGE and GIN also support `chunk_size`; GAT
|
|
240
|
+
uses a two-pass algorithm — all four layers accept `chunk_size` in `forward()`.
|
|
221
241
|
|
|
222
242
|
### Hardware compatibility
|
|
223
243
|
|
|
@@ -291,7 +311,15 @@ with TensorBoardLogger("runs/tb") as tb:
|
|
|
291
311
|
```
|
|
292
312
|
|
|
293
313
|
Nothing is written unless you explicitly pass a logger.
|
|
294
|
-
|
|
314
|
+
|
|
315
|
+
### MLflow logging (optional)
|
|
316
|
+
|
|
317
|
+
```python
|
|
318
|
+
from tgraphx.tracking import MLflowLogger # pip install mlflow
|
|
319
|
+
|
|
320
|
+
with MLflowLogger(run_name="my_run", experiment="gnn") as mlf:
|
|
321
|
+
history = fit(model, train_loader, logger=mlf, ...)
|
|
322
|
+
```
|
|
295
323
|
|
|
296
324
|
---
|
|
297
325
|
|
|
@@ -1182,7 +1210,7 @@ batch.to(device)
|
|
|
1182
1210
|
| ✅ Stable | Tested in CI; API is stable |
|
|
1183
1211
|
| 🧪 Experimental | Available but not yet guaranteed-stable |
|
|
1184
1212
|
| ⚠️ Best-effort | Works in practice; known constraints documented |
|
|
1185
|
-
| ⏳ Planned | On roadmap
|
|
1213
|
+
| ⏳ Planned | On roadmap for a future release |
|
|
1186
1214
|
| ❌ Not supported | Out of scope for the current release |
|
|
1187
1215
|
| 🔒 Opt-in | Disabled by default; explicitly enabled by the user |
|
|
1188
1216
|
|
|
@@ -1216,12 +1244,12 @@ batch.to(device)
|
|
|
1216
1244
|
| Vector edge features `[E, D_e]` | ✅ Stable | GAT, SAGE, GIN |
|
|
1217
1245
|
| Spatial edge features `[E, C_e, H, W]` | ⚠️ Best-effort | ConvMP (concat); GAT (mean-pooled); SAGE/GIN (full) |
|
|
1218
1246
|
| Volumetric edge features `[E, C_e, D, H, W]` | ⚠️ Best-effort | Same as spatial; `spatial_rank=3` |
|
|
1219
|
-
| Graph Transformer
|
|
1220
|
-
| Heterogeneous graphs
|
|
1221
|
-
| Temporal graphs
|
|
1222
|
-
| Learned graph construction
|
|
1223
|
-
| PyG/DGL converters |
|
|
1224
|
-
| MLflowLogger |
|
|
1247
|
+
| Graph Transformer (vector node features only) | 🧪 Experimental | `tgraphx.layers.graph_transformer.GraphTransformerLayer`; tensor-aware variant ⏳ planned |
|
|
1248
|
+
| Heterogeneous graphs (container + batch + HeteroConv + classifiers) | 🧪 Experimental | `HeteroGraph`, `HeteroGraphBatch`, `HeteroConv`, `HeteroGraphClassifier`, `HeteroNodeClassifier`; vector node features only |
|
|
1249
|
+
| Temporal graphs (container + batch + readout + classifier) | 🧪 Experimental | `TemporalGraphSequence`, `TemporalGraphBatch`, `temporal_readout`, `TemporalGraphClassifier`/`Regressor`; snapshot-loop pattern, no recurrent memory module |
|
|
1250
|
+
| Learned graph construction (soft adjacency, edge scorer) | ✅ Stable | `tgraphx.learned_graph` — discrete top-k is non-differentiable |
|
|
1251
|
+
| PyG / DGL converters | ✅ Opt-in | `tgraphx.interop` — data converters only, not API replacement |
|
|
1252
|
+
| MLflowLogger | ✅ Opt-in | Lazy `mlflow` import; `pip install "tgraphx[mlflow]"` |
|
|
1225
1253
|
| Dashboard | 🔒 Opt-in | Launch explicitly; zero overhead when off |
|
|
1226
1254
|
| Offline dashboard export | ✅ Stable | `--export-html` or `export_dashboard_html()` |
|
|
1227
1255
|
| Multi-run dashboard | ✅ Stable | Point `--logdir` at parent directory |
|
|
@@ -1233,15 +1261,19 @@ batch.to(device)
|
|
|
1233
1261
|
| Feature | Status | Notes |
|
|
1234
1262
|
|---------|:------:|-------|
|
|
1235
1263
|
| `ConvMessagePassing` chunked forward | ✅ Stable | `aggr="sum"` / `"mean"`; max falls back with warning |
|
|
1236
|
-
| `TensorGraphSAGELayer` chunked forward |
|
|
1237
|
-
| `TensorGINLayer` chunked forward |
|
|
1238
|
-
| `TensorGATLayer` chunked forward | ⏳ Planned v0.2.
|
|
1264
|
+
| `TensorGraphSAGELayer` chunked forward | ✅ Stable | v0.2.3; mean and max; pass `chunk_size=K` to `forward()` |
|
|
1265
|
+
| `TensorGINLayer` chunked forward | ✅ Stable | v0.2.3; sum aggregation; pass `chunk_size=K` to `forward()` |
|
|
1266
|
+
| `TensorGATLayer` chunked forward | ⏳ Planned v0.2.4 | Requires two-pass algorithm for destination-wise softmax |
|
|
1239
1267
|
| `build_grid_graph` / `build_grid_graph_3d` | ✅ Stable | O(E) — scales well |
|
|
1240
1268
|
| `build_random_graph` | ✅ Stable | O(E) — scales well |
|
|
1241
|
-
| `build_knn_graph` / `build_radius_graph` | ⚠️ Best-effort | O(N²)
|
|
1242
|
-
| `build_fully_connected_graph`
|
|
1243
|
-
|
|
|
1244
|
-
|
|
|
1269
|
+
| `build_knn_graph` / `build_radius_graph` | ⚠️ Best-effort | O(N²) time; `chunk_size=K` reduces peak memory to O(K×N) |
|
|
1270
|
+
| `build_fully_connected_graph` | ⚠️ Best-effort | O(N²) edges; N > 5 000 emits warning |
|
|
1271
|
+
| `build_iou_graph` | ⚠️ Best-effort | O(N²) IoU; `chunk_size=K` reduces peak memory to O(K×N) |
|
|
1272
|
+
| `build_random_graph` | ✅ Stable | `algorithm="sample"` uses O(num_edges) memory for large N |
|
|
1273
|
+
| Dashboard metrics API | ✅ Stable | Incremental `?since_row=N`; `--max-metric-rows` cap; byte-seek tail-read (v0.2.3) |
|
|
1274
|
+
| Large `metrics.csv` tail-read | ✅ Stable | v0.2.3: byte-seek on append; full reparse on rotation/truncation |
|
|
1275
|
+
| Subgraph / k-hop / neighbour sampling | ✅ Stable v0.2.6 | `tgraphx.sampling` + `SubgraphDataLoader` / `NeighborSamplerLoader` |
|
|
1276
|
+
| Distributed helpers (rank-zero, barrier) | ✅ Stable v0.2.6 | `tgraphx.distributed`; never auto-initialises DDP |
|
|
1245
1277
|
|
|
1246
1278
|
> ⚠️ **Scalability warning:** `build_knn_graph`, `build_radius_graph`, `build_fully_connected_graph`,
|
|
1247
1279
|
> and `build_iou_graph` use pairwise `torch.cdist` or enumerate all pairs. Memory and time grow as
|
|
@@ -1287,10 +1319,12 @@ batch.to(device)
|
|
|
1287
1319
|
| 3-D / volumetric node features | ✅ `ConvMessagePassing`, `TensorGATLayer`, `TensorGraphSAGELayer`, `TensorGINLayer` | ✅ | `[N, C, D, H, W]`; pass `spatial_rank=3` to GAT/SAGE/GIN, or `(C, D, H, W)` `in_shape` to `ConvMessagePassing`. `DeepCNNAggregator` is rank-aware. `LinearMessagePassing` covers vector `[N, D]` and is unaffected. |
|
|
1288
1320
|
| Edge-conditioned MP (vector) | ✅ `TensorGATLayer`, `TensorGraphSAGELayer`, `TensorGINLayer` | ✅ | edge features `[E, D_e]`; `edge_features_kind="vector"` |
|
|
1289
1321
|
| `aggr="sum"\|"mean"\|"max"` base | ✅ all three modes | ✅ hand-computed + backward | `ConvMessagePassing` `aggr="max"` routes through `scatter_max` |
|
|
1290
|
-
| Graph Transformer |
|
|
1291
|
-
| Heterogeneous graphs |
|
|
1292
|
-
| Temporal
|
|
1293
|
-
| Learned graph construction |
|
|
1322
|
+
| Graph Transformer (vector features) | 🧪 `GraphTransformerLayer` | ✅ | global multi-head self-attention `[N, D]`; O(N²); tensor-aware variant ⏳ planned |
|
|
1323
|
+
| Heterogeneous graphs | 🧪 `HeteroGraph`, `HeteroGraphBatch`, `HeteroConv`, `HeteroGraphClassifier`, `HeteroNodeClassifier` | ✅ | vector features; relation-dispatch wrapper + per-type classifier; full PyG-style layer zoo ⏳ planned |
|
|
1324
|
+
| Temporal graphs | 🧪 `TemporalGraphSequence`, `TemporalGraphBatch`, `temporal_readout`, `TemporalGraphClassifier`/`Regressor` | ✅ | snapshot loop + readout pattern; TGN/TGAT-style memory module ⏳ planned |
|
|
1325
|
+
| Learned graph construction | ✅ `tgraphx.learned_graph` | ✅ | soft adjacency, EdgeScorer (differentiable); top-k discrete (non-diff) |
|
|
1326
|
+
| PyG / DGL converters | ✅ `tgraphx.interop` | ✅ | data converters only (lazy imports); not an API replacement |
|
|
1327
|
+
| MLflowLogger | ✅ `tgraphx.tracking.MLflowLogger` | ✅ | lazy mlflow import; opt-in via `tgraphx[mlflow]` extra |
|
|
1294
1328
|
| Arbitrary-rank tensor support beyond rank 0 / 2 / 3 | ❌ | — | only vector, 2-D, and 3-D shapes are supported |
|
|
1295
1329
|
|
|
1296
1330
|
---
|
|
@@ -74,28 +74,46 @@ drop-in clones of PyTorch Geometric's vector-feature implementations.
|
|
|
74
74
|
| Dataset & loader | `GraphDataset`, `GraphDataLoader` | — | Wraps `torch.utils.data` |
|
|
75
75
|
| Utilities | `load_config`, `get_device` | — | YAML/JSON config; CUDA→MPS→CPU |
|
|
76
76
|
|
|
77
|
-
##
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
77
|
+
## Current scope and boundaries
|
|
78
|
+
|
|
79
|
+
TGraphX is a focused library for tensor-aware patch-graph GNNs. Here is
|
|
80
|
+
what is stable, what is experimental, and what is intentionally out of scope.
|
|
81
|
+
Full details are in [docs/limitations.md](docs/limitations.md) and
|
|
82
|
+
[docs/roadmap.md](docs/roadmap.md).
|
|
83
|
+
|
|
84
|
+
### Optional and experimental features (v0.2.4)
|
|
85
|
+
|
|
86
|
+
| Feature | Status | Install / usage |
|
|
87
|
+
|---------|:------:|---------|
|
|
88
|
+
| `TensorGATLayer(attention_mode="channel")` | 🧪 Experimental | Constructor argument |
|
|
89
|
+
| `TensorGATLayer(chunk_size=K)` forward | ✅ Stable | `forward(chunk_size=K)` |
|
|
90
|
+
| `GraphTransformerLayer` (vector features only) | 🧪 Experimental | `from tgraphx.layers.graph_transformer import GraphTransformerLayer` |
|
|
91
|
+
| `HeteroGraph` container | 🧪 Experimental | `from tgraphx.core.hetero_graph import HeteroGraph` |
|
|
92
|
+
| `TemporalGraphSequence` container | 🧪 Experimental | `from tgraphx.core.temporal import TemporalGraphSequence` |
|
|
93
|
+
| `MLflowLogger` | ✅ Opt-in | `pip install mlflow` or `pip install "tgraphx[mlflow]"` |
|
|
94
|
+
| PyG / DGL converters | ✅ Opt-in | `from tgraphx.interop import to_pyg_data, to_dgl_graph, …` |
|
|
95
|
+
| Learned graph helpers | ✅ Stable | `from tgraphx.learned_graph import soft_adjacency_from_embeddings, EdgeScorer, …` |
|
|
96
|
+
| Patch helper `padding="auto"` | ✅ Stable | `image_to_patches(imgs, ps, padding="auto")` |
|
|
97
|
+
| Hardware monitoring dashboard | 🔒 Opt-in | `pip install "tgraphx[monitoring]"` |
|
|
98
|
+
| TensorBoard logging | 🔒 Opt-in | `pip install "tgraphx[tracking]"` |
|
|
99
|
+
|
|
100
|
+
### Scope boundaries (design decisions, not bugs)
|
|
101
|
+
|
|
102
|
+
- **Spatial / volumetric node features:** `[N, D]`, `[N, C, H, W]`, and
|
|
103
|
+
`[N, C, D, H, W]` are the supported shapes. Arbitrary-rank tensors
|
|
104
|
+
(rank ≥ 5 node features) are out of scope.
|
|
105
|
+
- **GAT per-pixel / per-voxel attention:** score tensors would be
|
|
106
|
+
`O(E × K × H × W)` — prohibitive for typical spatial GNN workloads.
|
|
107
|
+
Planned for a future release after memory analysis.
|
|
108
|
+
- **Full hetero/temporal GNN layers:** `HeteroGraph` and
|
|
109
|
+
`TemporalGraphSequence` are *containers*, not GNN implementations.
|
|
110
|
+
- **PyG/DGL drop-in compatibility:** TGraphX is not a replacement.
|
|
111
|
+
The optional converters transfer data only; APIs differ.
|
|
112
|
+
- **Neighbor sampling, distributed training, multi-GPU:** out of scope
|
|
113
|
+
for the current release.
|
|
114
|
+
- **Profiling and file writes:** disabled by default; all are opt-in.
|
|
115
|
+
|
|
116
|
+
See [docs/roadmap.md](docs/roadmap.md) for the v0.2.5+ planned items.
|
|
99
117
|
|
|
100
118
|
---
|
|
101
119
|
|
|
@@ -171,8 +189,8 @@ out = layer(x, edge_index, chunk_size=512)
|
|
|
171
189
|
```
|
|
172
190
|
|
|
173
191
|
Supported aggregations: `"sum"` and `"mean"`. `"max"` falls back to the
|
|
174
|
-
standard path with a warning.
|
|
175
|
-
|
|
192
|
+
standard path with a warning. SAGE and GIN also support `chunk_size`; GAT
|
|
193
|
+
uses a two-pass algorithm — all four layers accept `chunk_size` in `forward()`.
|
|
176
194
|
|
|
177
195
|
### Hardware compatibility
|
|
178
196
|
|
|
@@ -246,7 +264,15 @@ with TensorBoardLogger("runs/tb") as tb:
|
|
|
246
264
|
```
|
|
247
265
|
|
|
248
266
|
Nothing is written unless you explicitly pass a logger.
|
|
249
|
-
|
|
267
|
+
|
|
268
|
+
### MLflow logging (optional)
|
|
269
|
+
|
|
270
|
+
```python
|
|
271
|
+
from tgraphx.tracking import MLflowLogger # pip install mlflow
|
|
272
|
+
|
|
273
|
+
with MLflowLogger(run_name="my_run", experiment="gnn") as mlf:
|
|
274
|
+
history = fit(model, train_loader, logger=mlf, ...)
|
|
275
|
+
```
|
|
250
276
|
|
|
251
277
|
---
|
|
252
278
|
|
|
@@ -1137,7 +1163,7 @@ batch.to(device)
|
|
|
1137
1163
|
| ✅ Stable | Tested in CI; API is stable |
|
|
1138
1164
|
| 🧪 Experimental | Available but not yet guaranteed-stable |
|
|
1139
1165
|
| ⚠️ Best-effort | Works in practice; known constraints documented |
|
|
1140
|
-
| ⏳ Planned | On roadmap
|
|
1166
|
+
| ⏳ Planned | On roadmap for a future release |
|
|
1141
1167
|
| ❌ Not supported | Out of scope for the current release |
|
|
1142
1168
|
| 🔒 Opt-in | Disabled by default; explicitly enabled by the user |
|
|
1143
1169
|
|
|
@@ -1171,12 +1197,12 @@ batch.to(device)
|
|
|
1171
1197
|
| Vector edge features `[E, D_e]` | ✅ Stable | GAT, SAGE, GIN |
|
|
1172
1198
|
| Spatial edge features `[E, C_e, H, W]` | ⚠️ Best-effort | ConvMP (concat); GAT (mean-pooled); SAGE/GIN (full) |
|
|
1173
1199
|
| Volumetric edge features `[E, C_e, D, H, W]` | ⚠️ Best-effort | Same as spatial; `spatial_rank=3` |
|
|
1174
|
-
| Graph Transformer
|
|
1175
|
-
| Heterogeneous graphs
|
|
1176
|
-
| Temporal graphs
|
|
1177
|
-
| Learned graph construction
|
|
1178
|
-
| PyG/DGL converters |
|
|
1179
|
-
| MLflowLogger |
|
|
1200
|
+
| Graph Transformer (vector node features only) | 🧪 Experimental | `tgraphx.layers.graph_transformer.GraphTransformerLayer`; tensor-aware variant ⏳ planned |
|
|
1201
|
+
| Heterogeneous graphs (container + batch + HeteroConv + classifiers) | 🧪 Experimental | `HeteroGraph`, `HeteroGraphBatch`, `HeteroConv`, `HeteroGraphClassifier`, `HeteroNodeClassifier`; vector node features only |
|
|
1202
|
+
| Temporal graphs (container + batch + readout + classifier) | 🧪 Experimental | `TemporalGraphSequence`, `TemporalGraphBatch`, `temporal_readout`, `TemporalGraphClassifier`/`Regressor`; snapshot-loop pattern, no recurrent memory module |
|
|
1203
|
+
| Learned graph construction (soft adjacency, edge scorer) | ✅ Stable | `tgraphx.learned_graph` — discrete top-k is non-differentiable |
|
|
1204
|
+
| PyG / DGL converters | ✅ Opt-in | `tgraphx.interop` — data converters only, not API replacement |
|
|
1205
|
+
| MLflowLogger | ✅ Opt-in | Lazy `mlflow` import; `pip install "tgraphx[mlflow]"` |
|
|
1180
1206
|
| Dashboard | 🔒 Opt-in | Launch explicitly; zero overhead when off |
|
|
1181
1207
|
| Offline dashboard export | ✅ Stable | `--export-html` or `export_dashboard_html()` |
|
|
1182
1208
|
| Multi-run dashboard | ✅ Stable | Point `--logdir` at parent directory |
|
|
@@ -1188,15 +1214,19 @@ batch.to(device)
|
|
|
1188
1214
|
| Feature | Status | Notes |
|
|
1189
1215
|
|---------|:------:|-------|
|
|
1190
1216
|
| `ConvMessagePassing` chunked forward | ✅ Stable | `aggr="sum"` / `"mean"`; max falls back with warning |
|
|
1191
|
-
| `TensorGraphSAGELayer` chunked forward |
|
|
1192
|
-
| `TensorGINLayer` chunked forward |
|
|
1193
|
-
| `TensorGATLayer` chunked forward | ⏳ Planned v0.2.
|
|
1217
|
+
| `TensorGraphSAGELayer` chunked forward | ✅ Stable | v0.2.3; mean and max; pass `chunk_size=K` to `forward()` |
|
|
1218
|
+
| `TensorGINLayer` chunked forward | ✅ Stable | v0.2.3; sum aggregation; pass `chunk_size=K` to `forward()` |
|
|
1219
|
+
| `TensorGATLayer` chunked forward | ⏳ Planned v0.2.4 | Requires two-pass algorithm for destination-wise softmax |
|
|
1194
1220
|
| `build_grid_graph` / `build_grid_graph_3d` | ✅ Stable | O(E) — scales well |
|
|
1195
1221
|
| `build_random_graph` | ✅ Stable | O(E) — scales well |
|
|
1196
|
-
| `build_knn_graph` / `build_radius_graph` | ⚠️ Best-effort | O(N²)
|
|
1197
|
-
| `build_fully_connected_graph`
|
|
1198
|
-
|
|
|
1199
|
-
|
|
|
1222
|
+
| `build_knn_graph` / `build_radius_graph` | ⚠️ Best-effort | O(N²) time; `chunk_size=K` reduces peak memory to O(K×N) |
|
|
1223
|
+
| `build_fully_connected_graph` | ⚠️ Best-effort | O(N²) edges; N > 5 000 emits warning |
|
|
1224
|
+
| `build_iou_graph` | ⚠️ Best-effort | O(N²) IoU; `chunk_size=K` reduces peak memory to O(K×N) |
|
|
1225
|
+
| `build_random_graph` | ✅ Stable | `algorithm="sample"` uses O(num_edges) memory for large N |
|
|
1226
|
+
| Dashboard metrics API | ✅ Stable | Incremental `?since_row=N`; `--max-metric-rows` cap; byte-seek tail-read (v0.2.3) |
|
|
1227
|
+
| Large `metrics.csv` tail-read | ✅ Stable | v0.2.3: byte-seek on append; full reparse on rotation/truncation |
|
|
1228
|
+
| Subgraph / k-hop / neighbour sampling | ✅ Stable v0.2.6 | `tgraphx.sampling` + `SubgraphDataLoader` / `NeighborSamplerLoader` |
|
|
1229
|
+
| Distributed helpers (rank-zero, barrier) | ✅ Stable v0.2.6 | `tgraphx.distributed`; never auto-initialises DDP |
|
|
1200
1230
|
|
|
1201
1231
|
> ⚠️ **Scalability warning:** `build_knn_graph`, `build_radius_graph`, `build_fully_connected_graph`,
|
|
1202
1232
|
> and `build_iou_graph` use pairwise `torch.cdist` or enumerate all pairs. Memory and time grow as
|
|
@@ -1242,10 +1272,12 @@ batch.to(device)
|
|
|
1242
1272
|
| 3-D / volumetric node features | ✅ `ConvMessagePassing`, `TensorGATLayer`, `TensorGraphSAGELayer`, `TensorGINLayer` | ✅ | `[N, C, D, H, W]`; pass `spatial_rank=3` to GAT/SAGE/GIN, or `(C, D, H, W)` `in_shape` to `ConvMessagePassing`. `DeepCNNAggregator` is rank-aware. `LinearMessagePassing` covers vector `[N, D]` and is unaffected. |
|
|
1243
1273
|
| Edge-conditioned MP (vector) | ✅ `TensorGATLayer`, `TensorGraphSAGELayer`, `TensorGINLayer` | ✅ | edge features `[E, D_e]`; `edge_features_kind="vector"` |
|
|
1244
1274
|
| `aggr="sum"\|"mean"\|"max"` base | ✅ all three modes | ✅ hand-computed + backward | `ConvMessagePassing` `aggr="max"` routes through `scatter_max` |
|
|
1245
|
-
| Graph Transformer |
|
|
1246
|
-
| Heterogeneous graphs |
|
|
1247
|
-
| Temporal
|
|
1248
|
-
| Learned graph construction |
|
|
1275
|
+
| Graph Transformer (vector features) | 🧪 `GraphTransformerLayer` | ✅ | global multi-head self-attention `[N, D]`; O(N²); tensor-aware variant ⏳ planned |
|
|
1276
|
+
| Heterogeneous graphs | 🧪 `HeteroGraph`, `HeteroGraphBatch`, `HeteroConv`, `HeteroGraphClassifier`, `HeteroNodeClassifier` | ✅ | vector features; relation-dispatch wrapper + per-type classifier; full PyG-style layer zoo ⏳ planned |
|
|
1277
|
+
| Temporal graphs | 🧪 `TemporalGraphSequence`, `TemporalGraphBatch`, `temporal_readout`, `TemporalGraphClassifier`/`Regressor` | ✅ | snapshot loop + readout pattern; TGN/TGAT-style memory module ⏳ planned |
|
|
1278
|
+
| Learned graph construction | ✅ `tgraphx.learned_graph` | ✅ | soft adjacency, EdgeScorer (differentiable); top-k discrete (non-diff) |
|
|
1279
|
+
| PyG / DGL converters | ✅ `tgraphx.interop` | ✅ | data converters only (lazy imports); not an API replacement |
|
|
1280
|
+
| MLflowLogger | ✅ `tgraphx.tracking.MLflowLogger` | ✅ | lazy mlflow import; opt-in via `tgraphx[mlflow]` extra |
|
|
1249
1281
|
| Arbitrary-rank tensor support beyond rank 0 / 2 / 3 | ❌ | — | only vector, 2-D, and 3-D shapes are supported |
|
|
1250
1282
|
|
|
1251
1283
|
---
|
|
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "tgraphx"
|
|
7
7
|
# Keep this in sync with tgraphx/__init__.py::__version__
|
|
8
|
-
version = "0.2.
|
|
8
|
+
version = "0.2.7"
|
|
9
9
|
description = "Tensor-aware graph neural networks preserving spatial node feature layouts"
|
|
10
10
|
readme = "README.md"
|
|
11
11
|
requires-python = ">=3.9"
|
|
@@ -62,11 +62,13 @@ monitoring = [
|
|
|
62
62
|
"pynvml>=11.0",
|
|
63
63
|
]
|
|
64
64
|
# TensorBoard integration (TensorBoardLogger uses torch.utils.tensorboard)
|
|
65
|
-
# MLflow is intentionally excluded: MLflowLogger is not implemented in TGraphX.
|
|
66
|
-
# Install mlflow separately if you use it: pip install mlflow
|
|
67
65
|
tracking = [
|
|
68
66
|
"tensorboard>=2.11",
|
|
69
67
|
]
|
|
68
|
+
# MLflow integration (MLflowLogger uses lazy mlflow import)
|
|
69
|
+
mlflow = [
|
|
70
|
+
"mlflow>=2.0",
|
|
71
|
+
]
|
|
70
72
|
|
|
71
73
|
[project.scripts]
|
|
72
74
|
tgraphx-dashboard = "tgraphx.dashboard.app:main"
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
"""Backward-compatibility regression tests (v0.2 → v0.3 prep).
|
|
2
|
+
|
|
3
|
+
Covers the stable public surface listed in ``docs/deprecation_policy.md``
|
|
4
|
+
and ``docs/migration_v0_2_to_v0_3.md``.
|
|
5
|
+
"""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import inspect
|
|
9
|
+
|
|
10
|
+
import pytest
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
# ── Stable top-level exports ─────────────────────────────────────────────────
|
|
15
|
+
|
|
16
|
+
def test_stable_imports_work():
|
|
17
|
+
from tgraphx import (
|
|
18
|
+
Graph, GraphBatch, GraphDataLoader, GraphDataset,
|
|
19
|
+
build_model, build_model_from_config, make_layer,
|
|
20
|
+
fit, train_epoch, evaluate, set_seed,
|
|
21
|
+
save_checkpoint, load_checkpoint,
|
|
22
|
+
CSVLogger, TensorBoardLogger,
|
|
23
|
+
env_report, recommended_device, estimate_message_memory,
|
|
24
|
+
GraphClassifier, NodeClassifier, EdgePredictor,
|
|
25
|
+
NodeRegressor, GraphRegressor,
|
|
26
|
+
ConvMessagePassing, TensorGATLayer, TensorGraphSAGELayer,
|
|
27
|
+
TensorGINLayer, LinearMessagePassing,
|
|
28
|
+
AttentionMessagePassing,
|
|
29
|
+
build_grid_graph, build_grid_graph_3d,
|
|
30
|
+
build_fully_connected_graph, build_knn_graph, build_radius_graph,
|
|
31
|
+
build_iou_graph, build_random_graph,
|
|
32
|
+
image_to_patches, volume_to_patches,
|
|
33
|
+
induced_subgraph, k_hop_subgraph, neighbor_sample,
|
|
34
|
+
SubgraphDataLoader, NeighborSamplerLoader,
|
|
35
|
+
write_graph_stats,
|
|
36
|
+
)
|
|
37
|
+
# If we reach here, all stable imports succeeded.
|
|
38
|
+
assert True
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def test_stable_graph_constructor_signature():
|
|
42
|
+
"""Old keyword args must still work."""
|
|
43
|
+
import torch
|
|
44
|
+
from tgraphx import Graph
|
|
45
|
+
x = torch.randn(3, 4)
|
|
46
|
+
ei = torch.tensor([[0, 1], [1, 2]], dtype=torch.long)
|
|
47
|
+
ew = torch.tensor([1.0, 0.5])
|
|
48
|
+
ef = torch.randn(2, 3)
|
|
49
|
+
nl = torch.tensor([0, 1, 2])
|
|
50
|
+
el = torch.tensor([0, 1])
|
|
51
|
+
gl = torch.tensor(7)
|
|
52
|
+
g = Graph(
|
|
53
|
+
node_features=x,
|
|
54
|
+
edge_index=ei,
|
|
55
|
+
edge_weight=ew,
|
|
56
|
+
edge_features=ef,
|
|
57
|
+
node_labels=nl,
|
|
58
|
+
edge_labels=el,
|
|
59
|
+
graph_label=gl,
|
|
60
|
+
metadata={"k": "v"},
|
|
61
|
+
)
|
|
62
|
+
assert g.num_nodes == 3
|
|
63
|
+
assert g.num_edges == 2
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def test_stable_layer_constructors():
|
|
67
|
+
"""All four spatial layers + LinearMessagePassing keep working."""
|
|
68
|
+
from tgraphx.layers import (
|
|
69
|
+
ConvMessagePassing, TensorGATLayer, TensorGraphSAGELayer,
|
|
70
|
+
TensorGINLayer, LinearMessagePassing,
|
|
71
|
+
)
|
|
72
|
+
cm = ConvMessagePassing((4, 4, 4), (4, 4, 4))
|
|
73
|
+
gat = TensorGATLayer(4, 4, num_heads=2)
|
|
74
|
+
sage = TensorGraphSAGELayer(4, 4)
|
|
75
|
+
gin = TensorGINLayer(4, 4)
|
|
76
|
+
lin = LinearMessagePassing((4,), (4,))
|
|
77
|
+
for m in (cm, gat, sage, gin, lin):
|
|
78
|
+
assert hasattr(m, "forward")
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def test_stable_factory_signatures():
|
|
82
|
+
from tgraphx import make_layer, build_model
|
|
83
|
+
# Old-style factory calls.
|
|
84
|
+
l = make_layer("gat", in_shape=(4, 4, 4), out_shape=(4, 4, 4), heads=2)
|
|
85
|
+
assert l is not None
|
|
86
|
+
m = build_model(
|
|
87
|
+
task="graph_classification",
|
|
88
|
+
layer="gat",
|
|
89
|
+
in_shape=(4, 4, 4), hidden_shape=(8, 4, 4),
|
|
90
|
+
num_layers=2, num_classes=3, heads=2, pooling="mean",
|
|
91
|
+
)
|
|
92
|
+
assert hasattr(m, "forward")
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def test_stable_training_helpers():
|
|
96
|
+
from tgraphx import set_seed, fit, train_epoch, evaluate
|
|
97
|
+
assert callable(set_seed)
|
|
98
|
+
assert callable(fit)
|
|
99
|
+
assert callable(train_epoch)
|
|
100
|
+
assert callable(evaluate)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def test_stable_logger_classes():
|
|
104
|
+
from tgraphx import CSVLogger, TensorBoardLogger
|
|
105
|
+
assert inspect.isclass(CSVLogger)
|
|
106
|
+
assert inspect.isclass(TensorBoardLogger)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
# ── Experimental imports — should still be available ─────────────────────────
|
|
110
|
+
|
|
111
|
+
def test_experimental_imports_available():
|
|
112
|
+
from tgraphx import (
|
|
113
|
+
HeteroGraph, HeteroGraphBatch,
|
|
114
|
+
TemporalGraphSequence, TemporalGraphBatch,
|
|
115
|
+
MLflowLogger,
|
|
116
|
+
)
|
|
117
|
+
assert all(inspect.isclass(c) for c in (
|
|
118
|
+
HeteroGraph, HeteroGraphBatch,
|
|
119
|
+
TemporalGraphSequence, TemporalGraphBatch,
|
|
120
|
+
MLflowLogger,
|
|
121
|
+
))
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def test_experimental_modules_importable():
|
|
125
|
+
import tgraphx.interop # noqa: F401
|
|
126
|
+
import tgraphx.learned_graph # noqa: F401
|
|
127
|
+
import tgraphx.distributed # noqa: F401
|
|
128
|
+
import tgraphx.sampling # noqa: F401
|
|
129
|
+
import tgraphx.sampling_loaders # noqa: F401
|
|
130
|
+
from tgraphx.layers.hetero import HeteroConv # noqa: F401
|
|
131
|
+
from tgraphx.layers.graph_transformer import GraphTransformerLayer # noqa: F401
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
# ── No optional heavy import on package import ───────────────────────────────
|
|
135
|
+
|
|
136
|
+
def test_no_eager_optional_imports():
|
|
137
|
+
"""Importing tgraphx must not pull in mlflow/torch_geometric/dgl/tensorboard."""
|
|
138
|
+
import subprocess, sys
|
|
139
|
+
code = (
|
|
140
|
+
"import tgraphx, sys; "
|
|
141
|
+
"for m in ('mlflow', 'torch_geometric', 'dgl', 'tensorboard'): "
|
|
142
|
+
" assert m not in sys.modules, f'eager import: {m}'"
|
|
143
|
+
)
|
|
144
|
+
# Multi-line in subprocess via -c is awkward with for; use exec-friendly.
|
|
145
|
+
code = (
|
|
146
|
+
"import tgraphx, sys\n"
|
|
147
|
+
"for m in ('mlflow', 'torch_geometric', 'dgl', 'tensorboard'):\n"
|
|
148
|
+
" assert m not in sys.modules, f'eager import: {m}'\n"
|
|
149
|
+
"print('OK')\n"
|
|
150
|
+
)
|
|
151
|
+
result = subprocess.run([sys.executable, "-c", code],
|
|
152
|
+
capture_output=True, text=True)
|
|
153
|
+
assert result.returncode == 0, result.stderr
|