>Business >Spiking Neural Network (SNN) with PyTorch: towards bridging the gap between deep learning and the human brain by Guillaume Chevalier

Spiking Neural Network (SNN) with PyTorch: towards bridging the gap between deep learning and the human brain by Guillaume Chevalier

An amazing discovery has been made. Hebbian learning naturally occurs during the backpropagation of SNNs. Backpropagation in Spiking Neural Networks (SNNs) engenders Spike-Timing-Dependent Plasticity (STDP)-like Hebbian Learning Behaviour.


  • At the start, I just thought, “hey, what about coding a Spiking Neural Network leveraging an automatic differentiation framework?” Here it is.
  • Then I began reading on how to accomplish that, like reading on Hebbian Learning. Quickly explained: Hebbian learning is in a way the saying that “neurons that fire together, wire together.”
  • Then, something amazing has been discovered. What if when performing backpropagation on a Spiking Neural Network (SNN), Hebbian Learning would occur naturally as a consequence of including that refractory time axis?
  • I had the opportunity to discuss that concept with Yoshua Bengio at a conference, and I couldn’t get the concept out of my head beyond that point, so I coded it as follows:
  • In summary, I believe that the link between deep learning and the human brain is closer that we have previously thought: backpropagation is akin to Hebbian learning.

If you do not comprehend what SNNs are, you should take a look at this interesting SNN animation which will quickly give you an idea of what the concept is. Particularly pay attention to how neurons get activated gradually over the passage of time, kind of like a storm, rather than statistically from their inputs.

SNN animation

Spiking Neural Networks (SNNs) vs. Artificial Neural Networks (ANNs)

With regards to SNNs, there exists a time axis and the neural network observes data over the course of time, and activation functions are rather spikes that are raised beyond a specific pre-activation threshold. Pre-activation values consistently fades if neurons aren’t excited enough. You can perceive it as a time-distributed ReLU with spikes or nothing at specific time steps.

Spiking Neural Networks (SNNs) are neural networks that are nearer to what occurs in the brain in contrast to what individuals code when performing machine learning and Deep Learning. In the scenario of SNNs, the neurons accumulate the input activation up until a threshold is attained, and when this threshold is attained, the neurons empties itself from its activation and fire. After it is empty, it should indeed take a refractory period until it fires once more, as it occurs in the brain.

This implies adding a time axis to Artificial Neural Networks (ANNs) where the signal is accumulated across time in a pre-activation stage, then after a threshold is attained, a signal is raised to the neurons above as a firing activation. At each moment, like when the threshold hasn’t been attained as of yet, the signal’s pre-activation value fades.

Leveraging PyTorch, this refractory neuron firing pattern behaviour was approximately replicated. This was coded without reading current code for me to attempt to come up with a solution by myself as some type of challenge and also not be biased by what others do, developing from first principles. The assumption was that across time, perceptron’s readings would be noisy so as to include randomization. As an afterthought, I’m thrilled with what I came up with.

Spiking Neural Networks (SNNs) vs. Recurrent Neural Networks (RNNs)

The SNN is not an RNN, regardless it evolves throughout time as well. For this SNN to be an RNN, the belief is that it would need some more connections like from the outputs back into the inputs. As a matter of fact, RNNs are defined as a function of a few inputs and of several neurons at the prior time step, like:

In our scenario, we keep some state, but it’s nothing comparable to possessing a connection back to other neurons in the past. Forgive me in advance for having appended the suffix “RNN” to the SNN PyTorch class below, as I leverage it like an RNN with a time axis. But the concept is theoretically differing – regardless, they share a mutual “forward time axis” structure.

How does it function?

Okay, let’s take an in-depth dive into the details.

We define a neruon’s firing method through the subsequent steps, where the argument x is an input:

Before anything, we require an initialize (or to empty) the state for every neuron upon beginning predictions.

    self.prev_inner = torch.zeros([batch_size, self.n_hidden]).to(self.device)
    self.prev_outer = torch.zeros([batch_size, self.n_hidden]).to(self.device)

Following this, a weight matrix multiplies the input x, which in our scenario is the handwritten MNIST digits. It is to be observed that x was altered to be flickering randomly across time (dependent on the intensity of the original input x multiplied by a random uniform noise mask), or else x is already the output of a lower deep spiking layer:

input_excitation = self.fully_connected(x)

