MNIST classification using an insect-inspired mushroom body model
This example learns MNIST digits using STDP and an insect-inspired mushroom body model
This example can be used as follows:
[ ]:
import mnist
import numpy as np
from copy import copy
from argparse import ArgumentParser
from pygenn import (create_current_source_model, create_custom_update_model,
create_neuron_model, create_out_post_var_ref,
create_spike_time_var_ref, create_var_ref,
create_weight_update_model, init_sparse_connectivity,
init_postsynaptic, init_weight_update, GeNNModel)
from tqdm.auto import tqdm
# ----------------------------------------------------------------------------
# Parameters
# ----------------------------------------------------------------------------
# Simulation time step
DT = 0.1
# Scaling factor for converting normalised image pixels to input currents (nA)
INPUT_SCALE = 80.0
# Size of current to use to stimulate correct MBON when training (nA)
MBON_STIMULUS_CURRENT = 5.0
# Number of Projection Neurons in model (should match image size)
NUM_PN = 784
# Number of Kenyon Cells in model (defines memory capacity)
NUM_KC = 20000
# Number of output neurons in model
NUM_MBON = 10
# How long to present each image to model
PRESENT_TIME_MS = 20.0
# Standard LIF neurons parameters
LIF_PARAMS = {
"C": 0.2,
"TauM": 20.0,
"Vrest": -60.0,
"Vreset": -60.0,
"Vthresh": -50.0,
"Ioffset": 0.0,
"TauRefrac": 2.0}
# We only want PNs to spike once
PN_PARAMS = copy(LIF_PARAMS)
PN_PARAMS["TauRefrac"] = 100.0
# Weight of each synaptic connection
PN_KC_WEIGHT = 0.2
# Time constant of synaptic integration
PN_KC_TAU_SYN = 3.0
# How many projection neurons should be connected to each Kenyon Cell
PN_KC_FAN_IN = 20
# We will use weights of 1.0 for KC->GGN connections and
# want the GGN to inhibit the KCs after 200 spikes
GGN_PARAMS = {"Vthresh": 200.0}
KC_MBON_TAU_SYN = 3.0
KC_MBON_PARAMS = {"tau": 15.0,
"rho": 0.01,
"eta": 0.00002,
"wMin": 0.0,
"wMax": 0.0233}
# ----------------------------------------------------------------------------
# Custom models
# ----------------------------------------------------------------------------
# Current source model, allowing current to be injected into neuron from variable
cs_model = create_current_source_model(
"cs_model",
vars=[("magnitude", "scalar")],
injection_code="injectCurrent(magnitude);")
# Minimal integrate and fire neuron model
if_model = create_neuron_model(
"IF",
params=["Vthresh"],
vars=[("V", "scalar")],
sim_code=
"""
V += Isyn;
""",
threshold_condition_code=
"""
V >= Vthresh
""",
reset_code=
"""
V= 0.0;
""")
# Symmetric STDP learning rule
symmetric_stdp = create_weight_update_model(
"symmetric_stdp",
params=["tau", "rho", "eta", "wMin", "wMax"],
vars=[("g", "scalar")],
pre_spike_syn_code=
"""
const scalar dt = t - st_post;
const scalar timing = exp(-dt / tau) - rho;
const scalar newWeight = g + (eta * timing);
g = fmin(wMax, fmax(wMin, newWeight));
""",
post_spike_syn_code=
"""
const scalar dt = t - st_pre;
const scalar timing = fmax(exp(-dt / tau) - rho, -0.1 * rho);
const scalar newWeight = g + (eta * timing);
g = fmin(wMax, fmax(wMin, newWeight));
""")
# Custom update for resetting neuron state
pn_reset = create_custom_update_model(
"pn_reset",
params=["Vreset"],
var_refs=[("V", "scalar"), ("RefracTime", "scalar")],
update_code=
"""
V = Vreset;
RefracTime = 0.0;
""")
kc_reset = create_custom_update_model(
"kc_reset",
params=["Vreset"],
var_refs=[("PNOutPost", "scalar"), ("GGNOutPost", "scalar"),
("V", "scalar"), ("RefracTime", "scalar")],
update_code=
"""
PNOutPost = 0.0;
GGNOutPost = 0.0;
V = Vreset;
RefracTime = 0.0;
""")
ggn_reset = create_custom_update_model(
"ggn_reset",
params=["Vreset"],
var_refs=[("V", "scalar")],
update_code=
"""
V = Vreset;
""")
mbon_reset = create_custom_update_model(
"mbon_reset",
params=["Vreset"],
var_refs=[("OutPost", "scalar"), ("V", "scalar"),
("RefracTime", "scalar")],
update_code=
"""
OutPost = 0.0;
V = Vreset;
RefracTime = 0.0;
""")
# Custom update for resetting spike times
reset_st = create_custom_update_model(
"reset_st",
var_refs=[("SpikeTimes", "scalar")],
update_code=
f"""
SpikeTimes = {-np.finfo(np.float32).max};
""")
# ----------------------------------------------------------------------------
# CLI
# ----------------------------------------------------------------------------
def get_parser():
parser = ArgumentParser()
parser.add_argument("--test", action="store_true", help="Load saved weights (rather than training)")
parser.add_argument("--plot-weight-distribution", action="store_true", help="Plot weight distribution after training")
return parser
if __name__ == "__main__":
args = get_parser().parse_args()
# Set the download url for the MNIST dataset
mnist.datasets_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
# Reshape and normalise data
images = mnist.test_images() if args.test else mnist.train_images()
images = np.reshape(images, (images.shape[0], -1)).astype(np.float32)
images /= np.sum(images, axis=1)[:, np.newaxis]
labels = mnist.test_labels() if args.test else mnist.train_labels()
# Create model
model = GeNNModel("float", "mnist_mb")
model.dt = DT
# Create neuron populations
lif_init = {"V": PN_PARAMS["Vreset"], "RefracTime": 0.0}
if_init = {"V": 0.0}
pn = model.add_neuron_population("pn", NUM_PN, "LIF", PN_PARAMS, lif_init)
kc = model.add_neuron_population("kc", NUM_KC, "LIF", LIF_PARAMS, lif_init)
ggn = model.add_neuron_population("ggn", 1, if_model, GGN_PARAMS, if_init)
mbon = model.add_neuron_population("mbon", NUM_MBON, "LIF", LIF_PARAMS, lif_init)
# Turn on spike recording
pn.spike_recording_enabled = True
kc.spike_recording_enabled = True
mbon.spike_recording_enabled = True
# Create current sources to deliver input to network
pn_input = model.add_current_source("pn_input", cs_model, pn , {}, {"magnitude": 0.0})
# Create current sources to deliver input and supervision to network
if not args.test:
mbon_input = model.add_current_source("mbon_input", cs_model, mbon , {}, {"magnitude": 0.0})
# Create synapse populations
pn_kc_connectivity = None if args.test else init_sparse_connectivity("FixedNumberPreWithReplacement", {"num": PN_KC_FAN_IN})
pn_kc = model.add_synapse_population("pn_kc", "SPARSE",
pn, kc,
init_weight_update("StaticPulseConstantWeight", {"g": PN_KC_WEIGHT}),
init_postsynaptic("ExpCurr", {"tau": PN_KC_TAU_SYN}),
pn_kc_connectivity)
# Load saved connectivity if testing
if args.test:
pn_kc_ind = np.load("pn_kc_ind.npy")
pn_kc.set_sparse_connections(pn_kc_ind[0], pn_kc_ind[1])
kc_ggn = model.add_synapse_population("kc_ggn", "DENSE",
kc, ggn,
init_weight_update("StaticPulseConstantWeight", {"g": 1.0}),
init_postsynaptic("DeltaCurr"))
ggn_kc = model.add_synapse_population("ggn_kc", "DENSE",
ggn, kc,
init_weight_update("StaticPulseConstantWeight", {"g": -5.0}),
init_postsynaptic("ExpCurr", {"tau": 5.0}))
kc_mbon_weight_update = (init_weight_update("StaticPulse", {}, {"g": np.load("kc_mbon_g.npy")}) if args.test
else init_weight_update(symmetric_stdp, KC_MBON_PARAMS, {"g": 0.0}))
kc_mbon = model.add_synapse_population("kc_mbon", "DENSE",
kc, mbon,
kc_mbon_weight_update,
init_postsynaptic("ExpCurr", {"tau": KC_MBON_TAU_SYN}))
# Add custom updates to reset model state between examples
model.add_custom_update("pn_reset", "Reset", pn_reset,
{"Vreset": LIF_PARAMS["Vreset"]},
var_refs={"V": create_var_ref(pn, "V"),
"RefracTime": create_var_ref(pn, "RefracTime")})
model.add_custom_update("kc_reset", "Reset", kc_reset,
{"Vreset": LIF_PARAMS["Vreset"]},
var_refs={"V": create_var_ref(kc, "V"),
"RefracTime": create_var_ref(kc, "RefracTime"),
"PNOutPost": create_out_post_var_ref(pn_kc),
"GGNOutPost": create_out_post_var_ref(ggn_kc)})
model.add_custom_update("ggn_reset", "Reset", ggn_reset,
{"Vreset": 0.0},
var_refs={"V": create_var_ref(ggn, "V")})
model.add_custom_update("mbon_reset", "Reset", mbon_reset,
{"Vreset": LIF_PARAMS["Vreset"]},
var_refs={"V": create_var_ref(mbon, "V"),
"RefracTime": create_var_ref(mbon, "RefracTime"),
"OutPost": create_out_post_var_ref(kc_mbon)})
if not args.test:
model.add_custom_update("kc_reset_st", "ResetST", reset_st,
var_refs={"SpikeTimes": create_spike_time_var_ref(kc)})
model.add_custom_update("mbon_reset_st", "ResetST", reset_st,
var_refs={"SpikeTimes": create_spike_time_var_ref(mbon)})
# Convert present time into timesteps
present_timesteps = int(round(PRESENT_TIME_MS / DT))
# Build model and load it
model.build()
model.load(num_recording_timesteps=present_timesteps)
# Present images
num_correct = 0
for s in tqdm(range(images.shape[0])):
# Set training image
pn_input.vars["magnitude"].view[:] = images[s] * INPUT_SCALE
pn_input.vars["magnitude"].push_to_device()
# Turn on correct output neuron
if not args.test:
mbon_input.vars["magnitude"].view[:] = 0
mbon_input.vars["magnitude"].view[labels[s]] = MBON_STIMULUS_CURRENT
mbon_input.vars["magnitude"].push_to_device()
# Simulate present timesteps
for i in range(present_timesteps):
model.step_time()
# Reset neuron state
model.custom_update("Reset")
# Reset spike times
if not args.test:
model.custom_update("ResetST")
if args.test:
# Download spikes from GPU
model.pull_recording_buffers_from_device()
# Determine the classification and count correct
mbon_spike_times, mbon_spike_ids = mbon.spike_recording_data[0]
if len(mbon_spike_times) > 0:
if mbon_spike_ids[np.argmin(mbon_spike_times)] == labels[s]:
num_correct += 1
if args.test:
print(f"\n{num_correct}/{images.shape[0]} correct ({(num_correct * 100.0) / images.shape[0]} %)")
else:
pn_kc.pull_connectivity_from_device()
kc_mbon.vars["g"].pull_from_device()
# Save weighs and connectivity
kc_mbon_g_view = kc_mbon.vars["g"].view
np.save("kc_mbon_g.npy", kc_mbon_g_view)
np.save("pn_kc_ind.npy", np.vstack((pn_kc.get_sparse_pre_inds(),
pn_kc.get_sparse_post_inds())))
# Plot weight distribution
if args.plot_weight_distribution:
from matplotlib import pyplot as plt
fig, axis = plt.subplots(figsize=(10, 5))
axis.hist(kc_mbon_g_view, bins=100)
axis.axvline(np.average(kc_mbon_g_view), linestyle="--")
axis.set_xlabel("Weight [nA]")
axis.set_ylabel("Count");
plt.show()