论文:ICCV 2017《Grad-CAM:Visual Explanations from Deep Networks via Gradient-based Localization》
代码:https://github.com/yizt/Grad-CAM.pytorch/blob/master/main.py
https://github.com/jacobgil/pytorch-grad-cam/blob/master/grad-cam.py
1、首先定义并训练好CNN网络,原网络结构不用调整。假设网路训练好,得到一个best_net。
class GradCAM(object):
"""
1: gradients update when input
2: backpropatation by the high scores of class
"""
def __init__(self, net, layer_name):
self.net = net
self.layer_name = layer_name
self.feature = None
self.gradient = None
self.net.eval()
self.handlers = []
self._register_hook()
def _get_features_hook(self, module, input, output):
self.feature = output
#print("feature shape:{}".format(output.size()))
def _get_grads_hook(self, module, input_grad, output_grad):
"""
:param input_grad: tuple, input_grad[0]: None
input_grad[1]: weight
input_grad[2]: bias
:param output_grad:tuple,length