{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# MNIST classification using an insect-inspired mushroom body model\nThis example learns MNIST digits using STDP and an insect-inspired mushroom body model\n\nThis example can be used as follows:\n\n.. argparse::\n :filename: ../userproject/mnist_mb_classifier.py\n :func: get_parser\n :prog: mnist_mb_classifier\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import mnist\nimport numpy as np\nfrom copy import copy\nfrom argparse import ArgumentParser\nfrom pygenn import (create_current_source_model, create_custom_update_model,\n create_neuron_model, create_out_post_var_ref, \n create_spike_time_var_ref, create_var_ref, \n create_weight_update_model, init_sparse_connectivity,\n init_postsynaptic, init_weight_update, GeNNModel)\nfrom tqdm.auto import tqdm\n\n# ----------------------------------------------------------------------------\n# Parameters\n# ----------------------------------------------------------------------------\n# Simulation time step\nDT = 0.1\n\n# Scaling factor for converting normalised image pixels to input currents (nA)\nINPUT_SCALE = 80.0\n\n# Size of current to use to stimulate correct MBON when training (nA)\nMBON_STIMULUS_CURRENT = 5.0\n\n# Number of Projection Neurons in model (should match image size)\nNUM_PN = 784\n\n# Number of Kenyon Cells in model (defines memory capacity)\nNUM_KC = 20000\n\n# Number of output neurons in model\nNUM_MBON = 10\n\n# How long to present each image to model\nPRESENT_TIME_MS = 20.0\n\n# Standard LIF neurons parameters\nLIF_PARAMS = {\n \"C\": 0.2,\n \"TauM\": 20.0,\n \"Vrest\": -60.0,\n \"Vreset\": -60.0,\n \"Vthresh\": -50.0,\n \"Ioffset\": 0.0,\n \"TauRefrac\": 2.0}\n\n# We only want PNs to spike once\nPN_PARAMS = copy(LIF_PARAMS)\nPN_PARAMS[\"TauRefrac\"] = 100.0\n\n# Weight of each synaptic connection\nPN_KC_WEIGHT = 0.2\n\n# Time constant of synaptic integration\nPN_KC_TAU_SYN = 3.0\n\n# How many projection neurons should be connected to each Kenyon Cell\nPN_KC_FAN_IN = 20\n\n# We will use weights of 1.0 for KC->GGN connections and\n# want the GGN to inhibit the KCs after 200 spikes\nGGN_PARAMS = {\"Vthresh\": 200.0}\n\nKC_MBON_TAU_SYN = 3.0\nKC_MBON_PARAMS = {\"tau\": 15.0,\n \"rho\": 0.01,\n \"eta\": 0.00002,\n \"wMin\": 0.0,\n \"wMax\": 0.0233}\n\n# ----------------------------------------------------------------------------\n# Custom models\n# ----------------------------------------------------------------------------\n# Current source model, allowing current to be injected into neuron from variable\ncs_model = create_current_source_model(\n \"cs_model\",\n vars=[(\"magnitude\", \"scalar\")],\n injection_code=\"injectCurrent(magnitude);\")\n\n# Minimal integrate and fire neuron model\nif_model = create_neuron_model(\n \"IF\",\n params=[\"Vthresh\"],\n vars=[(\"V\", \"scalar\")],\n sim_code=\n \"\"\"\n V += Isyn;\n \"\"\",\n threshold_condition_code=\n \"\"\"\n V >= Vthresh\n \"\"\",\n reset_code=\n \"\"\"\n V= 0.0;\n \"\"\")\n\n# Symmetric STDP learning rule\nsymmetric_stdp = create_weight_update_model(\n \"symmetric_stdp\",\n params=[\"tau\", \"rho\", \"eta\", \"wMin\", \"wMax\"],\n vars=[(\"g\", \"scalar\")],\n pre_spike_syn_code=\n \"\"\"\n const scalar dt = t - st_post;\n const scalar timing = exp(-dt / tau) - rho;\n const scalar newWeight = g + (eta * timing);\n g = fmin(wMax, fmax(wMin, newWeight));\n \"\"\",\n post_spike_syn_code=\n \"\"\"\n const scalar dt = t - st_pre;\n const scalar timing = fmax(exp(-dt / tau) - rho, -0.1 * rho);\n const scalar newWeight = g + (eta * timing);\n g = fmin(wMax, fmax(wMin, newWeight));\n \"\"\")\n\n# Custom update for resetting neuron state\npn_reset = create_custom_update_model(\n \"pn_reset\",\n params=[\"Vreset\"],\n var_refs=[(\"V\", \"scalar\"), (\"RefracTime\", \"scalar\")],\n update_code=\n \"\"\"\n V = Vreset;\n RefracTime = 0.0;\n \"\"\")\n\nkc_reset = create_custom_update_model(\n \"kc_reset\",\n params=[\"Vreset\"],\n var_refs=[(\"PNOutPost\", \"scalar\"), (\"GGNOutPost\", \"scalar\"),\n (\"V\", \"scalar\"), (\"RefracTime\", \"scalar\")],\n update_code=\n \"\"\"\n PNOutPost = 0.0;\n GGNOutPost = 0.0;\n V = Vreset;\n RefracTime = 0.0;\n \"\"\")\n\nggn_reset = create_custom_update_model(\n \"ggn_reset\",\n params=[\"Vreset\"],\n var_refs=[(\"V\", \"scalar\")],\n update_code=\n \"\"\"\n V = Vreset;\n \"\"\")\n\nmbon_reset = create_custom_update_model(\n \"mbon_reset\",\n params=[\"Vreset\"],\n var_refs=[(\"OutPost\", \"scalar\"), (\"V\", \"scalar\"),\n (\"RefracTime\", \"scalar\")],\n update_code=\n \"\"\"\n OutPost = 0.0;\n V = Vreset;\n RefracTime = 0.0;\n \"\"\")\n\n# Custom update for resetting spike times\nreset_st = create_custom_update_model(\n \"reset_st\",\n var_refs=[(\"SpikeTimes\", \"scalar\")],\n update_code=\n f\"\"\"\n SpikeTimes = {-np.finfo(np.float32).max};\n \"\"\")\n\n# ----------------------------------------------------------------------------\n# CLI\n# ----------------------------------------------------------------------------\ndef get_parser():\n parser = ArgumentParser()\n parser.add_argument(\"--test\", action=\"store_true\", help=\"Load saved weights (rather than training)\")\n parser.add_argument(\"--plot-weight-distribution\", action=\"store_true\", help=\"Plot weight distribution after training\")\n return parser\n\nif __name__ == \"__main__\":\n args = get_parser().parse_args()\n\n # Set the download url for the MNIST dataset\n mnist.datasets_url = \"https://storage.googleapis.com/cvdf-datasets/mnist/\"\n # Reshape and normalise data\n images = mnist.test_images() if args.test else mnist.train_images()\n images = np.reshape(images, (images.shape[0], -1)).astype(np.float32)\n images /= np.sum(images, axis=1)[:, np.newaxis]\n labels = mnist.test_labels() if args.test else mnist.train_labels()\n\n # Create model\n model = GeNNModel(\"float\", \"mnist_mb\")\n model.dt = DT\n\n # Create neuron populations\n lif_init = {\"V\": PN_PARAMS[\"Vreset\"], \"RefracTime\": 0.0}\n if_init = {\"V\": 0.0}\n pn = model.add_neuron_population(\"pn\", NUM_PN, \"LIF\", PN_PARAMS, lif_init)\n kc = model.add_neuron_population(\"kc\", NUM_KC, \"LIF\", LIF_PARAMS, lif_init)\n ggn = model.add_neuron_population(\"ggn\", 1, if_model, GGN_PARAMS, if_init)\n mbon = model.add_neuron_population(\"mbon\", NUM_MBON, \"LIF\", LIF_PARAMS, lif_init)\n\n # Turn on spike recording\n pn.spike_recording_enabled = True\n kc.spike_recording_enabled = True\n mbon.spike_recording_enabled = True\n\n # Create current sources to deliver input to network\n pn_input = model.add_current_source(\"pn_input\", cs_model, pn , {}, {\"magnitude\": 0.0})\n\n # Create current sources to deliver input and supervision to network\n if not args.test:\n mbon_input = model.add_current_source(\"mbon_input\", cs_model, mbon , {}, {\"magnitude\": 0.0})\n\n # Create synapse populations\n pn_kc_connectivity = None if args.test else init_sparse_connectivity(\"FixedNumberPreWithReplacement\", {\"num\": PN_KC_FAN_IN})\n pn_kc = model.add_synapse_population(\"pn_kc\", \"SPARSE\",\n pn, kc,\n init_weight_update(\"StaticPulseConstantWeight\", {\"g\": PN_KC_WEIGHT}),\n init_postsynaptic(\"ExpCurr\", {\"tau\": PN_KC_TAU_SYN}),\n pn_kc_connectivity)\n\n # Load saved connectivity if testing\n if args.test:\n pn_kc_ind = np.load(\"pn_kc_ind.npy\")\n pn_kc.set_sparse_connections(pn_kc_ind[0], pn_kc_ind[1])\n\n kc_ggn = model.add_synapse_population(\"kc_ggn\", \"DENSE\",\n kc, ggn,\n init_weight_update(\"StaticPulseConstantWeight\", {\"g\": 1.0}),\n init_postsynaptic(\"DeltaCurr\"))\n\n ggn_kc = model.add_synapse_population(\"ggn_kc\", \"DENSE\",\n ggn, kc,\n init_weight_update(\"StaticPulseConstantWeight\", {\"g\": -5.0}),\n init_postsynaptic(\"ExpCurr\", {\"tau\": 5.0}))\n\n kc_mbon_weight_update = (init_weight_update(\"StaticPulse\", {}, {\"g\": np.load(\"kc_mbon_g.npy\")}) if args.test\n else init_weight_update(symmetric_stdp, KC_MBON_PARAMS, {\"g\": 0.0}))\n kc_mbon = model.add_synapse_population(\"kc_mbon\", \"DENSE\",\n kc, mbon,\n kc_mbon_weight_update,\n init_postsynaptic(\"ExpCurr\", {\"tau\": KC_MBON_TAU_SYN}))\n\n # Add custom updates to reset model state between examples\n model.add_custom_update(\"pn_reset\", \"Reset\", pn_reset,\n {\"Vreset\": LIF_PARAMS[\"Vreset\"]},\n var_refs={\"V\": create_var_ref(pn, \"V\"),\n \"RefracTime\": create_var_ref(pn, \"RefracTime\")})\n model.add_custom_update(\"kc_reset\", \"Reset\", kc_reset,\n {\"Vreset\": LIF_PARAMS[\"Vreset\"]},\n var_refs={\"V\": create_var_ref(kc, \"V\"),\n \"RefracTime\": create_var_ref(kc, \"RefracTime\"),\n \"PNOutPost\": create_out_post_var_ref(pn_kc),\n \"GGNOutPost\": create_out_post_var_ref(ggn_kc)})\n model.add_custom_update(\"ggn_reset\", \"Reset\", ggn_reset,\n {\"Vreset\": 0.0},\n var_refs={\"V\": create_var_ref(ggn, \"V\")})\n model.add_custom_update(\"mbon_reset\", \"Reset\", mbon_reset,\n {\"Vreset\": LIF_PARAMS[\"Vreset\"]},\n var_refs={\"V\": create_var_ref(mbon, \"V\"),\n \"RefracTime\": create_var_ref(mbon, \"RefracTime\"),\n \"OutPost\": create_out_post_var_ref(kc_mbon)})\n\n if not args.test:\n model.add_custom_update(\"kc_reset_st\", \"ResetST\", reset_st,\n var_refs={\"SpikeTimes\": create_spike_time_var_ref(kc)})\n model.add_custom_update(\"mbon_reset_st\", \"ResetST\", reset_st,\n var_refs={\"SpikeTimes\": create_spike_time_var_ref(mbon)})\n\n # Convert present time into timesteps\n present_timesteps = int(round(PRESENT_TIME_MS / DT))\n\n # Build model and load it\n model.build()\n model.load(num_recording_timesteps=present_timesteps)\n\n # Present images\n num_correct = 0\n for s in tqdm(range(images.shape[0])):\n # Set training image\n pn_input.vars[\"magnitude\"].view[:] = images[s] * INPUT_SCALE\n pn_input.vars[\"magnitude\"].push_to_device()\n\n # Turn on correct output neuron\n if not args.test:\n mbon_input.vars[\"magnitude\"].view[:] = 0\n mbon_input.vars[\"magnitude\"].view[labels[s]] = MBON_STIMULUS_CURRENT\n mbon_input.vars[\"magnitude\"].push_to_device()\n\n # Simulate present timesteps\n for i in range(present_timesteps):\n model.step_time()\n\n # Reset neuron state\n model.custom_update(\"Reset\")\n\n # Reset spike times\n if not args.test:\n model.custom_update(\"ResetST\")\n\n if args.test:\n # Download spikes from GPU\n model.pull_recording_buffers_from_device()\n\n # Determine the classification and count correct\n mbon_spike_times, mbon_spike_ids = mbon.spike_recording_data[0]\n if len(mbon_spike_times) > 0:\n if mbon_spike_ids[np.argmin(mbon_spike_times)] == labels[s]:\n num_correct += 1\n\n if args.test:\n print(f\"\\n{num_correct}/{images.shape[0]} correct ({(num_correct * 100.0) / images.shape[0]} %)\")\n else:\n pn_kc.pull_connectivity_from_device()\n kc_mbon.vars[\"g\"].pull_from_device()\n\n # Save weighs and connectivity\n kc_mbon_g_view = kc_mbon.vars[\"g\"].view\n np.save(\"kc_mbon_g.npy\", kc_mbon_g_view)\n np.save(\"pn_kc_ind.npy\", np.vstack((pn_kc.get_sparse_pre_inds(),\n pn_kc.get_sparse_post_inds())))\n\n # Plot weight distribution\n if args.plot_weight_distribution:\n from matplotlib import pyplot as plt\n\n fig, axis = plt.subplots(figsize=(10, 5))\n axis.hist(kc_mbon_g_view, bins=100)\n axis.axvline(np.average(kc_mbon_g_view), linestyle=\"--\")\n axis.set_xlabel(\"Weight [nA]\")\n axis.set_ylabel(\"Count\");\n plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5" } }, "nbformat": 4, "nbformat_minor": 0 }