GeNN  4.9.0
GPU enhanced Neuronal Networks (GeNN)
weightUpdateModels.h
Go to the documentation of this file.
1 #pragma once
2 
3 // GeNN includes
4 #include "models.h"
5 
6 //----------------------------------------------------------------------------
7 // Macros
8 //----------------------------------------------------------------------------
9 #define DECLARE_WEIGHT_UPDATE_MODEL(TYPE, NUM_PARAMS, NUM_VARS, NUM_PRE_VARS, NUM_POST_VARS) \
10  DECLARE_SNIPPET(TYPE, NUM_PARAMS); \
11  typedef Models::VarInitContainerBase<NUM_VARS> VarValues; \
12  typedef Models::VarInitContainerBase<NUM_PRE_VARS> PreVarValues; \
13  typedef Models::VarInitContainerBase<NUM_POST_VARS> PostVarValues
14 
15 #define SET_SIM_CODE(SIM_CODE) virtual std::string getSimCode() const override{ return SIM_CODE; }
16 #define SET_EVENT_CODE(EVENT_CODE) virtual std::string getEventCode() const override{ return EVENT_CODE; }
17 #define SET_LEARN_POST_CODE(LEARN_POST_CODE) virtual std::string getLearnPostCode() const override{ return LEARN_POST_CODE; }
18 #define SET_SYNAPSE_DYNAMICS_CODE(SYNAPSE_DYNAMICS_CODE) virtual std::string getSynapseDynamicsCode() const override{ return SYNAPSE_DYNAMICS_CODE; }
19 #define SET_EVENT_THRESHOLD_CONDITION_CODE(EVENT_THRESHOLD_CONDITION_CODE) virtual std::string getEventThresholdConditionCode() const override{ return EVENT_THRESHOLD_CONDITION_CODE; }
20 
21 #define SET_SIM_SUPPORT_CODE(SIM_SUPPORT_CODE) virtual std::string getSimSupportCode() const override{ return SIM_SUPPORT_CODE; }
22 #define SET_LEARN_POST_SUPPORT_CODE(LEARN_POST_SUPPORT_CODE) virtual std::string getLearnPostSupportCode() const override{ return LEARN_POST_SUPPORT_CODE; }
23 #define SET_SYNAPSE_DYNAMICS_SUPPORT_CODE(SYNAPSE_DYNAMICS_SUPPORT_CODE) virtual std::string getSynapseDynamicsSuppportCode() const override{ return SYNAPSE_DYNAMICS_SUPPORT_CODE; }
24 #define SET_PRE_SPIKE_CODE(PRE_SPIKE_CODE) virtual std::string getPreSpikeCode() const override{ return PRE_SPIKE_CODE; }
25 #define SET_POST_SPIKE_CODE(POST_SPIKE_CODE) virtual std::string getPostSpikeCode() const override{ return POST_SPIKE_CODE; }
26 #define SET_PRE_DYNAMICS_CODE(PRE_DYNAMICS_CODE) virtual std::string getPreDynamicsCode() const override{ return PRE_DYNAMICS_CODE; }
27 #define SET_POST_DYNAMICS_CODE(POST_DYNAMICS_CODE) virtual std::string getPostDynamicsCode() const override{ return POST_DYNAMICS_CODE; }
28 
29 #define SET_PRE_VARS(...) virtual VarVec getPreVars() const override{ return __VA_ARGS__; }
30 #define SET_POST_VARS(...) virtual VarVec getPostVars() const override{ return __VA_ARGS__; }
31 
32 #define SET_NEEDS_PRE_SPIKE_TIME(PRE_SPIKE_TIME_REQUIRED) virtual bool isPreSpikeTimeRequired() const override{ return PRE_SPIKE_TIME_REQUIRED; }
33 #define SET_NEEDS_POST_SPIKE_TIME(POST_SPIKE_TIME_REQUIRED) virtual bool isPostSpikeTimeRequired() const override{ return POST_SPIKE_TIME_REQUIRED; }
34 #define SET_NEEDS_PRE_SPIKE_EVENT_TIME(PRE_SPIKE_EVENT_TIME_REQUIRED) virtual bool isPreSpikeEventTimeRequired() const override{ return PRE_SPIKE_EVENT_TIME_REQUIRED; }
35 
36 #define SET_NEEDS_PREV_PRE_SPIKE_TIME(PREV_PRE_SPIKE_TIME_REQUIRED) virtual bool isPrevPreSpikeTimeRequired() const override{ return PREV_PRE_SPIKE_TIME_REQUIRED; }
37 #define SET_NEEDS_PREV_POST_SPIKE_TIME(PREV_POST_SPIKE_TIME_REQUIRED) virtual bool isPrevPostSpikeTimeRequired() const override{ return PREV_POST_SPIKE_TIME_REQUIRED; }
38 #define SET_NEEDS_PREV_PRE_SPIKE_EVENT_TIME(PREV_PRE_SPIKE_EVENT_TIME_REQUIRED) virtual bool isPrevPreSpikeEventTimeRequired() const override{ return PREV_PRE_SPIKE_EVENT_TIME_REQUIRED; }
39 
40 //----------------------------------------------------------------------------
41 // WeightUpdateModels::Base
42 //----------------------------------------------------------------------------
44 {
47 {
48 public:
49  //----------------------------------------------------------------------------
50  // Declared virtuals
51  //----------------------------------------------------------------------------
53  virtual std::string getSimCode() const{ return ""; }
54 
56  virtual std::string getEventCode() const{ return ""; }
57 
59 
61  virtual std::string getLearnPostCode() const{ return ""; }
62 
64  virtual std::string getSynapseDynamicsCode() const{ return ""; }
65 
67  virtual std::string getEventThresholdConditionCode() const{ return ""; }
68 
70 
75  virtual std::string getSimSupportCode() const{ return ""; }
76 
78 
81  virtual std::string getLearnPostSupportCode() const{ return ""; }
82 
84 
87  virtual std::string getSynapseDynamicsSuppportCode() const{ return ""; }
88 
91 
93  virtual std::string getPreSpikeCode() const{ return ""; }
94 
97 
99  virtual std::string getPostSpikeCode() const{ return ""; }
100 
102 
104  virtual std::string getPreDynamicsCode() const{ return ""; }
105 
107 
109  virtual std::string getPostDynamicsCode() const{ return ""; }
110 
113  virtual VarVec getPreVars() const{ return {}; }
114 
117  virtual VarVec getPostVars() const{ return {}; }
118 
120  virtual bool isPreSpikeTimeRequired() const{ return false; }
121 
123  virtual bool isPostSpikeTimeRequired() const{ return false; }
124 
126  virtual bool isPreSpikeEventTimeRequired() const { return false; }
127 
129  virtual bool isPrevPreSpikeTimeRequired() const{ return false; }
130 
132  virtual bool isPrevPostSpikeTimeRequired() const{ return false; }
133 
135  virtual bool isPrevPreSpikeEventTimeRequired() const { return false; }
136 
137  //------------------------------------------------------------------------
138  // Public methods
139  //------------------------------------------------------------------------
141  size_t getPreVarIndex(const std::string &varName) const
142  {
143  return getNamedVecIndex(varName, getPreVars());
144  }
145 
147  size_t getPostVarIndex(const std::string &varName) const
148  {
149  return getNamedVecIndex(varName, getPostVars());
150  }
151 
153  boost::uuids::detail::sha1::digest_type getHashDigest() const;
154 
156  void validate() const;
157 };
158 
159 //----------------------------------------------------------------------------
160 // WeightUpdateModels::StaticPulse
161 //----------------------------------------------------------------------------
163 
174 class StaticPulse : public Base
175 {
176 public:
178 
179  SET_VARS({{"g", "scalar", VarAccess::READ_ONLY}});
180 
181  SET_SIM_CODE("$(addToInSyn, $(g));\n");
182 };
183 
184 //----------------------------------------------------------------------------
185 // WeightUpdateModels::StaticPulseDendriticDelay
186 //----------------------------------------------------------------------------
188 
201 {
202 public:
204 
205  SET_VARS({{"g", "scalar", VarAccess::READ_ONLY}, {"d", "uint8_t", VarAccess::READ_ONLY}});
206 
207  SET_SIM_CODE("$(addToInSynDelay, $(g), $(d));\n");
208 };
209 
210 //----------------------------------------------------------------------------
211 // WeightUpdateModels::StaticGraded
212 //----------------------------------------------------------------------------
214 
236 class StaticGraded : public Base
237 {
238 public:
240 
241  SET_PARAM_NAMES({"Epre", "Vslope"});
242  SET_VARS({{"g", "scalar", VarAccess::READ_ONLY}});
243 
244  SET_EVENT_CODE("$(addToInSyn, fmax(0.0, $(g) * tanh(($(V_pre) - $(Epre)) / $(Vslope))* DT));\n");
245 
246  SET_EVENT_THRESHOLD_CONDITION_CODE("$(V_pre) > $(Epre)");
247 };
248 
249 //----------------------------------------------------------------------------
250 // PiecewiseSTDP
251 //----------------------------------------------------------------------------
253 
306 class PiecewiseSTDP : public Base
307 {
308 public:
310 
311  SET_PARAM_NAMES({"tLrn", "tChng", "tDecay", "tPunish10", "tPunish01",
312  "gMax", "gMid", "gSlope", "tauShift", "gSyn0"});
313  SET_VARS({{"g", "scalar"}, {"gRaw", "scalar"}});
314 
315  SET_SIM_CODE(
316  "$(addToInSyn, $(g));\n"
317  "scalar dt = $(sT_post) - $(t) - ($(tauShift)); \n"
318  "scalar dg = 0;\n"
319  "if (dt > $(lim0)) \n"
320  " dg = -($(off0)) ; \n"
321  "else if (dt > 0) \n"
322  " dg = $(slope0) * dt + ($(off1)); \n"
323  "else if (dt > $(lim1)) \n"
324  " dg = $(slope1) * dt + ($(off1)); \n"
325  "else dg = - ($(off2)) ; \n"
326  "$(gRaw) += dg; \n"
327  "$(g)=$(gMax)/2 *(tanh($(gSlope)*($(gRaw) - ($(gMid))))+1); \n");
329  "scalar dt = $(t) - ($(sT_pre)) - ($(tauShift)); \n"
330  "scalar dg =0; \n"
331  "if (dt > $(lim0)) \n"
332  " dg = -($(off0)) ; \n"
333  "else if (dt > 0) \n"
334  " dg = $(slope0) * dt + ($(off1)); \n"
335  "else if (dt > $(lim1)) \n"
336  " dg = $(slope1) * dt + ($(off1)); \n"
337  "else dg = -($(off2)) ; \n"
338  "$(gRaw) += dg; \n"
339  "$(g)=$(gMax)/2.0 *(tanh($(gSlope)*($(gRaw) - ($(gMid))))+1); \n");
340 
342  {"lim0", [](const std::vector<double> &pars, double){ return (1/pars[4] + 1/pars[1]) * pars[0] / (2/pars[1]); }},
343  {"lim1", [](const std::vector<double> &pars, double){ return -((1/pars[3] + 1/pars[1]) * pars[0] / (2/pars[1])); }},
344  {"slope0", [](const std::vector<double> &pars, double){ return -2*pars[5]/(pars[1]*pars[0]); }},
345  {"slope1", [](const std::vector<double> &pars, double){ return 2*pars[5]/(pars[1]*pars[0]); }},
346  {"off0", [](const std::vector<double> &pars, double){ return pars[5] / pars[4]; }},
347  {"off1", [](const std::vector<double> &pars, double){ return pars[5] / pars[1]; }},
348  {"off2", [](const std::vector<double> &pars, double){ return pars[5] / pars[3]; }}});
349 
352 };
353 } // WeightUpdateModels
virtual std::string getPostDynamicsCode() const
Gets code to be run after postsynaptic neuron update.
Definition: weightUpdateModels.h:109
virtual bool isPrevPreSpikeTimeRequired() const
Whether PREVIOUS presynaptic spike times are needed or not.
Definition: weightUpdateModels.h:129
virtual std::string getLearnPostSupportCode() const
Gets support code to be made available within learnSynapsesPost kernel/function.
Definition: weightUpdateModels.h:81
virtual std::string getEventCode() const
Gets code run when events (all the instances where event threshold condition is met) are received...
Definition: weightUpdateModels.h:56
#define SET_VARS(...)
Definition: models.h:39
#define SET_SIM_CODE(SIM_CODE)
Definition: weightUpdateModels.h:15
#define SET_NEEDS_PRE_SPIKE_TIME(PRE_SPIKE_TIME_REQUIRED)
Definition: weightUpdateModels.h:32
virtual std::string getSynapseDynamicsSuppportCode() const
Gets support code to be made available within the synapse dynamics kernel/function.
Definition: weightUpdateModels.h:87
virtual VarVec getPreVars() const
Definition: weightUpdateModels.h:113
std::vector< Var > VarVec
Definition: models.h:115
virtual std::string getPreDynamicsCode() const
Gets code to be run after presynaptic neuron update.
Definition: weightUpdateModels.h:104
#define SET_EVENT_CODE(EVENT_CODE)
Definition: weightUpdateModels.h:16
Definition: weightUpdateModels.h:43
virtual bool isPrevPreSpikeEventTimeRequired() const
Whether PREVIOUS presynaptic spike-like-event times are needed or not.
Definition: weightUpdateModels.h:135
Definition: models.h:48
#define GENN_EXPORT
Definition: gennExport.h:13
virtual VarVec getPostVars() const
Definition: weightUpdateModels.h:117
#define SET_LEARN_POST_CODE(LEARN_POST_CODE)
Definition: weightUpdateModels.h:17
virtual bool isPreSpikeTimeRequired() const
Whether presynaptic spike times are needed or not.
Definition: weightUpdateModels.h:120
This is a simple STDP rule including a time delay for the finite transmission speed of the synapse...
Definition: weightUpdateModels.h:306
virtual bool isPrevPostSpikeTimeRequired() const
Whether PREVIOUS postsynaptic spike times are needed or not.
Definition: weightUpdateModels.h:132
virtual bool isPreSpikeEventTimeRequired() const
Whether presynaptic spike-like-event times are needed or not.
Definition: weightUpdateModels.h:126
Base class for all weight update models.
Definition: weightUpdateModels.h:46
#define SET_EVENT_THRESHOLD_CONDITION_CODE(EVENT_THRESHOLD_CONDITION_CODE)
Definition: weightUpdateModels.h:19
size_t getPreVarIndex(const std::string &varName) const
Find the index of a named presynaptic variable.
Definition: weightUpdateModels.h:141
Pulse-coupled, static synapse with heterogenous dendritic delays.
Definition: weightUpdateModels.h:200
#define SET_NEEDS_POST_SPIKE_TIME(POST_SPIKE_TIME_REQUIRED)
Definition: weightUpdateModels.h:33
virtual std::string getSimCode() const
Gets simulation code run when &#39;true&#39; spikes are received.
Definition: weightUpdateModels.h:53
#define SET_DERIVED_PARAMS(...)
Definition: snippet.h:37
virtual std::string getEventThresholdConditionCode() const
Gets codes to test for events.
Definition: weightUpdateModels.h:67
Pulse-coupled, static synapse.
Definition: weightUpdateModels.h:174
size_t getPostVarIndex(const std::string &varName) const
Find the index of a named postsynaptic variable.
Definition: weightUpdateModels.h:147
virtual std::string getLearnPostCode() const
Gets code to include in the learnSynapsesPost kernel/function.
Definition: weightUpdateModels.h:61
virtual std::string getPostSpikeCode() const
Definition: weightUpdateModels.h:99
#define SET_PARAM_NAMES(...)
Definition: snippet.h:36
virtual std::string getSynapseDynamicsCode() const
Gets code for synapse dynamics which are independent of spike detection.
Definition: weightUpdateModels.h:64
virtual std::string getPreSpikeCode() const
Definition: weightUpdateModels.h:93
#define DECLARE_WEIGHT_UPDATE_MODEL(TYPE, NUM_PARAMS, NUM_VARS, NUM_PRE_VARS, NUM_POST_VARS)
Definition: weightUpdateModels.h:9
Graded-potential, static synapse.
Definition: weightUpdateModels.h:236
virtual std::string getSimSupportCode() const
Gets support code to be made available within the synapse kernel/function.
Definition: weightUpdateModels.h:75
virtual bool isPostSpikeTimeRequired() const
Whether postsynaptic spike times are needed or not.
Definition: weightUpdateModels.h:123