tgraphx 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tgraphx-0.0.1/PKG-INFO +425 -0
- tgraphx-0.0.1/README.md +413 -0
- tgraphx-0.0.1/pyproject.toml +14 -0
- tgraphx-0.0.1/setup.cfg +4 -0
- tgraphx-0.0.1/setup.py +11 -0
- tgraphx-0.0.1/tgraphx/__init__.py +1 -0
- tgraphx-0.0.1/tgraphx.egg-info/PKG-INFO +425 -0
- tgraphx-0.0.1/tgraphx.egg-info/SOURCES.txt +8 -0
- tgraphx-0.0.1/tgraphx.egg-info/dependency_links.txt +1 -0
- tgraphx-0.0.1/tgraphx.egg-info/top_level.txt +1 -0
tgraphx-0.0.1/PKG-INFO
ADDED
|
@@ -0,0 +1,425 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: tgraphx
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Early placeholder for TGraphX PyPI reservation
|
|
5
|
+
Author: Arash Sajjadi
|
|
6
|
+
Author-email: Arash Sajjadi <arash.sajjadi@usask.ca>
|
|
7
|
+
License: MIT
|
|
8
|
+
Requires-Python: >=3.8
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
10
|
+
Dynamic: author
|
|
11
|
+
Dynamic: requires-python
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# TGraphX
|
|
16
|
+
|
|
17
|
+
TGraphX is a **PyTorch**-based framework for building Graph Neural Networks (GNNs) that work with node and edge features of any dimension while retaining their **spatial layout**. The code is designed for flexibility, easy GPU-acceleration, and rapid prototyping of new GNN ideas, **especially** those that need to preserve local spatial details (e.g., image or volumetric patches).
|
|
18
|
+
|
|
19
|
+
đź“„ **Preprint**: [TGraphX: Tensor-Aware Graph Neural Network for Multi-Dimensional Feature Learning](https://arxiv.org/abs/2504.03953)
|
|
20
|
+
✏️ *Authors: Arash Sajjadi, Mark Eramian*
|
|
21
|
+
🗓️ *Published on arXiv, April 2025*
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
> **Note:** TGraphX includes optional skip connections that help with
|
|
25
|
+
> stable gradient flow in deeper GNN stacks. The overall design is rooted
|
|
26
|
+
> in rigorous theoretical and practical foundations, aiming to unify
|
|
27
|
+
> convolutional neural networks (CNNs) with GNN-based relational reasoning.
|
|
28
|
+
|
|
29
|
+
---
|
|
30
|
+
## Table of Contents
|
|
31
|
+
|
|
32
|
+
- [Overview](#overview)
|
|
33
|
+
- [Key Features](#key-features)
|
|
34
|
+
- [Architecture Highlights](#architecture-highlights)
|
|
35
|
+
- [Preserving Spatial Fidelity](#preserving-spatial-fidelity)
|
|
36
|
+
- [Convolution-Based Message Passing](#convolution-based-message-passing)
|
|
37
|
+
- [Deep CNN Aggregator with Residuals](#deep-cnn-aggregator-with-residuals)
|
|
38
|
+
- [End-to-End Differentiability](#end-to-end-differentiability)
|
|
39
|
+
- [Future Works](#future-works)
|
|
40
|
+
- [Installation](#installation)
|
|
41
|
+
- [Folder Structure](#folder-structure)
|
|
42
|
+
- [Core Components](#core-components)
|
|
43
|
+
- [Layers](#layers)
|
|
44
|
+
- [Models](#models)
|
|
45
|
+
- [Configuration Options](#configuration-options)
|
|
46
|
+
- [Advanced Topics](#advanced-topics)
|
|
47
|
+
- [Novelties and Contributions](#novelties-and-contributions)
|
|
48
|
+
- [Conclusion](#conclusion)
|
|
49
|
+
- [Citations](#citations)
|
|
50
|
+
- [License](#license)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
---
|
|
54
|
+
|
|
55
|
+
## Overview
|
|
56
|
+
|
|
57
|
+
TGraphX provides a modular way to create GNNs by combining several components:
|
|
58
|
+
|
|
59
|
+
1. **Graph Representation**
|
|
60
|
+
A `Graph` class holds node features, edge indices, and optional edge features. Unlike traditional GNNs where node features are vectors, TGraphX supports multi-dimensional features such as `[C, H, W]` tensors—making it particularly effective for vision tasks.
|
|
61
|
+
|
|
62
|
+
2. **Message Passing Layers**
|
|
63
|
+
Customizable layers process messages between nodes *while preserving the spatial layout of features*. In TGraphX:
|
|
64
|
+
- **ConvMessagePassing** uses `1Ă—1` convolutions on concatenated spatial features (e.g., `Conv1Ă—1(Concat(Xi, Xj, Eij))`).
|
|
65
|
+
- **DeepCNNAggregator** is a deep CNN (default 4 layers) that refines aggregated messages, keeping their spatial structure intact (i.e., `[C, H, W]` shape).
|
|
66
|
+
|
|
67
|
+
3. **Models**
|
|
68
|
+
Pre-built models combine a CNN encoder with GNN layers:
|
|
69
|
+
- **CNN Encoder** processes raw image patches into spatial feature maps (e.g., `[C, H, W]`).
|
|
70
|
+
- **Optional Pre-Encoder** (e.g., ResNet-like) can be enabled to further refine raw patches before the main CNN encoder.
|
|
71
|
+
- **Unified CNN‑GNN Model** uses CNN encoders for local features and GNN layers for global relational reasoning, then pools the result for final classification.
|
|
72
|
+
- An extra *skip connection* (if enabled) merges the raw CNN patch output with the GNN output for better gradient flow and more stable learning.
|
|
73
|
+
|
|
74
|
+
---
|
|
75
|
+
## Key Features
|
|
76
|
+
|
|
77
|
+
- **Support for Arbitrary Dimensions**
|
|
78
|
+
Handle vectors, 2D images, or even volumetric 3D patches as node features.
|
|
79
|
+
|
|
80
|
+
- **Spatial Message Passing**
|
|
81
|
+
Messages preserve spatial dimensions (e.g., `[C, H, W]`), letting convolutional filters capture local patterns and avoid destructive flattening of features.
|
|
82
|
+
|
|
83
|
+
- **Deep Aggregation**
|
|
84
|
+
A deep CNN aggregator (with multiple `3×3` convolutions, batch normalization, dropout, and ReLU) refines messages across multiple hops, enabling better local–global fusion.
|
|
85
|
+
|
|
86
|
+
- **Optional Pre‑Encoder**
|
|
87
|
+
Pre-process raw patches with a ResNet-like module (or even load a pretrained ResNet-18) to enrich features before the main GNN pipeline.
|
|
88
|
+
|
|
89
|
+
- **Flexible Data Loading**
|
|
90
|
+
TGraphX includes custom dataset and data loader classes (`GraphDataset` and `GraphDataLoader`) for direct graph-based batching.
|
|
91
|
+
|
|
92
|
+
- **Configurable Skip Connections**
|
|
93
|
+
Enable or disable skip connections that pass CNN outputs directly into the final stages, improving gradient flow.
|
|
94
|
+
|
|
95
|
+
---
|
|
96
|
+
|
|
97
|
+
## Architecture Highlights
|
|
98
|
+
|
|
99
|
+
### Preserving Spatial Fidelity
|
|
100
|
+
Unlike conventional GNNs that flatten node features into vectors, TGraphX retains the full spatial layout `[C, H, W]` at each node. This ensures that local pixel-level (or voxel-level) structure, which is crucial for vision tasks, remains intact throughout the message passing process.
|
|
101
|
+
|
|
102
|
+
### Convolution-Based Message Passing
|
|
103
|
+
TGraphX implements message passing via `Conv1Ă—1(Concat(Xi, Xj, Eij))`. This approach:
|
|
104
|
+
- Respects spatial alignment (i.e., each spatial location in one node’s feature map can directly interact with the same location in its neighbors’ feature maps).
|
|
105
|
+
- Preserves the dimension `[C, H, W]`, avoiding vector flattening.
|
|
106
|
+
- Optionally incorporates edge features `Eij` for more advanced relational cues (e.g., distances, bounding-box overlaps).
|
|
107
|
+
|
|
108
|
+
### Deep CNN Aggregator with Residuals
|
|
109
|
+
Messages from neighbors are aggregated (summed or averaged) and then passed to a **deep CNN aggregator** that uses multiple `3Ă—3` convolutions with *residual skips*. This design:
|
|
110
|
+
- Prevents the overwriting of original features by always adding `Aggregator(mj)` to the old node state `Xj`.
|
|
111
|
+
- Facilitates stable gradient flow in deep GNN stacks.
|
|
112
|
+
- Broadens the effective receptive field in feature space, capturing both local patches and more distant interactions.
|
|
113
|
+
|
|
114
|
+
### End-to-End Differentiability
|
|
115
|
+
Every stage of TGraphX—patch extraction, optional pre-encoder, CNN encoder, graph construction, message passing, aggregation, and classification—remains **fully differentiable** in PyTorch. This end-to-end design simplifies model development, parameter tuning, and experimentation with novel GNN layers.
|
|
116
|
+
|
|
117
|
+
---
|
|
118
|
+
|
|
119
|
+
## Future Works
|
|
120
|
+
|
|
121
|
+
- **Scalability and Data Requirements**
|
|
122
|
+
Adapting TGraphX to higher-resolution inputs or massive datasets (e.g., MS COCO) may require further optimizations, including efficient graph construction or pruning strategies.
|
|
123
|
+
|
|
124
|
+
- **Domain-Specific Customization**
|
|
125
|
+
Some tasks might not need full spatial fidelity at every message-passing step. Researchers could explore ways to selectively reduce resolution or apply specialized convolutions to different node subsets.
|
|
126
|
+
|
|
127
|
+
- **Alternative Edge Definitions**
|
|
128
|
+
Learned adjacency or richer spatial features (e.g., IoU or geometric cues) can further improve performance in complex scenes.
|
|
129
|
+
|
|
130
|
+
- **Multimodal and Real-Time Extensions**
|
|
131
|
+
Integrating TGraphX with sensor data or text embeddings could enable richer reasoning for applications like autonomous driving or real-time video surveillance.
|
|
132
|
+
|
|
133
|
+
---
|
|
134
|
+
## Installation
|
|
135
|
+
|
|
136
|
+
1. **Clone the Repository**
|
|
137
|
+
```bash
|
|
138
|
+
git clone https://github.com/YourUsername/TGraphX.git
|
|
139
|
+
cd TGraphX
|
|
140
|
+
```
|
|
141
|
+
|
|
142
|
+
2. **Set Up the Environment**
|
|
143
|
+
Use the provided `environment.yml` to create a conda environment:
|
|
144
|
+
```bash
|
|
145
|
+
conda env create -f environment.yml
|
|
146
|
+
conda activate tgraphx
|
|
147
|
+
```
|
|
148
|
+
|
|
149
|
+
3. **Install PyTorch**
|
|
150
|
+
Install a recent version of [PyTorch](https://pytorch.org/) (GPU version if possible).
|
|
151
|
+
|
|
152
|
+
4. **Install Additional Dependencies**
|
|
153
|
+
```bash
|
|
154
|
+
pip install -r requirements.txt
|
|
155
|
+
```
|
|
156
|
+
|
|
157
|
+
5. **Editable Mode (Optional)**
|
|
158
|
+
```bash
|
|
159
|
+
pip install -e .
|
|
160
|
+
```
|
|
161
|
+
|
|
162
|
+
---
|
|
163
|
+
|
|
164
|
+
## Folder Structure
|
|
165
|
+
|
|
166
|
+
```
|
|
167
|
+
TGraphX/
|
|
168
|
+
├── __init__.py
|
|
169
|
+
├── core/
|
|
170
|
+
│ ├── dataloader.py
|
|
171
|
+
│ ├── graph.py
|
|
172
|
+
│ └── utils.py
|
|
173
|
+
├── layers/
|
|
174
|
+
│ ├── aggregator.py
|
|
175
|
+
│ ├── attention_message.py
|
|
176
|
+
│ ├── base.py
|
|
177
|
+
│ ├── conv_message.py
|
|
178
|
+
│ └── safe_pool.py
|
|
179
|
+
├── models/
|
|
180
|
+
│ ├── cnn_encoder.py
|
|
181
|
+
│ ├── cnn_gnn_model.py
|
|
182
|
+
│ ├── graph_classifier.py
|
|
183
|
+
│ ├── node_classifier.py
|
|
184
|
+
│ └── pre_encoder.py
|
|
185
|
+
├── environment.yml
|
|
186
|
+
└── README.md
|
|
187
|
+
```
|
|
188
|
+
|
|
189
|
+
---
|
|
190
|
+
|
|
191
|
+
## Core Components
|
|
192
|
+
|
|
193
|
+
### Graph and Data Loading
|
|
194
|
+
|
|
195
|
+
- **`Graph` & `GraphBatch`**
|
|
196
|
+
Represent individual graphs (nodes, edges) and batches of graphs. The batch version offsets node indices to avoid collisions, allowing parallel processing in PyTorch.
|
|
197
|
+
|
|
198
|
+
- **`GraphDataset` & `GraphDataLoader`**
|
|
199
|
+
Custom dataset and data loader classes that streamline the creation of graph batches from a set of images, patches, or other structured data.
|
|
200
|
+
|
|
201
|
+
### Utility Functions
|
|
202
|
+
|
|
203
|
+
- **`load_config`**
|
|
204
|
+
Load YAML/JSON configuration files to keep hyperparameters consistent across experiments.
|
|
205
|
+
|
|
206
|
+
- **`get_device`**
|
|
207
|
+
Utility to automatically detect and return the correct device (GPU or CPU).
|
|
208
|
+
|
|
209
|
+
---
|
|
210
|
+
|
|
211
|
+
## Layers
|
|
212
|
+
|
|
213
|
+
### Base Layer
|
|
214
|
+
|
|
215
|
+
- **`TensorMessagePassingLayer`**
|
|
216
|
+
An abstract base class that defines the interface (message, aggregate, update steps) for all message passing. Crucially, it handles multi-dimensional node features (e.g., `[C, H, W]`).
|
|
217
|
+
|
|
218
|
+
### Convolution-Based Message Passing
|
|
219
|
+
|
|
220
|
+
- **`ConvMessagePassing`**
|
|
221
|
+
Concatenates source and target node feature maps (plus optional edge features) along the channel dimension and applies a `1Ă—1` convolution:
|
|
222
|
+
```python
|
|
223
|
+
Mij = Conv1Ă—1(Concat(Xi, Xj, Eij))
|
|
224
|
+
```
|
|
225
|
+
- **Message Phase**: Each pair `(i, j)` of nodes exchanges messages computed by a `1Ă—1` conv.
|
|
226
|
+
- **Aggregation + Residual Update**: After summing messages from neighbors, a deep CNN aggregator processes the sum, and the original node features are updated via a **residual skip**.
|
|
227
|
+
|
|
228
|
+
### Deep CNN Aggregator
|
|
229
|
+
|
|
230
|
+
- **`DeepCNNAggregator`**
|
|
231
|
+
A stack of `3Ă—3` convolutional layers with batch normalization, ReLU, and dropout. It refines the aggregated messages:
|
|
232
|
+
```python
|
|
233
|
+
X'_j = X_j + A( m_j )
|
|
234
|
+
```
|
|
235
|
+
where `m_j = sum of messages to node j`. Residual connections ensure stable gradient flow.
|
|
236
|
+
|
|
237
|
+
### Attention-Based Message Passing
|
|
238
|
+
|
|
239
|
+
- **`AttentionMessagePassing`**
|
|
240
|
+
An alternative that uses `1Ă—1` convolutions to compute query, key, and value maps for each node. Spatial alignment is preserved while attention weights scale incoming messages. Useful for tasks needing dynamic connectivity or weighting.
|
|
241
|
+
|
|
242
|
+
### Safe Pooling
|
|
243
|
+
|
|
244
|
+
- **`SafeMaxPool2d`**
|
|
245
|
+
A robust pooling module that checks if spatial dimensions `[H, W]` are large enough before applying max pooling. Prevents dimension mismatch errors in deeper aggregator stacks.
|
|
246
|
+
|
|
247
|
+
---
|
|
248
|
+
|
|
249
|
+
## Models
|
|
250
|
+
|
|
251
|
+
### CNN Encoder and Pre-Encoder
|
|
252
|
+
|
|
253
|
+
- **`CNNEncoder`**
|
|
254
|
+
Converts raw patches (`[C_in, patch_H, patch_W]`) into *spatial feature maps* (e.g., `[C_out, H2, W2]`). Includes:
|
|
255
|
+
- Multiple 3Ă—3 conv blocks with BN, ReLU, and dropout.
|
|
256
|
+
- Optional residual connections.
|
|
257
|
+
- Safe max pooling if the spatial size remains large.
|
|
258
|
+
|
|
259
|
+
- **Optional Pre‑Encoder**
|
|
260
|
+
- If `use_preencoder` is `True`, a **ResNet‑like** (or fully custom) module first processes each patch, returning refined features with the same spatial structure.
|
|
261
|
+
- `pretrained_resnet` can load weights from a standard ResNet‑18 for transfer learning.
|
|
262
|
+
|
|
263
|
+
### Unified CNN‑GNN Model
|
|
264
|
+
|
|
265
|
+
- **`CNN_GNN_Model`**
|
|
266
|
+
A single pipeline that:
|
|
267
|
+
1. Splits the image into patches, optionally uses `PreEncoder`.
|
|
268
|
+
2. Feeds patches into `CNNEncoder` to get `[C, H, W]` maps.
|
|
269
|
+
3. Builds a graph where each node holds a `[C, H, W]` map.
|
|
270
|
+
4. Applies multiple GNN layers (like `ConvMessagePassing` + `DeepCNNAggregator`).
|
|
271
|
+
5. Optionally uses a skip connection to combine CNN outputs with GNN outputs.
|
|
272
|
+
6. Performs final spatial pooling before classification.
|
|
273
|
+
|
|
274
|
+
### Graph & Node Classification Models
|
|
275
|
+
|
|
276
|
+
- **`GraphClassifier`**
|
|
277
|
+
Intended for graph-level tasks (e.g., classification of an entire image or object ensemble). Combines message passing with a final pooling layer (mean, max, or attention) over nodes, then feeds the result into a classifier.
|
|
278
|
+
|
|
279
|
+
- **`NodeClassifier`**
|
|
280
|
+
Suitable for node-level tasks (e.g., labeling each patch or region). Stacks simpler message passing layers for classification on each node separately.
|
|
281
|
+
|
|
282
|
+
---
|
|
283
|
+
|
|
284
|
+
## Configuration Options
|
|
285
|
+
|
|
286
|
+
TGraphX is highly configurable. Some key parameters include:
|
|
287
|
+
|
|
288
|
+
```python
|
|
289
|
+
config = {
|
|
290
|
+
"cnn_params": {
|
|
291
|
+
"in_channels": 3,
|
|
292
|
+
"out_features": 64,
|
|
293
|
+
"num_layers": 2,
|
|
294
|
+
"hidden_channels": 64,
|
|
295
|
+
"dropout_prob": 0.3,
|
|
296
|
+
"use_batchnorm": True,
|
|
297
|
+
"use_residual": True,
|
|
298
|
+
"pool_layers": 2,
|
|
299
|
+
"debug": False,
|
|
300
|
+
"return_feature_map": True
|
|
301
|
+
},
|
|
302
|
+
"use_preencoder": False,
|
|
303
|
+
"pretrained_resnet": False,
|
|
304
|
+
"preencoder_params": {
|
|
305
|
+
"in_channels": 3,
|
|
306
|
+
"out_channels": 32,
|
|
307
|
+
"hidden_channels": 32
|
|
308
|
+
},
|
|
309
|
+
"gnn_in_dim": (64, 5, 5),
|
|
310
|
+
"gnn_hidden_dim": (128, 5, 5),
|
|
311
|
+
"num_classes": 10,
|
|
312
|
+
"num_gnn_layers": 4,
|
|
313
|
+
"gnn_dropout": 0.3,
|
|
314
|
+
"residual": True,
|
|
315
|
+
"aggregator_params": {
|
|
316
|
+
"num_layers": 4,
|
|
317
|
+
"dropout_prob": 0.3,
|
|
318
|
+
"use_batchnorm": True
|
|
319
|
+
}
|
|
320
|
+
}
|
|
321
|
+
```
|
|
322
|
+
|
|
323
|
+
- **`cnn_params`**: Controls the CNN encoder architecture (e.g., channels, dropout, pooling).
|
|
324
|
+
- **`use_preencoder`**: Boolean indicating whether to preprocess patches with a custom or pretrained module.
|
|
325
|
+
- **`pretrained_resnet`**: If `True`, loads pretrained ResNet-18 weights in the pre-encoder.
|
|
326
|
+
- **`gnn_in_dim`, `gnn_hidden_dim`**: Shapes of the node features in GNN layers. Each dimension can be `[C, H, W]`.
|
|
327
|
+
- **`num_gnn_layers`**: How many message passing layers to stack.
|
|
328
|
+
- **`aggregator_params`**: Depth, dropout, and BN usage in the aggregator.
|
|
329
|
+
- **`residual`**: Enables skip connections in the GNN layers.
|
|
330
|
+
|
|
331
|
+
---
|
|
332
|
+
|
|
333
|
+
## Advanced Topics
|
|
334
|
+
|
|
335
|
+
### Theoretical Insights
|
|
336
|
+
|
|
337
|
+
1. **Universal Approximation via Deep CNN**
|
|
338
|
+
Stacking multiple convolutional layers with residual skips (in both the CNN encoder and the aggregator) enhances the effective receptive field and helps approximate complex local feature maps.
|
|
339
|
+
|
|
340
|
+
2. **Residual Learning for Gradient Flow**
|
|
341
|
+
Residual connections in both the CNN encoder and aggregator mitigate vanishing gradients, allowing deeper structures to train effectively end-to-end.
|
|
342
|
+
|
|
343
|
+
3. **Spatial vs. Flattened Features**
|
|
344
|
+
Preserving the `[C, H, W]` layout at each node addresses a key limitation in conventional GNNs—loss of local spatial semantics. TGraphX’s design is grounded in the observation that many vision tasks require capturing fine-grained local details alongside global relational structures.
|
|
345
|
+
|
|
346
|
+
### Possible Extensions
|
|
347
|
+
|
|
348
|
+
- **Adaptive Edge Construction**
|
|
349
|
+
Dynamically compute adjacency based on patch similarity or learned attention, rather than fixed proximity thresholds.
|
|
350
|
+
|
|
351
|
+
- **Mixed Modalities**
|
|
352
|
+
Combine image data with textual or numerical features by storing them as separate channels or separate GNN streams.
|
|
353
|
+
|
|
354
|
+
- **Task-Specific Losses**
|
|
355
|
+
Add auxiliary losses (e.g., bounding-box IoU or segmentation overlap) for detection or segmentation tasks, integrated into the GNN training loop.
|
|
356
|
+
|
|
357
|
+
- **Performance Optimizations**
|
|
358
|
+
Use group convolutions or low-rank factorization in the aggregator to reduce memory and computational overhead.
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
---
|
|
362
|
+
|
|
363
|
+
## Novelties and Contributions
|
|
364
|
+
|
|
365
|
+
TGraphX departs from traditional GNN designs in several ways:
|
|
366
|
+
|
|
367
|
+
1. **Full Spatial Fidelity**
|
|
368
|
+
Each node in the graph remains a *multi-dimensional* feature map rather than a flattened vector, preserving local spatial relationships crucial for vision tasks.
|
|
369
|
+
|
|
370
|
+
2. **Convolution-Based Message Passing**
|
|
371
|
+
Employing `1Ă—1` convolutions on `[C, H, W]` feature maps lets neighboring patches exchange information at *every pixel location*, ensuring alignment and detail retention.
|
|
372
|
+
|
|
373
|
+
3. **Deep Residual Aggregation**
|
|
374
|
+
Multiple `3×3` CNN layers in the aggregator—complete with batch normalization, ReLU, dropout, and skip connections—allow the model to fuse multi-hop messages in a stable, expressive manner.
|
|
375
|
+
|
|
376
|
+
4. **End-to-End Differentiable**
|
|
377
|
+
From raw image patches to final classification or detection outputs, **all** steps—CNN feature extraction, graph construction, message passing, and aggregator updates—are trained jointly, strengthening synergy between local feature extraction and relational reasoning.
|
|
378
|
+
|
|
379
|
+
5. **Modular & Extensible**
|
|
380
|
+
- Allows easy substitution of the aggregator or attention-based message passing layers.
|
|
381
|
+
- Accommodates multiple data modalities (image, volumetric, or otherwise).
|
|
382
|
+
- Scales from small graphs (few patches) to larger patch partitions for high-resolution images.
|
|
383
|
+
|
|
384
|
+
These innovations build on earlier GNN research while pushing further to **retain** all the valuable local details that are typically lost in flattened GNN nodes.
|
|
385
|
+
|
|
386
|
+
---
|
|
387
|
+
|
|
388
|
+
## Conclusion
|
|
389
|
+
|
|
390
|
+
We have presented **TGraphX**, an architecture aimed at integrating convolutional neural
|
|
391
|
+
networks (CNNs) and graph neural networks (GNNs) in a way that preserves spatial fidelity.
|
|
392
|
+
By retaining multi-dimensional CNN feature maps as node representations and employing
|
|
393
|
+
convolution-based message passing, TGraphX captures both local and global spatial context.
|
|
394
|
+
Our experiments—particularly those involving detection refinement—demonstrate its potential
|
|
395
|
+
to resolve detection discrepancies and refine localization accuracy in challenging vision tasks.
|
|
396
|
+
|
|
397
|
+
While we do not claim it to be universally optimal for all computer vision scenarios, TGraphX
|
|
398
|
+
offers a flexible framework that other researchers can adapt or extend. This integration of
|
|
399
|
+
CNN-based feature extraction with GNN-based relational reasoning is a promising direction
|
|
400
|
+
for future AI and vision research.
|
|
401
|
+
|
|
402
|
+
---
|
|
403
|
+
## Citations
|
|
404
|
+
|
|
405
|
+
```bibtex
|
|
406
|
+
@misc{sajjadi2025tgraphxtensorawaregraphneural,
|
|
407
|
+
title={TGraphX: Tensor-Aware Graph Neural Network for Multi-Dimensional Feature Learning},
|
|
408
|
+
author={Arash Sajjadi and Mark Eramian},
|
|
409
|
+
year={2025},
|
|
410
|
+
eprint={2504.03953},
|
|
411
|
+
archivePrefix={arXiv},
|
|
412
|
+
primaryClass={cs.CV},
|
|
413
|
+
url={https://arxiv.org/abs/2504.03953},
|
|
414
|
+
}
|
|
415
|
+
```
|
|
416
|
+
---
|
|
417
|
+
|
|
418
|
+
## License
|
|
419
|
+
|
|
420
|
+
TGraphX is released under the [MIT License](https://opensource.org/licenses/MIT). See the `LICENSE` file for more details.
|
|
421
|
+
|
|
422
|
+
---
|
|
423
|
+
|
|
424
|
+
**Enjoy exploring and developing your spatially-aware graph neural networks with TGraphX!**
|
|
425
|
+
If you have any questions, suggestions, or want to contribute, feel free to open an issue or submit a pull request.
|
tgraphx-0.0.1/README.md
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
1
|
+
|
|
2
|
+
|
|
3
|
+
# TGraphX
|
|
4
|
+
|
|
5
|
+
TGraphX is a **PyTorch**-based framework for building Graph Neural Networks (GNNs) that work with node and edge features of any dimension while retaining their **spatial layout**. The code is designed for flexibility, easy GPU-acceleration, and rapid prototyping of new GNN ideas, **especially** those that need to preserve local spatial details (e.g., image or volumetric patches).
|
|
6
|
+
|
|
7
|
+
đź“„ **Preprint**: [TGraphX: Tensor-Aware Graph Neural Network for Multi-Dimensional Feature Learning](https://arxiv.org/abs/2504.03953)
|
|
8
|
+
✏️ *Authors: Arash Sajjadi, Mark Eramian*
|
|
9
|
+
🗓️ *Published on arXiv, April 2025*
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
> **Note:** TGraphX includes optional skip connections that help with
|
|
13
|
+
> stable gradient flow in deeper GNN stacks. The overall design is rooted
|
|
14
|
+
> in rigorous theoretical and practical foundations, aiming to unify
|
|
15
|
+
> convolutional neural networks (CNNs) with GNN-based relational reasoning.
|
|
16
|
+
|
|
17
|
+
---
|
|
18
|
+
## Table of Contents
|
|
19
|
+
|
|
20
|
+
- [Overview](#overview)
|
|
21
|
+
- [Key Features](#key-features)
|
|
22
|
+
- [Architecture Highlights](#architecture-highlights)
|
|
23
|
+
- [Preserving Spatial Fidelity](#preserving-spatial-fidelity)
|
|
24
|
+
- [Convolution-Based Message Passing](#convolution-based-message-passing)
|
|
25
|
+
- [Deep CNN Aggregator with Residuals](#deep-cnn-aggregator-with-residuals)
|
|
26
|
+
- [End-to-End Differentiability](#end-to-end-differentiability)
|
|
27
|
+
- [Future Works](#future-works)
|
|
28
|
+
- [Installation](#installation)
|
|
29
|
+
- [Folder Structure](#folder-structure)
|
|
30
|
+
- [Core Components](#core-components)
|
|
31
|
+
- [Layers](#layers)
|
|
32
|
+
- [Models](#models)
|
|
33
|
+
- [Configuration Options](#configuration-options)
|
|
34
|
+
- [Advanced Topics](#advanced-topics)
|
|
35
|
+
- [Novelties and Contributions](#novelties-and-contributions)
|
|
36
|
+
- [Conclusion](#conclusion)
|
|
37
|
+
- [Citations](#citations)
|
|
38
|
+
- [License](#license)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
---
|
|
42
|
+
|
|
43
|
+
## Overview
|
|
44
|
+
|
|
45
|
+
TGraphX provides a modular way to create GNNs by combining several components:
|
|
46
|
+
|
|
47
|
+
1. **Graph Representation**
|
|
48
|
+
A `Graph` class holds node features, edge indices, and optional edge features. Unlike traditional GNNs where node features are vectors, TGraphX supports multi-dimensional features such as `[C, H, W]` tensors—making it particularly effective for vision tasks.
|
|
49
|
+
|
|
50
|
+
2. **Message Passing Layers**
|
|
51
|
+
Customizable layers process messages between nodes *while preserving the spatial layout of features*. In TGraphX:
|
|
52
|
+
- **ConvMessagePassing** uses `1Ă—1` convolutions on concatenated spatial features (e.g., `Conv1Ă—1(Concat(Xi, Xj, Eij))`).
|
|
53
|
+
- **DeepCNNAggregator** is a deep CNN (default 4 layers) that refines aggregated messages, keeping their spatial structure intact (i.e., `[C, H, W]` shape).
|
|
54
|
+
|
|
55
|
+
3. **Models**
|
|
56
|
+
Pre-built models combine a CNN encoder with GNN layers:
|
|
57
|
+
- **CNN Encoder** processes raw image patches into spatial feature maps (e.g., `[C, H, W]`).
|
|
58
|
+
- **Optional Pre-Encoder** (e.g., ResNet-like) can be enabled to further refine raw patches before the main CNN encoder.
|
|
59
|
+
- **Unified CNN‑GNN Model** uses CNN encoders for local features and GNN layers for global relational reasoning, then pools the result for final classification.
|
|
60
|
+
- An extra *skip connection* (if enabled) merges the raw CNN patch output with the GNN output for better gradient flow and more stable learning.
|
|
61
|
+
|
|
62
|
+
---
|
|
63
|
+
## Key Features
|
|
64
|
+
|
|
65
|
+
- **Support for Arbitrary Dimensions**
|
|
66
|
+
Handle vectors, 2D images, or even volumetric 3D patches as node features.
|
|
67
|
+
|
|
68
|
+
- **Spatial Message Passing**
|
|
69
|
+
Messages preserve spatial dimensions (e.g., `[C, H, W]`), letting convolutional filters capture local patterns and avoid destructive flattening of features.
|
|
70
|
+
|
|
71
|
+
- **Deep Aggregation**
|
|
72
|
+
A deep CNN aggregator (with multiple `3×3` convolutions, batch normalization, dropout, and ReLU) refines messages across multiple hops, enabling better local–global fusion.
|
|
73
|
+
|
|
74
|
+
- **Optional Pre‑Encoder**
|
|
75
|
+
Pre-process raw patches with a ResNet-like module (or even load a pretrained ResNet-18) to enrich features before the main GNN pipeline.
|
|
76
|
+
|
|
77
|
+
- **Flexible Data Loading**
|
|
78
|
+
TGraphX includes custom dataset and data loader classes (`GraphDataset` and `GraphDataLoader`) for direct graph-based batching.
|
|
79
|
+
|
|
80
|
+
- **Configurable Skip Connections**
|
|
81
|
+
Enable or disable skip connections that pass CNN outputs directly into the final stages, improving gradient flow.
|
|
82
|
+
|
|
83
|
+
---
|
|
84
|
+
|
|
85
|
+
## Architecture Highlights
|
|
86
|
+
|
|
87
|
+
### Preserving Spatial Fidelity
|
|
88
|
+
Unlike conventional GNNs that flatten node features into vectors, TGraphX retains the full spatial layout `[C, H, W]` at each node. This ensures that local pixel-level (or voxel-level) structure, which is crucial for vision tasks, remains intact throughout the message passing process.
|
|
89
|
+
|
|
90
|
+
### Convolution-Based Message Passing
|
|
91
|
+
TGraphX implements message passing via `Conv1Ă—1(Concat(Xi, Xj, Eij))`. This approach:
|
|
92
|
+
- Respects spatial alignment (i.e., each spatial location in one node’s feature map can directly interact with the same location in its neighbors’ feature maps).
|
|
93
|
+
- Preserves the dimension `[C, H, W]`, avoiding vector flattening.
|
|
94
|
+
- Optionally incorporates edge features `Eij` for more advanced relational cues (e.g., distances, bounding-box overlaps).
|
|
95
|
+
|
|
96
|
+
### Deep CNN Aggregator with Residuals
|
|
97
|
+
Messages from neighbors are aggregated (summed or averaged) and then passed to a **deep CNN aggregator** that uses multiple `3Ă—3` convolutions with *residual skips*. This design:
|
|
98
|
+
- Prevents the overwriting of original features by always adding `Aggregator(mj)` to the old node state `Xj`.
|
|
99
|
+
- Facilitates stable gradient flow in deep GNN stacks.
|
|
100
|
+
- Broadens the effective receptive field in feature space, capturing both local patches and more distant interactions.
|
|
101
|
+
|
|
102
|
+
### End-to-End Differentiability
|
|
103
|
+
Every stage of TGraphX—patch extraction, optional pre-encoder, CNN encoder, graph construction, message passing, aggregation, and classification—remains **fully differentiable** in PyTorch. This end-to-end design simplifies model development, parameter tuning, and experimentation with novel GNN layers.
|
|
104
|
+
|
|
105
|
+
---
|
|
106
|
+
|
|
107
|
+
## Future Works
|
|
108
|
+
|
|
109
|
+
- **Scalability and Data Requirements**
|
|
110
|
+
Adapting TGraphX to higher-resolution inputs or massive datasets (e.g., MS COCO) may require further optimizations, including efficient graph construction or pruning strategies.
|
|
111
|
+
|
|
112
|
+
- **Domain-Specific Customization**
|
|
113
|
+
Some tasks might not need full spatial fidelity at every message-passing step. Researchers could explore ways to selectively reduce resolution or apply specialized convolutions to different node subsets.
|
|
114
|
+
|
|
115
|
+
- **Alternative Edge Definitions**
|
|
116
|
+
Learned adjacency or richer spatial features (e.g., IoU or geometric cues) can further improve performance in complex scenes.
|
|
117
|
+
|
|
118
|
+
- **Multimodal and Real-Time Extensions**
|
|
119
|
+
Integrating TGraphX with sensor data or text embeddings could enable richer reasoning for applications like autonomous driving or real-time video surveillance.
|
|
120
|
+
|
|
121
|
+
---
|
|
122
|
+
## Installation
|
|
123
|
+
|
|
124
|
+
1. **Clone the Repository**
|
|
125
|
+
```bash
|
|
126
|
+
git clone https://github.com/YourUsername/TGraphX.git
|
|
127
|
+
cd TGraphX
|
|
128
|
+
```
|
|
129
|
+
|
|
130
|
+
2. **Set Up the Environment**
|
|
131
|
+
Use the provided `environment.yml` to create a conda environment:
|
|
132
|
+
```bash
|
|
133
|
+
conda env create -f environment.yml
|
|
134
|
+
conda activate tgraphx
|
|
135
|
+
```
|
|
136
|
+
|
|
137
|
+
3. **Install PyTorch**
|
|
138
|
+
Install a recent version of [PyTorch](https://pytorch.org/) (GPU version if possible).
|
|
139
|
+
|
|
140
|
+
4. **Install Additional Dependencies**
|
|
141
|
+
```bash
|
|
142
|
+
pip install -r requirements.txt
|
|
143
|
+
```
|
|
144
|
+
|
|
145
|
+
5. **Editable Mode (Optional)**
|
|
146
|
+
```bash
|
|
147
|
+
pip install -e .
|
|
148
|
+
```
|
|
149
|
+
|
|
150
|
+
---
|
|
151
|
+
|
|
152
|
+
## Folder Structure
|
|
153
|
+
|
|
154
|
+
```
|
|
155
|
+
TGraphX/
|
|
156
|
+
├── __init__.py
|
|
157
|
+
├── core/
|
|
158
|
+
│ ├── dataloader.py
|
|
159
|
+
│ ├── graph.py
|
|
160
|
+
│ └── utils.py
|
|
161
|
+
├── layers/
|
|
162
|
+
│ ├── aggregator.py
|
|
163
|
+
│ ├── attention_message.py
|
|
164
|
+
│ ├── base.py
|
|
165
|
+
│ ├── conv_message.py
|
|
166
|
+
│ └── safe_pool.py
|
|
167
|
+
├── models/
|
|
168
|
+
│ ├── cnn_encoder.py
|
|
169
|
+
│ ├── cnn_gnn_model.py
|
|
170
|
+
│ ├── graph_classifier.py
|
|
171
|
+
│ ├── node_classifier.py
|
|
172
|
+
│ └── pre_encoder.py
|
|
173
|
+
├── environment.yml
|
|
174
|
+
└── README.md
|
|
175
|
+
```
|
|
176
|
+
|
|
177
|
+
---
|
|
178
|
+
|
|
179
|
+
## Core Components
|
|
180
|
+
|
|
181
|
+
### Graph and Data Loading
|
|
182
|
+
|
|
183
|
+
- **`Graph` & `GraphBatch`**
|
|
184
|
+
Represent individual graphs (nodes, edges) and batches of graphs. The batch version offsets node indices to avoid collisions, allowing parallel processing in PyTorch.
|
|
185
|
+
|
|
186
|
+
- **`GraphDataset` & `GraphDataLoader`**
|
|
187
|
+
Custom dataset and data loader classes that streamline the creation of graph batches from a set of images, patches, or other structured data.
|
|
188
|
+
|
|
189
|
+
### Utility Functions
|
|
190
|
+
|
|
191
|
+
- **`load_config`**
|
|
192
|
+
Load YAML/JSON configuration files to keep hyperparameters consistent across experiments.
|
|
193
|
+
|
|
194
|
+
- **`get_device`**
|
|
195
|
+
Utility to automatically detect and return the correct device (GPU or CPU).
|
|
196
|
+
|
|
197
|
+
---
|
|
198
|
+
|
|
199
|
+
## Layers
|
|
200
|
+
|
|
201
|
+
### Base Layer
|
|
202
|
+
|
|
203
|
+
- **`TensorMessagePassingLayer`**
|
|
204
|
+
An abstract base class that defines the interface (message, aggregate, update steps) for all message passing. Crucially, it handles multi-dimensional node features (e.g., `[C, H, W]`).
|
|
205
|
+
|
|
206
|
+
### Convolution-Based Message Passing
|
|
207
|
+
|
|
208
|
+
- **`ConvMessagePassing`**
|
|
209
|
+
Concatenates source and target node feature maps (plus optional edge features) along the channel dimension and applies a `1Ă—1` convolution:
|
|
210
|
+
```python
|
|
211
|
+
Mij = Conv1Ă—1(Concat(Xi, Xj, Eij))
|
|
212
|
+
```
|
|
213
|
+
- **Message Phase**: Each pair `(i, j)` of nodes exchanges messages computed by a `1Ă—1` conv.
|
|
214
|
+
- **Aggregation + Residual Update**: After summing messages from neighbors, a deep CNN aggregator processes the sum, and the original node features are updated via a **residual skip**.
|
|
215
|
+
|
|
216
|
+
### Deep CNN Aggregator
|
|
217
|
+
|
|
218
|
+
- **`DeepCNNAggregator`**
|
|
219
|
+
A stack of `3Ă—3` convolutional layers with batch normalization, ReLU, and dropout. It refines the aggregated messages:
|
|
220
|
+
```python
|
|
221
|
+
X'_j = X_j + A( m_j )
|
|
222
|
+
```
|
|
223
|
+
where `m_j = sum of messages to node j`. Residual connections ensure stable gradient flow.
|
|
224
|
+
|
|
225
|
+
### Attention-Based Message Passing
|
|
226
|
+
|
|
227
|
+
- **`AttentionMessagePassing`**
|
|
228
|
+
An alternative that uses `1Ă—1` convolutions to compute query, key, and value maps for each node. Spatial alignment is preserved while attention weights scale incoming messages. Useful for tasks needing dynamic connectivity or weighting.
|
|
229
|
+
|
|
230
|
+
### Safe Pooling
|
|
231
|
+
|
|
232
|
+
- **`SafeMaxPool2d`**
|
|
233
|
+
A robust pooling module that checks if spatial dimensions `[H, W]` are large enough before applying max pooling. Prevents dimension mismatch errors in deeper aggregator stacks.
|
|
234
|
+
|
|
235
|
+
---
|
|
236
|
+
|
|
237
|
+
## Models
|
|
238
|
+
|
|
239
|
+
### CNN Encoder and Pre-Encoder
|
|
240
|
+
|
|
241
|
+
- **`CNNEncoder`**
|
|
242
|
+
Converts raw patches (`[C_in, patch_H, patch_W]`) into *spatial feature maps* (e.g., `[C_out, H2, W2]`). Includes:
|
|
243
|
+
- Multiple 3Ă—3 conv blocks with BN, ReLU, and dropout.
|
|
244
|
+
- Optional residual connections.
|
|
245
|
+
- Safe max pooling if the spatial size remains large.
|
|
246
|
+
|
|
247
|
+
- **Optional Pre‑Encoder**
|
|
248
|
+
- If `use_preencoder` is `True`, a **ResNet‑like** (or fully custom) module first processes each patch, returning refined features with the same spatial structure.
|
|
249
|
+
- `pretrained_resnet` can load weights from a standard ResNet‑18 for transfer learning.
|
|
250
|
+
|
|
251
|
+
### Unified CNN‑GNN Model
|
|
252
|
+
|
|
253
|
+
- **`CNN_GNN_Model`**
|
|
254
|
+
A single pipeline that:
|
|
255
|
+
1. Splits the image into patches, optionally uses `PreEncoder`.
|
|
256
|
+
2. Feeds patches into `CNNEncoder` to get `[C, H, W]` maps.
|
|
257
|
+
3. Builds a graph where each node holds a `[C, H, W]` map.
|
|
258
|
+
4. Applies multiple GNN layers (like `ConvMessagePassing` + `DeepCNNAggregator`).
|
|
259
|
+
5. Optionally uses a skip connection to combine CNN outputs with GNN outputs.
|
|
260
|
+
6. Performs final spatial pooling before classification.
|
|
261
|
+
|
|
262
|
+
### Graph & Node Classification Models
|
|
263
|
+
|
|
264
|
+
- **`GraphClassifier`**
|
|
265
|
+
Intended for graph-level tasks (e.g., classification of an entire image or object ensemble). Combines message passing with a final pooling layer (mean, max, or attention) over nodes, then feeds the result into a classifier.
|
|
266
|
+
|
|
267
|
+
- **`NodeClassifier`**
|
|
268
|
+
Suitable for node-level tasks (e.g., labeling each patch or region). Stacks simpler message passing layers for classification on each node separately.
|
|
269
|
+
|
|
270
|
+
---
|
|
271
|
+
|
|
272
|
+
## Configuration Options
|
|
273
|
+
|
|
274
|
+
TGraphX is highly configurable. Some key parameters include:
|
|
275
|
+
|
|
276
|
+
```python
|
|
277
|
+
config = {
|
|
278
|
+
"cnn_params": {
|
|
279
|
+
"in_channels": 3,
|
|
280
|
+
"out_features": 64,
|
|
281
|
+
"num_layers": 2,
|
|
282
|
+
"hidden_channels": 64,
|
|
283
|
+
"dropout_prob": 0.3,
|
|
284
|
+
"use_batchnorm": True,
|
|
285
|
+
"use_residual": True,
|
|
286
|
+
"pool_layers": 2,
|
|
287
|
+
"debug": False,
|
|
288
|
+
"return_feature_map": True
|
|
289
|
+
},
|
|
290
|
+
"use_preencoder": False,
|
|
291
|
+
"pretrained_resnet": False,
|
|
292
|
+
"preencoder_params": {
|
|
293
|
+
"in_channels": 3,
|
|
294
|
+
"out_channels": 32,
|
|
295
|
+
"hidden_channels": 32
|
|
296
|
+
},
|
|
297
|
+
"gnn_in_dim": (64, 5, 5),
|
|
298
|
+
"gnn_hidden_dim": (128, 5, 5),
|
|
299
|
+
"num_classes": 10,
|
|
300
|
+
"num_gnn_layers": 4,
|
|
301
|
+
"gnn_dropout": 0.3,
|
|
302
|
+
"residual": True,
|
|
303
|
+
"aggregator_params": {
|
|
304
|
+
"num_layers": 4,
|
|
305
|
+
"dropout_prob": 0.3,
|
|
306
|
+
"use_batchnorm": True
|
|
307
|
+
}
|
|
308
|
+
}
|
|
309
|
+
```
|
|
310
|
+
|
|
311
|
+
- **`cnn_params`**: Controls the CNN encoder architecture (e.g., channels, dropout, pooling).
|
|
312
|
+
- **`use_preencoder`**: Boolean indicating whether to preprocess patches with a custom or pretrained module.
|
|
313
|
+
- **`pretrained_resnet`**: If `True`, loads pretrained ResNet-18 weights in the pre-encoder.
|
|
314
|
+
- **`gnn_in_dim`, `gnn_hidden_dim`**: Shapes of the node features in GNN layers. Each dimension can be `[C, H, W]`.
|
|
315
|
+
- **`num_gnn_layers`**: How many message passing layers to stack.
|
|
316
|
+
- **`aggregator_params`**: Depth, dropout, and BN usage in the aggregator.
|
|
317
|
+
- **`residual`**: Enables skip connections in the GNN layers.
|
|
318
|
+
|
|
319
|
+
---
|
|
320
|
+
|
|
321
|
+
## Advanced Topics
|
|
322
|
+
|
|
323
|
+
### Theoretical Insights
|
|
324
|
+
|
|
325
|
+
1. **Universal Approximation via Deep CNN**
|
|
326
|
+
Stacking multiple convolutional layers with residual skips (in both the CNN encoder and the aggregator) enhances the effective receptive field and helps approximate complex local feature maps.
|
|
327
|
+
|
|
328
|
+
2. **Residual Learning for Gradient Flow**
|
|
329
|
+
Residual connections in both the CNN encoder and aggregator mitigate vanishing gradients, allowing deeper structures to train effectively end-to-end.
|
|
330
|
+
|
|
331
|
+
3. **Spatial vs. Flattened Features**
|
|
332
|
+
Preserving the `[C, H, W]` layout at each node addresses a key limitation in conventional GNNs—loss of local spatial semantics. TGraphX’s design is grounded in the observation that many vision tasks require capturing fine-grained local details alongside global relational structures.
|
|
333
|
+
|
|
334
|
+
### Possible Extensions
|
|
335
|
+
|
|
336
|
+
- **Adaptive Edge Construction**
|
|
337
|
+
Dynamically compute adjacency based on patch similarity or learned attention, rather than fixed proximity thresholds.
|
|
338
|
+
|
|
339
|
+
- **Mixed Modalities**
|
|
340
|
+
Combine image data with textual or numerical features by storing them as separate channels or separate GNN streams.
|
|
341
|
+
|
|
342
|
+
- **Task-Specific Losses**
|
|
343
|
+
Add auxiliary losses (e.g., bounding-box IoU or segmentation overlap) for detection or segmentation tasks, integrated into the GNN training loop.
|
|
344
|
+
|
|
345
|
+
- **Performance Optimizations**
|
|
346
|
+
Use group convolutions or low-rank factorization in the aggregator to reduce memory and computational overhead.
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
---
|
|
350
|
+
|
|
351
|
+
## Novelties and Contributions
|
|
352
|
+
|
|
353
|
+
TGraphX departs from traditional GNN designs in several ways:
|
|
354
|
+
|
|
355
|
+
1. **Full Spatial Fidelity**
|
|
356
|
+
Each node in the graph remains a *multi-dimensional* feature map rather than a flattened vector, preserving local spatial relationships crucial for vision tasks.
|
|
357
|
+
|
|
358
|
+
2. **Convolution-Based Message Passing**
|
|
359
|
+
Employing `1Ă—1` convolutions on `[C, H, W]` feature maps lets neighboring patches exchange information at *every pixel location*, ensuring alignment and detail retention.
|
|
360
|
+
|
|
361
|
+
3. **Deep Residual Aggregation**
|
|
362
|
+
Multiple `3×3` CNN layers in the aggregator—complete with batch normalization, ReLU, dropout, and skip connections—allow the model to fuse multi-hop messages in a stable, expressive manner.
|
|
363
|
+
|
|
364
|
+
4. **End-to-End Differentiable**
|
|
365
|
+
From raw image patches to final classification or detection outputs, **all** steps—CNN feature extraction, graph construction, message passing, and aggregator updates—are trained jointly, strengthening synergy between local feature extraction and relational reasoning.
|
|
366
|
+
|
|
367
|
+
5. **Modular & Extensible**
|
|
368
|
+
- Allows easy substitution of the aggregator or attention-based message passing layers.
|
|
369
|
+
- Accommodates multiple data modalities (image, volumetric, or otherwise).
|
|
370
|
+
- Scales from small graphs (few patches) to larger patch partitions for high-resolution images.
|
|
371
|
+
|
|
372
|
+
These innovations build on earlier GNN research while pushing further to **retain** all the valuable local details that are typically lost in flattened GNN nodes.
|
|
373
|
+
|
|
374
|
+
---
|
|
375
|
+
|
|
376
|
+
## Conclusion
|
|
377
|
+
|
|
378
|
+
We have presented **TGraphX**, an architecture aimed at integrating convolutional neural
|
|
379
|
+
networks (CNNs) and graph neural networks (GNNs) in a way that preserves spatial fidelity.
|
|
380
|
+
By retaining multi-dimensional CNN feature maps as node representations and employing
|
|
381
|
+
convolution-based message passing, TGraphX captures both local and global spatial context.
|
|
382
|
+
Our experiments—particularly those involving detection refinement—demonstrate its potential
|
|
383
|
+
to resolve detection discrepancies and refine localization accuracy in challenging vision tasks.
|
|
384
|
+
|
|
385
|
+
While we do not claim it to be universally optimal for all computer vision scenarios, TGraphX
|
|
386
|
+
offers a flexible framework that other researchers can adapt or extend. This integration of
|
|
387
|
+
CNN-based feature extraction with GNN-based relational reasoning is a promising direction
|
|
388
|
+
for future AI and vision research.
|
|
389
|
+
|
|
390
|
+
---
|
|
391
|
+
## Citations
|
|
392
|
+
|
|
393
|
+
```bibtex
|
|
394
|
+
@misc{sajjadi2025tgraphxtensorawaregraphneural,
|
|
395
|
+
title={TGraphX: Tensor-Aware Graph Neural Network for Multi-Dimensional Feature Learning},
|
|
396
|
+
author={Arash Sajjadi and Mark Eramian},
|
|
397
|
+
year={2025},
|
|
398
|
+
eprint={2504.03953},
|
|
399
|
+
archivePrefix={arXiv},
|
|
400
|
+
primaryClass={cs.CV},
|
|
401
|
+
url={https://arxiv.org/abs/2504.03953},
|
|
402
|
+
}
|
|
403
|
+
```
|
|
404
|
+
---
|
|
405
|
+
|
|
406
|
+
## License
|
|
407
|
+
|
|
408
|
+
TGraphX is released under the [MIT License](https://opensource.org/licenses/MIT). See the `LICENSE` file for more details.
|
|
409
|
+
|
|
410
|
+
---
|
|
411
|
+
|
|
412
|
+
**Enjoy exploring and developing your spatially-aware graph neural networks with TGraphX!**
|
|
413
|
+
If you have any questions, suggestions, or want to contribute, feel free to open an issue or submit a pull request.
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "tgraphx"
|
|
7
|
+
version = "0.0.1"
|
|
8
|
+
authors = [
|
|
9
|
+
{name = "Arash Sajjadi", email = "arash.sajjadi@usask.ca"}
|
|
10
|
+
]
|
|
11
|
+
description = "Early placeholder for TGraphX PyPI reservation"
|
|
12
|
+
readme = "README.md"
|
|
13
|
+
requires-python = ">=3.8"
|
|
14
|
+
license = {text = "MIT"}
|
tgraphx-0.0.1/setup.cfg
ADDED
tgraphx-0.0.1/setup.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from setuptools import setup
|
|
2
|
+
|
|
3
|
+
setup(
|
|
4
|
+
name="tgraphx",
|
|
5
|
+
version="0.0.1",
|
|
6
|
+
description="Early placeholder for TGraphX PyPI reservation",
|
|
7
|
+
author="Arash Sajjadi",
|
|
8
|
+
author_email="arash.sajjadi@usask.ca",
|
|
9
|
+
packages=["tgraphx"],
|
|
10
|
+
python_requires=">=3.8",
|
|
11
|
+
)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Placeholder
|
|
@@ -0,0 +1,425 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: tgraphx
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Early placeholder for TGraphX PyPI reservation
|
|
5
|
+
Author: Arash Sajjadi
|
|
6
|
+
Author-email: Arash Sajjadi <arash.sajjadi@usask.ca>
|
|
7
|
+
License: MIT
|
|
8
|
+
Requires-Python: >=3.8
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
10
|
+
Dynamic: author
|
|
11
|
+
Dynamic: requires-python
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# TGraphX
|
|
16
|
+
|
|
17
|
+
TGraphX is a **PyTorch**-based framework for building Graph Neural Networks (GNNs) that work with node and edge features of any dimension while retaining their **spatial layout**. The code is designed for flexibility, easy GPU-acceleration, and rapid prototyping of new GNN ideas, **especially** those that need to preserve local spatial details (e.g., image or volumetric patches).
|
|
18
|
+
|
|
19
|
+
đź“„ **Preprint**: [TGraphX: Tensor-Aware Graph Neural Network for Multi-Dimensional Feature Learning](https://arxiv.org/abs/2504.03953)
|
|
20
|
+
✏️ *Authors: Arash Sajjadi, Mark Eramian*
|
|
21
|
+
🗓️ *Published on arXiv, April 2025*
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
> **Note:** TGraphX includes optional skip connections that help with
|
|
25
|
+
> stable gradient flow in deeper GNN stacks. The overall design is rooted
|
|
26
|
+
> in rigorous theoretical and practical foundations, aiming to unify
|
|
27
|
+
> convolutional neural networks (CNNs) with GNN-based relational reasoning.
|
|
28
|
+
|
|
29
|
+
---
|
|
30
|
+
## Table of Contents
|
|
31
|
+
|
|
32
|
+
- [Overview](#overview)
|
|
33
|
+
- [Key Features](#key-features)
|
|
34
|
+
- [Architecture Highlights](#architecture-highlights)
|
|
35
|
+
- [Preserving Spatial Fidelity](#preserving-spatial-fidelity)
|
|
36
|
+
- [Convolution-Based Message Passing](#convolution-based-message-passing)
|
|
37
|
+
- [Deep CNN Aggregator with Residuals](#deep-cnn-aggregator-with-residuals)
|
|
38
|
+
- [End-to-End Differentiability](#end-to-end-differentiability)
|
|
39
|
+
- [Future Works](#future-works)
|
|
40
|
+
- [Installation](#installation)
|
|
41
|
+
- [Folder Structure](#folder-structure)
|
|
42
|
+
- [Core Components](#core-components)
|
|
43
|
+
- [Layers](#layers)
|
|
44
|
+
- [Models](#models)
|
|
45
|
+
- [Configuration Options](#configuration-options)
|
|
46
|
+
- [Advanced Topics](#advanced-topics)
|
|
47
|
+
- [Novelties and Contributions](#novelties-and-contributions)
|
|
48
|
+
- [Conclusion](#conclusion)
|
|
49
|
+
- [Citations](#citations)
|
|
50
|
+
- [License](#license)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
---
|
|
54
|
+
|
|
55
|
+
## Overview
|
|
56
|
+
|
|
57
|
+
TGraphX provides a modular way to create GNNs by combining several components:
|
|
58
|
+
|
|
59
|
+
1. **Graph Representation**
|
|
60
|
+
A `Graph` class holds node features, edge indices, and optional edge features. Unlike traditional GNNs where node features are vectors, TGraphX supports multi-dimensional features such as `[C, H, W]` tensors—making it particularly effective for vision tasks.
|
|
61
|
+
|
|
62
|
+
2. **Message Passing Layers**
|
|
63
|
+
Customizable layers process messages between nodes *while preserving the spatial layout of features*. In TGraphX:
|
|
64
|
+
- **ConvMessagePassing** uses `1Ă—1` convolutions on concatenated spatial features (e.g., `Conv1Ă—1(Concat(Xi, Xj, Eij))`).
|
|
65
|
+
- **DeepCNNAggregator** is a deep CNN (default 4 layers) that refines aggregated messages, keeping their spatial structure intact (i.e., `[C, H, W]` shape).
|
|
66
|
+
|
|
67
|
+
3. **Models**
|
|
68
|
+
Pre-built models combine a CNN encoder with GNN layers:
|
|
69
|
+
- **CNN Encoder** processes raw image patches into spatial feature maps (e.g., `[C, H, W]`).
|
|
70
|
+
- **Optional Pre-Encoder** (e.g., ResNet-like) can be enabled to further refine raw patches before the main CNN encoder.
|
|
71
|
+
- **Unified CNN‑GNN Model** uses CNN encoders for local features and GNN layers for global relational reasoning, then pools the result for final classification.
|
|
72
|
+
- An extra *skip connection* (if enabled) merges the raw CNN patch output with the GNN output for better gradient flow and more stable learning.
|
|
73
|
+
|
|
74
|
+
---
|
|
75
|
+
## Key Features
|
|
76
|
+
|
|
77
|
+
- **Support for Arbitrary Dimensions**
|
|
78
|
+
Handle vectors, 2D images, or even volumetric 3D patches as node features.
|
|
79
|
+
|
|
80
|
+
- **Spatial Message Passing**
|
|
81
|
+
Messages preserve spatial dimensions (e.g., `[C, H, W]`), letting convolutional filters capture local patterns and avoid destructive flattening of features.
|
|
82
|
+
|
|
83
|
+
- **Deep Aggregation**
|
|
84
|
+
A deep CNN aggregator (with multiple `3×3` convolutions, batch normalization, dropout, and ReLU) refines messages across multiple hops, enabling better local–global fusion.
|
|
85
|
+
|
|
86
|
+
- **Optional Pre‑Encoder**
|
|
87
|
+
Pre-process raw patches with a ResNet-like module (or even load a pretrained ResNet-18) to enrich features before the main GNN pipeline.
|
|
88
|
+
|
|
89
|
+
- **Flexible Data Loading**
|
|
90
|
+
TGraphX includes custom dataset and data loader classes (`GraphDataset` and `GraphDataLoader`) for direct graph-based batching.
|
|
91
|
+
|
|
92
|
+
- **Configurable Skip Connections**
|
|
93
|
+
Enable or disable skip connections that pass CNN outputs directly into the final stages, improving gradient flow.
|
|
94
|
+
|
|
95
|
+
---
|
|
96
|
+
|
|
97
|
+
## Architecture Highlights
|
|
98
|
+
|
|
99
|
+
### Preserving Spatial Fidelity
|
|
100
|
+
Unlike conventional GNNs that flatten node features into vectors, TGraphX retains the full spatial layout `[C, H, W]` at each node. This ensures that local pixel-level (or voxel-level) structure, which is crucial for vision tasks, remains intact throughout the message passing process.
|
|
101
|
+
|
|
102
|
+
### Convolution-Based Message Passing
|
|
103
|
+
TGraphX implements message passing via `Conv1Ă—1(Concat(Xi, Xj, Eij))`. This approach:
|
|
104
|
+
- Respects spatial alignment (i.e., each spatial location in one node’s feature map can directly interact with the same location in its neighbors’ feature maps).
|
|
105
|
+
- Preserves the dimension `[C, H, W]`, avoiding vector flattening.
|
|
106
|
+
- Optionally incorporates edge features `Eij` for more advanced relational cues (e.g., distances, bounding-box overlaps).
|
|
107
|
+
|
|
108
|
+
### Deep CNN Aggregator with Residuals
|
|
109
|
+
Messages from neighbors are aggregated (summed or averaged) and then passed to a **deep CNN aggregator** that uses multiple `3Ă—3` convolutions with *residual skips*. This design:
|
|
110
|
+
- Prevents the overwriting of original features by always adding `Aggregator(mj)` to the old node state `Xj`.
|
|
111
|
+
- Facilitates stable gradient flow in deep GNN stacks.
|
|
112
|
+
- Broadens the effective receptive field in feature space, capturing both local patches and more distant interactions.
|
|
113
|
+
|
|
114
|
+
### End-to-End Differentiability
|
|
115
|
+
Every stage of TGraphX—patch extraction, optional pre-encoder, CNN encoder, graph construction, message passing, aggregation, and classification—remains **fully differentiable** in PyTorch. This end-to-end design simplifies model development, parameter tuning, and experimentation with novel GNN layers.
|
|
116
|
+
|
|
117
|
+
---
|
|
118
|
+
|
|
119
|
+
## Future Works
|
|
120
|
+
|
|
121
|
+
- **Scalability and Data Requirements**
|
|
122
|
+
Adapting TGraphX to higher-resolution inputs or massive datasets (e.g., MS COCO) may require further optimizations, including efficient graph construction or pruning strategies.
|
|
123
|
+
|
|
124
|
+
- **Domain-Specific Customization**
|
|
125
|
+
Some tasks might not need full spatial fidelity at every message-passing step. Researchers could explore ways to selectively reduce resolution or apply specialized convolutions to different node subsets.
|
|
126
|
+
|
|
127
|
+
- **Alternative Edge Definitions**
|
|
128
|
+
Learned adjacency or richer spatial features (e.g., IoU or geometric cues) can further improve performance in complex scenes.
|
|
129
|
+
|
|
130
|
+
- **Multimodal and Real-Time Extensions**
|
|
131
|
+
Integrating TGraphX with sensor data or text embeddings could enable richer reasoning for applications like autonomous driving or real-time video surveillance.
|
|
132
|
+
|
|
133
|
+
---
|
|
134
|
+
## Installation
|
|
135
|
+
|
|
136
|
+
1. **Clone the Repository**
|
|
137
|
+
```bash
|
|
138
|
+
git clone https://github.com/YourUsername/TGraphX.git
|
|
139
|
+
cd TGraphX
|
|
140
|
+
```
|
|
141
|
+
|
|
142
|
+
2. **Set Up the Environment**
|
|
143
|
+
Use the provided `environment.yml` to create a conda environment:
|
|
144
|
+
```bash
|
|
145
|
+
conda env create -f environment.yml
|
|
146
|
+
conda activate tgraphx
|
|
147
|
+
```
|
|
148
|
+
|
|
149
|
+
3. **Install PyTorch**
|
|
150
|
+
Install a recent version of [PyTorch](https://pytorch.org/) (GPU version if possible).
|
|
151
|
+
|
|
152
|
+
4. **Install Additional Dependencies**
|
|
153
|
+
```bash
|
|
154
|
+
pip install -r requirements.txt
|
|
155
|
+
```
|
|
156
|
+
|
|
157
|
+
5. **Editable Mode (Optional)**
|
|
158
|
+
```bash
|
|
159
|
+
pip install -e .
|
|
160
|
+
```
|
|
161
|
+
|
|
162
|
+
---
|
|
163
|
+
|
|
164
|
+
## Folder Structure
|
|
165
|
+
|
|
166
|
+
```
|
|
167
|
+
TGraphX/
|
|
168
|
+
├── __init__.py
|
|
169
|
+
├── core/
|
|
170
|
+
│ ├── dataloader.py
|
|
171
|
+
│ ├── graph.py
|
|
172
|
+
│ └── utils.py
|
|
173
|
+
├── layers/
|
|
174
|
+
│ ├── aggregator.py
|
|
175
|
+
│ ├── attention_message.py
|
|
176
|
+
│ ├── base.py
|
|
177
|
+
│ ├── conv_message.py
|
|
178
|
+
│ └── safe_pool.py
|
|
179
|
+
├── models/
|
|
180
|
+
│ ├── cnn_encoder.py
|
|
181
|
+
│ ├── cnn_gnn_model.py
|
|
182
|
+
│ ├── graph_classifier.py
|
|
183
|
+
│ ├── node_classifier.py
|
|
184
|
+
│ └── pre_encoder.py
|
|
185
|
+
├── environment.yml
|
|
186
|
+
└── README.md
|
|
187
|
+
```
|
|
188
|
+
|
|
189
|
+
---
|
|
190
|
+
|
|
191
|
+
## Core Components
|
|
192
|
+
|
|
193
|
+
### Graph and Data Loading
|
|
194
|
+
|
|
195
|
+
- **`Graph` & `GraphBatch`**
|
|
196
|
+
Represent individual graphs (nodes, edges) and batches of graphs. The batch version offsets node indices to avoid collisions, allowing parallel processing in PyTorch.
|
|
197
|
+
|
|
198
|
+
- **`GraphDataset` & `GraphDataLoader`**
|
|
199
|
+
Custom dataset and data loader classes that streamline the creation of graph batches from a set of images, patches, or other structured data.
|
|
200
|
+
|
|
201
|
+
### Utility Functions
|
|
202
|
+
|
|
203
|
+
- **`load_config`**
|
|
204
|
+
Load YAML/JSON configuration files to keep hyperparameters consistent across experiments.
|
|
205
|
+
|
|
206
|
+
- **`get_device`**
|
|
207
|
+
Utility to automatically detect and return the correct device (GPU or CPU).
|
|
208
|
+
|
|
209
|
+
---
|
|
210
|
+
|
|
211
|
+
## Layers
|
|
212
|
+
|
|
213
|
+
### Base Layer
|
|
214
|
+
|
|
215
|
+
- **`TensorMessagePassingLayer`**
|
|
216
|
+
An abstract base class that defines the interface (message, aggregate, update steps) for all message passing. Crucially, it handles multi-dimensional node features (e.g., `[C, H, W]`).
|
|
217
|
+
|
|
218
|
+
### Convolution-Based Message Passing
|
|
219
|
+
|
|
220
|
+
- **`ConvMessagePassing`**
|
|
221
|
+
Concatenates source and target node feature maps (plus optional edge features) along the channel dimension and applies a `1Ă—1` convolution:
|
|
222
|
+
```python
|
|
223
|
+
Mij = Conv1Ă—1(Concat(Xi, Xj, Eij))
|
|
224
|
+
```
|
|
225
|
+
- **Message Phase**: Each pair `(i, j)` of nodes exchanges messages computed by a `1Ă—1` conv.
|
|
226
|
+
- **Aggregation + Residual Update**: After summing messages from neighbors, a deep CNN aggregator processes the sum, and the original node features are updated via a **residual skip**.
|
|
227
|
+
|
|
228
|
+
### Deep CNN Aggregator
|
|
229
|
+
|
|
230
|
+
- **`DeepCNNAggregator`**
|
|
231
|
+
A stack of `3Ă—3` convolutional layers with batch normalization, ReLU, and dropout. It refines the aggregated messages:
|
|
232
|
+
```python
|
|
233
|
+
X'_j = X_j + A( m_j )
|
|
234
|
+
```
|
|
235
|
+
where `m_j = sum of messages to node j`. Residual connections ensure stable gradient flow.
|
|
236
|
+
|
|
237
|
+
### Attention-Based Message Passing
|
|
238
|
+
|
|
239
|
+
- **`AttentionMessagePassing`**
|
|
240
|
+
An alternative that uses `1Ă—1` convolutions to compute query, key, and value maps for each node. Spatial alignment is preserved while attention weights scale incoming messages. Useful for tasks needing dynamic connectivity or weighting.
|
|
241
|
+
|
|
242
|
+
### Safe Pooling
|
|
243
|
+
|
|
244
|
+
- **`SafeMaxPool2d`**
|
|
245
|
+
A robust pooling module that checks if spatial dimensions `[H, W]` are large enough before applying max pooling. Prevents dimension mismatch errors in deeper aggregator stacks.
|
|
246
|
+
|
|
247
|
+
---
|
|
248
|
+
|
|
249
|
+
## Models
|
|
250
|
+
|
|
251
|
+
### CNN Encoder and Pre-Encoder
|
|
252
|
+
|
|
253
|
+
- **`CNNEncoder`**
|
|
254
|
+
Converts raw patches (`[C_in, patch_H, patch_W]`) into *spatial feature maps* (e.g., `[C_out, H2, W2]`). Includes:
|
|
255
|
+
- Multiple 3Ă—3 conv blocks with BN, ReLU, and dropout.
|
|
256
|
+
- Optional residual connections.
|
|
257
|
+
- Safe max pooling if the spatial size remains large.
|
|
258
|
+
|
|
259
|
+
- **Optional Pre‑Encoder**
|
|
260
|
+
- If `use_preencoder` is `True`, a **ResNet‑like** (or fully custom) module first processes each patch, returning refined features with the same spatial structure.
|
|
261
|
+
- `pretrained_resnet` can load weights from a standard ResNet‑18 for transfer learning.
|
|
262
|
+
|
|
263
|
+
### Unified CNN‑GNN Model
|
|
264
|
+
|
|
265
|
+
- **`CNN_GNN_Model`**
|
|
266
|
+
A single pipeline that:
|
|
267
|
+
1. Splits the image into patches, optionally uses `PreEncoder`.
|
|
268
|
+
2. Feeds patches into `CNNEncoder` to get `[C, H, W]` maps.
|
|
269
|
+
3. Builds a graph where each node holds a `[C, H, W]` map.
|
|
270
|
+
4. Applies multiple GNN layers (like `ConvMessagePassing` + `DeepCNNAggregator`).
|
|
271
|
+
5. Optionally uses a skip connection to combine CNN outputs with GNN outputs.
|
|
272
|
+
6. Performs final spatial pooling before classification.
|
|
273
|
+
|
|
274
|
+
### Graph & Node Classification Models
|
|
275
|
+
|
|
276
|
+
- **`GraphClassifier`**
|
|
277
|
+
Intended for graph-level tasks (e.g., classification of an entire image or object ensemble). Combines message passing with a final pooling layer (mean, max, or attention) over nodes, then feeds the result into a classifier.
|
|
278
|
+
|
|
279
|
+
- **`NodeClassifier`**
|
|
280
|
+
Suitable for node-level tasks (e.g., labeling each patch or region). Stacks simpler message passing layers for classification on each node separately.
|
|
281
|
+
|
|
282
|
+
---
|
|
283
|
+
|
|
284
|
+
## Configuration Options
|
|
285
|
+
|
|
286
|
+
TGraphX is highly configurable. Some key parameters include:
|
|
287
|
+
|
|
288
|
+
```python
|
|
289
|
+
config = {
|
|
290
|
+
"cnn_params": {
|
|
291
|
+
"in_channels": 3,
|
|
292
|
+
"out_features": 64,
|
|
293
|
+
"num_layers": 2,
|
|
294
|
+
"hidden_channels": 64,
|
|
295
|
+
"dropout_prob": 0.3,
|
|
296
|
+
"use_batchnorm": True,
|
|
297
|
+
"use_residual": True,
|
|
298
|
+
"pool_layers": 2,
|
|
299
|
+
"debug": False,
|
|
300
|
+
"return_feature_map": True
|
|
301
|
+
},
|
|
302
|
+
"use_preencoder": False,
|
|
303
|
+
"pretrained_resnet": False,
|
|
304
|
+
"preencoder_params": {
|
|
305
|
+
"in_channels": 3,
|
|
306
|
+
"out_channels": 32,
|
|
307
|
+
"hidden_channels": 32
|
|
308
|
+
},
|
|
309
|
+
"gnn_in_dim": (64, 5, 5),
|
|
310
|
+
"gnn_hidden_dim": (128, 5, 5),
|
|
311
|
+
"num_classes": 10,
|
|
312
|
+
"num_gnn_layers": 4,
|
|
313
|
+
"gnn_dropout": 0.3,
|
|
314
|
+
"residual": True,
|
|
315
|
+
"aggregator_params": {
|
|
316
|
+
"num_layers": 4,
|
|
317
|
+
"dropout_prob": 0.3,
|
|
318
|
+
"use_batchnorm": True
|
|
319
|
+
}
|
|
320
|
+
}
|
|
321
|
+
```
|
|
322
|
+
|
|
323
|
+
- **`cnn_params`**: Controls the CNN encoder architecture (e.g., channels, dropout, pooling).
|
|
324
|
+
- **`use_preencoder`**: Boolean indicating whether to preprocess patches with a custom or pretrained module.
|
|
325
|
+
- **`pretrained_resnet`**: If `True`, loads pretrained ResNet-18 weights in the pre-encoder.
|
|
326
|
+
- **`gnn_in_dim`, `gnn_hidden_dim`**: Shapes of the node features in GNN layers. Each dimension can be `[C, H, W]`.
|
|
327
|
+
- **`num_gnn_layers`**: How many message passing layers to stack.
|
|
328
|
+
- **`aggregator_params`**: Depth, dropout, and BN usage in the aggregator.
|
|
329
|
+
- **`residual`**: Enables skip connections in the GNN layers.
|
|
330
|
+
|
|
331
|
+
---
|
|
332
|
+
|
|
333
|
+
## Advanced Topics
|
|
334
|
+
|
|
335
|
+
### Theoretical Insights
|
|
336
|
+
|
|
337
|
+
1. **Universal Approximation via Deep CNN**
|
|
338
|
+
Stacking multiple convolutional layers with residual skips (in both the CNN encoder and the aggregator) enhances the effective receptive field and helps approximate complex local feature maps.
|
|
339
|
+
|
|
340
|
+
2. **Residual Learning for Gradient Flow**
|
|
341
|
+
Residual connections in both the CNN encoder and aggregator mitigate vanishing gradients, allowing deeper structures to train effectively end-to-end.
|
|
342
|
+
|
|
343
|
+
3. **Spatial vs. Flattened Features**
|
|
344
|
+
Preserving the `[C, H, W]` layout at each node addresses a key limitation in conventional GNNs—loss of local spatial semantics. TGraphX’s design is grounded in the observation that many vision tasks require capturing fine-grained local details alongside global relational structures.
|
|
345
|
+
|
|
346
|
+
### Possible Extensions
|
|
347
|
+
|
|
348
|
+
- **Adaptive Edge Construction**
|
|
349
|
+
Dynamically compute adjacency based on patch similarity or learned attention, rather than fixed proximity thresholds.
|
|
350
|
+
|
|
351
|
+
- **Mixed Modalities**
|
|
352
|
+
Combine image data with textual or numerical features by storing them as separate channels or separate GNN streams.
|
|
353
|
+
|
|
354
|
+
- **Task-Specific Losses**
|
|
355
|
+
Add auxiliary losses (e.g., bounding-box IoU or segmentation overlap) for detection or segmentation tasks, integrated into the GNN training loop.
|
|
356
|
+
|
|
357
|
+
- **Performance Optimizations**
|
|
358
|
+
Use group convolutions or low-rank factorization in the aggregator to reduce memory and computational overhead.
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
---
|
|
362
|
+
|
|
363
|
+
## Novelties and Contributions
|
|
364
|
+
|
|
365
|
+
TGraphX departs from traditional GNN designs in several ways:
|
|
366
|
+
|
|
367
|
+
1. **Full Spatial Fidelity**
|
|
368
|
+
Each node in the graph remains a *multi-dimensional* feature map rather than a flattened vector, preserving local spatial relationships crucial for vision tasks.
|
|
369
|
+
|
|
370
|
+
2. **Convolution-Based Message Passing**
|
|
371
|
+
Employing `1Ă—1` convolutions on `[C, H, W]` feature maps lets neighboring patches exchange information at *every pixel location*, ensuring alignment and detail retention.
|
|
372
|
+
|
|
373
|
+
3. **Deep Residual Aggregation**
|
|
374
|
+
Multiple `3×3` CNN layers in the aggregator—complete with batch normalization, ReLU, dropout, and skip connections—allow the model to fuse multi-hop messages in a stable, expressive manner.
|
|
375
|
+
|
|
376
|
+
4. **End-to-End Differentiable**
|
|
377
|
+
From raw image patches to final classification or detection outputs, **all** steps—CNN feature extraction, graph construction, message passing, and aggregator updates—are trained jointly, strengthening synergy between local feature extraction and relational reasoning.
|
|
378
|
+
|
|
379
|
+
5. **Modular & Extensible**
|
|
380
|
+
- Allows easy substitution of the aggregator or attention-based message passing layers.
|
|
381
|
+
- Accommodates multiple data modalities (image, volumetric, or otherwise).
|
|
382
|
+
- Scales from small graphs (few patches) to larger patch partitions for high-resolution images.
|
|
383
|
+
|
|
384
|
+
These innovations build on earlier GNN research while pushing further to **retain** all the valuable local details that are typically lost in flattened GNN nodes.
|
|
385
|
+
|
|
386
|
+
---
|
|
387
|
+
|
|
388
|
+
## Conclusion
|
|
389
|
+
|
|
390
|
+
We have presented **TGraphX**, an architecture aimed at integrating convolutional neural
|
|
391
|
+
networks (CNNs) and graph neural networks (GNNs) in a way that preserves spatial fidelity.
|
|
392
|
+
By retaining multi-dimensional CNN feature maps as node representations and employing
|
|
393
|
+
convolution-based message passing, TGraphX captures both local and global spatial context.
|
|
394
|
+
Our experiments—particularly those involving detection refinement—demonstrate its potential
|
|
395
|
+
to resolve detection discrepancies and refine localization accuracy in challenging vision tasks.
|
|
396
|
+
|
|
397
|
+
While we do not claim it to be universally optimal for all computer vision scenarios, TGraphX
|
|
398
|
+
offers a flexible framework that other researchers can adapt or extend. This integration of
|
|
399
|
+
CNN-based feature extraction with GNN-based relational reasoning is a promising direction
|
|
400
|
+
for future AI and vision research.
|
|
401
|
+
|
|
402
|
+
---
|
|
403
|
+
## Citations
|
|
404
|
+
|
|
405
|
+
```bibtex
|
|
406
|
+
@misc{sajjadi2025tgraphxtensorawaregraphneural,
|
|
407
|
+
title={TGraphX: Tensor-Aware Graph Neural Network for Multi-Dimensional Feature Learning},
|
|
408
|
+
author={Arash Sajjadi and Mark Eramian},
|
|
409
|
+
year={2025},
|
|
410
|
+
eprint={2504.03953},
|
|
411
|
+
archivePrefix={arXiv},
|
|
412
|
+
primaryClass={cs.CV},
|
|
413
|
+
url={https://arxiv.org/abs/2504.03953},
|
|
414
|
+
}
|
|
415
|
+
```
|
|
416
|
+
---
|
|
417
|
+
|
|
418
|
+
## License
|
|
419
|
+
|
|
420
|
+
TGraphX is released under the [MIT License](https://opensource.org/licenses/MIT). See the `LICENSE` file for more details.
|
|
421
|
+
|
|
422
|
+
---
|
|
423
|
+
|
|
424
|
+
**Enjoy exploring and developing your spatially-aware graph neural networks with TGraphX!**
|
|
425
|
+
If you have any questions, suggestions, or want to contribute, feel free to open an issue or submit a pull request.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
tgraphx
|