The result is then added to a decayed variant of the data within the neuron that we already had at the prior time step / time tick (Δt Time Elapsed) The decay_multiplier serves the purpose of gradually fading the inner activation so that we don’t accumulate stimulus for too long to be able to have the neurons to rest. The decay_multiplier could possess a value of 0.9 for example. Decay as such is also referred to as exponential decay and yields an effect of Exponential moving average over the passage of time on the most latest values observed, which also impacts the gradients upon backpropagating. So through repeatedly multiplying by 0.9, the inner activation over time, it decays and neurons unexcite themselves prior to firing. From this perspective, it’s now really true that “neurons that fire together, wire together.”: when a pre-synaptic input is received nearer to the moment of giving an output, the latest value will not have had the time to be decayed/faded. In this fashion, the gradient of recent neurons that took part in exciting the present neuron that fired will be robust. Learning will be able to occur through gradient descent according to the decay’s weightings. So in the opposite scenario, a stimuli that occurred too long ago will suffer from vanishing gradients as it has been exponentially decayed down. So it will not be useful in the learning process of backprop, which is what we require and respects the “neurons that fire together, wire together” idiom of Hebbian learning.

inner_excitation = input_excitation + self.prev_inner * self.decay_multiplier


Now, we compute the activation of the neurons to identify their output value. We possess a threshold to reach prior to having the neuron activate. The ReLU function might not be the most relevant here, (more on that later), but I wished to get a working prototype quickly:

outer_excitation = F.relu(inner_excitation – self.threshold)


Now, magic occurs. If the neuron fires, the activation of the neuron is subtracted to its inner state to reset every neuron. Firstly, this has the impact of resetting them to a resting position so that they won’t be firing constantly upon being activated once. Secondly, resetting them as such will clip the gradient over time to isolate every firing events from one another. SNNs as such are really inspired from the brain, as the natural neurons also possess a refractory period: neurons, after they fire, will require to wait a bit prior to firing again even if fully excited by lower neurons they have as inputs. So here, I even subtracted a secondary penalty named penalty_threshold after each resetting threshold.

Disclaimer: I was not certain if the negative part in the biological refractory period was on the outputs of the neurons or inside the neurons (e.g., axon v.s. body?), so here, I’ve merely put it inside. Let’s observe how I subtract this just when the neuron fires to have it to have a refractory period.

do_penalize_gate = (outer_excitation > 0).float()

inner_excitation = inner_excitation – (self.penalty_threshold + outer_excitation) * do_penalize_gate


Lastly, I return the prior output, simulating a small firing delay, which is useless for the time being, but which may be fascinating to have if the SNN I coded was ever to have recurrent connections which need time offsets in the connections from top layers near the outputs back into bottom layers near the input:

    delayed_return_state = self.prev_inner
    delayed_return_output = self.prev_outer
    self.prev_inner = inner_excitation
    self.prev_outer = outer_excitation
    return delayed_return_state, delayed_return_output


Beyond that, to perform the classification, the values of that classification output spiking neurons are averaged over the time axis so as to possess one number per class to plug into the softmax cross entropy loss for classification as we know it and we backpropagate. This implies that the present SNN PyTorch class is reusable within any other feedforward neural network, as it repeats inputs over the passage of time with arbitrary noisy masks, and averages outputs across time.

Shockingly, it worked on the first attempt once the dimension mismatching errors were rectified. And the precision was about the same as the precision of a simple non-spiking FeedForward Neural Network with the same number of neurons. And the threshold wasn’t even tuned. In the end, I came to the realization that coding and training a Spiking Neural Network (SNN) with PyTorch was simple enough as demonstrated above, it can be coded in an evening as such.

Essentially, the neuron’s activation must decay through time and fire only when getting beyond a specific threshold. So I’ve gated the output of the


Scroll on! Nice visuals await you.

import os
import matplotlib.pyplot as plt
import torchvision.datasets
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.autograd import Variable
def train(model, device, train_set_loader, optimizer, epoch, logging_interval=100):
    # This method is derived from:
    # https://github.com/pytorch/examples/blob/master/mnist/main.py
    # Was licensed BSD-3-clause
    for batch_idx, (data, target) in enumerate(train_set_loader):
        data, target = data.to(device), target.to(device)
        output = model(data)
        loss = F.nll_loss(output, target)
        if batch_idx % logging_interval == 0:
            pred = output.max(1, keepdim=True)[1]  # get the index of the max log-probability
            correct = pred.eq(target.view_as(pred)).float().mean().item()
            print(‘Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f} Accuracy: {:.2f}%’.format(
                epoch, batch_idx * len(data), len(train_set_loader.dataset),
                100. * batch_idx / len(train_set_loader), loss.item(),
                100. * correct))
def train_many_epochs(model):
    epoch = 1
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.5)
    train(model, device, train_set_loader, optimizer, epoch, logging_interval=10)
    test(model, device, test_set_loader)
    epoch = 2
    optimizer = optim.SGD(model.parameters(), lr=0.05, momentum=0.5)
    train(model, device, train_set_loader, optimizer, epoch, logging_interval=10)
    test(model, device, test_set_loader)
    epoch = 3
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
    train(model, device, train_set_loader, optimizer, epoch, logging_interval=10)
    test(model, device, test_set_loader)
