PyGeNN implementation of SuperSpike

This example model is a reimplementation of the model developed by Friedemann Zenke and Surya Ganguli [Zenke2018]. It uses the SuperSpike learning rule to learn the transformation between fixed spike trains of Poisson noise and a target spiking output (by default the Radcliffe Camera at Oxford).

This example can be used as follows:

[ ]:
import numpy as np

from argparse import ArgumentParser
from pygenn import (create_custom_update_model, create_neuron_model,
                    create_postsynaptic_model, create_var_ref,
                    create_weight_update_model, create_wu_var_ref,
                    init_postsynaptic, init_var, init_weight_update)
from pygenn import GeNNModel

# ----------------------------------------------------------------------------
# Parameters
# ----------------------------------------------------------------------------
TIMESTEP_MS = 0.1

# Network structure
NUM_INPUT = 200
NUM_OUTPUT = 200
NUM_HIDDEN = 256

# Model parameters
TAU_RISE_MS = 5.0
TAU_DECAY_MS = 10.0
TAU_RMS_MS = 30000.0
TAU_AVG_ERR_MS = 10000.0
R0 = 0.001 * 1000.0
EPSILON = 1E-32
TAU_DECAY_S = TAU_DECAY_MS / 1000.0
TAU_RISE_S = TAU_RISE_MS / 1000.0
TAU_AVG_ERR_S = TAU_AVG_ERR_MS / 1000.0
SCALE_TR_ERR_FLT = 1.0 / (pow((TAU_DECAY_S * TAU_RISE_S)/(TAU_DECAY_S - TAU_RISE_S),2) * (TAU_DECAY_S/2+TAU_RISE_S/2-2*(TAU_DECAY_S*TAU_RISE_S)/(TAU_DECAY_S+TAU_RISE_S))) / TAU_AVG_ERR_S

# Weights
# **NOTE** Auryn units are volts, seconds etc so essentially 1000x GeNN parameters
W_MIN = -0.1 * 1000.0
W_MAX = 0.1 * 1000.0
W0 = 0.05 * 1000.0

# Experiment parameters
INPUT_FREQ_HZ = 5.0
UPDATE_TIME_MS = 500.0
TRIAL_MS = 1890.0

# Convert parameters to timesteps
UPDATE_TIMESTEPS = int(UPDATE_TIME_MS / TIMESTEP_MS)
TRIAL_TIMESTEPS = int(TRIAL_MS / TIMESTEP_MS)

# ----------------------------------------------------------------------------
# Helper functions
# ----------------------------------------------------------------------------
def calc_t_peak(tau_rise, tau_decay):
    return ((tau_decay * tau_rise) / (tau_decay - tau_rise)) * np.log(tau_decay / tau_rise)

def write_spike_file(filename, data):
    np.savetxt(filename, np.column_stack(data[0]), fmt=["%f","%d"],
               delimiter=",", header="Time [ms], Neuron ID")

# ----------------------------------------------------------------------------
# Custom models
# ----------------------------------------------------------------------------
r_max_prop_model = create_custom_update_model(
    "r_max_prop",
    params=["updateTime", "tauRMS", "epsilon", "wMin", "wMax", "r0"],
    vars=[("upsilon", "scalar")],
    derived_params=[("updateTimesteps", lambda pars, dt: pars["updateTime"] / dt),
                    ("expRMS", lambda pars, dt: np.exp(-pars["updateTime"] / pars["tauRMS"]))],
    var_refs=[("m", "scalar"), ("variable", "scalar")],
    update_code="""
    // Get gradients
    const scalar gradient = m / updateTimesteps;
    // Calculate learning rate r
    upsilon = fmax(upsilon * expRMS, gradient * gradient);
    const scalar r = r0 / (sqrt(upsilon) + epsilon);
    // Update synaptic parameter
    variable += r * gradient;
    variable = fmin(wMax, fmax(wMin, variable));
    m = 0.0;
    """)

