A compact message passing network written in plain PyTorch. It predicts a scalar property of a molecule from its graph structure. There is no torch-geometric and no rdkit. Molecules are represented as small synthetic graphs, each one a node feature matrix paired with a dense adjacency matrix, so the whole thing runs on CPU in seconds.
The model is a graph neural network built from three pieces:
- A message passing layer that aggregates neighbor features through a normalized adjacency, then applies a learned linear map and a ReLU.
- A stack of those layers, so information can travel several hops across the graph.
- A sum readout that pools node states into one vector per graph, followed by a linear head that produces the predicted property.
The normalization is the standard symmetric form, D^{-1/2} (A + I) D^{-1/2}.
Adding the identity gives every node a self loop, which lets a node keep part
of its own signal while it blends in messages from its neighbors. The
symmetric scaling keeps feature magnitudes stable no matter how many edges a
node has.
Real chemistry is not the point. Each graph is a random graph with random Gaussian node features, and the target property is the number of edges in the graph. Edge count is a genuine structural signal: you cannot read it off any single node, so a model has to actually aggregate over the graph to predict it. That makes it a fair test of whether message passing and the readout are doing real work.
src/
graph.py MolGraph container, normalized adjacency, block diagonal batching
data.py random graph generator and the edge count dataset
model.py MessagePassingLayer and GraphRegressor
train.py a small full batch training loop and a predict helper
tests/
test_message_passing.py exact checks on one aggregation step
test_model_learns.py the trained model beats the mean baseline
Several graphs are combined into one block diagonal adjacency. Because there are no edges between the blocks, running message passing on the combined graph is the same as running each graph on its own. A batch index records which graph each node belongs to, and the sum readout uses it to scatter node states back into per graph vectors. One of the tests confirms that no information leaks between batched graphs.
The tests are property and behavior checks. The message passing tests verify that one aggregation step equals the normalized adjacency times the features, that two connected nodes pull toward each other, that an isolated node keeps only its own signal, and that disconnected graphs stay independent. The learning test trains the regressor on edge count and asserts that its held out error is well under half of the constant mean predictor, with a strong correlation between predictions and true counts.
python -m pytest tests/ -q
On a recent run all eight tests passed in about three seconds on CPU.
import torch
from src.data import make_dataset
from src.train import train_regressor, predict
graphs, targets = make_dataset(num_graphs=120, num_features=4, seed=1)
model, losses = train_regressor(graphs, targets, in_dim=4, epochs=250)
preds = predict(model, graphs)You can also call the model on a single molecule by passing its node features
and adjacency with batch_index=None.