SALSA Unlearning Library is a cohesive, comprehensive PyTorch framework for Machine Unlearning in classification. It serves as the official accompanying codebase for the paper SALSA: A Secure, Adaptive and Label-Agnostic Scalable Algorithm for Machine Unlearning, and provides a unified interface for evaluating multiple state-of-the-art unlearning algorithms, including SFR-on, SCRUB, SalUn, and more.
With salsa-unlearn, researchers can seamlessly pretrain models, execute various unlearning strategies, evaluate the retained performance (accuracy), and assess the effectiveness of the unlearning via standard Membership Inference Attacks (MIA).
Machine unlearning selectively removes the influence of certain training samples or classes from a trained model without retraining it from scratch. This is vital for maintaining data privacy, regulatory compliance (e.g., the "Right to be Forgotten" under GDPR), and mitigating the cost of model retraining.
You can easily install the library from PyPI:
pip install unlearn-libgit clone https://github.com/your-username/salsa-unlearn.git
cd salsa-unlearn
pip install -e .pip install -r requirements.txtThe library currently supports:
SALSA: A Secure, Adaptive and Label-Agnostic Scalable Algorithm for Machine Unlearning (Ours)SFRonSCRUBSalUnBadTeacherGradAscentRandomLabelFinetuneRetrainBaseline
The library comes with driver scripts (main.py, main_pretrain.py, main_random.py) that demonstrate how to use the framework.
To pretrain a model (e.g., ResNet-18 on CIFAR-10) before executing unlearning:
python main_pretrain.py --dataset CIFAR10 --model resnet18You can use main.py which demonstrates how to instantiate the unlearning methods, run them, and evaluate them automatically using the built-in MIA attacker.
python main.pyYou can also easily use salsa-unlearn inside your own Python projects. Everything is neatly packaged under unlearn.
import torch
import torch.nn as nn
from unlearn.models import create_model
from unlearn import create_unlearn_method
# Load your base model
model = create_model("ResNet18", num_classes=10)
model.load_state_dict(torch.load("pretrained.pth")["state_dict"])
model.cuda()
# Define the unlearning method
method_name = "SALSA" # Or "SFRon", "SCRUB", etc.
loss_fn = nn.CrossEntropyLoss()
# args can be a simple dataclass holding hyperparameters (batch_size, num_classes, etc.)
unlearn_method = create_unlearn_method(method_name)(model, loss_fn, "./results", args)
# Prepare dataloaders
# Requires a dictionary containing: 'forget_train', 'retain_train', 'forget_valid', 'retain_valid', 'train'
unlearn_method.prepare_unlearn(unlearn_dataloaders)
# Execute unlearning and get the modified model
unlearn_model = unlearn_method.get_unlearned_model()You can easily define your own unlearning method by subclassing UnlearnMethod and implementing the get_unlearned_model function.
import torch
from unlearn.unlearn.unlearn_method import UnlearnMethod
class MyCustomUnlearn(UnlearnMethod):
def __init__(self, model, loss_function, save_path, args) -> None:
super().__init__(model, loss_function, save_path, args)
# Initialize custom parameters or optimizers here
def get_unlearned_model(self) -> torch.nn.Module:
# Access data from self.unlearn_dataloaders
forget_loader = self.unlearn_dataloaders["forget_train"]
retain_loader = self.unlearn_dataloaders["retain_train"]
# Implement your unlearning logic here
# e.g., fine-tuning on retain data, gradient ascent on forget data, etc.
# Return the modified model
return self.model├── src/unlearn/ # Core unlearning library
│ ├── dataset/ # Dataset loading and splitting logic
│ ├── evaluation/ # Evaluation metrics (JS Divergence, MIA)
│ ├── models/ # Supported network architectures (ResNet, Swin, ViT)
│ ├── trainer/ # Base training and validation loops
│ ├── unlearn/ # Implementations of unlearning algorithms (SALSA, SFRon, etc.)
│ ├── attack.py # Attacker models for evaluating forgetting (MIA)
│ └── utils.py # Helpful utilities
├── scripts/ # Shell scripts for standard experiments
├── main.py # Example script to unlearn and evaluate
├── main_pretrain.py # Example script to pretrain models
└── setup.py # Python package setup configurationIf you use this library or the SALSA algorithm in your work, please cite:
@inproceedings{makroosalsa,
title={SALSA: A Secure, Adaptive and Label-Agnostic Scalable Algorithm for Machine Unlearning},
author={Makroo, Owais and Hassan, Atif and Khare, Swanand},
booktitle={The 41st Conference on Uncertainty in Artificial Intelligence}
}