superspike_model = create_weight_update_model(
    "superspike",
    params=["tauRise", "tauDecay", "beta", "Vthresh"],
    vars=[("w", "scalar"), ("e", "scalar"),
          ("lambda", "scalar"), ("m", "scalar")],
    pre_vars=[("z", "scalar"), ("zTilda", "scalar")],
    post_vars=[("sigmaPrime", "scalar")],
    post_neuron_var_refs=[("V", "scalar"), ("errTilda", "scalar")],

    pre_spike_syn_code="""
    addToPost(w);
    """,

    pre_spike_code="""
    z += 1.0;
    """,
    pre_dynamics_code="""
    // filtered presynaptic trace
    z += (-z / tauRise) * dt;
    zTilda += ((-zTilda + z) / tauDecay) * dt;
    """,

    post_dynamics_code="""
    // filtered partial derivative
    if(V < -80.0) {
       sigmaPrime = 0.0;
    }
    else {
       const scalar onePlusHi = 1.0 + fabs(beta * 0.001 * (V - Vthresh));
       sigmaPrime = beta / (onePlusHi * onePlusHi);
    }
    """,

    synapse_dynamics_code="""
    // Filtered eligibility trace
    e += (zTilda * sigmaPrime - e / tauRise) * dt;
    lambda += ((-lambda + e) / tauDecay) * dt;

    // Get error from neuron model and compute full
    // expression under integral and calculate m
    m += lambda * errTilda;
    """)

feedback_model = create_weight_update_model(
    "feedback",
    vars=[("w", "scalar")],
    pre_neuron_var_refs=[("errTilda", "scalar")],
    synapse_dynamics_code="""
    addToPost(w * errTilda);
    """)

hidden_neuron_model = create_neuron_model(
    "hidden",
    params=["C", "tauMem", "Vrest", "Vthresh", "tauRefrac"],
    vars=[("V", "scalar"), ("refracTime", "scalar"), ("errTilda", "scalar")],
    additional_input_vars=[("ISynFeedback", "scalar", 0.0)],
    derived_params=[("ExpTC", lambda pars, dt: np.exp(-dt / pars["tauMem"])),
                    ("Rmembrane", lambda pars, dt: pars["tauMem"] / pars["C"])],

    sim_code="""
    // membrane potential dynamics
    if (refracTime == tauRefrac) {
        V = Vrest;
    }
    if (refracTime <= 0.0) {
        scalar alpha = (Isyn * Rmembrane) + Vrest;
        V = alpha - (ExpTC * (alpha - V));
    }
    else {
        refracTime -= dt;
    }
    // error
    errTilda = ISynFeedback;
    """,
    reset_code="""
    refracTime = tauRefrac;
    """,
    threshold_condition_code="""
    refracTime <= 0.0 && V >= Vthresh
    """)

output_neuron_model = create_neuron_model(
    "output",
    params=["C", "tauMem", "Vrest", "Vthresh", "tauRefrac",
            "tauRise", "tauDecay", "tauAvgErr"],
    vars=[("V", "scalar"), ("refracTime", "scalar"), ("errRise", "scalar"),
          ("errTilda", "scalar"), ("avgSqrErr", "scalar"), ("errDecay", "scalar"),
           ("startSpike", "unsigned int"), ("endSpike", "unsigned int")],
    extra_global_params=[("spikeTimes", "scalar*")],
    derived_params=[("ExpTC", lambda pars, dt: np.exp(-dt / pars["tauMem"])),
                    ("Rmembrane", lambda pars, dt: pars["tauMem"] / pars["C"]),
                    ("normFactor", lambda pars, dt: 1.0 / (-np.exp(-calc_t_peak(pars["tauRise"], pars["tauDecay"]) / pars["tauRise"]) + np.exp(-calc_t_peak(pars["tauRise"], pars["tauDecay"]) / pars["tauDecay"]))),
                    ("tRiseMult", lambda pars, dt: np.exp(-dt / pars["tauRise"])),
                    ("tDecayMult", lambda pars, dt: np.exp(-dt / pars["tauDecay"])),
                    ("tPeak", lambda pars, dt: calc_t_peak(pars["tauRise"], pars["tauDecay"])),
                    ("mulAvgErr", lambda pars, dt: np.exp(-dt / pars["tauAvgErr"]))],

    sim_code="""
    // membrane potential dynamics
    if (refracTime == tauRefrac) {
        V = Vrest;
    }
    if (refracTime <= 0.0) {
        scalar alpha = (Isyn * Rmembrane) + Vrest;
        V = alpha - (ExpTC * (alpha - V));
    }
    else {
        refracTime -= dt;
    }
    // error
    scalar sPred = 0.0;
    if (startSpike != endSpike && t >= spikeTimes[startSpike]) {
        startSpike++;
        sPred = 1.0;
    }
    const scalar sReal = (refracTime <= 0.0 && V >= Vthresh) ? 1.0 : 0.0;
    const scalar mismatch = sPred - sReal;
    errRise = (errRise * tRiseMult) + mismatch;
    errDecay = (errDecay * tDecayMult) + mismatch;
    errTilda = (errDecay - errRise) * normFactor;
    // calculate average error trace
    const scalar temp = errTilda * errTilda * dt * 0.001;
    avgSqrErr *= mulAvgErr;
    avgSqrErr += temp;
    """,
    reset_code="""
    refracTime = tauRefrac;
    """,
    threshold_condition_code="""
    refracTime <= 0.0 && V >= Vthresh
    """)

