GeNN  3.3.0
GPU enhanced Neuronal Networks (GeNN)
newWeightUpdateModels.h
Go to the documentation of this file.
1 #pragma once
2 
3 // GeNN includes
4 #include "newModels.h"
5 #include "synapseModels.h"
6 
7 //----------------------------------------------------------------------------
8 // Macros
9 //----------------------------------------------------------------------------
10 #define DECLARE_WEIGHT_UPDATE_MODEL(TYPE, NUM_PARAMS, NUM_VARS, NUM_PRE_VARS, NUM_POST_VARS) \
11  DECLARE_SNIPPET(TYPE, NUM_PARAMS) \
12  typedef NewModels::VarInitContainerBase<NUM_VARS> VarValues; \
13  typedef NewModels::VarInitContainerBase<NUM_PRE_VARS> PreVarValues; \
14  typedef NewModels::VarInitContainerBase<NUM_POST_VARS> PostVarValues;
15 
16 #define SET_SIM_CODE(SIM_CODE) virtual std::string getSimCode() const override{ return SIM_CODE; }
17 #define SET_EVENT_CODE(EVENT_CODE) virtual std::string getEventCode() const override{ return EVENT_CODE; }
18 #define SET_LEARN_POST_CODE(LEARN_POST_CODE) virtual std::string getLearnPostCode() const override{ return LEARN_POST_CODE; }
19 #define SET_SYNAPSE_DYNAMICS_CODE(SYNAPSE_DYNAMICS_CODE) virtual std::string getSynapseDynamicsCode() const override{ return SYNAPSE_DYNAMICS_CODE; }
20 #define SET_EVENT_THRESHOLD_CONDITION_CODE(EVENT_THRESHOLD_CONDITION_CODE) virtual std::string getEventThresholdConditionCode() const override{ return EVENT_THRESHOLD_CONDITION_CODE; }
21 
22 #define SET_SIM_SUPPORT_CODE(SIM_SUPPORT_CODE) virtual std::string getSimSupportCode() const override{ return SIM_SUPPORT_CODE; }
23 #define SET_LEARN_POST_SUPPORT_CODE(LEARN_POST_SUPPORT_CODE) virtual std::string getLearnPostSupportCode() const override{ return LEARN_POST_SUPPORT_CODE; }
24 #define SET_SYNAPSE_DYNAMICS_SUPPORT_CODE(SYNAPSE_DYNAMICS_SUPPORT_CODE) virtual std::string getSynapseDynamicsSuppportCode() const override{ return SYNAPSE_DYNAMICS_SUPPORT_CODE; }
25 #define SET_PRE_SPIKE_CODE(PRE_SPIKE_CODE) virtual std::string getPreSpikeCode() const override{ return PRE_SPIKE_CODE; }
26 #define SET_POST_SPIKE_CODE(POST_SPIKE_CODE) virtual std::string getPostSpikeCode() const override{ return POST_SPIKE_CODE; }
27 
28 #define SET_PRE_VARS(...) virtual StringPairVec getPreVars() const override{ return __VA_ARGS__; }
29 #define SET_POST_VARS(...) virtual StringPairVec getPostVars() const override{ return __VA_ARGS__; }
30 
31 #define SET_EXTRA_GLOBAL_PARAMS(...) virtual StringPairVec getExtraGlobalParams() const override{ return __VA_ARGS__; }
32 
33 #define SET_NEEDS_PRE_SPIKE_TIME(PRE_SPIKE_TIME_REQUIRED) virtual bool isPreSpikeTimeRequired() const override{ return PRE_SPIKE_TIME_REQUIRED; }
34 #define SET_NEEDS_POST_SPIKE_TIME(POST_SPIKE_TIME_REQUIRED) virtual bool isPostSpikeTimeRequired() const override{ return POST_SPIKE_TIME_REQUIRED; }
35 
36 //----------------------------------------------------------------------------
37 // WeightUpdateModels::Base
38 //----------------------------------------------------------------------------
40 {
42 class Base : public NewModels::Base
43 {
44 public:
45  //----------------------------------------------------------------------------
46  // Declared virtuals
47  //----------------------------------------------------------------------------
49  virtual std::string getSimCode() const{ return ""; }
50 
52  virtual std::string getEventCode() const{ return ""; }
53 
55 
57  virtual std::string getLearnPostCode() const{ return ""; }
58 
60  virtual std::string getSynapseDynamicsCode() const{ return ""; }
61 
63  virtual std::string getEventThresholdConditionCode() const{ return ""; }
64 
66 
71  virtual std::string getSimSupportCode() const{ return ""; }
72 
74 
77  virtual std::string getLearnPostSupportCode() const{ return ""; }
78 
80 
83  virtual std::string getSynapseDynamicsSuppportCode() const{ return ""; }
84 
87 
89  virtual std::string getPreSpikeCode() const{ return ""; }
90 
93 
95  virtual std::string getPostSpikeCode() const{ return ""; }
96 
99  virtual StringPairVec getPreVars() const{ return {}; }
100 
103  virtual StringPairVec getPostVars() const{ return {}; }
104 
107  virtual StringPairVec getExtraGlobalParams() const{ return {}; }
108 
110  virtual bool isPreSpikeTimeRequired() const{ return false; }
111 
113  virtual bool isPostSpikeTimeRequired() const{ return false; }
114 
115  //------------------------------------------------------------------------
116  // Public methods
117  //------------------------------------------------------------------------
119  size_t getPreVarIndex(const std::string &varName) const
120  {
121  return getVarIndex(varName, getPreVars());
122  }
123 
125  size_t getPostVarIndex(const std::string &varName) const
126  {
127  return getVarIndex(varName, getPostVars());
128  }
129 
130 };
131 
132 //----------------------------------------------------------------------------
133 // WeightUpdateModels::LegacyWrapper
134 //----------------------------------------------------------------------------
136 class LegacyWrapper : public NewModels::LegacyWrapper<Base, weightUpdateModel, weightUpdateModels>
137 {
138 public:
139  LegacyWrapper(unsigned int legacyTypeIndex) : NewModels::LegacyWrapper<Base, weightUpdateModel, weightUpdateModels>(legacyTypeIndex)
140  {
141  }
142 
143  //----------------------------------------------------------------------------
144  // Base virtuals
145  //----------------------------------------------------------------------------
147  virtual std::string getSimCode() const;
148 
150  virtual std::string getEventCode() const;
151 
153  virtual std::string getLearnPostCode() const;
154 
156  virtual std::string getSynapseDynamicsCode() const;
157 
159  virtual std::string getEventThresholdConditionCode() const;
160 
162  virtual std::string getSimSupportCode() const;
163 
165  virtual std::string getLearnPostSupportCode() const;
166 
168  virtual std::string getSynapseDynamicsSuppportCode() const;
169 
172 
174  virtual bool isPreSpikeTimeRequired() const;
175 
177  virtual bool isPostSpikeTimeRequired() const;
178 };
179 
180 //----------------------------------------------------------------------------
181 // WeightUpdateModels::StaticPulse
182 //----------------------------------------------------------------------------
184 
195 class StaticPulse : public Base
196 {
197 public:
199 
200  SET_VARS({{"g", "scalar"}});
201 
202  SET_SIM_CODE("$(addToInSyn, $(g));\n");
203 };
204 
205 //----------------------------------------------------------------------------
206 // WeightUpdateModels::StaticPulseDendriticDelay
207 //----------------------------------------------------------------------------
209 
222 {
223 public:
225 
226  SET_VARS({{"g", "scalar"},{"d", "uint8_t"}});
227 
228  SET_SIM_CODE("$(addToInSynDelay, $(g), $(d));\n");
229 };
230 
231 //----------------------------------------------------------------------------
232 // WeightUpdateModels::StaticGraded
233 //----------------------------------------------------------------------------
235 
257 class StaticGraded : public Base
258 {
259 public:
261 
262  SET_PARAM_NAMES({"Epre", "Vslope"});
263  SET_VARS({{"g", "scalar"}});
264 
265  SET_EVENT_CODE("$(addToInSyn, max(0.0, $(g) * tanh(($(V_pre) - $(Epre)) / $(Vslope))* DT));\n");
266 
267  SET_EVENT_THRESHOLD_CONDITION_CODE("$(V_pre) > $(Epre)");
268 };
269 
270 //----------------------------------------------------------------------------
271 // PiecewiseSTDP
272 //----------------------------------------------------------------------------
274 
327 class PiecewiseSTDP : public Base
328 {
329 public:
331 
332  SET_PARAM_NAMES({"tLrn", "tChng", "tDecay", "tPunish10", "tPunish01",
333  "gMax", "gMid", "gSlope", "tauShift", "gSyn0"});
334  SET_VARS({{"g", "scalar"}, {"gRaw", "scalar"}});
335 
336  SET_SIM_CODE(
337  "$(addToInSyn, $(g));\n"
338  "scalar dt = $(sT_post) - $(t) - ($(tauShift)); \n"
339  "scalar dg = 0;\n"
340  "if (dt > $(lim0)) \n"
341  " dg = -($(off0)) ; \n"
342  "else if (dt > 0) \n"
343  " dg = $(slope0) * dt + ($(off1)); \n"
344  "else if (dt > $(lim1)) \n"
345  " dg = $(slope1) * dt + ($(off1)); \n"
346  "else dg = - ($(off2)) ; \n"
347  "$(gRaw) += dg; \n"
348  "$(g)=$(gMax)/2 *(tanh($(gSlope)*($(gRaw) - ($(gMid))))+1); \n");
350  "scalar dt = $(t) - ($(sT_pre)) - ($(tauShift)); \n"
351  "scalar dg =0; \n"
352  "if (dt > $(lim0)) \n"
353  " dg = -($(off0)) ; \n"
354  "else if (dt > 0) \n"
355  " dg = $(slope0) * dt + ($(off1)); \n"
356  "else if (dt > $(lim1)) \n"
357  " dg = $(slope1) * dt + ($(off1)); \n"
358  "else dg = -($(off2)) ; \n"
359  "$(gRaw) += dg; \n"
360  "$(g)=$(gMax)/2.0 *(tanh($(gSlope)*($(gRaw) - ($(gMid))))+1); \n");
361 
363  {"lim0", [](const vector<double> &pars, double){ return (1/pars[4] + 1/pars[1]) * pars[0] / (2/pars[1]); }},
364  {"lim1", [](const vector<double> &pars, double){ return -((1/pars[3] + 1/pars[1]) * pars[0] / (2/pars[1])); }},
365  {"slope0", [](const vector<double> &pars, double){ return -2*pars[5]/(pars[1]*pars[0]); }},
366  {"slope1", [](const vector<double> &pars, double){ return 2*pars[5]/(pars[1]*pars[0]); }},
367  {"off0", [](const vector<double> &pars, double){ return pars[5] / pars[4]; }},
368  {"off1", [](const vector<double> &pars, double){ return pars[5] / pars[1]; }},
369  {"off2", [](const vector<double> &pars, double){ return pars[5] / pars[3]; }}});
370 
373 };
374 } // WeightUpdateModels
virtual StringPairVec getExtraGlobalParams() const
Definition: newWeightUpdateModels.h:107
virtual std::string getLearnPostSupportCode() const
Gets support code to be made available within learnSynapsesPost kernel/function.
Definition: newWeightUpdateModels.h:77
virtual std::string getEventCode() const
Gets code run when events (all the instances where event threshold condition is met) are received...
Definition: newWeightUpdateModels.h:52
virtual std::string getSynapseDynamicsSuppportCode() const
Gets support code to be made available within the synapse dynamics kernel/function.
Definition: newWeightUpdateModels.h:83
#define DECLARE_WEIGHT_UPDATE_MODEL(TYPE, NUM_PARAMS, NUM_VARS, NUM_PRE_VARS, NUM_POST_VARS)
Definition: newWeightUpdateModels.h:10
Definition: newWeightUpdateModels.h:39
virtual bool isPreSpikeTimeRequired() const
Whether presynaptic spike times are needed or not.
Definition: newWeightUpdateModels.h:110
Wrapper around old-style models stored in global arrays and referenced by index.
Definition: newModels.h:170
This is a simple STDP rule including a time delay for the finite transmission speed of the synapse...
Definition: newWeightUpdateModels.h:327
Base class for all weight update models.
Definition: newWeightUpdateModels.h:42
#define SET_SIM_CODE(SIM_CODE)
Definition: newWeightUpdateModels.h:16
size_t getPreVarIndex(const std::string &varName) const
Find the index of a named presynaptic variable.
Definition: newWeightUpdateModels.h:119
Pulse-coupled, static synapse with heterogenous dendritic delays.
Definition: newWeightUpdateModels.h:221
#define SET_EVENT_CODE(EVENT_CODE)
Definition: newWeightUpdateModels.h:17
virtual StringPairVec getPostVars() const
Definition: newWeightUpdateModels.h:103
Base class for all models - in addition to the parameters snippets have, models can have state variab...
Definition: newModels.h:132
#define SET_NEEDS_PRE_SPIKE_TIME(PRE_SPIKE_TIME_REQUIRED)
Definition: newWeightUpdateModels.h:33
virtual std::string getSimCode() const
Gets simulation code run when &#39;true&#39; spikes are received.
Definition: newWeightUpdateModels.h:49
#define SET_DERIVED_PARAMS(...)
Definition: snippet.h:29
virtual std::string getEventThresholdConditionCode() const
Gets codes to test for events.
Definition: newWeightUpdateModels.h:63
Class to hold the information that defines a weightupdate model (a model of how spikes affect synapti...
Definition: synapseModels.h:16
Pulse-coupled, static synapse.
Definition: newWeightUpdateModels.h:195
#define DECLARE_MODEL(TYPE, NUM_PARAMS, NUM_VARS)
Definition: newModels.h:18
size_t getPostVarIndex(const std::string &varName) const
Find the index of a named postsynaptic variable.
Definition: newWeightUpdateModels.h:125
virtual std::string getLearnPostCode() const
Gets code to include in the learnSynapsesPost kernel/function.
Definition: newWeightUpdateModels.h:57
virtual StringPairVec getPreVars() const
Definition: newWeightUpdateModels.h:99
LegacyWrapper(unsigned int legacyTypeIndex)
Definition: newWeightUpdateModels.h:139
Wrapper around legacy weight update models stored in weightUpdateModels array of weightUpdateModel ob...
Definition: newWeightUpdateModels.h:136
#define SET_EVENT_THRESHOLD_CONDITION_CODE(EVENT_THRESHOLD_CONDITION_CODE)
Definition: newWeightUpdateModels.h:20
virtual std::string getPostSpikeCode() const
Definition: newWeightUpdateModels.h:95
#define SET_PARAM_NAMES(...)
Definition: snippet.h:28
virtual std::string getSynapseDynamicsCode() const
Gets code for synapse dynamics which are independent of spike detection.
Definition: newWeightUpdateModels.h:60
vector< weightUpdateModel > weightUpdateModels
Global C++ vector containing all weightupdate model descriptions.
Definition: synapseModels.cc:31
#define SET_VARS(...)
Definition: newModels.h:26
#define SET_LEARN_POST_CODE(LEARN_POST_CODE)
Definition: newWeightUpdateModels.h:18
std::vector< std::pair< std::string, std::string > > StringPairVec
Definition: snippet.h:117
Definition: codeGenUtils.h:24
size_t getVarIndex(const std::string &varName) const
Find the index of a named variable.
Definition: newModels.h:145
virtual std::string getPreSpikeCode() const
Definition: newWeightUpdateModels.h:89
#define SET_NEEDS_POST_SPIKE_TIME(POST_SPIKE_TIME_REQUIRED)
Definition: newWeightUpdateModels.h:34
Graded-potential, static synapse.
Definition: newWeightUpdateModels.h:257
virtual std::string getSimSupportCode() const
Gets support code to be made available within the synapse kernel/function.
Definition: newWeightUpdateModels.h:71
virtual bool isPostSpikeTimeRequired() const
Whether postsynaptic spike times are needed or not.
Definition: newWeightUpdateModels.h:113