GeNN  4.9.0
GPU enhanced Neuronal Networks (GeNN)
snippet.h
Go to the documentation of this file.
1 #pragma once
2 
3 // Standard C++ includes
4 #include <algorithm>
5 #include <functional>
6 #include <string>
7 #include <vector>
8 
9 // Standard C includes
10 #include <cassert>
11 
12 // GeNN includes
13 #include "gennExport.h"
14 #include "gennUtils.h"
15 
16 //----------------------------------------------------------------------------
17 // Macros
18 //----------------------------------------------------------------------------
19 #define DECLARE_SNIPPET(TYPE, NUM_PARAMS) \
20 private: \
21  GENN_EXPORT static TYPE *s_Instance; \
22 public: \
23  static const TYPE *getInstance() \
24  { \
25  if(s_Instance == NULL) \
26  { \
27  s_Instance = new TYPE; \
28  } \
29  return s_Instance; \
30  } \
31  typedef Snippet::ValueBase<NUM_PARAMS> ParamValues \
32 
33 
34 #define IMPLEMENT_SNIPPET(TYPE) TYPE *TYPE::s_Instance = NULL
35 
36 #define SET_PARAM_NAMES(...) virtual StringVec getParamNames() const override{ return __VA_ARGS__; }
37 #define SET_DERIVED_PARAMS(...) virtual DerivedParamVec getDerivedParams() const override{ return __VA_ARGS__; }
38 #define SET_EXTRA_GLOBAL_PARAMS(...) virtual EGPVec getExtraGlobalParams() const override{ return __VA_ARGS__; }
39 
40 //----------------------------------------------------------------------------
41 // Snippet::InitialiserContainerBase
42 //----------------------------------------------------------------------------
45 namespace Snippet
46 {
47 template<typename V, size_t NumVars>
49 {
50 public:
51  // **NOTE** other less terrifying forms of constructor won't complain at compile time about
52  // number of parameters e.g. std::array<V, 4> can be initialized with <= 4 elements
53  template<typename... T>
54  InitialiserContainerBase(T&&... vals) : m_Values(std::vector<V>{{std::forward<const V>(vals)...}})
55  {
56  static_assert(sizeof...(vals) == NumVars, "Wrong number of values");
57  }
58 
59  //----------------------------------------------------------------------------
60  // Public API
61  //----------------------------------------------------------------------------
63  const std::vector<V> &getInitialisers() const
64  {
65  return m_Values;
66  }
67 
68  //----------------------------------------------------------------------------
69  // Operators
70  //----------------------------------------------------------------------------
71  const V &operator[](size_t pos) const
72  {
73  return m_Values[pos];
74  }
75 
76 private:
77  //----------------------------------------------------------------------------
78  // Members
79  //----------------------------------------------------------------------------
80  std::vector<V> m_Values;
81 };
82 
83 //----------------------------------------------------------------------------
84 // Snippet::InitialiserContainerBase<0>
85 //----------------------------------------------------------------------------
88 template<typename V>
90 {
91 public:
92  // **NOTE** other less terrifying forms of constructor won't complain at compile time about
93  // number of parameters e.g. std::array<double, 4> can be initialized with <= 4 elements
94  template<typename... T>
96  {
97  static_assert(sizeof...(vals) == 0, "Wrong number of values");
98  }
99 
100  //----------------------------------------------------------------------------
101  // Public API
102  //----------------------------------------------------------------------------
104  std::vector<V> getInitialisers() const
105  {
106  return {};
107  }
108 };
109 
110 //----------------------------------------------------------------------------
111 // Snippet::ValueBase
112 //----------------------------------------------------------------------------
113 template<size_t NumVars>
115 
116 //----------------------------------------------------------------------------
117 // Snippet::Base
118 //----------------------------------------------------------------------------
121 {
122 public:
123  virtual ~Base()
124  {
125  }
126 
127  //----------------------------------------------------------------------------
128  // Structs
129  //----------------------------------------------------------------------------
131  struct EGP
132  {
133  bool operator == (const EGP &other) const
134  {
135  return ((name == other.name) && (type == other.type));
136  }
137 
138  std::string name;
139  std::string type;
140  };
141 
143  struct ParamVal
144  {
145  ParamVal(const std::string &n, const std::string &t, const std::string &v) : name(n), type(t), value(v)
146  {}
147  ParamVal(const std::string &n, const std::string &t, double v) : ParamVal(n, t, Utils::writePreciseString(v))
148  {}
149  ParamVal() : ParamVal("", "", "0.0")
150  {}
151 
152  bool operator == (const ParamVal &other) const
153  {
154  return ((name == other.name) && (type == other.type) && (value == other.value));
155  }
156 
157  std::string name;
158  std::string type;
159  std::string value;
160  };
161 
164  {
165  bool operator == (const DerivedParam &other) const
166  {
167  return (name == other.name);
168  }
169 
170  std::string name;
171  std::function<double(const std::vector<double> &, double)> func;
172  };
173 
174 
175  //----------------------------------------------------------------------------
176  // Typedefines
177  //----------------------------------------------------------------------------
178  typedef std::vector<std::string> StringVec;
179  typedef std::vector<EGP> EGPVec;
180  typedef std::vector<ParamVal> ParamValVec;
181  typedef std::vector<DerivedParam> DerivedParamVec;
182 
183  //----------------------------------------------------------------------------
184  // Declared virtuals
185  //----------------------------------------------------------------------------
187  virtual StringVec getParamNames() const{ return {}; }
188 
191  virtual DerivedParamVec getDerivedParams() const{ return {}; }
192 
195  virtual EGPVec getExtraGlobalParams() const { return {}; }
196 
197  //------------------------------------------------------------------------
198  // Public methods
199  //------------------------------------------------------------------------
201  size_t getExtraGlobalParamIndex(const std::string &paramName) const
202  {
203  return getNamedVecIndex(paramName, getExtraGlobalParams());
204  }
205 
206 protected:
207  //------------------------------------------------------------------------
208  // Protected methods
209  //------------------------------------------------------------------------
210  void updateHash(boost::uuids::detail::sha1 &hash) const
211  {
212  Utils::updateHash(getParamNames(), hash);
213  Utils::updateHash(getDerivedParams(), hash);
214  Utils::updateHash(getExtraGlobalParams(), hash);
215  }
216 
218  void validate() const
219  {
220  Utils::validateParamNames(getParamNames());
221  Utils::validateVecNames(getDerivedParams(), "Derived parameter");
222  Utils::validateVecNames(getExtraGlobalParams(), "Derived parameter");
223  }
224 
225  //------------------------------------------------------------------------
226  // Protected static helpers
227  //------------------------------------------------------------------------
228  template<typename T>
229  static size_t getNamedVecIndex(const std::string &name, const std::vector<T> &vec)
230  {
231  auto iter = std::find_if(vec.begin(), vec.end(),
232  [name](const T &v){ return (v.name == name); });
233 
234  if(iter == vec.end()) {
235  throw std::runtime_error("Cannot find variable '" + name + "'");
236  }
237 
238  // Return 'distance' between first entry in vector and iterator i.e. index
239  return distance(vec.begin(), iter);
240  }
241 };
242 
243 //----------------------------------------------------------------------------
244 // Snippet::Init
245 //----------------------------------------------------------------------------
249 template<typename SnippetBase>
250 class Init
251 {
252 public:
253  Init(const SnippetBase *snippet, const std::vector<double> &params)
254  : m_Snippet(snippet), m_Params(params)
255  {
256  // Validate names
257  getSnippet()->validate();
258  }
259 
260  //----------------------------------------------------------------------------
261  // Public API
262  //----------------------------------------------------------------------------
263  const SnippetBase *getSnippet() const{ return m_Snippet; }
264  const std::vector<double> &getParams() const{ return m_Params; }
265  const std::vector<double> &getDerivedParams() const{ return m_DerivedParams; }
266 
267  void initDerivedParams(double dt)
268  {
269  auto derivedParams = m_Snippet->getDerivedParams();
270 
271  // Reserve vector to hold derived parameters
272  m_DerivedParams.reserve(derivedParams.size());
273 
274  // Loop through derived parameters
275  for(const auto &d : derivedParams) {
276  m_DerivedParams.push_back(d.func(m_Params, dt));
277  }
278  }
279 
280  boost::uuids::detail::sha1::digest_type getHashDigest() const
281  {
282  return getSnippet()->getHashDigest();
283  }
284 
285 private:
286  //----------------------------------------------------------------------------
287  // Members
288  //----------------------------------------------------------------------------
289  const SnippetBase *m_Snippet;
290  std::vector<double> m_Params;
291  std::vector<double> m_DerivedParams;
292 };
293 
294 //----------------------------------------------------------------------------
295 // updateHash overrides
296 //----------------------------------------------------------------------------
297 inline void updateHash(const Base::EGP &e, boost::uuids::detail::sha1 &hash)
298 {
299  Utils::updateHash(e.name, hash);
300  Utils::updateHash(e.type, hash);
301 }
302 
303 inline void updateHash(const Base::ParamVal &p, boost::uuids::detail::sha1 &hash)
304 {
305  Utils::updateHash(p.name, hash);
306  Utils::updateHash(p.type, hash);
307  Utils::updateHash(p.value, hash);
308 }
309 
310 inline void updateHash(const Base::DerivedParam &d, boost::uuids::detail::sha1 &hash)
311 {
312  Utils::updateHash(d.name, hash);
313 }
314 } // namespace Snippet
ParamVal()
Definition: snippet.h:149
std::vector< V > getInitialisers() const
Gets values as a vector of doubles.
Definition: snippet.h:104
std::string type
Definition: snippet.h:158
std::string name
Definition: snippet.h:170
ParamVal(const std::string &n, const std::string &t, double v)
Definition: snippet.h:147
std::string name
Definition: snippet.h:157
#define GENN_EXPORT
Definition: gennExport.h:13
ParamVal(const std::string &n, const std::string &t, const std::string &v)
Definition: snippet.h:145
Definition: snippet.h:250
STL namespace.
Init(const SnippetBase *snippet, const std::vector< double > &params)
Definition: snippet.h:253
void writePreciseString(std::ostream &os, T value)
This function writes a floating point value to a stream -setting the precision so no digits are lost...
Definition: gennUtils.h:92
size_t getExtraGlobalParamIndex(const std::string &paramName) const
Find the index of a named extra global parameter.
Definition: snippet.h:201
std::string value
Definition: snippet.h:159
boost::uuids::detail::sha1::digest_type getHashDigest() const
Definition: snippet.h:280
void updateHash(const T &value, boost::uuids::detail::sha1 &hash)
Hash arithmetic types and enums.
Definition: gennUtils.h:128
A derived parameter has a name and a function for obtaining its value.
Definition: snippet.h:163
void updateHash(boost::uuids::detail::sha1 &hash) const
Definition: snippet.h:210
std::string type
Definition: snippet.h:139
void initDerivedParams(double dt)
Definition: snippet.h:267
virtual ~Base()
Definition: snippet.h:123
const std::vector< V > & getInitialisers() const
Gets values as a vector.
Definition: snippet.h:63
virtual StringVec getParamNames() const
Gets names of of (independent) model parameters.
Definition: snippet.h:187
const SnippetBase * getSnippet() const
Definition: snippet.h:263
const V & operator[](size_t pos) const
Definition: snippet.h:71
const std::vector< double > & getParams() const
Definition: snippet.h:264
std::vector< DerivedParam > DerivedParamVec
Definition: snippet.h:181
void updateHash(const Base::EGP &e, boost::uuids::detail::sha1 &hash)
Definition: snippet.h:297
std::vector< std::string > StringVec
Definition: snippet.h:178
virtual EGPVec getExtraGlobalParams() const
Definition: snippet.h:195
Definition: snippet.h:48
InitialiserContainerBase(T &&... vals)
Definition: snippet.h:54
An extra global parameter has a name and a type.
Definition: snippet.h:131
Base class for all code snippets.
Definition: snippet.h:120
Additional input variables, row state variables and other things have a name, a type and an initial v...
Definition: snippet.h:143
Definition: gennUtils.h:30
static size_t getNamedVecIndex(const std::string &name, const std::vector< T > &vec)
Definition: snippet.h:229
Definition: snippet.h:45
InitialiserContainerBase(T &&... vals)
Definition: snippet.h:95
void validate() const
Validate names of parameters etc.
Definition: snippet.h:218
virtual DerivedParamVec getDerivedParams() const
Definition: snippet.h:191
void validateVecNames(const std::vector< T > &vec, const std::string &description)
Are the &#39;name&#39; fields of all structs in vector valid? GeNN variables and population names must obey C...
Definition: gennUtils.h:81
std::string name
Definition: snippet.h:138
GENN_EXPORT void validateParamNames(const std::vector< std::string > &paramNames)
Are all the parameter names in vector valid? GeNN variables and population names must obey C variable...
Definition: gennUtils.cc:142
const std::vector< double > & getDerivedParams() const
Definition: snippet.h:265
std::vector< ParamVal > ParamValVec
Definition: snippet.h:180
std::vector< EGP > EGPVec
Definition: snippet.h:179