# ----------------------------------------------------------------------------
# CLI
# ----------------------------------------------------------------------------
def get_parser():
    parser = ArgumentParser()
    parser.add_argument("--record-trial", type=int, nargs="*", required=True, help="Index of trial(s) to record")
    parser.add_argument("--target-file", type=str, default="oxford-target.ras", help="Filename of spike file to train model on")
    parser.add_argument("--num-trials", type=int, default=600, help="Number of trials to train for")
    parser.add_argument("--kernel-profiling", action="store_true", help="Output kernel profiling data")
    parser.add_argument("--save-data", action="store_true", help="Save spike data (rather than plotting it)")
    return parser


# ----------------------------------------------------------------------------
# Entry point
# ----------------------------------------------------------------------------
if __name__ == "__main__":
    args = get_parser().parse_args()
    # Sort trial indices to record
    args.record_trial = sorted(args.record_trial)

    # ----------------------------------------------------------------------------
    # Load target data
    # ----------------------------------------------------------------------------
    # Load target data
    target_spikes = np.loadtxt(args.target_file,
                               dtype={"names": ("time", "neuron_id"),
                                      "formats": (float, int)})

    # Make neuron IDs zero-based
    target_spikes["neuron_id"] -= 1

    # Convert times to milliseconds
    target_spikes["time"] *= 1000.0

    # Sort first by neuron id and then by time
    target_spikes = np.sort(target_spikes, order=["neuron_id", "time"])

    # Count number of spikes
    target_neuron_end_times = np.cumsum(np.bincount(target_spikes["neuron_id"], minlength=NUM_OUTPUT))
    target_neuron_start_times = np.concatenate(([0], target_neuron_end_times[:-1]))

    # ----------------------------------------------------------------------------
    # Generate frozen poisson input
    # ----------------------------------------------------------------------------
    input_isi_ms = 1000.0 / INPUT_FREQ_HZ

    # Generate time of first spike for each neuron
    input_spike_times = input_isi_ms * np.random.exponential(size=NUM_INPUT)
    input_spike_times = np.reshape(input_spike_times, (1, NUM_INPUT))

    while True:
        # Generate vector of spike times
        s = input_isi_ms * np.random.exponential(size=NUM_INPUT)

        # Add previous times
        s += input_spike_times[-1,:]

        # If all neurons have reached end of trial
        if np.all(s >= TRIAL_MS):
            break
        # Otherwise stack
        else:
            input_spike_times = np.vstack((input_spike_times, s))

    # Count spikes per input neuron
    input_spikes_per_neuron = np.sum(input_spike_times < TRIAL_MS, axis=0)

    # Concatenate spikes within trial together
    input_spikes = np.concatenate([input_spike_times[:input_spikes_per_neuron[i],i]
                                   for i in range(NUM_INPUT)])

    # Calculate indices
    input_neuron_end_times = np.cumsum(input_spikes_per_neuron)
    input_neuron_start_times = np.concatenate(([0], input_neuron_end_times[:-1]))

    # ----------------------------------------------------------------------------
    # Neuron initialisation
    # ----------------------------------------------------------------------------
    input_init_vars = {"startSpike": input_neuron_start_times, "endSpike": input_neuron_end_times}

    hidden_params = {"C" : 10.0, "tauMem": 10.0, "Vrest": -60.0,
                     "Vthresh": -50.0 , "tauRefrac": 5.0}
    hidden_init_vars = {"V": -60.0, "refracTime": 0.0, "errTilda": 0.0}

    output_params = {"C": 10.0, "tauMem": 10.0, "Vrest": -60.0,
                     "Vthresh": -50.0, "tauRefrac": 5.0, "tauRise": TAU_RISE_MS,
                     "tauDecay": TAU_DECAY_MS, "tauAvgErr": TAU_AVG_ERR_MS}
    output_init_vars = {"V": -60.0, "refracTime": 0.0, "errRise": 0.0,
                        "errTilda": 0.0, "errDecay": 0.0, "avgSqrErr": 0.0,
                        "startSpike": target_neuron_start_times, "endSpike": target_neuron_end_times}

    # ----------------------------------------------------------------------------
    # Synapse initialisation
    # ----------------------------------------------------------------------------
    superspike_params = {"tauRise": TAU_RISE_MS, "tauDecay": TAU_DECAY_MS, "beta": 1000.0, "Vthresh": -50.0}
    superspike_pre_init_vars = {"z": 0.0, "zTilda": 0.0}
    superspike_post_init_vars = {"sigmaPrime": 0.0}

    input_hidden_weight_dist_params = {"mean": 0.0, "sd": W0 / np.sqrt(float(NUM_INPUT)),
                                       "min": W_MIN, "max": W_MAX}
    input_hidden_init_vars = {"w": init_var("NormalClipped", input_hidden_weight_dist_params),
                              "e": 0.0, "lambda": 0.0, "m": 0.0}

    hidden_output_weight_dist_params = {"mean": 0.0, "sd": W0 / np.sqrt(float(NUM_HIDDEN)),
                                        "min": W_MIN, "max": W_MAX}
    hidden_output_init_vars = {"w": init_var("NormalClipped", hidden_output_weight_dist_params),
                               "e": 0.0, "lambda": 0.0, "m": 0.0}

    # ----------------------------------------------------------------------------
    # Custom update initialisation
    # ----------------------------------------------------------------------------
    r_max_prop_params = {"updateTime": UPDATE_TIME_MS, "tauRMS": TAU_RMS_MS,
                         "epsilon": EPSILON, "wMin": W_MIN, "wMax": W_MAX, "r0": R0}

    # ----------------------------------------------------------------------------
    # Model description
    # ----------------------------------------------------------------------------
    model = GeNNModel("float", "superspike_demo", generateLineInfo=True)
    model.dt = TIMESTEP_MS
    model.timing_enabled = args.kernel_profiling

    # Add neuron populations
    input = model.add_neuron_population("Input", NUM_INPUT, "SpikeSourceArray",
                                        {}, input_init_vars)
    hidden = model.add_neuron_population("Hidden", NUM_HIDDEN, hidden_neuron_model,
                                         hidden_params, hidden_init_vars)
    output = model.add_neuron_population("Output", NUM_OUTPUT, output_neuron_model,
                                         output_params, output_init_vars)

    input.extra_global_params["spikeTimes"].set_init_values(input_spikes)
    output.extra_global_params["spikeTimes"].set_init_values(target_spikes["time"])

    # Turn on recording
    any_recording = (len(args.record_trial) > 0)
    input.spike_recording_enabled = any_recording
    hidden.spike_recording_enabled = any_recording
    output.spike_recording_enabled = any_recording

    # Add synapse populations
    input_hidden = model.add_synapse_population(
        "InputHidden", "DENSE",
        input, hidden,
        init_weight_update(superspike_model, superspike_params, input_hidden_init_vars, superspike_pre_init_vars, superspike_post_init_vars,
                           post_var_refs={"V": create_var_ref(hidden, "V"), "errTilda": create_var_ref(hidden, "errTilda")}),
        init_postsynaptic("ExpCurr", {"tau": 5.0}))

    hidden_output = model.add_synapse_population(
        "HiddenOutput", "DENSE",
        hidden, output,
        init_weight_update(superspike_model, superspike_params, hidden_output_init_vars, superspike_pre_init_vars, superspike_post_init_vars,
                           post_var_refs={"V": create_var_ref(output, "V"), "errTilda": create_var_ref(output, "errTilda")}),
        init_postsynaptic("ExpCurr", {"tau": 5.0}))

    output_hidden = model.add_synapse_population(
        "OutputHidden", "DENSE",
        output, hidden,
        init_weight_update(feedback_model, {}, {"w": 0.0}, pre_var_refs={"errTilda": create_var_ref(output, "errTilda")}),
        init_postsynaptic("DeltaCurr"))
    output_hidden.post_target_var = "ISynFeedback"

    # Add custom update for calculating initial tranpose weights
    model.add_custom_update("input_hidden_transpose", "CalculateTranspose", "Transpose",
                            {}, {}, {"variable": create_wu_var_ref(hidden_output, "w", output_hidden, "w")})

    # Add custom updates for gradient update
    input_hidden_optimiser_var_refs = {"m": create_wu_var_ref(input_hidden, "m"),
                                       "variable": create_wu_var_ref(input_hidden, "w")}
    input_hidden_optimiser = model.add_custom_update("input_hidden_optimiser", "GradientLearn", r_max_prop_model,
                                                     r_max_prop_params, {"upsilon": 0.0}, input_hidden_optimiser_var_refs)
    input_hidden_optimiser.set_param_dynamic("r0")

    hidden_output_optimiser_var_refs = {"m": create_wu_var_ref(hidden_output, "m"),
                                       "variable": create_wu_var_ref(hidden_output, "w", output_hidden, "w")}
    hidden_output_optimiser = model.add_custom_update("hidden_output_optimiser", "GradientLearn", r_max_prop_model,
                                                      r_max_prop_params, {"upsilon": 0.0}, hidden_output_optimiser_var_refs)
    hidden_output_optimiser.set_param_dynamic("r0")

    # Build and load model
    model.build()
    model.load(num_recording_timesteps=TRIAL_TIMESTEPS)

    # Calculate initial transpose feedback weights
    model.custom_update("CalculateTranspose")

    # Loop through trials
    output_avg_sqr_err_var = output.vars["avgSqrErr"]
    current_r0 = R0
    timestep = 0
    input_spikes = []
    hidden_spikes = []
    output_spikes = []
    for trial in range(args.num_trials):
        # Reduce learning rate every 400 trials
        if trial != 0 and (trial % 400) == 0:
            current_r0 *= 0.1

            input_hidden_optimiser.set_dynamic_param_value("r0", current_r0)
            hidden_output_optimiser.set_dynamic_param_value("r0", current_r0)

        # Display trial number peridically
        if trial != 0 and (trial % 10) == 0:
            # Get average square error
            output_avg_sqr_err_var.pull_from_device()

            # Calculate mean error
            time_s = timestep * TIMESTEP_MS / 1000.0;
            mean_error = np.sum(output_avg_sqr_err_var.view) / float(NUM_OUTPUT);
            mean_error *= SCALE_TR_ERR_FLT / (1.0 - np.exp(-time_s / TAU_AVG_ERR_S) + 1.0E-9);

            print("Trial %u (r0 = %f, error = %f)" % (trial, current_r0, mean_error))

        # Reset model timestep
        model.timestep = 0

        # Loop through timesteps within trial
        for i in range(TRIAL_TIMESTEPS):
            model.step_time()

            # If it's time to update weights
            if timestep != 0 and (timestep % UPDATE_TIMESTEPS) == 0:
                model.custom_update("GradientLearn");
            timestep+=1;


        # Reset spike sources by re-uploading starting spike indices
        # **TODO** build repeating spike source array
        input.vars["startSpike"].push_to_device()
        output.vars["startSpike"].push_to_device()

        if trial in args.record_trial:
            model.pull_recording_buffers_from_device();

            if args.save_data:
                write_spike_file("input_spikes_%u.csv" % trial, input.spike_recording_data)
                write_spike_file("hidden_spikes_%u.csv" % trial, hidden.spike_recording_data)
                write_spike_file("output_spikes_%u.csv" % trial, output.spike_recording_data)
            else:
                input_spikes.append(input.spike_recording_data[0])
                hidden_spikes.append(hidden.spike_recording_data[0])
                output_spikes.append(output.spike_recording_data[0])

    if args.kernel_profiling:
        print("Init: %f" % model.init_time)
        print("Init sparse: %f" % model.init_sparse_time)
        print("Neuron update: %f" % model.neuron_update_time)
        print("Presynaptic update: %f" % model.presynaptic_update_time)
        print("Synapse dynamics: %f" % model.synapse_dynamics_time)
        print("Gradient learning custom update: %f" % model.get_custom_update_time("GradientLearn"))
        print("Gradient learning custom update transpose: %f" % model.get_custom_update_transpose_time("GradientLearn"))

    if not args.save_data:
        import matplotlib.pyplot as plt

        # Create plot
        fig, axes = plt.subplots(3, len(input_spikes), sharex="col", sharey="row", squeeze=False)

        for i, spikes in enumerate(zip(input_spikes, hidden_spikes, output_spikes)):
            # Plot spikes
            start_time_s = float(args.record_trial[i]) * 1.890
            axes[0, i].scatter(start_time_s + (spikes[0][0] / 1000.0), spikes[0][1], s=2, edgecolors="none")
            axes[1, i].scatter(start_time_s + (spikes[1][0] / 1000.0), spikes[1][1], s=2, edgecolors="none")
            axes[2, i].scatter(start_time_s + (spikes[2][0] / 1000.0), spikes[2][1], s=2, edgecolors="none")

            axes[2, i].set_xlabel("Time [s]")

        axes[0, 0].set_ylabel("Neuron number")
        axes[1, 0].set_ylabel("Neuron number")
        axes[2, 0].set_ylabel("Neuron number")

        # Show plot
        plt.show()