Skip to the content.

Saliency Methods

Overview

This repository contains code for the following saliency techniques:

1. Vanilla Gradients

gradient

2. Guided Backpropogation (with Smoothgrad)

guided

3. Integrated Gradients

integrated

4. (Guided) Grad-CAM

gradcam

5. XRAI

xrai

Remarks

The methods should work with all models from the torchvision package. Tested models so far are:

*In order for Guided Backpropagation and Grad-CAM to work properly with the Inception and GoogLeNet models, they need to by modified slightly, such that all ReLUs are modules of the model rather than function calls.

# This class can be found at the very end of inception.py and googlenet.py respectively.
class BasicConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
        self.relu = nn.ReLU(inplace=True)  # Add this line

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return self.relu(x)  # Replaces F.relu(x, inplace=True)

Examples

For a brief overview on how to use the package, please have a look at this short tutorial notebook. The bare minimum is summarized below.

# Standard imports 
import torchvision

# Import desired utils and methods
from ml_utils import load_image, show_mask
from guided_backprop import GuidedBackprop

# Load model and image
model = torchvision.models.resnet50(pretrained=True)
doberman = load_image('images/doberman.png', size=224)

# Construct a saliency object and compute the saliency mask.
guided_backprop = GuidedBackprop(model)
rgb_mask = guided_backprop.get_mask(image_tensor=doberman)

# Visualize the result
show_mask(rgb_mask, title='Guided Backprop')

Credits

The implementation follows closely that of the corresponding TensorFlow saliency repository, reusing its code were applicable (mostly for the XRAI method).

Further inspiration has been taken from this repository.