tgraphx 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tgraphx-0.0.1/PKG-INFO ADDED
@@ -0,0 +1,425 @@
1
+ Metadata-Version: 2.4
2
+ Name: tgraphx
3
+ Version: 0.0.1
4
+ Summary: Early placeholder for TGraphX PyPI reservation
5
+ Author: Arash Sajjadi
6
+ Author-email: Arash Sajjadi <arash.sajjadi@usask.ca>
7
+ License: MIT
8
+ Requires-Python: >=3.8
9
+ Description-Content-Type: text/markdown
10
+ Dynamic: author
11
+ Dynamic: requires-python
12
+
13
+
14
+
15
+ # TGraphX
16
+
17
+ TGraphX is a **PyTorch**-based framework for building Graph Neural Networks (GNNs) that work with node and edge features of any dimension while retaining their **spatial layout**. The code is designed for flexibility, easy GPU-acceleration, and rapid prototyping of new GNN ideas, **especially** those that need to preserve local spatial details (e.g., image or volumetric patches).
18
+
19
+ đź“„ **Preprint**: [TGraphX: Tensor-Aware Graph Neural Network for Multi-Dimensional Feature Learning](https://arxiv.org/abs/2504.03953)
20
+ ✏️ *Authors: Arash Sajjadi, Mark Eramian*
21
+ 🗓️ *Published on arXiv, April 2025*
22
+
23
+
24
+ > **Note:** TGraphX includes optional skip connections that help with
25
+ > stable gradient flow in deeper GNN stacks. The overall design is rooted
26
+ > in rigorous theoretical and practical foundations, aiming to unify
27
+ > convolutional neural networks (CNNs) with GNN-based relational reasoning.
28
+
29
+ ---
30
+ ## Table of Contents
31
+
32
+ - [Overview](#overview)
33
+ - [Key Features](#key-features)
34
+ - [Architecture Highlights](#architecture-highlights)
35
+ - [Preserving Spatial Fidelity](#preserving-spatial-fidelity)
36
+ - [Convolution-Based Message Passing](#convolution-based-message-passing)
37
+ - [Deep CNN Aggregator with Residuals](#deep-cnn-aggregator-with-residuals)
38
+ - [End-to-End Differentiability](#end-to-end-differentiability)
39
+ - [Future Works](#future-works)
40
+ - [Installation](#installation)
41
+ - [Folder Structure](#folder-structure)
42
+ - [Core Components](#core-components)
43
+ - [Layers](#layers)
44
+ - [Models](#models)
45
+ - [Configuration Options](#configuration-options)
46
+ - [Advanced Topics](#advanced-topics)
47
+ - [Novelties and Contributions](#novelties-and-contributions)
48
+ - [Conclusion](#conclusion)
49
+ - [Citations](#citations)
50
+ - [License](#license)
51
+
52
+
53
+ ---
54
+
55
+ ## Overview
56
+
57
+ TGraphX provides a modular way to create GNNs by combining several components:
58
+
59
+ 1. **Graph Representation**
60
+ A `Graph` class holds node features, edge indices, and optional edge features. Unlike traditional GNNs where node features are vectors, TGraphX supports multi-dimensional features such as `[C, H, W]` tensors—making it particularly effective for vision tasks.
61
+
62
+ 2. **Message Passing Layers**
63
+ Customizable layers process messages between nodes *while preserving the spatial layout of features*. In TGraphX:
64
+ - **ConvMessagePassing** uses `1Ă—1` convolutions on concatenated spatial features (e.g., `Conv1Ă—1(Concat(Xi, Xj, Eij))`).
65
+ - **DeepCNNAggregator** is a deep CNN (default 4 layers) that refines aggregated messages, keeping their spatial structure intact (i.e., `[C, H, W]` shape).
66
+
67
+ 3. **Models**
68
+ Pre-built models combine a CNN encoder with GNN layers:
69
+ - **CNN Encoder** processes raw image patches into spatial feature maps (e.g., `[C, H, W]`).
70
+ - **Optional Pre-Encoder** (e.g., ResNet-like) can be enabled to further refine raw patches before the main CNN encoder.
71
+ - **Unified CNN‑GNN Model** uses CNN encoders for local features and GNN layers for global relational reasoning, then pools the result for final classification.
72
+ - An extra *skip connection* (if enabled) merges the raw CNN patch output with the GNN output for better gradient flow and more stable learning.
73
+
74
+ ---
75
+ ## Key Features
76
+
77
+ - **Support for Arbitrary Dimensions**
78
+ Handle vectors, 2D images, or even volumetric 3D patches as node features.
79
+
80
+ - **Spatial Message Passing**
81
+ Messages preserve spatial dimensions (e.g., `[C, H, W]`), letting convolutional filters capture local patterns and avoid destructive flattening of features.
82
+
83
+ - **Deep Aggregation**
84
+ A deep CNN aggregator (with multiple `3×3` convolutions, batch normalization, dropout, and ReLU) refines messages across multiple hops, enabling better local–global fusion.
85
+
86
+ - **Optional Pre‑Encoder**
87
+ Pre-process raw patches with a ResNet-like module (or even load a pretrained ResNet-18) to enrich features before the main GNN pipeline.
88
+
89
+ - **Flexible Data Loading**
90
+ TGraphX includes custom dataset and data loader classes (`GraphDataset` and `GraphDataLoader`) for direct graph-based batching.
91
+
92
+ - **Configurable Skip Connections**
93
+ Enable or disable skip connections that pass CNN outputs directly into the final stages, improving gradient flow.
94
+
95
+ ---
96
+
97
+ ## Architecture Highlights
98
+
99
+ ### Preserving Spatial Fidelity
100
+ Unlike conventional GNNs that flatten node features into vectors, TGraphX retains the full spatial layout `[C, H, W]` at each node. This ensures that local pixel-level (or voxel-level) structure, which is crucial for vision tasks, remains intact throughout the message passing process.
101
+
102
+ ### Convolution-Based Message Passing
103
+ TGraphX implements message passing via `Conv1Ă—1(Concat(Xi, Xj, Eij))`. This approach:
104
+ - Respects spatial alignment (i.e., each spatial location in one node’s feature map can directly interact with the same location in its neighbors’ feature maps).
105
+ - Preserves the dimension `[C, H, W]`, avoiding vector flattening.
106
+ - Optionally incorporates edge features `Eij` for more advanced relational cues (e.g., distances, bounding-box overlaps).
107
+
108
+ ### Deep CNN Aggregator with Residuals
109
+ Messages from neighbors are aggregated (summed or averaged) and then passed to a **deep CNN aggregator** that uses multiple `3Ă—3` convolutions with *residual skips*. This design:
110
+ - Prevents the overwriting of original features by always adding `Aggregator(mj)` to the old node state `Xj`.
111
+ - Facilitates stable gradient flow in deep GNN stacks.
112
+ - Broadens the effective receptive field in feature space, capturing both local patches and more distant interactions.
113
+
114
+ ### End-to-End Differentiability
115
+ Every stage of TGraphX—patch extraction, optional pre-encoder, CNN encoder, graph construction, message passing, aggregation, and classification—remains **fully differentiable** in PyTorch. This end-to-end design simplifies model development, parameter tuning, and experimentation with novel GNN layers.
116
+
117
+ ---
118
+
119
+ ## Future Works
120
+
121
+ - **Scalability and Data Requirements**
122
+ Adapting TGraphX to higher-resolution inputs or massive datasets (e.g., MS COCO) may require further optimizations, including efficient graph construction or pruning strategies.
123
+
124
+ - **Domain-Specific Customization**
125
+ Some tasks might not need full spatial fidelity at every message-passing step. Researchers could explore ways to selectively reduce resolution or apply specialized convolutions to different node subsets.
126
+
127
+ - **Alternative Edge Definitions**
128
+ Learned adjacency or richer spatial features (e.g., IoU or geometric cues) can further improve performance in complex scenes.
129
+
130
+ - **Multimodal and Real-Time Extensions**
131
+ Integrating TGraphX with sensor data or text embeddings could enable richer reasoning for applications like autonomous driving or real-time video surveillance.
132
+
133
+ ---
134
+ ## Installation
135
+
136
+ 1. **Clone the Repository**
137
+ ```bash
138
+ git clone https://github.com/YourUsername/TGraphX.git
139
+ cd TGraphX
140
+ ```
141
+
142
+ 2. **Set Up the Environment**
143
+ Use the provided `environment.yml` to create a conda environment:
144
+ ```bash
145
+ conda env create -f environment.yml
146
+ conda activate tgraphx
147
+ ```
148
+
149
+ 3. **Install PyTorch**
150
+ Install a recent version of [PyTorch](https://pytorch.org/) (GPU version if possible).
151
+
152
+ 4. **Install Additional Dependencies**
153
+ ```bash
154
+ pip install -r requirements.txt
155
+ ```
156
+
157
+ 5. **Editable Mode (Optional)**
158
+ ```bash
159
+ pip install -e .
160
+ ```
161
+
162
+ ---
163
+
164
+ ## Folder Structure
165
+
166
+ ```
167
+ TGraphX/
168
+ ├── __init__.py
169
+ ├── core/
170
+ │ ├── dataloader.py
171
+ │ ├── graph.py
172
+ │ └── utils.py
173
+ ├── layers/
174
+ │ ├── aggregator.py
175
+ │ ├── attention_message.py
176
+ │ ├── base.py
177
+ │ ├── conv_message.py
178
+ │ └── safe_pool.py
179
+ ├── models/
180
+ │ ├── cnn_encoder.py
181
+ │ ├── cnn_gnn_model.py
182
+ │ ├── graph_classifier.py
183
+ │ ├── node_classifier.py
184
+ │ └── pre_encoder.py
185
+ ├── environment.yml
186
+ └── README.md
187
+ ```
188
+
189
+ ---
190
+
191
+ ## Core Components
192
+
193
+ ### Graph and Data Loading
194
+
195
+ - **`Graph` & `GraphBatch`**
196
+ Represent individual graphs (nodes, edges) and batches of graphs. The batch version offsets node indices to avoid collisions, allowing parallel processing in PyTorch.
197
+
198
+ - **`GraphDataset` & `GraphDataLoader`**
199
+ Custom dataset and data loader classes that streamline the creation of graph batches from a set of images, patches, or other structured data.
200
+
201
+ ### Utility Functions
202
+
203
+ - **`load_config`**
204
+ Load YAML/JSON configuration files to keep hyperparameters consistent across experiments.
205
+
206
+ - **`get_device`**
207
+ Utility to automatically detect and return the correct device (GPU or CPU).
208
+
209
+ ---
210
+
211
+ ## Layers
212
+
213
+ ### Base Layer
214
+
215
+ - **`TensorMessagePassingLayer`**
216
+ An abstract base class that defines the interface (message, aggregate, update steps) for all message passing. Crucially, it handles multi-dimensional node features (e.g., `[C, H, W]`).
217
+
218
+ ### Convolution-Based Message Passing
219
+
220
+ - **`ConvMessagePassing`**
221
+ Concatenates source and target node feature maps (plus optional edge features) along the channel dimension and applies a `1Ă—1` convolution:
222
+ ```python
223
+ Mij = Conv1Ă—1(Concat(Xi, Xj, Eij))
224
+ ```
225
+ - **Message Phase**: Each pair `(i, j)` of nodes exchanges messages computed by a `1Ă—1` conv.
226
+ - **Aggregation + Residual Update**: After summing messages from neighbors, a deep CNN aggregator processes the sum, and the original node features are updated via a **residual skip**.
227
+
228
+ ### Deep CNN Aggregator
229
+
230
+ - **`DeepCNNAggregator`**
231
+ A stack of `3Ă—3` convolutional layers with batch normalization, ReLU, and dropout. It refines the aggregated messages:
232
+ ```python
233
+ X'_j = X_j + A( m_j )
234
+ ```
235
+ where `m_j = sum of messages to node j`. Residual connections ensure stable gradient flow.
236
+
237
+ ### Attention-Based Message Passing
238
+
239
+ - **`AttentionMessagePassing`**
240
+ An alternative that uses `1Ă—1` convolutions to compute query, key, and value maps for each node. Spatial alignment is preserved while attention weights scale incoming messages. Useful for tasks needing dynamic connectivity or weighting.
241
+
242
+ ### Safe Pooling
243
+
244
+ - **`SafeMaxPool2d`**
245
+ A robust pooling module that checks if spatial dimensions `[H, W]` are large enough before applying max pooling. Prevents dimension mismatch errors in deeper aggregator stacks.
246
+
247
+ ---
248
+
249
+ ## Models
250
+
251
+ ### CNN Encoder and Pre-Encoder
252
+
253
+ - **`CNNEncoder`**
254
+ Converts raw patches (`[C_in, patch_H, patch_W]`) into *spatial feature maps* (e.g., `[C_out, H2, W2]`). Includes:
255
+ - Multiple 3Ă—3 conv blocks with BN, ReLU, and dropout.
256
+ - Optional residual connections.
257
+ - Safe max pooling if the spatial size remains large.
258
+
259
+ - **Optional Pre‑Encoder**
260
+ - If `use_preencoder` is `True`, a **ResNet‑like** (or fully custom) module first processes each patch, returning refined features with the same spatial structure.
261
+ - `pretrained_resnet` can load weights from a standard ResNet‑18 for transfer learning.
262
+
263
+ ### Unified CNN‑GNN Model
264
+
265
+ - **`CNN_GNN_Model`**
266
+ A single pipeline that:
267
+ 1. Splits the image into patches, optionally uses `PreEncoder`.
268
+ 2. Feeds patches into `CNNEncoder` to get `[C, H, W]` maps.
269
+ 3. Builds a graph where each node holds a `[C, H, W]` map.
270
+ 4. Applies multiple GNN layers (like `ConvMessagePassing` + `DeepCNNAggregator`).
271
+ 5. Optionally uses a skip connection to combine CNN outputs with GNN outputs.
272
+ 6. Performs final spatial pooling before classification.
273
+
274
+ ### Graph & Node Classification Models
275
+
276
+ - **`GraphClassifier`**
277
+ Intended for graph-level tasks (e.g., classification of an entire image or object ensemble). Combines message passing with a final pooling layer (mean, max, or attention) over nodes, then feeds the result into a classifier.
278
+
279
+ - **`NodeClassifier`**
280
+ Suitable for node-level tasks (e.g., labeling each patch or region). Stacks simpler message passing layers for classification on each node separately.
281
+
282
+ ---
283
+
284
+ ## Configuration Options
285
+
286
+ TGraphX is highly configurable. Some key parameters include:
287
+
288
+ ```python
289
+ config = {
290
+ "cnn_params": {
291
+ "in_channels": 3,
292
+ "out_features": 64,
293
+ "num_layers": 2,
294
+ "hidden_channels": 64,
295
+ "dropout_prob": 0.3,
296
+ "use_batchnorm": True,
297
+ "use_residual": True,
298
+ "pool_layers": 2,
299
+ "debug": False,
300
+ "return_feature_map": True
301
+ },
302
+ "use_preencoder": False,
303
+ "pretrained_resnet": False,
304
+ "preencoder_params": {
305
+ "in_channels": 3,
306
+ "out_channels": 32,
307
+ "hidden_channels": 32
308
+ },
309
+ "gnn_in_dim": (64, 5, 5),
310
+ "gnn_hidden_dim": (128, 5, 5),
311
+ "num_classes": 10,
312
+ "num_gnn_layers": 4,
313
+ "gnn_dropout": 0.3,
314
+ "residual": True,
315
+ "aggregator_params": {
316
+ "num_layers": 4,
317
+ "dropout_prob": 0.3,
318
+ "use_batchnorm": True
319
+ }
320
+ }
321
+ ```
322
+
323
+ - **`cnn_params`**: Controls the CNN encoder architecture (e.g., channels, dropout, pooling).
324
+ - **`use_preencoder`**: Boolean indicating whether to preprocess patches with a custom or pretrained module.
325
+ - **`pretrained_resnet`**: If `True`, loads pretrained ResNet-18 weights in the pre-encoder.
326
+ - **`gnn_in_dim`, `gnn_hidden_dim`**: Shapes of the node features in GNN layers. Each dimension can be `[C, H, W]`.
327
+ - **`num_gnn_layers`**: How many message passing layers to stack.
328
+ - **`aggregator_params`**: Depth, dropout, and BN usage in the aggregator.
329
+ - **`residual`**: Enables skip connections in the GNN layers.
330
+
331
+ ---
332
+
333
+ ## Advanced Topics
334
+
335
+ ### Theoretical Insights
336
+
337
+ 1. **Universal Approximation via Deep CNN**
338
+ Stacking multiple convolutional layers with residual skips (in both the CNN encoder and the aggregator) enhances the effective receptive field and helps approximate complex local feature maps.
339
+
340
+ 2. **Residual Learning for Gradient Flow**
341
+ Residual connections in both the CNN encoder and aggregator mitigate vanishing gradients, allowing deeper structures to train effectively end-to-end.
342
+
343
+ 3. **Spatial vs. Flattened Features**
344
+ Preserving the `[C, H, W]` layout at each node addresses a key limitation in conventional GNNs—loss of local spatial semantics. TGraphX’s design is grounded in the observation that many vision tasks require capturing fine-grained local details alongside global relational structures.
345
+
346
+ ### Possible Extensions
347
+
348
+ - **Adaptive Edge Construction**
349
+ Dynamically compute adjacency based on patch similarity or learned attention, rather than fixed proximity thresholds.
350
+
351
+ - **Mixed Modalities**
352
+ Combine image data with textual or numerical features by storing them as separate channels or separate GNN streams.
353
+
354
+ - **Task-Specific Losses**
355
+ Add auxiliary losses (e.g., bounding-box IoU or segmentation overlap) for detection or segmentation tasks, integrated into the GNN training loop.
356
+
357
+ - **Performance Optimizations**
358
+ Use group convolutions or low-rank factorization in the aggregator to reduce memory and computational overhead.
359
+
360
+
361
+ ---
362
+
363
+ ## Novelties and Contributions
364
+
365
+ TGraphX departs from traditional GNN designs in several ways:
366
+
367
+ 1. **Full Spatial Fidelity**
368
+ Each node in the graph remains a *multi-dimensional* feature map rather than a flattened vector, preserving local spatial relationships crucial for vision tasks.
369
+
370
+ 2. **Convolution-Based Message Passing**
371
+ Employing `1Ă—1` convolutions on `[C, H, W]` feature maps lets neighboring patches exchange information at *every pixel location*, ensuring alignment and detail retention.
372
+
373
+ 3. **Deep Residual Aggregation**
374
+ Multiple `3×3` CNN layers in the aggregator—complete with batch normalization, ReLU, dropout, and skip connections—allow the model to fuse multi-hop messages in a stable, expressive manner.
375
+
376
+ 4. **End-to-End Differentiable**
377
+ From raw image patches to final classification or detection outputs, **all** steps—CNN feature extraction, graph construction, message passing, and aggregator updates—are trained jointly, strengthening synergy between local feature extraction and relational reasoning.
378
+
379
+ 5. **Modular & Extensible**
380
+ - Allows easy substitution of the aggregator or attention-based message passing layers.
381
+ - Accommodates multiple data modalities (image, volumetric, or otherwise).
382
+ - Scales from small graphs (few patches) to larger patch partitions for high-resolution images.
383
+
384
+ These innovations build on earlier GNN research while pushing further to **retain** all the valuable local details that are typically lost in flattened GNN nodes.
385
+
386
+ ---
387
+
388
+ ## Conclusion
389
+
390
+ We have presented **TGraphX**, an architecture aimed at integrating convolutional neural
391
+ networks (CNNs) and graph neural networks (GNNs) in a way that preserves spatial fidelity.
392
+ By retaining multi-dimensional CNN feature maps as node representations and employing
393
+ convolution-based message passing, TGraphX captures both local and global spatial context.
394
+ Our experiments—particularly those involving detection refinement—demonstrate its potential
395
+ to resolve detection discrepancies and refine localization accuracy in challenging vision tasks.
396
+
397
+ While we do not claim it to be universally optimal for all computer vision scenarios, TGraphX
398
+ offers a flexible framework that other researchers can adapt or extend. This integration of
399
+ CNN-based feature extraction with GNN-based relational reasoning is a promising direction
400
+ for future AI and vision research.
401
+
402
+ ---
403
+ ## Citations
404
+
405
+ ```bibtex
406
+ @misc{sajjadi2025tgraphxtensorawaregraphneural,
407
+ title={TGraphX: Tensor-Aware Graph Neural Network for Multi-Dimensional Feature Learning},
408
+ author={Arash Sajjadi and Mark Eramian},
409
+ year={2025},
410
+ eprint={2504.03953},
411
+ archivePrefix={arXiv},
412
+ primaryClass={cs.CV},
413
+ url={https://arxiv.org/abs/2504.03953},
414
+ }
415
+ ```
416
+ ---
417
+
418
+ ## License
419
+
420
+ TGraphX is released under the [MIT License](https://opensource.org/licenses/MIT). See the `LICENSE` file for more details.
421
+
422
+ ---
423
+
424
+ **Enjoy exploring and developing your spatially-aware graph neural networks with TGraphX!**
425
+ If you have any questions, suggestions, or want to contribute, feel free to open an issue or submit a pull request.
@@ -0,0 +1,413 @@
1
+
2
+
3
+ # TGraphX
4
+
5
+ TGraphX is a **PyTorch**-based framework for building Graph Neural Networks (GNNs) that work with node and edge features of any dimension while retaining their **spatial layout**. The code is designed for flexibility, easy GPU-acceleration, and rapid prototyping of new GNN ideas, **especially** those that need to preserve local spatial details (e.g., image or volumetric patches).
6
+
7
+ đź“„ **Preprint**: [TGraphX: Tensor-Aware Graph Neural Network for Multi-Dimensional Feature Learning](https://arxiv.org/abs/2504.03953)
8
+ ✏️ *Authors: Arash Sajjadi, Mark Eramian*
9
+ 🗓️ *Published on arXiv, April 2025*
10
+
11
+
12
+ > **Note:** TGraphX includes optional skip connections that help with
13
+ > stable gradient flow in deeper GNN stacks. The overall design is rooted
14
+ > in rigorous theoretical and practical foundations, aiming to unify
15
+ > convolutional neural networks (CNNs) with GNN-based relational reasoning.
16
+
17
+ ---
18
+ ## Table of Contents
19
+
20
+ - [Overview](#overview)
21
+ - [Key Features](#key-features)
22
+ - [Architecture Highlights](#architecture-highlights)
23
+ - [Preserving Spatial Fidelity](#preserving-spatial-fidelity)
24
+ - [Convolution-Based Message Passing](#convolution-based-message-passing)
25
+ - [Deep CNN Aggregator with Residuals](#deep-cnn-aggregator-with-residuals)
26
+ - [End-to-End Differentiability](#end-to-end-differentiability)
27
+ - [Future Works](#future-works)
28
+ - [Installation](#installation)
29
+ - [Folder Structure](#folder-structure)
30
+ - [Core Components](#core-components)
31
+ - [Layers](#layers)
32
+ - [Models](#models)
33
+ - [Configuration Options](#configuration-options)
34
+ - [Advanced Topics](#advanced-topics)
35
+ - [Novelties and Contributions](#novelties-and-contributions)
36
+ - [Conclusion](#conclusion)
37
+ - [Citations](#citations)
38
+ - [License](#license)
39
+
40
+
41
+ ---
42
+
43
+ ## Overview
44
+
45
+ TGraphX provides a modular way to create GNNs by combining several components:
46
+
47
+ 1. **Graph Representation**
48
+ A `Graph` class holds node features, edge indices, and optional edge features. Unlike traditional GNNs where node features are vectors, TGraphX supports multi-dimensional features such as `[C, H, W]` tensors—making it particularly effective for vision tasks.
49
+
50
+ 2. **Message Passing Layers**
51
+ Customizable layers process messages between nodes *while preserving the spatial layout of features*. In TGraphX:
52
+ - **ConvMessagePassing** uses `1Ă—1` convolutions on concatenated spatial features (e.g., `Conv1Ă—1(Concat(Xi, Xj, Eij))`).
53
+ - **DeepCNNAggregator** is a deep CNN (default 4 layers) that refines aggregated messages, keeping their spatial structure intact (i.e., `[C, H, W]` shape).
54
+
55
+ 3. **Models**
56
+ Pre-built models combine a CNN encoder with GNN layers:
57
+ - **CNN Encoder** processes raw image patches into spatial feature maps (e.g., `[C, H, W]`).
58
+ - **Optional Pre-Encoder** (e.g., ResNet-like) can be enabled to further refine raw patches before the main CNN encoder.
59
+ - **Unified CNN‑GNN Model** uses CNN encoders for local features and GNN layers for global relational reasoning, then pools the result for final classification.
60
+ - An extra *skip connection* (if enabled) merges the raw CNN patch output with the GNN output for better gradient flow and more stable learning.
61
+
62
+ ---
63
+ ## Key Features
64
+
65
+ - **Support for Arbitrary Dimensions**
66
+ Handle vectors, 2D images, or even volumetric 3D patches as node features.
67
+
68
+ - **Spatial Message Passing**
69
+ Messages preserve spatial dimensions (e.g., `[C, H, W]`), letting convolutional filters capture local patterns and avoid destructive flattening of features.
70
+
71
+ - **Deep Aggregation**
72
+ A deep CNN aggregator (with multiple `3×3` convolutions, batch normalization, dropout, and ReLU) refines messages across multiple hops, enabling better local–global fusion.
73
+
74
+ - **Optional Pre‑Encoder**
75
+ Pre-process raw patches with a ResNet-like module (or even load a pretrained ResNet-18) to enrich features before the main GNN pipeline.
76
+
77
+ - **Flexible Data Loading**
78
+ TGraphX includes custom dataset and data loader classes (`GraphDataset` and `GraphDataLoader`) for direct graph-based batching.
79
+
80
+ - **Configurable Skip Connections**
81
+ Enable or disable skip connections that pass CNN outputs directly into the final stages, improving gradient flow.
82
+
83
+ ---
84
+
85
+ ## Architecture Highlights
86
+
87
+ ### Preserving Spatial Fidelity
88
+ Unlike conventional GNNs that flatten node features into vectors, TGraphX retains the full spatial layout `[C, H, W]` at each node. This ensures that local pixel-level (or voxel-level) structure, which is crucial for vision tasks, remains intact throughout the message passing process.
89
+
90
+ ### Convolution-Based Message Passing
91
+ TGraphX implements message passing via `Conv1Ă—1(Concat(Xi, Xj, Eij))`. This approach:
92
+ - Respects spatial alignment (i.e., each spatial location in one node’s feature map can directly interact with the same location in its neighbors’ feature maps).
93
+ - Preserves the dimension `[C, H, W]`, avoiding vector flattening.
94
+ - Optionally incorporates edge features `Eij` for more advanced relational cues (e.g., distances, bounding-box overlaps).
95
+
96
+ ### Deep CNN Aggregator with Residuals
97
+ Messages from neighbors are aggregated (summed or averaged) and then passed to a **deep CNN aggregator** that uses multiple `3Ă—3` convolutions with *residual skips*. This design:
98
+ - Prevents the overwriting of original features by always adding `Aggregator(mj)` to the old node state `Xj`.
99
+ - Facilitates stable gradient flow in deep GNN stacks.
100
+ - Broadens the effective receptive field in feature space, capturing both local patches and more distant interactions.
101
+
102
+ ### End-to-End Differentiability
103
+ Every stage of TGraphX—patch extraction, optional pre-encoder, CNN encoder, graph construction, message passing, aggregation, and classification—remains **fully differentiable** in PyTorch. This end-to-end design simplifies model development, parameter tuning, and experimentation with novel GNN layers.
104
+
105
+ ---
106
+
107
+ ## Future Works
108
+
109
+ - **Scalability and Data Requirements**
110
+ Adapting TGraphX to higher-resolution inputs or massive datasets (e.g., MS COCO) may require further optimizations, including efficient graph construction or pruning strategies.
111
+
112
+ - **Domain-Specific Customization**
113
+ Some tasks might not need full spatial fidelity at every message-passing step. Researchers could explore ways to selectively reduce resolution or apply specialized convolutions to different node subsets.
114
+
115
+ - **Alternative Edge Definitions**
116
+ Learned adjacency or richer spatial features (e.g., IoU or geometric cues) can further improve performance in complex scenes.
117
+
118
+ - **Multimodal and Real-Time Extensions**
119
+ Integrating TGraphX with sensor data or text embeddings could enable richer reasoning for applications like autonomous driving or real-time video surveillance.
120
+
121
+ ---
122
+ ## Installation
123
+
124
+ 1. **Clone the Repository**
125
+ ```bash
126
+ git clone https://github.com/YourUsername/TGraphX.git
127
+ cd TGraphX
128
+ ```
129
+
130
+ 2. **Set Up the Environment**
131
+ Use the provided `environment.yml` to create a conda environment:
132
+ ```bash
133
+ conda env create -f environment.yml
134
+ conda activate tgraphx
135
+ ```
136
+
137
+ 3. **Install PyTorch**
138
+ Install a recent version of [PyTorch](https://pytorch.org/) (GPU version if possible).
139
+
140
+ 4. **Install Additional Dependencies**
141
+ ```bash
142
+ pip install -r requirements.txt
143
+ ```
144
+
145
+ 5. **Editable Mode (Optional)**
146
+ ```bash
147
+ pip install -e .
148
+ ```
149
+
150
+ ---
151
+
152
+ ## Folder Structure
153
+
154
+ ```
155
+ TGraphX/
156
+ ├── __init__.py
157
+ ├── core/
158
+ │ ├── dataloader.py
159
+ │ ├── graph.py
160
+ │ └── utils.py
161
+ ├── layers/
162
+ │ ├── aggregator.py
163
+ │ ├── attention_message.py
164
+ │ ├── base.py
165
+ │ ├── conv_message.py
166
+ │ └── safe_pool.py
167
+ ├── models/
168
+ │ ├── cnn_encoder.py
169
+ │ ├── cnn_gnn_model.py
170
+ │ ├── graph_classifier.py
171
+ │ ├── node_classifier.py
172
+ │ └── pre_encoder.py
173
+ ├── environment.yml
174
+ └── README.md
175
+ ```
176
+
177
+ ---
178
+
179
+ ## Core Components
180
+
181
+ ### Graph and Data Loading
182
+
183
+ - **`Graph` & `GraphBatch`**
184
+ Represent individual graphs (nodes, edges) and batches of graphs. The batch version offsets node indices to avoid collisions, allowing parallel processing in PyTorch.
185
+
186
+ - **`GraphDataset` & `GraphDataLoader`**
187
+ Custom dataset and data loader classes that streamline the creation of graph batches from a set of images, patches, or other structured data.
188
+
189
+ ### Utility Functions
190
+
191
+ - **`load_config`**
192
+ Load YAML/JSON configuration files to keep hyperparameters consistent across experiments.
193
+
194
+ - **`get_device`**
195
+ Utility to automatically detect and return the correct device (GPU or CPU).
196
+
197
+ ---
198
+
199
+ ## Layers
200
+
201
+ ### Base Layer
202
+
203
+ - **`TensorMessagePassingLayer`**
204
+ An abstract base class that defines the interface (message, aggregate, update steps) for all message passing. Crucially, it handles multi-dimensional node features (e.g., `[C, H, W]`).
205
+
206
+ ### Convolution-Based Message Passing
207
+
208
+ - **`ConvMessagePassing`**
209
+ Concatenates source and target node feature maps (plus optional edge features) along the channel dimension and applies a `1Ă—1` convolution:
210
+ ```python
211
+ Mij = Conv1Ă—1(Concat(Xi, Xj, Eij))
212
+ ```
213
+ - **Message Phase**: Each pair `(i, j)` of nodes exchanges messages computed by a `1Ă—1` conv.
214
+ - **Aggregation + Residual Update**: After summing messages from neighbors, a deep CNN aggregator processes the sum, and the original node features are updated via a **residual skip**.
215
+
216
+ ### Deep CNN Aggregator
217
+
218
+ - **`DeepCNNAggregator`**
219
+ A stack of `3Ă—3` convolutional layers with batch normalization, ReLU, and dropout. It refines the aggregated messages:
220
+ ```python
221
+ X'_j = X_j + A( m_j )
222
+ ```
223
+ where `m_j = sum of messages to node j`. Residual connections ensure stable gradient flow.
224
+
225
+ ### Attention-Based Message Passing
226
+
227
+ - **`AttentionMessagePassing`**
228
+ An alternative that uses `1Ă—1` convolutions to compute query, key, and value maps for each node. Spatial alignment is preserved while attention weights scale incoming messages. Useful for tasks needing dynamic connectivity or weighting.
229
+
230
+ ### Safe Pooling
231
+
232
+ - **`SafeMaxPool2d`**
233
+ A robust pooling module that checks if spatial dimensions `[H, W]` are large enough before applying max pooling. Prevents dimension mismatch errors in deeper aggregator stacks.
234
+
235
+ ---
236
+
237
+ ## Models
238
+
239
+ ### CNN Encoder and Pre-Encoder
240
+
241
+ - **`CNNEncoder`**
242
+ Converts raw patches (`[C_in, patch_H, patch_W]`) into *spatial feature maps* (e.g., `[C_out, H2, W2]`). Includes:
243
+ - Multiple 3Ă—3 conv blocks with BN, ReLU, and dropout.
244
+ - Optional residual connections.
245
+ - Safe max pooling if the spatial size remains large.
246
+
247
+ - **Optional Pre‑Encoder**
248
+ - If `use_preencoder` is `True`, a **ResNet‑like** (or fully custom) module first processes each patch, returning refined features with the same spatial structure.
249
+ - `pretrained_resnet` can load weights from a standard ResNet‑18 for transfer learning.
250
+
251
+ ### Unified CNN‑GNN Model
252
+
253
+ - **`CNN_GNN_Model`**
254
+ A single pipeline that:
255
+ 1. Splits the image into patches, optionally uses `PreEncoder`.
256
+ 2. Feeds patches into `CNNEncoder` to get `[C, H, W]` maps.
257
+ 3. Builds a graph where each node holds a `[C, H, W]` map.
258
+ 4. Applies multiple GNN layers (like `ConvMessagePassing` + `DeepCNNAggregator`).
259
+ 5. Optionally uses a skip connection to combine CNN outputs with GNN outputs.
260
+ 6. Performs final spatial pooling before classification.
261
+
262
+ ### Graph & Node Classification Models
263
+
264
+ - **`GraphClassifier`**
265
+ Intended for graph-level tasks (e.g., classification of an entire image or object ensemble). Combines message passing with a final pooling layer (mean, max, or attention) over nodes, then feeds the result into a classifier.
266
+
267
+ - **`NodeClassifier`**
268
+ Suitable for node-level tasks (e.g., labeling each patch or region). Stacks simpler message passing layers for classification on each node separately.
269
+
270
+ ---
271
+
272
+ ## Configuration Options
273
+
274
+ TGraphX is highly configurable. Some key parameters include:
275
+
276
+ ```python
277
+ config = {
278
+ "cnn_params": {
279
+ "in_channels": 3,
280
+ "out_features": 64,
281
+ "num_layers": 2,
282
+ "hidden_channels": 64,
283
+ "dropout_prob": 0.3,
284
+ "use_batchnorm": True,
285
+ "use_residual": True,
286
+ "pool_layers": 2,
287
+ "debug": False,
288
+ "return_feature_map": True
289
+ },
290
+ "use_preencoder": False,
291
+ "pretrained_resnet": False,
292
+ "preencoder_params": {
293
+ "in_channels": 3,
294
+ "out_channels": 32,
295
+ "hidden_channels": 32
296
+ },
297
+ "gnn_in_dim": (64, 5, 5),
298
+ "gnn_hidden_dim": (128, 5, 5),
299
+ "num_classes": 10,
300
+ "num_gnn_layers": 4,
301
+ "gnn_dropout": 0.3,
302
+ "residual": True,
303
+ "aggregator_params": {
304
+ "num_layers": 4,
305
+ "dropout_prob": 0.3,
306
+ "use_batchnorm": True
307
+ }
308
+ }
309
+ ```
310
+
311
+ - **`cnn_params`**: Controls the CNN encoder architecture (e.g., channels, dropout, pooling).
312
+ - **`use_preencoder`**: Boolean indicating whether to preprocess patches with a custom or pretrained module.
313
+ - **`pretrained_resnet`**: If `True`, loads pretrained ResNet-18 weights in the pre-encoder.
314
+ - **`gnn_in_dim`, `gnn_hidden_dim`**: Shapes of the node features in GNN layers. Each dimension can be `[C, H, W]`.
315
+ - **`num_gnn_layers`**: How many message passing layers to stack.
316
+ - **`aggregator_params`**: Depth, dropout, and BN usage in the aggregator.
317
+ - **`residual`**: Enables skip connections in the GNN layers.
318
+
319
+ ---
320
+
321
+ ## Advanced Topics
322
+
323
+ ### Theoretical Insights
324
+
325
+ 1. **Universal Approximation via Deep CNN**
326
+ Stacking multiple convolutional layers with residual skips (in both the CNN encoder and the aggregator) enhances the effective receptive field and helps approximate complex local feature maps.
327
+
328
+ 2. **Residual Learning for Gradient Flow**
329
+ Residual connections in both the CNN encoder and aggregator mitigate vanishing gradients, allowing deeper structures to train effectively end-to-end.
330
+
331
+ 3. **Spatial vs. Flattened Features**
332
+ Preserving the `[C, H, W]` layout at each node addresses a key limitation in conventional GNNs—loss of local spatial semantics. TGraphX’s design is grounded in the observation that many vision tasks require capturing fine-grained local details alongside global relational structures.
333
+
334
+ ### Possible Extensions
335
+
336
+ - **Adaptive Edge Construction**
337
+ Dynamically compute adjacency based on patch similarity or learned attention, rather than fixed proximity thresholds.
338
+
339
+ - **Mixed Modalities**
340
+ Combine image data with textual or numerical features by storing them as separate channels or separate GNN streams.
341
+
342
+ - **Task-Specific Losses**
343
+ Add auxiliary losses (e.g., bounding-box IoU or segmentation overlap) for detection or segmentation tasks, integrated into the GNN training loop.
344
+
345
+ - **Performance Optimizations**
346
+ Use group convolutions or low-rank factorization in the aggregator to reduce memory and computational overhead.
347
+
348
+
349
+ ---
350
+
351
+ ## Novelties and Contributions
352
+
353
+ TGraphX departs from traditional GNN designs in several ways:
354
+
355
+ 1. **Full Spatial Fidelity**
356
+ Each node in the graph remains a *multi-dimensional* feature map rather than a flattened vector, preserving local spatial relationships crucial for vision tasks.
357
+
358
+ 2. **Convolution-Based Message Passing**
359
+ Employing `1Ă—1` convolutions on `[C, H, W]` feature maps lets neighboring patches exchange information at *every pixel location*, ensuring alignment and detail retention.
360
+
361
+ 3. **Deep Residual Aggregation**
362
+ Multiple `3×3` CNN layers in the aggregator—complete with batch normalization, ReLU, dropout, and skip connections—allow the model to fuse multi-hop messages in a stable, expressive manner.
363
+
364
+ 4. **End-to-End Differentiable**
365
+ From raw image patches to final classification or detection outputs, **all** steps—CNN feature extraction, graph construction, message passing, and aggregator updates—are trained jointly, strengthening synergy between local feature extraction and relational reasoning.
366
+
367
+ 5. **Modular & Extensible**
368
+ - Allows easy substitution of the aggregator or attention-based message passing layers.
369
+ - Accommodates multiple data modalities (image, volumetric, or otherwise).
370
+ - Scales from small graphs (few patches) to larger patch partitions for high-resolution images.
371
+
372
+ These innovations build on earlier GNN research while pushing further to **retain** all the valuable local details that are typically lost in flattened GNN nodes.
373
+
374
+ ---
375
+
376
+ ## Conclusion
377
+
378
+ We have presented **TGraphX**, an architecture aimed at integrating convolutional neural
379
+ networks (CNNs) and graph neural networks (GNNs) in a way that preserves spatial fidelity.
380
+ By retaining multi-dimensional CNN feature maps as node representations and employing
381
+ convolution-based message passing, TGraphX captures both local and global spatial context.
382
+ Our experiments—particularly those involving detection refinement—demonstrate its potential
383
+ to resolve detection discrepancies and refine localization accuracy in challenging vision tasks.
384
+
385
+ While we do not claim it to be universally optimal for all computer vision scenarios, TGraphX
386
+ offers a flexible framework that other researchers can adapt or extend. This integration of
387
+ CNN-based feature extraction with GNN-based relational reasoning is a promising direction
388
+ for future AI and vision research.
389
+
390
+ ---
391
+ ## Citations
392
+
393
+ ```bibtex
394
+ @misc{sajjadi2025tgraphxtensorawaregraphneural,
395
+ title={TGraphX: Tensor-Aware Graph Neural Network for Multi-Dimensional Feature Learning},
396
+ author={Arash Sajjadi and Mark Eramian},
397
+ year={2025},
398
+ eprint={2504.03953},
399
+ archivePrefix={arXiv},
400
+ primaryClass={cs.CV},
401
+ url={https://arxiv.org/abs/2504.03953},
402
+ }
403
+ ```
404
+ ---
405
+
406
+ ## License
407
+
408
+ TGraphX is released under the [MIT License](https://opensource.org/licenses/MIT). See the `LICENSE` file for more details.
409
+
410
+ ---
411
+
412
+ **Enjoy exploring and developing your spatially-aware graph neural networks with TGraphX!**
413
+ If you have any questions, suggestions, or want to contribute, feel free to open an issue or submit a pull request.
@@ -0,0 +1,14 @@
1
+ [build-system]
2
+ requires = ["setuptools", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "tgraphx"
7
+ version = "0.0.1"
8
+ authors = [
9
+ {name = "Arash Sajjadi", email = "arash.sajjadi@usask.ca"}
10
+ ]
11
+ description = "Early placeholder for TGraphX PyPI reservation"
12
+ readme = "README.md"
13
+ requires-python = ">=3.8"
14
+ license = {text = "MIT"}
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
tgraphx-0.0.1/setup.py ADDED
@@ -0,0 +1,11 @@
1
+ from setuptools import setup
2
+
3
+ setup(
4
+ name="tgraphx",
5
+ version="0.0.1",
6
+ description="Early placeholder for TGraphX PyPI reservation",
7
+ author="Arash Sajjadi",
8
+ author_email="arash.sajjadi@usask.ca",
9
+ packages=["tgraphx"],
10
+ python_requires=">=3.8",
11
+ )
@@ -0,0 +1 @@
1
+ # Placeholder
@@ -0,0 +1,425 @@
1
+ Metadata-Version: 2.4
2
+ Name: tgraphx
3
+ Version: 0.0.1
4
+ Summary: Early placeholder for TGraphX PyPI reservation
5
+ Author: Arash Sajjadi
6
+ Author-email: Arash Sajjadi <arash.sajjadi@usask.ca>
7
+ License: MIT
8
+ Requires-Python: >=3.8
9
+ Description-Content-Type: text/markdown
10
+ Dynamic: author
11
+ Dynamic: requires-python
12
+
13
+
14
+
15
+ # TGraphX
16
+
17
+ TGraphX is a **PyTorch**-based framework for building Graph Neural Networks (GNNs) that work with node and edge features of any dimension while retaining their **spatial layout**. The code is designed for flexibility, easy GPU-acceleration, and rapid prototyping of new GNN ideas, **especially** those that need to preserve local spatial details (e.g., image or volumetric patches).
18
+
19
+ đź“„ **Preprint**: [TGraphX: Tensor-Aware Graph Neural Network for Multi-Dimensional Feature Learning](https://arxiv.org/abs/2504.03953)
20
+ ✏️ *Authors: Arash Sajjadi, Mark Eramian*
21
+ 🗓️ *Published on arXiv, April 2025*
22
+
23
+
24
+ > **Note:** TGraphX includes optional skip connections that help with
25
+ > stable gradient flow in deeper GNN stacks. The overall design is rooted
26
+ > in rigorous theoretical and practical foundations, aiming to unify
27
+ > convolutional neural networks (CNNs) with GNN-based relational reasoning.
28
+
29
+ ---
30
+ ## Table of Contents
31
+
32
+ - [Overview](#overview)
33
+ - [Key Features](#key-features)
34
+ - [Architecture Highlights](#architecture-highlights)
35
+ - [Preserving Spatial Fidelity](#preserving-spatial-fidelity)
36
+ - [Convolution-Based Message Passing](#convolution-based-message-passing)
37
+ - [Deep CNN Aggregator with Residuals](#deep-cnn-aggregator-with-residuals)
38
+ - [End-to-End Differentiability](#end-to-end-differentiability)
39
+ - [Future Works](#future-works)
40
+ - [Installation](#installation)
41
+ - [Folder Structure](#folder-structure)
42
+ - [Core Components](#core-components)
43
+ - [Layers](#layers)
44
+ - [Models](#models)
45
+ - [Configuration Options](#configuration-options)
46
+ - [Advanced Topics](#advanced-topics)
47
+ - [Novelties and Contributions](#novelties-and-contributions)
48
+ - [Conclusion](#conclusion)
49
+ - [Citations](#citations)
50
+ - [License](#license)
51
+
52
+
53
+ ---
54
+
55
+ ## Overview
56
+
57
+ TGraphX provides a modular way to create GNNs by combining several components:
58
+
59
+ 1. **Graph Representation**
60
+ A `Graph` class holds node features, edge indices, and optional edge features. Unlike traditional GNNs where node features are vectors, TGraphX supports multi-dimensional features such as `[C, H, W]` tensors—making it particularly effective for vision tasks.
61
+
62
+ 2. **Message Passing Layers**
63
+ Customizable layers process messages between nodes *while preserving the spatial layout of features*. In TGraphX:
64
+ - **ConvMessagePassing** uses `1Ă—1` convolutions on concatenated spatial features (e.g., `Conv1Ă—1(Concat(Xi, Xj, Eij))`).
65
+ - **DeepCNNAggregator** is a deep CNN (default 4 layers) that refines aggregated messages, keeping their spatial structure intact (i.e., `[C, H, W]` shape).
66
+
67
+ 3. **Models**
68
+ Pre-built models combine a CNN encoder with GNN layers:
69
+ - **CNN Encoder** processes raw image patches into spatial feature maps (e.g., `[C, H, W]`).
70
+ - **Optional Pre-Encoder** (e.g., ResNet-like) can be enabled to further refine raw patches before the main CNN encoder.
71
+ - **Unified CNN‑GNN Model** uses CNN encoders for local features and GNN layers for global relational reasoning, then pools the result for final classification.
72
+ - An extra *skip connection* (if enabled) merges the raw CNN patch output with the GNN output for better gradient flow and more stable learning.
73
+
74
+ ---
75
+ ## Key Features
76
+
77
+ - **Support for Arbitrary Dimensions**
78
+ Handle vectors, 2D images, or even volumetric 3D patches as node features.
79
+
80
+ - **Spatial Message Passing**
81
+ Messages preserve spatial dimensions (e.g., `[C, H, W]`), letting convolutional filters capture local patterns and avoid destructive flattening of features.
82
+
83
+ - **Deep Aggregation**
84
+ A deep CNN aggregator (with multiple `3×3` convolutions, batch normalization, dropout, and ReLU) refines messages across multiple hops, enabling better local–global fusion.
85
+
86
+ - **Optional Pre‑Encoder**
87
+ Pre-process raw patches with a ResNet-like module (or even load a pretrained ResNet-18) to enrich features before the main GNN pipeline.
88
+
89
+ - **Flexible Data Loading**
90
+ TGraphX includes custom dataset and data loader classes (`GraphDataset` and `GraphDataLoader`) for direct graph-based batching.
91
+
92
+ - **Configurable Skip Connections**
93
+ Enable or disable skip connections that pass CNN outputs directly into the final stages, improving gradient flow.
94
+
95
+ ---
96
+
97
+ ## Architecture Highlights
98
+
99
+ ### Preserving Spatial Fidelity
100
+ Unlike conventional GNNs that flatten node features into vectors, TGraphX retains the full spatial layout `[C, H, W]` at each node. This ensures that local pixel-level (or voxel-level) structure, which is crucial for vision tasks, remains intact throughout the message passing process.
101
+
102
+ ### Convolution-Based Message Passing
103
+ TGraphX implements message passing via `Conv1Ă—1(Concat(Xi, Xj, Eij))`. This approach:
104
+ - Respects spatial alignment (i.e., each spatial location in one node’s feature map can directly interact with the same location in its neighbors’ feature maps).
105
+ - Preserves the dimension `[C, H, W]`, avoiding vector flattening.
106
+ - Optionally incorporates edge features `Eij` for more advanced relational cues (e.g., distances, bounding-box overlaps).
107
+
108
+ ### Deep CNN Aggregator with Residuals
109
+ Messages from neighbors are aggregated (summed or averaged) and then passed to a **deep CNN aggregator** that uses multiple `3Ă—3` convolutions with *residual skips*. This design:
110
+ - Prevents the overwriting of original features by always adding `Aggregator(mj)` to the old node state `Xj`.
111
+ - Facilitates stable gradient flow in deep GNN stacks.
112
+ - Broadens the effective receptive field in feature space, capturing both local patches and more distant interactions.
113
+
114
+ ### End-to-End Differentiability
115
+ Every stage of TGraphX—patch extraction, optional pre-encoder, CNN encoder, graph construction, message passing, aggregation, and classification—remains **fully differentiable** in PyTorch. This end-to-end design simplifies model development, parameter tuning, and experimentation with novel GNN layers.
116
+
117
+ ---
118
+
119
+ ## Future Works
120
+
121
+ - **Scalability and Data Requirements**
122
+ Adapting TGraphX to higher-resolution inputs or massive datasets (e.g., MS COCO) may require further optimizations, including efficient graph construction or pruning strategies.
123
+
124
+ - **Domain-Specific Customization**
125
+ Some tasks might not need full spatial fidelity at every message-passing step. Researchers could explore ways to selectively reduce resolution or apply specialized convolutions to different node subsets.
126
+
127
+ - **Alternative Edge Definitions**
128
+ Learned adjacency or richer spatial features (e.g., IoU or geometric cues) can further improve performance in complex scenes.
129
+
130
+ - **Multimodal and Real-Time Extensions**
131
+ Integrating TGraphX with sensor data or text embeddings could enable richer reasoning for applications like autonomous driving or real-time video surveillance.
132
+
133
+ ---
134
+ ## Installation
135
+
136
+ 1. **Clone the Repository**
137
+ ```bash
138
+ git clone https://github.com/YourUsername/TGraphX.git
139
+ cd TGraphX
140
+ ```
141
+
142
+ 2. **Set Up the Environment**
143
+ Use the provided `environment.yml` to create a conda environment:
144
+ ```bash
145
+ conda env create -f environment.yml
146
+ conda activate tgraphx
147
+ ```
148
+
149
+ 3. **Install PyTorch**
150
+ Install a recent version of [PyTorch](https://pytorch.org/) (GPU version if possible).
151
+
152
+ 4. **Install Additional Dependencies**
153
+ ```bash
154
+ pip install -r requirements.txt
155
+ ```
156
+
157
+ 5. **Editable Mode (Optional)**
158
+ ```bash
159
+ pip install -e .
160
+ ```
161
+
162
+ ---
163
+
164
+ ## Folder Structure
165
+
166
+ ```
167
+ TGraphX/
168
+ ├── __init__.py
169
+ ├── core/
170
+ │ ├── dataloader.py
171
+ │ ├── graph.py
172
+ │ └── utils.py
173
+ ├── layers/
174
+ │ ├── aggregator.py
175
+ │ ├── attention_message.py
176
+ │ ├── base.py
177
+ │ ├── conv_message.py
178
+ │ └── safe_pool.py
179
+ ├── models/
180
+ │ ├── cnn_encoder.py
181
+ │ ├── cnn_gnn_model.py
182
+ │ ├── graph_classifier.py
183
+ │ ├── node_classifier.py
184
+ │ └── pre_encoder.py
185
+ ├── environment.yml
186
+ └── README.md
187
+ ```
188
+
189
+ ---
190
+
191
+ ## Core Components
192
+
193
+ ### Graph and Data Loading
194
+
195
+ - **`Graph` & `GraphBatch`**
196
+ Represent individual graphs (nodes, edges) and batches of graphs. The batch version offsets node indices to avoid collisions, allowing parallel processing in PyTorch.
197
+
198
+ - **`GraphDataset` & `GraphDataLoader`**
199
+ Custom dataset and data loader classes that streamline the creation of graph batches from a set of images, patches, or other structured data.
200
+
201
+ ### Utility Functions
202
+
203
+ - **`load_config`**
204
+ Load YAML/JSON configuration files to keep hyperparameters consistent across experiments.
205
+
206
+ - **`get_device`**
207
+ Utility to automatically detect and return the correct device (GPU or CPU).
208
+
209
+ ---
210
+
211
+ ## Layers
212
+
213
+ ### Base Layer
214
+
215
+ - **`TensorMessagePassingLayer`**
216
+ An abstract base class that defines the interface (message, aggregate, update steps) for all message passing. Crucially, it handles multi-dimensional node features (e.g., `[C, H, W]`).
217
+
218
+ ### Convolution-Based Message Passing
219
+
220
+ - **`ConvMessagePassing`**
221
+ Concatenates source and target node feature maps (plus optional edge features) along the channel dimension and applies a `1Ă—1` convolution:
222
+ ```python
223
+ Mij = Conv1Ă—1(Concat(Xi, Xj, Eij))
224
+ ```
225
+ - **Message Phase**: Each pair `(i, j)` of nodes exchanges messages computed by a `1Ă—1` conv.
226
+ - **Aggregation + Residual Update**: After summing messages from neighbors, a deep CNN aggregator processes the sum, and the original node features are updated via a **residual skip**.
227
+
228
+ ### Deep CNN Aggregator
229
+
230
+ - **`DeepCNNAggregator`**
231
+ A stack of `3Ă—3` convolutional layers with batch normalization, ReLU, and dropout. It refines the aggregated messages:
232
+ ```python
233
+ X'_j = X_j + A( m_j )
234
+ ```
235
+ where `m_j = sum of messages to node j`. Residual connections ensure stable gradient flow.
236
+
237
+ ### Attention-Based Message Passing
238
+
239
+ - **`AttentionMessagePassing`**
240
+ An alternative that uses `1Ă—1` convolutions to compute query, key, and value maps for each node. Spatial alignment is preserved while attention weights scale incoming messages. Useful for tasks needing dynamic connectivity or weighting.
241
+
242
+ ### Safe Pooling
243
+
244
+ - **`SafeMaxPool2d`**
245
+ A robust pooling module that checks if spatial dimensions `[H, W]` are large enough before applying max pooling. Prevents dimension mismatch errors in deeper aggregator stacks.
246
+
247
+ ---
248
+
249
+ ## Models
250
+
251
+ ### CNN Encoder and Pre-Encoder
252
+
253
+ - **`CNNEncoder`**
254
+ Converts raw patches (`[C_in, patch_H, patch_W]`) into *spatial feature maps* (e.g., `[C_out, H2, W2]`). Includes:
255
+ - Multiple 3Ă—3 conv blocks with BN, ReLU, and dropout.
256
+ - Optional residual connections.
257
+ - Safe max pooling if the spatial size remains large.
258
+
259
+ - **Optional Pre‑Encoder**
260
+ - If `use_preencoder` is `True`, a **ResNet‑like** (or fully custom) module first processes each patch, returning refined features with the same spatial structure.
261
+ - `pretrained_resnet` can load weights from a standard ResNet‑18 for transfer learning.
262
+
263
+ ### Unified CNN‑GNN Model
264
+
265
+ - **`CNN_GNN_Model`**
266
+ A single pipeline that:
267
+ 1. Splits the image into patches, optionally uses `PreEncoder`.
268
+ 2. Feeds patches into `CNNEncoder` to get `[C, H, W]` maps.
269
+ 3. Builds a graph where each node holds a `[C, H, W]` map.
270
+ 4. Applies multiple GNN layers (like `ConvMessagePassing` + `DeepCNNAggregator`).
271
+ 5. Optionally uses a skip connection to combine CNN outputs with GNN outputs.
272
+ 6. Performs final spatial pooling before classification.
273
+
274
+ ### Graph & Node Classification Models
275
+
276
+ - **`GraphClassifier`**
277
+ Intended for graph-level tasks (e.g., classification of an entire image or object ensemble). Combines message passing with a final pooling layer (mean, max, or attention) over nodes, then feeds the result into a classifier.
278
+
279
+ - **`NodeClassifier`**
280
+ Suitable for node-level tasks (e.g., labeling each patch or region). Stacks simpler message passing layers for classification on each node separately.
281
+
282
+ ---
283
+
284
+ ## Configuration Options
285
+
286
+ TGraphX is highly configurable. Some key parameters include:
287
+
288
+ ```python
289
+ config = {
290
+ "cnn_params": {
291
+ "in_channels": 3,
292
+ "out_features": 64,
293
+ "num_layers": 2,
294
+ "hidden_channels": 64,
295
+ "dropout_prob": 0.3,
296
+ "use_batchnorm": True,
297
+ "use_residual": True,
298
+ "pool_layers": 2,
299
+ "debug": False,
300
+ "return_feature_map": True
301
+ },
302
+ "use_preencoder": False,
303
+ "pretrained_resnet": False,
304
+ "preencoder_params": {
305
+ "in_channels": 3,
306
+ "out_channels": 32,
307
+ "hidden_channels": 32
308
+ },
309
+ "gnn_in_dim": (64, 5, 5),
310
+ "gnn_hidden_dim": (128, 5, 5),
311
+ "num_classes": 10,
312
+ "num_gnn_layers": 4,
313
+ "gnn_dropout": 0.3,
314
+ "residual": True,
315
+ "aggregator_params": {
316
+ "num_layers": 4,
317
+ "dropout_prob": 0.3,
318
+ "use_batchnorm": True
319
+ }
320
+ }
321
+ ```
322
+
323
+ - **`cnn_params`**: Controls the CNN encoder architecture (e.g., channels, dropout, pooling).
324
+ - **`use_preencoder`**: Boolean indicating whether to preprocess patches with a custom or pretrained module.
325
+ - **`pretrained_resnet`**: If `True`, loads pretrained ResNet-18 weights in the pre-encoder.
326
+ - **`gnn_in_dim`, `gnn_hidden_dim`**: Shapes of the node features in GNN layers. Each dimension can be `[C, H, W]`.
327
+ - **`num_gnn_layers`**: How many message passing layers to stack.
328
+ - **`aggregator_params`**: Depth, dropout, and BN usage in the aggregator.
329
+ - **`residual`**: Enables skip connections in the GNN layers.
330
+
331
+ ---
332
+
333
+ ## Advanced Topics
334
+
335
+ ### Theoretical Insights
336
+
337
+ 1. **Universal Approximation via Deep CNN**
338
+ Stacking multiple convolutional layers with residual skips (in both the CNN encoder and the aggregator) enhances the effective receptive field and helps approximate complex local feature maps.
339
+
340
+ 2. **Residual Learning for Gradient Flow**
341
+ Residual connections in both the CNN encoder and aggregator mitigate vanishing gradients, allowing deeper structures to train effectively end-to-end.
342
+
343
+ 3. **Spatial vs. Flattened Features**
344
+ Preserving the `[C, H, W]` layout at each node addresses a key limitation in conventional GNNs—loss of local spatial semantics. TGraphX’s design is grounded in the observation that many vision tasks require capturing fine-grained local details alongside global relational structures.
345
+
346
+ ### Possible Extensions
347
+
348
+ - **Adaptive Edge Construction**
349
+ Dynamically compute adjacency based on patch similarity or learned attention, rather than fixed proximity thresholds.
350
+
351
+ - **Mixed Modalities**
352
+ Combine image data with textual or numerical features by storing them as separate channels or separate GNN streams.
353
+
354
+ - **Task-Specific Losses**
355
+ Add auxiliary losses (e.g., bounding-box IoU or segmentation overlap) for detection or segmentation tasks, integrated into the GNN training loop.
356
+
357
+ - **Performance Optimizations**
358
+ Use group convolutions or low-rank factorization in the aggregator to reduce memory and computational overhead.
359
+
360
+
361
+ ---
362
+
363
+ ## Novelties and Contributions
364
+
365
+ TGraphX departs from traditional GNN designs in several ways:
366
+
367
+ 1. **Full Spatial Fidelity**
368
+ Each node in the graph remains a *multi-dimensional* feature map rather than a flattened vector, preserving local spatial relationships crucial for vision tasks.
369
+
370
+ 2. **Convolution-Based Message Passing**
371
+ Employing `1Ă—1` convolutions on `[C, H, W]` feature maps lets neighboring patches exchange information at *every pixel location*, ensuring alignment and detail retention.
372
+
373
+ 3. **Deep Residual Aggregation**
374
+ Multiple `3×3` CNN layers in the aggregator—complete with batch normalization, ReLU, dropout, and skip connections—allow the model to fuse multi-hop messages in a stable, expressive manner.
375
+
376
+ 4. **End-to-End Differentiable**
377
+ From raw image patches to final classification or detection outputs, **all** steps—CNN feature extraction, graph construction, message passing, and aggregator updates—are trained jointly, strengthening synergy between local feature extraction and relational reasoning.
378
+
379
+ 5. **Modular & Extensible**
380
+ - Allows easy substitution of the aggregator or attention-based message passing layers.
381
+ - Accommodates multiple data modalities (image, volumetric, or otherwise).
382
+ - Scales from small graphs (few patches) to larger patch partitions for high-resolution images.
383
+
384
+ These innovations build on earlier GNN research while pushing further to **retain** all the valuable local details that are typically lost in flattened GNN nodes.
385
+
386
+ ---
387
+
388
+ ## Conclusion
389
+
390
+ We have presented **TGraphX**, an architecture aimed at integrating convolutional neural
391
+ networks (CNNs) and graph neural networks (GNNs) in a way that preserves spatial fidelity.
392
+ By retaining multi-dimensional CNN feature maps as node representations and employing
393
+ convolution-based message passing, TGraphX captures both local and global spatial context.
394
+ Our experiments—particularly those involving detection refinement—demonstrate its potential
395
+ to resolve detection discrepancies and refine localization accuracy in challenging vision tasks.
396
+
397
+ While we do not claim it to be universally optimal for all computer vision scenarios, TGraphX
398
+ offers a flexible framework that other researchers can adapt or extend. This integration of
399
+ CNN-based feature extraction with GNN-based relational reasoning is a promising direction
400
+ for future AI and vision research.
401
+
402
+ ---
403
+ ## Citations
404
+
405
+ ```bibtex
406
+ @misc{sajjadi2025tgraphxtensorawaregraphneural,
407
+ title={TGraphX: Tensor-Aware Graph Neural Network for Multi-Dimensional Feature Learning},
408
+ author={Arash Sajjadi and Mark Eramian},
409
+ year={2025},
410
+ eprint={2504.03953},
411
+ archivePrefix={arXiv},
412
+ primaryClass={cs.CV},
413
+ url={https://arxiv.org/abs/2504.03953},
414
+ }
415
+ ```
416
+ ---
417
+
418
+ ## License
419
+
420
+ TGraphX is released under the [MIT License](https://opensource.org/licenses/MIT). See the `LICENSE` file for more details.
421
+
422
+ ---
423
+
424
+ **Enjoy exploring and developing your spatially-aware graph neural networks with TGraphX!**
425
+ If you have any questions, suggestions, or want to contribute, feel free to open an issue or submit a pull request.
@@ -0,0 +1,8 @@
1
+ README.md
2
+ pyproject.toml
3
+ setup.py
4
+ tgraphx/__init__.py
5
+ tgraphx.egg-info/PKG-INFO
6
+ tgraphx.egg-info/SOURCES.txt
7
+ tgraphx.egg-info/dependency_links.txt
8
+ tgraphx.egg-info/top_level.txt
@@ -0,0 +1 @@
1
+ tgraphx