Skip to content

SharvenRane/molecular-property-prediction

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

molecular property prediction

A compact message passing network written in plain PyTorch. It predicts a scalar property of a molecule from its graph structure. There is no torch-geometric and no rdkit. Molecules are represented as small synthetic graphs, each one a node feature matrix paired with a dense adjacency matrix, so the whole thing runs on CPU in seconds.

What is here

The model is a graph neural network built from three pieces:

  1. A message passing layer that aggregates neighbor features through a normalized adjacency, then applies a learned linear map and a ReLU.
  2. A stack of those layers, so information can travel several hops across the graph.
  3. A sum readout that pools node states into one vector per graph, followed by a linear head that produces the predicted property.

The normalization is the standard symmetric form, D^{-1/2} (A + I) D^{-1/2}. Adding the identity gives every node a self loop, which lets a node keep part of its own signal while it blends in messages from its neighbors. The symmetric scaling keeps feature magnitudes stable no matter how many edges a node has.

The synthetic task

Real chemistry is not the point. Each graph is a random graph with random Gaussian node features, and the target property is the number of edges in the graph. Edge count is a genuine structural signal: you cannot read it off any single node, so a model has to actually aggregate over the graph to predict it. That makes it a fair test of whether message passing and the readout are doing real work.

Layout

src/
  graph.py   MolGraph container, normalized adjacency, block diagonal batching
  data.py    random graph generator and the edge count dataset
  model.py   MessagePassingLayer and GraphRegressor
  train.py   a small full batch training loop and a predict helper
tests/
  test_message_passing.py   exact checks on one aggregation step
  test_model_learns.py      the trained model beats the mean baseline

How batching works

Several graphs are combined into one block diagonal adjacency. Because there are no edges between the blocks, running message passing on the combined graph is the same as running each graph on its own. A batch index records which graph each node belongs to, and the sum readout uses it to scatter node states back into per graph vectors. One of the tests confirms that no information leaks between batched graphs.

Running the tests

The tests are property and behavior checks. The message passing tests verify that one aggregation step equals the normalized adjacency times the features, that two connected nodes pull toward each other, that an isolated node keeps only its own signal, and that disconnected graphs stay independent. The learning test trains the regressor on edge count and asserts that its held out error is well under half of the constant mean predictor, with a strong correlation between predictions and true counts.

python -m pytest tests/ -q

On a recent run all eight tests passed in about three seconds on CPU.

Using the model

import torch
from src.data import make_dataset
from src.train import train_regressor, predict

graphs, targets = make_dataset(num_graphs=120, num_features=4, seed=1)
model, losses = train_regressor(graphs, targets, in_dim=4, epochs=250)

preds = predict(model, graphs)

You can also call the model on a single molecule by passing its node features and adjacency with batch_index=None.

About

Graph Neural Network for molecular property prediction on QM9 and ZINC datasets

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages