GeNN  4.9.0
GPU enhanced Neuronal Networks (GeNN)
models.h
Go to the documentation of this file.
1 #pragma once
2 
3 // Standard C++ includes
4 #include <algorithm>
5 #include <functional>
6 #include <string>
7 #include <vector>
8 
9 // GeNN includes
10 #include "snippet.h"
11 #include "initVarSnippet.h"
12 #include "varAccess.h"
13 
14 // Forward declarations
15 class CustomUpdate;
16 class CustomUpdateWU;
17 class NeuronGroup;
18 class SynapseGroup;
19 class CurrentSource;
23 namespace CodeGenerator
24 {
25 class BackendBase;
26 }
27 
28 //----------------------------------------------------------------------------
29 // Macros
30 //----------------------------------------------------------------------------
31 #define DECLARE_MODEL(TYPE, NUM_PARAMS, NUM_VARS) \
32  DECLARE_SNIPPET(TYPE, NUM_PARAMS); \
33  typedef Models::VarInitContainerBase<NUM_VARS> VarValues; \
34  typedef Models::VarInitContainerBase<0> PreVarValues; \
35  typedef Models::VarInitContainerBase<0> PostVarValues
36 
37 #define IMPLEMENT_MODEL(TYPE) IMPLEMENT_SNIPPET(TYPE)
38 
39 #define SET_VARS(...) virtual VarVec getVars() const override{ return __VA_ARGS__; }
40 
41 
42 //----------------------------------------------------------------------------
43 // Models::Base
44 //----------------------------------------------------------------------------
46 namespace Models
47 {
49 {
50 public:
51  //----------------------------------------------------------------------------
52  // Structs
53  //----------------------------------------------------------------------------
55 
58  struct Var
59  {
60  Var(const std::string &n, const std::string &t, VarAccess a) : name(n), type(t), access(a)
61  {}
62  Var(const std::string &n, const std::string &t) : Var(n, t, VarAccess::READ_WRITE)
63  {}
64  Var() : Var("", "", VarAccess::READ_WRITE)
65  {}
66 
67  bool operator == (const Var &other) const
68  {
69  return ((name == other.name) && (type == other.type) && (access == other.access));
70  }
71 
72  std::string name;
73  std::string type;
75  };
76 
77  struct VarRef
78  {
79  VarRef(const std::string &n, const std::string &t, VarAccessMode a) : name(n), type(t), access(a)
80  {}
81  VarRef(const std::string &n, const std::string &t) : VarRef(n, t, VarAccessMode::READ_WRITE)
82  {}
84  {}
85 
86  bool operator == (const VarRef &other) const
87  {
88  return ((name == other.name) && (type == other.type) && (access == other.access));
89  }
90 
91  std::string name;
92  std::string type;
94  };
95 
96  struct EGPRef
97  {
98  EGPRef(const std::string &n, const std::string &t) : name(n), type(t)
99  {}
100  EGPRef() : EGPRef("", "")
101  {}
102 
103  bool operator == (const EGPRef &other) const
104  {
105  return ((name == other.name) && (type == other.type));
106  }
107 
108  std::string name;
109  std::string type;
110  };
111 
112  //----------------------------------------------------------------------------
113  // Typedefines
114  //----------------------------------------------------------------------------
115  typedef std::vector<Var> VarVec;
116  typedef std::vector<VarRef> VarRefVec;
117  typedef std::vector<EGPRef> EGPRefVec;
118 
119  //----------------------------------------------------------------------------
120  // Declared virtuals
121  //------------------------------------------------------------------------
123  virtual VarVec getVars() const{ return {}; }
124 
125  //------------------------------------------------------------------------
126  // Public methods
127  //------------------------------------------------------------------------
129  size_t getVarIndex(const std::string &varName) const
130  {
131  return getNamedVecIndex(varName, getVars());
132  }
133 
134 protected:
135  //------------------------------------------------------------------------
136  // Protected methods
137  //------------------------------------------------------------------------
138  void updateHash(boost::uuids::detail::sha1 &hash) const;
139 
141  void validate() const;
142 };
143 
144 
145 //----------------------------------------------------------------------------
146 // Models::VarInit
147 //----------------------------------------------------------------------------
151 class VarInit : public Snippet::Init<InitVarSnippet::Base>
152 {
153 public:
154  VarInit(const InitVarSnippet::Base *snippet, const std::vector<double> &params)
155  : Snippet::Init<InitVarSnippet::Base>(snippet, params)
156  {
157  }
158 
159  VarInit(double constant)
160  : Snippet::Init<InitVarSnippet::Base>(InitVarSnippet::Constant::getInstance(), {constant})
161  {
162  }
163 };
164 
165 //----------------------------------------------------------------------------
166 // Models::VarInitContainerBase
167 //----------------------------------------------------------------------------
168 template<size_t NumVars>
170 
171 //----------------------------------------------------------------------------
172 // Models::VarReferenceBase
173 //----------------------------------------------------------------------------
175 {
176 public:
177  //------------------------------------------------------------------------
178  // Public API
179  //------------------------------------------------------------------------
180  const Models::Base::Var &getVar() const { return m_Var; }
181  size_t getVarIndex() const { return m_VarIndex; }
182  std::string getTargetName() const { return m_GetTargetName(); }
183  bool isBatched() const{ return m_IsBatched(); }
184 
185 protected:
186  //------------------------------------------------------------------------
187  // Typedefines
188  //------------------------------------------------------------------------
189  typedef std::function<std::string(void)> GetTargetNameFn;
190  typedef std::function<bool(void)> IsBatchedFn;
191 
192  VarReferenceBase(size_t varIndex, const Models::Base::VarVec &varVec,
193  GetTargetNameFn getTargetName, IsBatchedFn isBatched)
194  : m_VarIndex(varIndex), m_Var(varVec.at(varIndex)), m_GetTargetName(getTargetName), m_IsBatched(isBatched)
195  {}
196 
197 private:
198  //------------------------------------------------------------------------
199  // Members
200  //------------------------------------------------------------------------
201  size_t m_VarIndex;
202  Models::Base::Var m_Var;
203  GetTargetNameFn m_GetTargetName;
204  IsBatchedFn m_IsBatched;
205 };
206 
207 //----------------------------------------------------------------------------
208 // Models::VarReference
209 //----------------------------------------------------------------------------
211 {
212 public:
213  //------------------------------------------------------------------------
214  // Public API
215  //------------------------------------------------------------------------
216  unsigned int getSize() const { return m_Size; }
217  const NeuronGroup *getDelayNeuronGroup() const { return m_GetDelayNeuronGroup(); }
218 
219  //------------------------------------------------------------------------
220  // Static API
221  //------------------------------------------------------------------------
222  static VarReference createVarRef(const NeuronGroup *ng, const std::string &varName);
223  static VarReference createVarRef(const CurrentSource *cs, const std::string &varName);
224  static VarReference createVarRef(const CustomUpdate *su, const std::string &varName);
225  static VarReference createPSMVarRef(const SynapseGroup *sg, const std::string &varName);
226  static VarReference createWUPreVarRef(const SynapseGroup *sg, const std::string &varName);
227  static VarReference createWUPostVarRef(const SynapseGroup *sg, const std::string &varName);
228 
229 private:
230  //------------------------------------------------------------------------
231  // Typedefines
232  //------------------------------------------------------------------------
233  typedef std::function<const NeuronGroup*(void)> GetDelayNeuronGroupFn;
234 
235  VarReference(const NeuronGroupInternal *ng, const std::string &varName);
236  VarReference(const CurrentSourceInternal *cs, const std::string &varName);
237  VarReference(const CustomUpdate *cu, const std::string &varName);
238  VarReference(unsigned int size, GetDelayNeuronGroupFn getDelayNeuronGroup,
239  size_t varIndex, const Models::Base::VarVec &varVec,
240  GetTargetNameFn getTargetName, IsBatchedFn isBatched);
241 
242  //------------------------------------------------------------------------
243  // Members
244  //------------------------------------------------------------------------
245  unsigned int m_Size;
246  GetDelayNeuronGroupFn m_GetDelayNeuronGroup;
247 };
248 
249 //----------------------------------------------------------------------------
250 // Models::VarReferenceContainerBase
251 //----------------------------------------------------------------------------
252 template<size_t NumVars>
254 
255 //----------------------------------------------------------------------------
256 // Models::WUVarReference
257 //----------------------------------------------------------------------------
259 {
260 public:
261  WUVarReference(const SynapseGroup *sg, const std::string &varName,
262  const SynapseGroup *transposeSG = nullptr, const std::string &transposeVarName = "");
263  WUVarReference(const CustomUpdateWU *cu, const std::string &varName);
264 
265  //------------------------------------------------------------------------
266  // Public API
267  //------------------------------------------------------------------------
268  const SynapseGroup *getSynapseGroup() const;
269 
270  const SynapseGroup *getTransposeSynapseGroup() const;
271  const Models::Base::Var &getTransposeVar() const { return m_TransposeVar; }
272  size_t getTransposeVarIndex() const { return m_TransposeVarIndex; }
273  std::string getTransposeTargetName() const { return m_GetTransposeTargetName(); }
274 
275 private:
276  //------------------------------------------------------------------------
277  // Members
278  //------------------------------------------------------------------------
279  const SynapseGroupInternal *m_SG;
280  const SynapseGroupInternal *m_TransposeSG;
281  size_t m_TransposeVarIndex;
282  Models::Base::Var m_TransposeVar;
283  GetTargetNameFn m_GetTransposeTargetName;
284 };
285 
286 //----------------------------------------------------------------------------
287 // Models::WUVarReferenceContainerBase
288 //----------------------------------------------------------------------------
289 template<size_t NumVars>
291 
292 //----------------------------------------------------------------------------
293 // Models::EGPReference
294 //----------------------------------------------------------------------------
296 {
297 public:
298  //------------------------------------------------------------------------
299  // Public API
300  //------------------------------------------------------------------------
301  const Models::Base::EGP &getEGP() const { return m_EGP; }
302  size_t getEGPIndex() const { return m_EGPIndex; }
303  std::string getTargetName() const { return m_TargetName; }
304 
305  //------------------------------------------------------------------------
306  // Static API
307  //------------------------------------------------------------------------
308  static EGPReference createEGPRef(const NeuronGroup *ng, const std::string &egpName);
309  static EGPReference createEGPRef(const CurrentSource *cs, const std::string &egpName);
310  static EGPReference createEGPRef(const CustomUpdate *cu, const std::string &egpName);
311  static EGPReference createEGPRef(const CustomUpdateWU *cu, const std::string &egpName);
312  static EGPReference createPSMEGPRef(const SynapseGroup *sg, const std::string &egpName);
313  static EGPReference createWUEGPRef(const SynapseGroup *sg, const std::string &egpName);
314 
315 private:
316  EGPReference(size_t egpIndex, const Models::Base::EGPVec &egpVec,
317  const std::string &targetName)
318  : m_EGPIndex(egpIndex), m_EGP(egpVec.at(egpIndex)), m_TargetName(targetName)
319  {}
320  //------------------------------------------------------------------------
321  // Members
322  //------------------------------------------------------------------------
323  size_t m_EGPIndex;
324  Models::Base::EGP m_EGP;
325  std::string m_TargetName;
326 };
327 
328 //----------------------------------------------------------------------------
329 // Models::EGPReferenceContainerBase
330 //----------------------------------------------------------------------------
331 template<size_t NumEGPs>
333 
334 //----------------------------------------------------------------------------
335 // updateHash overrides
336 //----------------------------------------------------------------------------
337 GENN_EXPORT void updateHash(const Base::Var &v, boost::uuids::detail::sha1 &hash);
338 GENN_EXPORT void updateHash(const Base::VarRef &v, boost::uuids::detail::sha1 &hash);
339 GENN_EXPORT void updateHash(const Base::EGPRef &e, boost::uuids::detail::sha1 &hash);
340 GENN_EXPORT void updateHash(const VarReference &v, boost::uuids::detail::sha1 &hash);
341 GENN_EXPORT void updateHash(const WUVarReference &v, boost::uuids::detail::sha1 &hash);
342 GENN_EXPORT void updateHash(const EGPReference &v, boost::uuids::detail::sha1 &hash);
343 } // Models
std::string type
Definition: models.h:73
Definition: neuronGroup.h:21
Definition: neuronGroupInternal.h:9
VarAccessMode
Supported combination of VarAccessModeAttribute.
Definition: varAccess.h:19
std::vector< EGPRef > EGPRefVec
Definition: models.h:117
std::string getTargetName() const
Definition: models.h:182
const Models::Base::Var & getTransposeVar() const
Definition: models.h:271
std::vector< Var > VarVec
Definition: models.h:115
This variable is read only.
Definition: models.h:77
Definition: customUpdate.h:199
VarInit(const InitVarSnippet::Base *snippet, const std::vector< double > &params)
Definition: models.h:154
Definition: models.h:48
#define GENN_EXPORT
Definition: gennExport.h:13
std::function< bool(void)> IsBatchedFn
Definition: models.h:190
EGPRef(const std::string &n, const std::string &t)
Definition: models.h:98
Definition: snippet.h:250
unsigned int getSize() const
Definition: models.h:216
size_t getTransposeVarIndex() const
Definition: models.h:272
std::string type
Definition: models.h:109
VarRef(const std::string &n, const std::string &t, VarAccessMode a)
Definition: models.h:79
Definition: customUpdate.h:252
VarAccess
Supported combinations of VarAccessMode and VarAccessDuplication.
Definition: varAccess.h:36
Helper class for generating code - automatically inserts brackets, indents etc.
Definition: backendBase.h:30
Models::VarReference createWUPreVarRef(const SynapseGroup *sg, const std::string &varName)
Creates a reference to a weight update model presynaptic variable.
Definition: modelSpec.h:156
std::string getTransposeTargetName() const
Definition: models.h:273
std::function< std::string(void)> GetTargetNameFn
Definition: models.h:189
Definition: synapseGroupInternal.h:9
Models::EGPReference createEGPRef(const NeuronGroup *ng, const std::string &egpName)
Creates a reference to a neuron group extra global parameter.
Definition: modelSpec.h:181
VarRef(const std::string &n, const std::string &t)
Definition: models.h:81
Models::EGPReference createWUEGPRef(const SynapseGroup *sg, const std::string &egpName)
Creates a reference to a weight update model extra global parameter.
Definition: modelSpec.h:211
VarInit(double constant)
Definition: models.h:159
A variable has a name, a type and an access type.
Definition: models.h:58
Definition: synapseGroup.h:24
std::string getTargetName() const
Definition: models.h:303
Models::VarReference createVarRef(const NeuronGroup *ng, const std::string &varName)
Creates a reference to a neuron group variable.
Definition: modelSpec.h:132
Definition: models.h:151
Definition: initVarSnippet.h:17
virtual VarVec getVars() const
Gets names and types (as strings) of model variables.
Definition: models.h:123
size_t getVarIndex(const std::string &varName) const
Find the index of a named variable.
Definition: models.h:129
size_t getEGPIndex() const
Definition: models.h:302
VarRef()
Definition: models.h:83
Models::EGPReference createPSMEGPRef(const SynapseGroup *sg, const std::string &egpName)
Creates a reference to a postsynaptic model extra global parameter.
Definition: modelSpec.h:205
const Models::Base::Var & getVar() const
Definition: models.h:180
VarAccessMode access
Definition: models.h:93
const NeuronGroup * getDelayNeuronGroup() const
Definition: models.h:217
VarAccess access
Definition: models.h:74
Base class for all value initialisation snippets.
Definition: initVarSnippet.h:15
Models::VarReference createPSMVarRef(const SynapseGroup *sg, const std::string &varName)
Creates a reference to a postsynaptic model variable.
Definition: modelSpec.h:150
std::string name
Definition: models.h:108
Definition: models.h:210
Var(const std::string &n, const std::string &t)
Definition: models.h:62
Definition: models.h:295
Definition: snippet.h:48
EGPRef()
Definition: models.h:100
std::string type
Definition: models.h:92
An extra global parameter has a name and a type.
Definition: snippet.h:131
Definition: models.h:258
Base class for all code snippets.
Definition: snippet.h:120
std::vector< VarRef > VarRefVec
Definition: models.h:116
VarReferenceBase(size_t varIndex, const Models::Base::VarVec &varVec, GetTargetNameFn getTargetName, IsBatchedFn isBatched)
Definition: models.h:192
Definition: models.h:96
Base class for all models - in addition to the parameters snippets have, models can have state variab...
Definition: gennUtils.h:22
Var()
Definition: models.h:64
const Models::Base::EGP & getEGP() const
Definition: models.h:301
size_t getVarIndex() const
Definition: models.h:181
bool isBatched() const
Definition: models.h:183
std::string name
Definition: models.h:72
GENN_EXPORT void updateHash(const EGPReference &v, boost::uuids::detail::sha1 &hash)
Definition: models.cc:259
Definition: snippet.h:45
Var(const std::string &n, const std::string &t, VarAccess a)
Definition: models.h:60
Definition: models.h:174
Models::VarReference createWUPostVarRef(const SynapseGroup *sg, const std::string &varName)
Creates a reference to a weight update model postsynapticvariable.
Definition: modelSpec.h:162
std::string name
Definition: models.h:91
Definition: currentSourceInternal.h:9
Definition: currentSource.h:20
std::vector< EGP > EGPVec
Definition: snippet.h:179