GeNN  4.0.0
GPU enhanced Neuronal Networks (GeNN)
weightUpdateModels.h
Go to the documentation of this file.
1 #pragma once
2 
3 // GeNN includes
4 #include "models.h"
5 
6 //----------------------------------------------------------------------------
7 // Macros
8 //----------------------------------------------------------------------------
9 #define DECLARE_WEIGHT_UPDATE_MODEL(TYPE, NUM_PARAMS, NUM_VARS, NUM_PRE_VARS, NUM_POST_VARS) \
10  DECLARE_SNIPPET(TYPE, NUM_PARAMS); \
11  typedef Models::VarInitContainerBase<NUM_VARS> VarValues; \
12  typedef Models::VarInitContainerBase<NUM_PRE_VARS> PreVarValues; \
13  typedef Models::VarInitContainerBase<NUM_POST_VARS> PostVarValues
14 
15 #define SET_SIM_CODE(SIM_CODE) virtual std::string getSimCode() const override{ return SIM_CODE; }
16 #define SET_EVENT_CODE(EVENT_CODE) virtual std::string getEventCode() const override{ return EVENT_CODE; }
17 #define SET_LEARN_POST_CODE(LEARN_POST_CODE) virtual std::string getLearnPostCode() const override{ return LEARN_POST_CODE; }
18 #define SET_SYNAPSE_DYNAMICS_CODE(SYNAPSE_DYNAMICS_CODE) virtual std::string getSynapseDynamicsCode() const override{ return SYNAPSE_DYNAMICS_CODE; }
19 #define SET_EVENT_THRESHOLD_CONDITION_CODE(EVENT_THRESHOLD_CONDITION_CODE) virtual std::string getEventThresholdConditionCode() const override{ return EVENT_THRESHOLD_CONDITION_CODE; }
20 
21 #define SET_SIM_SUPPORT_CODE(SIM_SUPPORT_CODE) virtual std::string getSimSupportCode() const override{ return SIM_SUPPORT_CODE; }
22 #define SET_LEARN_POST_SUPPORT_CODE(LEARN_POST_SUPPORT_CODE) virtual std::string getLearnPostSupportCode() const override{ return LEARN_POST_SUPPORT_CODE; }
23 #define SET_SYNAPSE_DYNAMICS_SUPPORT_CODE(SYNAPSE_DYNAMICS_SUPPORT_CODE) virtual std::string getSynapseDynamicsSuppportCode() const override{ return SYNAPSE_DYNAMICS_SUPPORT_CODE; }
24 #define SET_PRE_SPIKE_CODE(PRE_SPIKE_CODE) virtual std::string getPreSpikeCode() const override{ return PRE_SPIKE_CODE; }
25 #define SET_POST_SPIKE_CODE(POST_SPIKE_CODE) virtual std::string getPostSpikeCode() const override{ return POST_SPIKE_CODE; }
26 
27 #define SET_PRE_VARS(...) virtual VarVec getPreVars() const override{ return __VA_ARGS__; }
28 #define SET_POST_VARS(...) virtual VarVec getPostVars() const override{ return __VA_ARGS__; }
29 
30 #define SET_NEEDS_PRE_SPIKE_TIME(PRE_SPIKE_TIME_REQUIRED) virtual bool isPreSpikeTimeRequired() const override{ return PRE_SPIKE_TIME_REQUIRED; }
31 #define SET_NEEDS_POST_SPIKE_TIME(POST_SPIKE_TIME_REQUIRED) virtual bool isPostSpikeTimeRequired() const override{ return POST_SPIKE_TIME_REQUIRED; }
32 
33 //----------------------------------------------------------------------------
34 // WeightUpdateModels::Base
35 //----------------------------------------------------------------------------
37 {
39 class Base : public Models::Base
40 {
41 public:
42  //----------------------------------------------------------------------------
43  // Declared virtuals
44  //----------------------------------------------------------------------------
46  virtual std::string getSimCode() const{ return ""; }
47 
49  virtual std::string getEventCode() const{ return ""; }
50 
52 
54  virtual std::string getLearnPostCode() const{ return ""; }
55 
57  virtual std::string getSynapseDynamicsCode() const{ return ""; }
58 
60  virtual std::string getEventThresholdConditionCode() const{ return ""; }
61 
63 
68  virtual std::string getSimSupportCode() const{ return ""; }
69 
71 
74  virtual std::string getLearnPostSupportCode() const{ return ""; }
75 
77 
80  virtual std::string getSynapseDynamicsSuppportCode() const{ return ""; }
81 
84 
86  virtual std::string getPreSpikeCode() const{ return ""; }
87 
90 
92  virtual std::string getPostSpikeCode() const{ return ""; }
93 
96  virtual VarVec getPreVars() const{ return {}; }
97 
100  virtual VarVec getPostVars() const{ return {}; }
101 
103  virtual bool isPreSpikeTimeRequired() const{ return false; }
104 
106  virtual bool isPostSpikeTimeRequired() const{ return false; }
107 
108  //------------------------------------------------------------------------
109  // Public methods
110  //------------------------------------------------------------------------
112  size_t getPreVarIndex(const std::string &varName) const
113  {
114  return getVarVecIndex(varName, getPreVars());
115  }
116 
118  size_t getPostVarIndex(const std::string &varName) const
119  {
120  return getVarVecIndex(varName, getPostVars());
121  }
122 
123 };
124 
125 //----------------------------------------------------------------------------
126 // WeightUpdateModels::StaticPulse
127 //----------------------------------------------------------------------------
129 
140 class StaticPulse : public Base
141 {
142 public:
144 
145  SET_VARS({{"g", "scalar"}});
146 
147  SET_SIM_CODE("$(addToInSyn, $(g));\n");
148 };
149 
150 //----------------------------------------------------------------------------
151 // WeightUpdateModels::StaticPulseDendriticDelay
152 //----------------------------------------------------------------------------
154 
167 {
168 public:
170 
171  SET_VARS({{"g", "scalar"},{"d", "uint8_t"}});
172 
173  SET_SIM_CODE("$(addToInSynDelay, $(g), $(d));\n");
174 };
175 
176 //----------------------------------------------------------------------------
177 // WeightUpdateModels::StaticGraded
178 //----------------------------------------------------------------------------
180 
202 class StaticGraded : public Base
203 {
204 public:
206 
207  SET_PARAM_NAMES({"Epre", "Vslope"});
208  SET_VARS({{"g", "scalar"}});
209 
210  SET_EVENT_CODE("$(addToInSyn, max(0.0, $(g) * tanh(($(V_pre) - $(Epre)) / $(Vslope))* DT));\n");
211 
212  SET_EVENT_THRESHOLD_CONDITION_CODE("$(V_pre) > $(Epre)");
213 };
214 
215 //----------------------------------------------------------------------------
216 // PiecewiseSTDP
217 //----------------------------------------------------------------------------
219 
272 class PiecewiseSTDP : public Base
273 {
274 public:
276 
277  SET_PARAM_NAMES({"tLrn", "tChng", "tDecay", "tPunish10", "tPunish01",
278  "gMax", "gMid", "gSlope", "tauShift", "gSyn0"});
279  SET_VARS({{"g", "scalar"}, {"gRaw", "scalar"}});
280 
281  SET_SIM_CODE(
282  "$(addToInSyn, $(g));\n"
283  "scalar dt = $(sT_post) - $(t) - ($(tauShift)); \n"
284  "scalar dg = 0;\n"
285  "if (dt > $(lim0)) \n"
286  " dg = -($(off0)) ; \n"
287  "else if (dt > 0) \n"
288  " dg = $(slope0) * dt + ($(off1)); \n"
289  "else if (dt > $(lim1)) \n"
290  " dg = $(slope1) * dt + ($(off1)); \n"
291  "else dg = - ($(off2)) ; \n"
292  "$(gRaw) += dg; \n"
293  "$(g)=$(gMax)/2 *(tanh($(gSlope)*($(gRaw) - ($(gMid))))+1); \n");
295  "scalar dt = $(t) - ($(sT_pre)) - ($(tauShift)); \n"
296  "scalar dg =0; \n"
297  "if (dt > $(lim0)) \n"
298  " dg = -($(off0)) ; \n"
299  "else if (dt > 0) \n"
300  " dg = $(slope0) * dt + ($(off1)); \n"
301  "else if (dt > $(lim1)) \n"
302  " dg = $(slope1) * dt + ($(off1)); \n"
303  "else dg = -($(off2)) ; \n"
304  "$(gRaw) += dg; \n"
305  "$(g)=$(gMax)/2.0 *(tanh($(gSlope)*($(gRaw) - ($(gMid))))+1); \n");
306 
308  {"lim0", [](const std::vector<double> &pars, double){ return (1/pars[4] + 1/pars[1]) * pars[0] / (2/pars[1]); }},
309  {"lim1", [](const std::vector<double> &pars, double){ return -((1/pars[3] + 1/pars[1]) * pars[0] / (2/pars[1])); }},
310  {"slope0", [](const std::vector<double> &pars, double){ return -2*pars[5]/(pars[1]*pars[0]); }},
311  {"slope1", [](const std::vector<double> &pars, double){ return 2*pars[5]/(pars[1]*pars[0]); }},
312  {"off0", [](const std::vector<double> &pars, double){ return pars[5] / pars[4]; }},
313  {"off1", [](const std::vector<double> &pars, double){ return pars[5] / pars[1]; }},
314  {"off2", [](const std::vector<double> &pars, double){ return pars[5] / pars[3]; }}});
315 
318 };
319 } // WeightUpdateModels
virtual std::string getLearnPostSupportCode() const
Gets support code to be made available within learnSynapsesPost kernel/function.
Definition: weightUpdateModels.h:74
virtual std::string getEventCode() const
Gets code run when events (all the instances where event threshold condition is met) are received...
Definition: weightUpdateModels.h:49
#define SET_VARS(...)
Definition: models.h:22
#define SET_SIM_CODE(SIM_CODE)
Definition: weightUpdateModels.h:15
#define SET_NEEDS_PRE_SPIKE_TIME(PRE_SPIKE_TIME_REQUIRED)
Definition: weightUpdateModels.h:30
virtual std::string getSynapseDynamicsSuppportCode() const
Gets support code to be made available within the synapse dynamics kernel/function.
Definition: weightUpdateModels.h:80
virtual VarVec getPreVars() const
Definition: weightUpdateModels.h:96
#define SET_EVENT_CODE(EVENT_CODE)
Definition: weightUpdateModels.h:16
Definition: weightUpdateModels.h:36
Base class for all models - in addition to the parameters snippets have, models can have state variab...
Definition: models.h:129
virtual VarVec getPostVars() const
Definition: weightUpdateModels.h:100
static size_t getVarVecIndex(const std::string &varName, const VarVec &vars)
Definition: snippet.h:167
#define SET_LEARN_POST_CODE(LEARN_POST_CODE)
Definition: weightUpdateModels.h:17
virtual bool isPreSpikeTimeRequired() const
Whether presynaptic spike times are needed or not.
Definition: weightUpdateModels.h:103
This is a simple STDP rule including a time delay for the finite transmission speed of the synapse...
Definition: weightUpdateModels.h:272
#define DECLARE_MODEL(TYPE, NUM_PARAMS, NUM_VARS)
Definition: models.h:14
Base class for all weight update models.
Definition: weightUpdateModels.h:39
#define SET_EVENT_THRESHOLD_CONDITION_CODE(EVENT_THRESHOLD_CONDITION_CODE)
Definition: weightUpdateModels.h:19
size_t getPreVarIndex(const std::string &varName) const
Find the index of a named presynaptic variable.
Definition: weightUpdateModels.h:112
Pulse-coupled, static synapse with heterogenous dendritic delays.
Definition: weightUpdateModels.h:166
#define SET_NEEDS_POST_SPIKE_TIME(POST_SPIKE_TIME_REQUIRED)
Definition: weightUpdateModels.h:31
virtual std::string getSimCode() const
Gets simulation code run when &#39;true&#39; spikes are received.
Definition: weightUpdateModels.h:46
#define SET_DERIVED_PARAMS(...)
Definition: snippet.h:36
virtual std::string getEventThresholdConditionCode() const
Gets codes to test for events.
Definition: weightUpdateModels.h:60
Pulse-coupled, static synapse.
Definition: weightUpdateModels.h:140
size_t getPostVarIndex(const std::string &varName) const
Find the index of a named postsynaptic variable.
Definition: weightUpdateModels.h:118
virtual std::string getLearnPostCode() const
Gets code to include in the learnSynapsesPost kernel/function.
Definition: weightUpdateModels.h:54
std::vector< Var > VarVec
Definition: snippet.h:148
virtual std::string getPostSpikeCode() const
Definition: weightUpdateModels.h:92
#define SET_PARAM_NAMES(...)
Definition: snippet.h:35
virtual std::string getSynapseDynamicsCode() const
Gets code for synapse dynamics which are independent of spike detection.
Definition: weightUpdateModels.h:57
virtual std::string getPreSpikeCode() const
Definition: weightUpdateModels.h:86
#define DECLARE_WEIGHT_UPDATE_MODEL(TYPE, NUM_PARAMS, NUM_VARS, NUM_PRE_VARS, NUM_POST_VARS)
Definition: weightUpdateModels.h:9
Graded-potential, static synapse.
Definition: weightUpdateModels.h:202
virtual std::string getSimSupportCode() const
Gets support code to be made available within the synapse kernel/function.
Definition: weightUpdateModels.h:68
virtual bool isPostSpikeTimeRequired() const
Whether postsynaptic spike times are needed or not.
Definition: weightUpdateModels.h:106