def test(model, device, test_set_loader):
    # This method is derived from:
    # https://github.com/pytorch/examples/blob/master/mnist/main.py
    # Was licensed BSD-3-clause
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_set_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            # Note: with `reduce=True`, I’m not sure what would happen with a final batch size
            # that would be smaller than regular previous batch sizes. For now it works.
            test_loss += F.nll_loss(output, target, reduce=True).item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_set_loader.dataset)
    print(‘Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)’.format(
        correct, len(test_set_loader.dataset),
        100. * correct / len(test_set_loader.dataset)))
def download_mnist(data_path):
    if not os.path.exists(data_path):
    transformation = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
    training_set = torchvision.datasets.MNIST(
        data_path, train=True, transform=transformation, download=True)
    testing_set = torchvision.datasets.MNIST(
        data_path, train=False, transform=transformation, download=True)
    return training_set, testing_set
batch_size = 1000
DATA_PATH = ‘./data’
training_set, testing_set = download_mnist(DATA_PATH)
train_set_loader = torch.utils.data.DataLoader(
test_set_loader = torch.utils.data.DataLoader(
# Use GPU whever possible!
use_cuda = torch.cuda.is_available()
device = torch.device(“cuda” if use_cuda else “cpu”)
class SpikingNeuronLayerRNN(nn.Module):
    def __init__(
        self, device, n_inputs=28*28, n_hidden=100,
        decay_multiplier=0.9, threshold=2.0, penalty_threshold=2.5
        super(SpikingNeuronLayerRNN, self).__init__()
        self.device = device
        self.n_inputs = n_inputs
        self.n_hidden = n_hidden
        self.decay_multiplier = decay_multiplier
        self.threshold = threshold
        self.penalty_threshold = penalty_threshold
        self.fc = nn.Linear(n_inputs, n_hidden)
    def init_parameters(self):
        for param in self.parameters():
            if param.dim() >= 2:
    def reset_state(self):
        self.prev_inner = torch.zeros([self.n_hidden]).to(self.device)
        self.prev_outer = torch.zeros([self.n_hidden]).to(self.device)
    def forward(self, x):
        Call the neuron at every time step.
        x: activated_neurons_below
        return: a tuple of (state, output) for each time step. Each item in the tuple
        are then themselves of shape (batch_size, n_hidden) and are PyTorch objects, such
        that the whole returned would be of shape (2, batch_size, n_hidden) if casted.
        if self.prev_inner.dim() == 1:
            # Adding batch_size dimension directly after doing a `self.reset_state()`:
            batch_size = x.shape[0]
            self.prev_inner = torch.stack(batch_size * [self.prev_inner])
            self.prev_outer = torch.stack(batch_size * [self.prev_outer])
        # 1. Weight matrix multiplies the input x
        input_excitation = self.fc(x)
        # 2. We add the result to a decayed version of the information we already had.
        inner_excitation = input_excitation + self.prev_inner * self.decay_multiplier
        # 3. We compute the activation of the neuron to find its output value,
        #    but before the activation, there is also a negative bias
        #    that refrain thing from firing too much.
        outer_excitation = F.relu(inner_excitation – self.threshold)
        # 4. If the neuron fires, the activation of the
        #    neuron is subtracted to its inner state
        #    (and with an extra penalty for increase refractory time),
        #    because it discharges naturally so it shouldn’t fire twice.
        do_penalize_gate = (outer_excitation > 0).float()
        # TODO: remove following /2?
        inner_excitation = inner_excitation – do_penalize_gate * (
            self.penalty_threshold/self.threshold * inner_excitation)
        # 5. The outer excitation has a negative part after the positive part.
        outer_excitation = outer_excitation
        # + torch.abs(self.prev_outer) * self.decay_multiplier / 2.0
        # 6. Setting internal values before returning.
        #    And the returning value is the one of the previous time step to delay
        #    activation of 1 time step of “processing” time.
        #    For logits, we don’t take activation.
        delayed_return_state = self.prev_inner
        delayed_return_output = self.prev_outer
        self.prev_inner = inner_excitation
        self.prev_outer = outer_excitation
        return delayed_return_state, delayed_return_output
class InputDataToSpikingPerceptronLayer(nn.Module):
    def __init__(self, device):
        super(InputDataToSpikingPerceptronLayer, self).__init__()
        self.device = device
    def reset_state(self):
        #     self.prev_state = torch.zeros([self.n_hidden]).to(self.device)
    def forward(self, x, is_2D=True):
        x = x.view(x.size(0), -1)  # Flatten 2D image to 1D for FC
        random_activation_perceptron = torch.rand(x.shape).to(self.device)
        return random_activation_perceptron * x
class OutputDataToSpikingPerceptronLayer(nn.Module):
    def __init__(self, average_output=True):
        average_output: might be needed if this is used within a
        regular neural net as a layer. Otherwise, sum may be numerically
        more stable for gradients with setting average_output=False.
        super(OutputDataToSpikingPerceptronLayer, self).__init__()
        if average_output:
            self.reducer = lambda x, dim: x.sum(dim=dim)
            self.reducer = lambda x, dim: x.mean(dim=dim)
    def forward(self, x):
        if type(x) == list:
            x = torch.stack(x)
        return self.reducer(x, 0)
class SpikingNet(nn.Module):
    def __init__(self, device, n_time_steps, begin_eval):
        super(SpikingNet, self).__init__()
        assert (0 <= begin_eval and begin_eval < n_time_steps)
        self.device = device
        self.n_time_steps = n_time_steps
        self.begin_eval = begin_eval
        self.input_conversion = InputDataToSpikingPerceptronLayer(device)
        self.layer1 = SpikingNeuronLayerRNN(
            device, n_inputs=28*28, n_hidden=100,
            decay_multiplier=0.9, threshold=1.0, penalty_threshold=1.5
        self.layer2 = SpikingNeuronLayerRNN(
            device, n_inputs=100, n_hidden=10,
            decay_multiplier=0.9, threshold=1.0, penalty_threshold=1.5
        self.output_conversion = OutputDataToSpikingPerceptronLayer(
            average_output=False)  # Sum on outputs.
    def forward_through_time(self, x):
        This acts as a layer. Its input is non-time-related, and its output too.
        So the time iterations happens inside, and the returned layer is thus
        passed through global average pooling on the time axis before the return
        such as to be able to mix this pipeline with regular backprop layers such
        as the input data and the output data.
        out = []
        all_layer1_states = []
        all_layer1_outputs = []
        all_layer2_states = []
        all_layer2_outputs = []
        for _ in range(self.n_time_steps):
            xi = self.input_conversion(x)
            # For layer 1, we take the regular output.
            layer1_state, layer1_output = self.layer1(xi)
            # We take inner state of layer 2 because it’s
            # pre-activation and thus acts as out logits.
            layer2_state, layer2_output = self.layer2(layer1_output)
        out = self.output_conversion(out[self.begin_eval:])
        return out, [[all_layer1_states, all_layer1_outputs], [
            all_layer2_states, all_layer2_outputs]]
    def forward(self, x):
        out, _ = self.forward_through_time(x)
        return F.log_softmax(out, dim=-1)
    def visualize_all_neurons(self, x):
        assert x.shape[0] == 1 and len(x.shape) == 4, (
            “Pass only 1 example to SpikingNet.visualize(x) with outer dimension shape of 1.”)
        _, layers_state = self.forward_through_time(x)
        for i, (all_layer_states, all_layer_outputs) in enumerate(layers_state):
            layer_state  =  torch.stack(all_layer_states).data.cpu(
            layer_output = torch.stack(all_layer_outputs).data.cpu(
            self.plot_layer(layer_state, title=”Inner state values of neurons for layer {}”.format(i))
            self.plot_layer(layer_output, title=”Output spikes (activation) values of neurons for layer {}”.format(i))
    def visualize_neuron(self, x, layer_idx, neuron_idx):
        assert x.shape[0] == 1 and len(x.shape) == 4, (
            “Pass only 1 example to SpikingNet.visualize(x) with outer dimension shape of 1.”)
        _, layers_state = self.forward_through_time(x)
        all_layer_states, all_layer_outputs = layers_state[layer_idx]
        layer_state  =  torch.stack(all_layer_states).data.cpu(
        layer_output = torch.stack(all_layer_outputs).data.cpu(
            title=”Inner state values neuron {} of layer {}”.format(neuron_idx, layer_idx))
            title=”Output spikes (activation) values of neuron {} of layer {}”.format(neuron_idx, layer_idx))
    def plot_layer(self, layer_values, title):
        This function is derived from:
        Which was released under the MIT License.
        width = max(16, layer_values.shape[0] / 8)
        height = max(4, layer_values.shape[1] / 8)
        plt.figure(figsize=(width, height))
        plt.ylabel(“Neurons of layer”)
    def plot_neuron(self, neuron_through_time, title):
        width = max(16, len(neuron_through_time) / 8)
        height = 4
        plt.figure(figsize=(width, height))
        plt.ylabel(“Neuron’s activation”)
class NonSpikingNet(nn.Module):
    def __init__(self):
        super(NonSpikingNet, self).__init__()
        self.layer1 = nn.Linear(28*28, 100)
        self.layer2 = nn.Linear(100, 10)
    def forward(self, x, is_2D=True):
        x = x.view(x.size(0), -1)  # Flatten 2D image to 1D for FC
        x = F.relu(self.layer1(x))
        x =        self.layer2(x)
        return F.log_softmax(x, dim=-1)


Training a Spiking Neural Network (SNN)

Let’s leverage our SpikingNet!

spiking_model = SpikingNet(device, n_time_steps=128, begin_eval=0)
Train Epoch: 1 [0/60000 (0%)] Loss: 2.460052 Accuracy: 9.90%
Train Epoch: 1 [10000/60000 (17%)] Loss: 1.811235 Accuracy: 30.00%
Train Epoch: 1 [20000/60000 (33%)] Loss: 1.797833 Accuracy: 38.60%
Train Epoch: 1 [30000/60000 (50%)] Loss: 0.645438 Accuracy: 83.30%
Train Epoch: 1 [40000/60000 (67%)] Loss: 0.522837 Accuracy: 83.50%
Train Epoch: 1 [50000/60000 (83%)] Loss: 0.528960 Accuracy: 81.80%
Test set: Average loss: 0.0004, Accuracy: 8955/10000 (89.55%)
Train Epoch: 2 [0/60000 (0%)] Loss: 0.405339 Accuracy: 87.80%
Train Epoch: 2 [10000/60000 (17%)] Loss: 0.357420 Accuracy: 88.80%
Train Epoch: 2 [20000/60000 (33%)] Loss: 0.326266 Accuracy: 90.10%
Train Epoch: 2 [30000/60000 (50%)] Loss: 0.377100 Accuracy: 89.60%
Train Epoch: 2 [40000/60000 (67%)] Loss: 0.335625 Accuracy: 90.60%
Train Epoch: 2 [50000/60000 (83%)] Loss: 0.359532 Accuracy: 88.90%
Test set: Average loss: 0.0003, Accuracy: 9061/10000 (90.61%)
Train Epoch: 3 [0/60000 (0%)] Loss: 0.342230 Accuracy: 90.40%
Train Epoch: 3 [10000/60000 (17%)] Loss: 0.347210 Accuracy: 89.90%
Train Epoch: 3 [20000/60000 (33%)] Loss: 0.346477 Accuracy: 89.60%
Train Epoch: 3 [30000/60000 (50%)] Loss: 0.317255 Accuracy: 90.70%
Train Epoch: 3 [40000/60000 (67%)] Loss: 0.329143 Accuracy: 90.40%
Train Epoch: 3 [50000/60000 (83%)] Loss: 0.310708 Accuracy: 90.70%
Test set: Average loss: 0.0003, Accuracy: 9065/10000 (90.65%)

Training a Feedforward Neural Network

It possess the same number of layers and neurons, and also leverages ReLU activation, but it’s not an SNN, this one is a regular one as defined in the code mentioned beforehand with this other class NonSpikingNet.

non_spiking_model = NonSpikingNet().to(device)
Train Epoch: 1 [0/60000 (0%)] Loss: 2.300953 Accuracy: 9.50%
Train Epoch: 1 [10000/60000 (17%)] Loss: 1.908515 Accuracy: 62.40%
Train Epoch: 1 [20000/60000 (33%)] Loss: 1.259780 Accuracy: 72.20%
Train Epoch: 1 [30000/60000 (50%)] Loss: 0.861031 Accuracy: 83.00%
Train Epoch: 1 [40000/60000 (67%)] Loss: 0.652988 Accuracy: 85.40%
Train Epoch: 1 [50000/60000 (83%)] Loss: 0.609710 Accuracy: 84.40%
Test set: Average loss: 0.0005, Accuracy: 8691/10000 (86.91%)
Train Epoch: 2 [0/60000 (0%)] Loss: 0.469882 Accuracy: 88.30%
Train Epoch: 2 [10000/60000 (17%)] Loss: 0.479579 Accuracy: 85.80%
Train Epoch: 2 [20000/60000 (33%)] Loss: 0.466115 Accuracy: 88.20%
Train Epoch: 2 [30000/60000 (50%)] Loss: 0.479764 Accuracy: 87.10%
Train Epoch: 2 [40000/60000 (67%)] Loss: 0.472486 Accuracy: 85.50%
Train Epoch: 2 [50000/60000 (83%)] Loss: 0.443070 Accuracy: 88.20%
Test set: Average loss: 0.0004, Accuracy: 8880/10000 (88.80%)
Train Epoch: 3 [0/60000 (0%)] Loss: 0.432652 Accuracy: 88.20%
Train Epoch: 3 [10000/60000 (17%)] Loss: 0.472320 Accuracy: 86.80%
Train Epoch: 3 [20000/60000 (33%)] Loss: 0.443402 Accuracy: 88.60%
Train Epoch: 3 [30000/60000 (50%)] Loss: 0.401267 Accuracy: 90.00%
Train Epoch: 3 [40000/60000 (67%)] Loss: 0.428927 Accuracy: 88.40%
Train Epoch: 3 [50000/60000 (83%)] Loss: 0.383301 Accuracy: 90.10%
Test set: Average loss: 0.0004, Accuracy: 8897/10000 (88.97%)


Let’s see how the neurons spiked:

data, target = test_set_loader.__iter__().__next__()
# taking 1st testing example:
x = torch.stack([data[0]])
y = target.data.numpy()[0]
plt.title(“Input image x of label y={}:”.format(y))
# plotting neuron’s activations:
print(“A hidden neuron that looks excited:”)
spiking_model.visualize_neuron(x, layer_idx=0, neuron_idx=0)
print(“The output neuron of the label:”)
spiking_model.visualize_neuron(x, layer_idx=1, neuron_idx=y)


A hidden neuron that looks excited:


The output neuron of the label:


Well, we’ve trained just a little here. My objective is not to break the benchmarks, I just wanted to make a comparison and to observe if it could undertake training. So it winds up that the results are approximately the same, although the SNN appears to perform a bit better, although it takes a lot more time to train.

Full disclaimer: no to almost no hyperparameter tuning has been executed yet as this has been coded mostly one-shot and uploaded, so the performances may have a lot of variance with additional tinkering. It would be worth it attempting a few more things to see how it goes (don’t hesitate to fork the repo)

Leveraging SNNs should act as a regularizer, just like dropout, as I wouldn’t expect the neurons to fire all simultaneously. Although, it is to be highlighted that it a fascinating path to explore, as Brain Rhythms appear to play an important part in the brain, whereas in deep learning, on the other had, no such thing occurs.


I think that I’ve discovered that backpropagation can entail Hebbian Learning provided a time axis and refractory firing behaviour on that time axis with Deep Neural Networks (DNN). Even though I’ve searched online and I haven’t found anybody explaining this link concretely, so this is why I wish to explain it as follows, as it might be the first time that this has been written down in words.

An interesting idea: Hebbian Learning naturally takes place in the backpropagation of SNNs

Prior to even explaining what was discovered, and how all of this relates, a little bit of theory is required here.

Introduction to Hebbian Theory and Spike-Timing Dependent Plasticity (STDP)

First, we ought to look at what Hebbian theory is. To put it in layman’s terms, neurons that fire together, wire together. More accurately let’s first observe how it’s explained on Wikipedia here:

Hebbian Theory

Hebbian Theory is a neuroscientific theory that claims that an increase in synaptic efficiency comes from a presynaptic cell’s repeated and persistent simulation of a postsynaptic cell. It is an effort to explain synaptic plasticity, the adaptation of brain neurons during the learning process. It was put forth by Donald Hebb in his seminal 1949 book The Organization of Behaviour. The theory is also referred to as Hebb’s rule, Hebb’s postulate, and cell assembly theory. Hebb describes it as follows:

Going by the assumption that the persistence or repetition of reverberatory activity (or “trace”) has a tendency to induce lasting cellular changes that add in to its stability. When an axon of cell A is near adequate to excite a cell B and repeatedly or persistently takes part in firing it, some growth process or metabolic changes occur in one or both cells such that A’s efficiency, as one of the cells firing B, is increased.

The theory is typically summed up as “Cells that fire together wire together.” This summary, although, should not be thought of in the literal sense. Hebb placed emphasis that cell A requires to take part in firing cell B, and such causality can happen only if Cell A first just prior, not simultaneously, as cell B. This critical aspect of causation in Hebb’s work foreshadowed what is now knowledge about spike-timing-dependent-plasticity, which needs temporal precedence.

Let’s take a time out here. What is Spike-Timing-Dependent-Plasticity (STDP)? From Wikipedia, here’s how it is described:

Spike-timing-dependent plasticity

Spike-timing-dependent-plasticity (STDP) is a biological process that adjusts the veracity of connection amongst neurons in the brain. The process adjusts the connection strengths on the basis of the relative timing of a specific neuron’s output and input action potentials (or spikes). The STDP process partially describes the activity-dependent development of nervous systems, particularly with regard to longer-term potentiation and long-term depression.


In the year 1973, M.M Taylor suggested that if synapses were strengthened for which a presynaptic spike happened just prior a postsynaptic spike more often than the reverse (Hebbian Learning), while with the opposite timing or without a closely timed presynaptic spike, synapses were weakened (anti-Hebbian learning) the outcome would be an informationally effective recording of input patterns. This proposal apparently passed unseen within the neuroscientific community, and experimentation that followed was conceived autonomously of these preliminary suggestions.

In studies on neuromuscular synapses carried out by Y. Dan and Mu-ming Poo in 1992, and on the hippocampus by D. Debanne, B. Gahwiler, and S. Thompson in 1994, demonstrated that asynchronous coupling of postsynaptic and synaptic activity induced longer-term synaptic depression.

Various reasons for timing-dependent plasticity have been indicated. For instance, STDP might furnish a substrate for Hebbian learning during development, or, as indicated by Taylor in 1973, the related Hebbian and anti-Hebbian learning rules might develop informationally efficient coding in bundles of connected neurons.


Owing to their high permeability for calcium, they produce a local chemical signal that is biggest when the back-propagating action potential in the dendrite arrives shortly following the synapse’s activity (pre-post spiking). Large postsynaptic calcium transients are known to trigger synaptic potentiation (long-term potentiation). The mechanism for spike-timing-dependent depression is less well comprehended, but often consists either postsynaptic voltage dependent calcium entry/mGluR activation […]

From Hebbian Rule to STDP

Going by the Hebbian rule, synapses enhance their efficiency if the synapse persistently takes part in firing the postsynaptic target neuron. An often-leveraged simplification is those who fire together, wire together, but if two neurons fire precisely simultaneously, then one cannot have been the cause of, or have taken part in firing the other. Rather, to take part in firing the postsynaptic neuron, the presynaptic neuron requires to fire just prior to the postsynaptic neuron. Experiments that simulated dual connected neurons with varying interstimulus asynchrony provided confirmation that the criticality of temporal precedence implicit in Hebb’s principle: the presynaptic neuron has to fire just prior to the postsynaptic neuron for the synapse to be potentiated. Additionally, it has become evident that the presynaptic neural firing requires to consistently forecast the postsynaptic firing for synaptic plasticity to happen robustly, mirroring at a synaptic level – what is known about the criticality of contingency in classical conditioning, where zero contingency procedures prevent the relationship between two stimuli.

You might now see where I’m going with all of these definitions.

Backpropagation engenders Hebbian Learning

Okay, let’s go forth with the explanations. One element is absent from the title above. We should instead state “Backpropagation engenders Hebbian learning when dealing with SNNs”. SNNs are the mathematical utility that facilitates doing the proof and making the link. Why? As the concept of timing in STDP is critical for moving from backpropagation to Hebbian Learning.

Just contemplated about it: from the perspective of a neuron that spikes, upon conducting backpropagation, the neuron will backpropagate to its inputs for which the input signal was obtained more recently than not, therefore the signal that really contributed to firing. Why? As older signals are decayed exponentially, and their gradients vanish in the direction of zero. So with SNNs, the gradients are majorly transferred to input that was fired just prior to self-firing an output from the perspective of a neuron.

However, there still remains a mystery.

  • To start with, the gradient could still remain negative regardless of the fact that the neurons did fire together. This doesn’t imply that the behaviour of the neuron that compelled the firing of another neuron will be necessarily reiterated positively.
  • Secondly, if the input neuron fires just after rather than just before, it’s gradient stays unimpacted or will have strongly decayed later on. To rectify this second point, I have at least two ideas on how a neuron that fired too late could be penalized:
  • Leverage weight decay (L2) regularization, or
  • Have neurons output a small negative signal just after the positive signal, and allow this negative signal be longer than the positive signal (so that the integral of the whole spike and its negative following counterpart would be null or close to null) and allow this negative signal be provoked by the late neuron’s signal so that their late signal be the one penalized during backpropagation.

All of that being said, this is a really fascinating link between the saying of “cells that fire together wire together” (Hebbian Learning) and backpropagation.

How is backpropagation implemented in the brain? Could it be Contrastive Divergence (CD) between the brain’s rhythm cycles?

This question was asked by Yoshua Bengio.

First, let me point you to the 35th minute of this video of Geoffrey Hinton. STDP is discussed, and earlier in the talk, Hinton did talk about autoencoders performing cycles. To me those cycles oddly resemble brain rhythms. This leads me to my next idea: what if brain rhythms were those cyclic autoencoder passes (as discussed by Hinton), between which the past and new data was contrasted to produce essentially local gradients? The linking between the two concepts appears to be minimal, but I feel like it could be somehow be very connected and a critical topic.

Cool idea: what if, in the brain, gradients comes from Contrastive Divergence (CD) between each brain rhythm iteration?

That appears like an interesting conjecture to me. It would be fascinating to try. So the differences of activation values between brain rhythms (as CD done in autoencoders) could function as a learning signal to assist auto-encoding data in an unsupervised fashion. First, we humans are born as unsupervised machines (or probably semi-supervised learning to minimize pain).

CD is the algorithm that is leveraged in autoencoders like the Restricted Boltzmann Machines (RBM), such that passing the information in the neurons as a cycle facilitates differentiating signals to structure data into clusters and categories. Have you ever categorized somebody new in one of your stereotype boxes? That might be the outcome of clusters of unsupervised learning.

Visualize for a second deep SNN for which the top-most layer would feed back into the lower-most layer and that CD is leveraged to assist learning into finding out patterns just like what’s done in Restricted Boltzmann Machines (RBM). This way you might already discover powerful gradients furnished by the CD which may assist the neural network lear by himself just observing data – even prior to having a precise supervised learning objective (gradient) being furnished.

There could also be cycles of cycles. So you could possess high-frequency SNN CD cycles, and low-frequency ones leveraging more layers between the cycles. This way every small layer can tune himself, but also the big picture also syncs up, having the top most layers re-feeding back into other portions of the network.

Other tracks that were highlighted by Yoshua Bengio to explore

I have talked to Yoshua Bengio on this topic of the discovery I’ve made on how that backdrop could entail Hebbian learning in the brain. Among other things, he highlighted to me a few learning resources on this exact subject (and also to read a paper a day), so I asked for what to read on the topic.

  • First, he named Blake Richard.
  • Second, he named Walter Senn
  • He also named Konrad Kording

In a video of Blake Richard, he states that he has been pondering for a long time on how the brain could do backpropagation or something like it. And he states that in Deep Learning, we have stripped the intricacy out of the neurons and that actual neurons are way more complex than neurons leveraged in Deep Learning. He even states that perhaps by whittling away the intricacy, we have made things more difficult to comprehend. Later, when he describes his research, he speaks of approximating the weight updates that would be prescribed by backpropagation in his algorithm. From what he states about ensembles later on in the speech, it appears to be a good idea, particularly when dealing with SNNs. Those ensembles are like having several neurons playing the part of one neuron (somehow like trees vs forests). I thought about that merely to furnish more spikes per “neuron”, where a “neuron” would now rather be a neuron ensemble with similar weights, but didn’t see it as a way to share the gradient of the entire grouping of neurons. It could be fascinating to attribute a backpropagated gradient to such an ensemble of neurons rather than just one, as SNN’s spikes are uncommon than continuous sigmoid-activated output. This could avoid dead neurons (think like dead ReLUs but in SNNs due to lack of spiking). It would somehow be doable to “cheat” the gradients by normalizing the total of all inner states to the sum of the amplitude of the output activations to attribute some of the gradients back to all neurons of the ensemble if one of the neurons of the ensemble fires. Fascinating fact: this also makes me think of capsules networks with their voting groups of neurons.

Regardless of the work presented here, successfully trained on the first shot upon coding and debugging dimensions, I have been drawn to code an ensemble myself. What shocks me is that I’ve coded the present SNN prior to watching Blake’s video, but in a different way, (which I have rollbacked due to too much added intricacy to this already complex code which would render it too unreadable to you open-source readers) So to start with, I tried to leverage output neuron’s activation instead of their inner state as logits.

One thing that struck me is that it would never learn what to do for a label if the label was never firing (like a dead neuron): the gradients would never reach the teaching signal if the neural net was badly initialized from random. This lead me to develop an ensemble of neurons for every label, which now as of writing makes me think of the Mixture of Softmaxes (MoS) in which several softmaxe layers are leveraged as the output, and which is like having an ensemble of neurons at the output. With such alterations, the network could learn. Also, I attempted to include a 3rd layer such as to have one extra hidden layer as presently. This at first didn’t function and needed to adjust meticulously the firing thresholds for each neuron not to be already dead at the beginning (particularly in the top layers, no data was received – just a fascinating SNN animation – SNNs probably require some tuning to make sure neurons don’t wind up dead easily.)


I am not the first person to program an SNN certainly, even though those kind of neural networks are rarely coded. However, I expose here some critical behaviour that emerges from leveraging SNNs where every neuron has its excitation evolving over time and then spiking, linking to fascinating resources, and furnishing dynamic PyTorch code as an instance on how to function with SNNs.

I lay out the foundations for thinking that STDP and Hebbian learning have a close relationship with backpropagation in SNNs. Backpropagation in SNNs could engender the STDP rule like in Hebbian learning, as in SNNs the inner pre-activation value fades until it attains a threshold and fire, which makes old pre-activation values fade with a vanishing gradient to enforce STDP.

It is also fascinating that brain rhythms could occur in SNNs, and that with Contrastive Divergence (CD), gradients could be furnished between every cycle of the rhythm to improve formation of clusters in the representation of data in the neurons, unsupervised learners that we humans are. I also point out to a concept of having a negative spike right after every positive spike, where negative spikes are inversely correlated to pre-synaptic neurones that fired too late for the impulse to be released so as to enhance timing or respect more the rule of STDP that neurons late to the party should be penalized and would rather require to fire prior to their post-synaptic target.


Here is the BibTeX citation code:

  title={Spiking Neural Network (SNN) with PyTorch where Backpropagation Engenders Spike-Timing-Dependent Plasticity (STDP)},
  author={Chevalier, Guillaume},

Note on BiBTeX: if the URL field above isn’t supported in your LateX/BiBTeX document, you could as well try to replace the url=https://… field by note=”\url{https://github.com/guillaume-chevalier/Spiking-Neural-Network-SNN-with-PyTorch-where-Backpropagation-engenders-STDP}” instead.


Copyright (c) 2019 Guillaume Chevalier.

The present article’s text is available under CC BY-SA 4.0

My source code in the present article is not available yet under a license, but it’s super fine to use it for educational or research purposes as long as you cite.


Guillaume Chevalier is a Machine Learning expert and founder of Neuraxio, an AI startup backing the other AI startups, delivering deep learning and machine learning services and products to businesses. (B2B) Having 9 years of rich experience within the technology domain, Guillaume also has experience in the education sector. He was worked on more than 100 projects and is a deep learning specialist. In addition, he is one of the esteemed speakers at AICoreSpot events.

Add Comment