GeNN  4.9.0
GPU enhanced Neuronal Networks (GeNN)
customUpdate.h
Go to the documentation of this file.
1 #pragma once
2 
3 // Standard includes
4 #include <string>
5 #include <vector>
6 
7 // GeNN includes
8 #include "gennExport.h"
9 #include "gennUtils.h"
10 #include "customUpdateModels.h"
11 #include "variableMode.h"
12 
13 //------------------------------------------------------------------------
14 // CustomUpdateBase
15 //------------------------------------------------------------------------
17 {
18 public:
19  CustomUpdateBase(const CustomUpdateBase &) = delete;
20  CustomUpdateBase() = delete;
21 
22  //------------------------------------------------------------------------
23  // Public methods
24  //------------------------------------------------------------------------
26 
27  void setVarLocation(const std::string &varName, VarLocation loc);
28 
29  //------------------------------------------------------------------------
30  // Public const methods
31  //------------------------------------------------------------------------
32  const std::string &getName() const{ return m_Name; }
33  const std::string &getUpdateGroupName() const { return m_UpdateGroupName; }
34 
36  const CustomUpdateModels::Base *getCustomUpdateModel() const{ return m_CustomUpdateModel; }
37 
38  const std::vector<double> &getParams() const{ return m_Params; }
39  const std::vector<Models::VarInit> &getVarInitialisers() const{ return m_VarInitialisers; }
40 
41  const std::vector<Models::EGPReference> &getEGPReferences() const{ return m_EGPReferences; }
42 
44  VarLocation getVarLocation(const std::string &varName) const;
45 
47  VarLocation getVarLocation(size_t index) const{ return m_VarLocation.at(index); }
48 
50  bool isVarInitRequired() const;
51 
52 protected:
53  CustomUpdateBase(const std::string &name, const std::string &updateGroupName,
54  const CustomUpdateModels::Base *customUpdateModel, const std::vector<double> &params,
55  const std::vector<Models::VarInit> &varInitialisers, const std::vector<Models::EGPReference> &egpReferences,
56  VarLocation defaultVarLocation, VarLocation defaultExtraGlobalParamLocation);
57 
58  //------------------------------------------------------------------------
59  // Protected methods
60  //------------------------------------------------------------------------
61  void initDerivedParams(double dt);
62 
63  //------------------------------------------------------------------------
64  // Protected const methods
65  //------------------------------------------------------------------------
66  const std::vector<double> &getDerivedParams() const{ return m_DerivedParams; }
67 
69  bool isInitRNGRequired() const;
70 
71  bool isZeroCopyEnabled() const;
72 
74  bool isBatched() const { return m_Batched; }
75 
77 
78  void updateHash(boost::uuids::detail::sha1 &hash) const;
79 
81 
82  void updateInitHash(boost::uuids::detail::sha1 &hash) const;
83 
84  boost::uuids::detail::sha1::digest_type getVarLocationHashDigest() const;
85 
86  template<typename V>
87  bool isReduction(const std::vector<V> &varRefs, VarAccessDuplication duplication) const
88  {
89  // Return true if any variables have REDUCE flag in their access mode and have correct duplication flag
90  const auto vars = getCustomUpdateModel()->getVars();
91  if(std::any_of(vars.cbegin(), vars.cend(),
92  [duplication](const Models::Base::Var &v)
93  {
94  return (v.access & VarAccessModeAttribute::REDUCE) && (v.access & duplication);
95  }))
96  {
97  return true;
98  }
99 
100  // Loop through all variable references
101  const auto modelVarRefs = getCustomUpdateModel()->getVarRefs();
102  for (size_t i = 0; i < varRefs.size(); i++) {
103  const auto varRef = varRefs.at(i);
104  const auto modelVarRef = modelVarRefs.at(i);
105 
106  // If custom update model reduces into this variable reference and the variable it targets has correct duplication flag
107  if ((modelVarRef.access & VarAccessModeAttribute::REDUCE) & (varRef.getVar().access & duplication)) {
108  return true;
109  }
110  }
111 
112  return false;
113  }
114 
116  template<typename V>
117  void checkVarReferences(const std::vector<V> &varRefs)
118  {
119  // Loop through all variable references
120  const auto modelVarRefs = getCustomUpdateModel()->getVarRefs();
121  for(size_t i = 0; i < varRefs.size(); i++) {
122  const auto varRef = varRefs.at(i);
123  const auto modelVarRef = modelVarRefs.at(i);
124 
125  // Check types of variable references against those specified in model
126  // **THINK** due to GeNN's current string-based type system this is rather conservative
127  if(varRef.getVar().type != modelVarRef.type) {
128  throw std::runtime_error("Incompatible type for variable reference '" + modelVarRef.name + "'");
129  }
130 
131  // Check that no reduction targets reference duplicated variables
132  if((varRef.getVar().access & VarAccessDuplication::DUPLICATE)
133  && (modelVarRef.access & VarAccessModeAttribute::REDUCE))
134  {
135  throw std::runtime_error("Reduction target variable reference must be to SHARED or SHARED_NEURON variables.");
136  }
137  }
138  }
139 
141  template<typename V>
142  void checkVarReferenceBatching(const std::vector<V>& varRefs, unsigned int batchSize)
143  {
144  // If target of any variable references is not shared across batches, custom update should be batched
145  if(batchSize > 1) {
146  m_Batched = std::any_of(varRefs.cbegin(), varRefs.cend(),
147  [](const V& v)
148  {
149  return (v.isBatched() && !(v.getVar().access & VarAccessDuplication::SHARED));
150  });
151  }
152  else {
153  m_Batched = false;
154  }
155 
156  // Loop through all variable references
157  const auto modelVarRefs = getCustomUpdateModel()->getVarRefs();
158  for (size_t i = 0; i < varRefs.size(); i++) {
159  const auto varRef = varRefs.at(i);
160  const auto modelVarRef = modelVarRefs.at(i);
161 
162  // If custom update is batched, check that any variable references to shared variables are read-only
163  // **NOTE** if custom update isn't batched, it's totally fine to write to shared variables
164  if(m_Batched && (varRef.getVar().access & VarAccessDuplication::SHARED)
165  && (modelVarRef.access == VarAccessMode::READ_WRITE))
166  {
167  throw std::runtime_error("Variable references to SHARED variables in batched custom updates cannot be read-write.");
168  }
169  }
170  }
171 
172 private:
173  //------------------------------------------------------------------------
174  // Members
175  //------------------------------------------------------------------------
176  const std::string m_Name;
177  const std::string m_UpdateGroupName;
178 
179  const CustomUpdateModels::Base *m_CustomUpdateModel;
180  const std::vector<double> m_Params;
181  std::vector<double> m_DerivedParams;
182  std::vector<Models::VarInit> m_VarInitialisers;
183 
184  std::vector<Models::EGPReference> m_EGPReferences;
185 
187  std::vector<VarLocation> m_VarLocation;
188 
190  std::vector<VarLocation> m_ExtraGlobalParamLocation;
191 
193  bool m_Batched;
194 };
195 
196 //------------------------------------------------------------------------
197 // CustomUpdate
198 //------------------------------------------------------------------------
200 {
201 public:
202  //------------------------------------------------------------------------
203  // Public const methods
204  //------------------------------------------------------------------------
205  const std::vector<Models::VarReference> &getVarReferences() const{ return m_VarReferences; }
206  unsigned int getSize() const { return m_Size; }
207 
208 protected:
209  CustomUpdate(const std::string &name, const std::string &updateGroupName,
210  const CustomUpdateModels::Base *customUpdateModel, const std::vector<double> &params,
211  const std::vector<Models::VarInit> &varInitialisers, const std::vector<Models::VarReference> &varReferences,
212  const std::vector<Models::EGPReference> &egpReferences, VarLocation defaultVarLocation,
213  VarLocation defaultExtraGlobalParamLocation);
214 
215  //------------------------------------------------------------------------
216  // Protected methods
217  //------------------------------------------------------------------------
218  void finalize(unsigned int batchSize);
219 
220  //------------------------------------------------------------------------
221  // Protected const methods
222  //------------------------------------------------------------------------
223  bool isBatchReduction() const { return isReduction(getVarReferences(), VarAccessDuplication::SHARED); }
224  bool isNeuronReduction() const { return isReduction(getVarReferences(), VarAccessDuplication::SHARED_NEURON); }
225  bool isPerNeuron() const{ return m_PerNeuron; }
226 
228 
229  boost::uuids::detail::sha1::digest_type getHashDigest() const;
230 
232 
233  boost::uuids::detail::sha1::digest_type getInitHashDigest() const;
234 
235  const NeuronGroup *getDelayNeuronGroup() const { return m_DelayNeuronGroup; }
236 
237 private:
238  //------------------------------------------------------------------------
239  // Members
240  //------------------------------------------------------------------------
241  const std::vector<Models::VarReference> m_VarReferences;
242  const unsigned int m_Size;
243  const NeuronGroup *m_DelayNeuronGroup;
244 
246  bool m_PerNeuron;
247 };
248 
249 //------------------------------------------------------------------------
250 // CustomUpdateWU
251 //------------------------------------------------------------------------
253 {
254 public:
255  //------------------------------------------------------------------------
256  // Public const methods
257  //------------------------------------------------------------------------
258  const std::vector<Models::WUVarReference> &getVarReferences() const{ return m_VarReferences; }
259 
260 protected:
261  CustomUpdateWU(const std::string &name, const std::string &updateGroupName,
262  const CustomUpdateModels::Base *customUpdateModel, const std::vector<double> &params,
263  const std::vector<Models::VarInit> &varInitialisers, const std::vector<Models::WUVarReference> &varReferences,
264  const std::vector<Models::EGPReference> &egpReferences, VarLocation defaultVarLocation,
265  VarLocation defaultExtraGlobalParamLocation);
266 
267  //------------------------------------------------------------------------
268  // Protected methods
269  //------------------------------------------------------------------------
270  void finalize(unsigned int batchSize);
271 
272  //------------------------------------------------------------------------
273  // Protected const methods
274  //------------------------------------------------------------------------
275  bool isBatchReduction() const { return isReduction(getVarReferences(), VarAccessDuplication::SHARED); }
276  bool isTransposeOperation() const;
277 
278  const SynapseGroupInternal *getSynapseGroup() const { return m_SynapseGroup; }
279 
281 
282  boost::uuids::detail::sha1::digest_type getHashDigest() const;
283 
285 
286  boost::uuids::detail::sha1::digest_type getInitHashDigest() const;
287 
288 private:
289  //------------------------------------------------------------------------
290  // Members
291  //------------------------------------------------------------------------
292  const std::vector<Models::WUVarReference> m_VarReferences;
293  const SynapseGroupInternal *m_SynapseGroup;
294 };
Definition: neuronGroup.h:21
Definition: customUpdate.h:16
VarLocation
< Flags defining which memory space variables should be allocated in
Definition: variableMode.h:10
const std::vector< double > & getDerivedParams() const
Definition: customUpdate.h:66
bool isBatched() const
Is this custom update batched i.e. run in parallel across model batches.
Definition: customUpdate.h:74
const std::string & getName() const
Definition: customUpdate.h:32
bool isBatchReduction() const
Definition: customUpdate.h:223
Definition: customUpdate.h:199
const std::vector< double > & getParams() const
Definition: customUpdate.h:38
const std::vector< Models::VarReference > & getVarReferences() const
Definition: customUpdate.h:205
void checkVarReferences(const std::vector< V > &varRefs)
Helper function to check if variable reference types match those specified in model.
Definition: customUpdate.h:117
#define GENN_EXPORT
Definition: gennExport.h:13
VarLocation getVarLocation(size_t index) const
Get variable location for custom update model state variable.
Definition: customUpdate.h:47
Definition: customUpdate.h:252
Base class for all current source models.
Definition: customUpdateModels.h:31
bool isNeuronReduction() const
Definition: customUpdate.h:224
const std::vector< Models::VarInit > & getVarInitialisers() const
Definition: customUpdate.h:39
Definition: synapseGroupInternal.h:9
const SynapseGroupInternal * getSynapseGroup() const
Definition: customUpdate.h:278
const NeuronGroup * getDelayNeuronGroup() const
Definition: customUpdate.h:235
void updateHash(const T &value, boost::uuids::detail::sha1 &hash)
Hash arithmetic types and enums.
Definition: gennUtils.h:128
A variable has a name, a type and an access type.
Definition: models.h:58
This variable should be shared between batches.
unsigned int getSize() const
Definition: customUpdate.h:206
const CustomUpdateModels::Base * getCustomUpdateModel() const
Gets the custom update model used by this group.
Definition: customUpdate.h:36
const std::vector< Models::EGPReference > & getEGPReferences() const
Definition: customUpdate.h:41
VarAccessDuplication
Flags defining how variables should be duplicated across multiple batches.
Definition: varAccess.h:28
void checkVarReferenceBatching(const std::vector< V > &varRefs, unsigned int batchSize)
Helper function to check if variable reference types match those specified in model.
Definition: customUpdate.h:142
const std::vector< Models::WUVarReference > & getVarReferences() const
Definition: customUpdate.h:258
bool isBatchReduction() const
Definition: customUpdate.h:275
bool isReduction(const std::vector< V > &varRefs, VarAccessDuplication duplication) const
Definition: customUpdate.h:87
This variable is read-write.
const std::string & getUpdateGroupName() const
Definition: customUpdate.h:33
bool isPerNeuron() const
Definition: customUpdate.h:225
This variable should be duplicated in each batch.