GeNN  3.3.0
GPU enhanced Neuronal Networks (GeNN)
newModels.h
Go to the documentation of this file.
1 #pragma once
2 
3 // Standard C++ includes
4 #include <algorithm>
5 #include <string>
6 #include <vector>
7 
8 // Standard C includes
9 #include <cassert>
10 
11 // GeNN includes
12 #include "snippet.h"
13 #include "initVarSnippet.h"
14 
15 //----------------------------------------------------------------------------
16 // Macros
17 //----------------------------------------------------------------------------
18 #define DECLARE_MODEL(TYPE, NUM_PARAMS, NUM_VARS) \
19  DECLARE_SNIPPET(TYPE, NUM_PARAMS) \
20  typedef NewModels::VarInitContainerBase<NUM_VARS> VarValues; \
21  typedef NewModels::VarInitContainerBase<0> PreVarValues; \
22  typedef NewModels::VarInitContainerBase<0> PostVarValues;
23 
24 #define IMPLEMENT_MODEL(TYPE) IMPLEMENT_SNIPPET(TYPE)
25 
26 #define SET_VARS(...) virtual StringPairVec getVars() const override{ return __VA_ARGS__; }
27 
28 //----------------------------------------------------------------------------
29 // NewModels::VarInit
30 //----------------------------------------------------------------------------
31 namespace NewModels
32 {
36 class VarInit : public Snippet::Init<InitVarSnippet::Base>
37 {
38 public:
39  VarInit(const InitVarSnippet::Base *snippet, const std::vector<double> &params)
40  : Snippet::Init<InitVarSnippet::Base>(snippet, params)
41  {
42  }
43 
44  VarInit(double constant)
45  : Snippet::Init<InitVarSnippet::Base>(InitVarSnippet::Constant::getInstance(), {constant})
46  {
47  }
48 };
49 
50 //----------------------------------------------------------------------------
51 // NewModels::VarInitContainerBase
52 //----------------------------------------------------------------------------
55 template<size_t NumVars>
57 {
58 private:
59  //----------------------------------------------------------------------------
60  // Typedefines
61  //----------------------------------------------------------------------------
62  typedef std::vector<VarInit> InitialiserArray;
63 
64 public:
65  // **NOTE** other less terrifying forms of constructor won't complain at compile time about
66  // number of parameters e.g. std::array<VarInit, 4> can be initialized with <= 4 elements
67  template<typename... T>
68  VarInitContainerBase(T&&... initialisers) : m_Initialisers(InitialiserArray{{std::forward<const VarInit>(initialisers)...}})
69  {
70  static_assert(sizeof...(initialisers) == NumVars, "Wrong number of initialisers");
71  }
72 
73  //----------------------------------------------------------------------------
74  // Public API
75  //----------------------------------------------------------------------------
77  const std::vector<VarInit> &getInitialisers() const
78  {
79  return m_Initialisers;
80  }
81 
82  //----------------------------------------------------------------------------
83  // Operators
84  //----------------------------------------------------------------------------
85  const VarInit &operator[](size_t pos) const
86  {
87  return m_Initialisers[pos];
88  }
89 
90 private:
91  //----------------------------------------------------------------------------
92  // Members
93  //----------------------------------------------------------------------------
94  InitialiserArray m_Initialisers;
95 };
96 
97 //----------------------------------------------------------------------------
98 // NewModels::VarInitContainerBase<0>
99 //----------------------------------------------------------------------------
102 template<>
104 {
105 public:
106  // **NOTE** other less terrifying forms of constructor won't complain at compile time about
107  // number of parameters e.g. std::array<VarInit, 4> can be initialized with <= 4 elements
108  template<typename... T>
109  VarInitContainerBase(T&&... initialisers)
110  {
111  static_assert(sizeof...(initialisers) == 0, "Wrong number of initialisers");
112  }
113 
115  {
116  }
117 
118  //----------------------------------------------------------------------------
119  // Public API
120  //----------------------------------------------------------------------------
122  std::vector<VarInit> getInitialisers() const
123  {
124  return {};
125  }
126 };
127 
128 //----------------------------------------------------------------------------
129 // NewModels::Base
130 //----------------------------------------------------------------------------
132 class Base : public Snippet::Base
133 {
134 public:
135  //----------------------------------------------------------------------------
136  // Declared virtuals
137  //------------------------------------------------------------------------
139  virtual StringPairVec getVars() const{ return {}; }
140 
141  //------------------------------------------------------------------------
142  // Public methods
143  //------------------------------------------------------------------------
145  size_t getVarIndex(const std::string &varName) const
146  {
147  return getVarIndex(varName, getVars());
148  }
149 
150 protected:
151  //------------------------------------------------------------------------
152  // Protected static helpers
153  //------------------------------------------------------------------------
154  static size_t getVarIndex(const std::string &varName, const StringPairVec &vars)
155  {
156  auto varIter = std::find_if(vars.begin(), vars.end(),
157  [varName](const StringPairVec::value_type &v){ return (v.first == varName); });
158  assert(varIter != vars.end());
159 
160  // Return flag corresponding to variable
161  return distance(vars.begin(), varIter);
162  }
163 };
164 
165 //----------------------------------------------------------------------------
166 // NewModels::LegacyWrapper
167 //----------------------------------------------------------------------------
169 template<typename ModelBase, typename LegacyModelType, const std::vector<LegacyModelType> &ModelArray>
170 class LegacyWrapper : public ModelBase
171 {
172 private:
173  typedef typename ModelBase::DerivedParamFunc DerivedParamFunc;
174  typedef typename ModelBase::StringVec StringVec;
175  typedef typename ModelBase::StringPairVec StringPairVec;
176  typedef typename ModelBase::DerivedParamVec DerivedParamVec;
177 
178 public:
179  LegacyWrapper(unsigned int legacyTypeIndex) : m_LegacyTypeIndex(legacyTypeIndex)
180  {
181  }
182 
183  //----------------------------------------------------------------------------
184  // ModelBase virtuals
185  //----------------------------------------------------------------------------
187  virtual StringVec getParamNames() const
188  {
189  const auto &nm = ModelArray[m_LegacyTypeIndex];
190  return nm.pNames;
191  }
192 
195  virtual DerivedParamVec getDerivedParams() const
196  {
197  const auto &m = ModelArray[m_LegacyTypeIndex];
198 
199  // Reserve vector to hold derived parameters
200  DerivedParamVec derivedParams;
201  derivedParams.reserve(m.dpNames.size());
202 
203  // Loop through derived parameters
204  for(size_t p = 0; p < m.dpNames.size(); p++)
205  {
206  // Add pair consisting of parameter name and lambda function which calls
207  // through to the DPS object associated with the legacy model
208  derivedParams.push_back(std::make_pair(
209  m.dpNames[p],
210  [this, p](const std::vector<double> &pars, double dt)
211  {
212  return ModelArray[m_LegacyTypeIndex].dps->calculateDerivedParameter(p, pars, dt);
213  }
214  ));
215  }
216 
217  return derivedParams;
218  }
219 
221  virtual StringPairVec getVars() const
222  {
223  const auto &nm = ModelArray[m_LegacyTypeIndex];
224  return zipStringVectors(nm.varNames, nm.varTypes);
225  }
226 
227 protected:
228  //----------------------------------------------------------------------------
229  // Static methods
230  //----------------------------------------------------------------------------
231  static StringPairVec zipStringVectors(const StringVec &a, const StringVec &b)
232  {
233  assert(a.size() == b.size());
234 
235  // Reserve vector to hold initial values
236  StringPairVec zip;
237  zip.reserve(a.size());
238 
239  // Build vector from legacy neuron model
240  for(size_t v = 0; v < a.size(); v++)
241  {
242  zip.push_back(std::make_pair(a[v], b[v]));
243  }
244 
245  return zip;
246  }
247 
248  //----------------------------------------------------------------------------
249  // Members
250  //----------------------------------------------------------------------------
252  const unsigned int m_LegacyTypeIndex;
253 };
254 } // NewModels
const unsigned int m_LegacyTypeIndex
Index into the array of legacy models.
Definition: newModels.h:252
VarInit(double constant)
Definition: newModels.h:44
virtual StringPairVec getVars() const
Gets names and types (as strings) of model variables.
Definition: newModels.h:221
virtual DerivedParamVec getDerivedParams() const
Definition: newModels.h:195
VarInitContainerBase(T &&... initialisers)
Definition: newModels.h:68
Definition: snippet.h:139
Definition: newModels.h:36
Init(const InitVarSnippet::Base *snippet, const std::vector< double > &params)
Definition: snippet.h:142
Wrapper around old-style models stored in global arrays and referenced by index.
Definition: newModels.h:170
VarInitContainerBase(const Snippet::ValueBase< 0 > &)
Definition: newModels.h:114
Definition: snippet.h:80
LegacyWrapper(unsigned int legacyTypeIndex)
Definition: newModels.h:179
Definition: initVarSnippet.h:17
Definition: newModels.h:56
Base class for all models - in addition to the parameters snippets have, models can have state variab...
Definition: newModels.h:132
static StringPairVec zipStringVectors(const StringVec &a, const StringVec &b)
Definition: newModels.h:231
virtual StringPairVec getVars() const
Gets names and types (as strings) of model variables.
Definition: newModels.h:139
const std::vector< VarInit > & getInitialisers() const
Gets initialisers as a vector of Values.
Definition: newModels.h:77
Base class for all value initialisation snippets.
Definition: initVarSnippet.h:15
static size_t getVarIndex(const std::string &varName, const StringPairVec &vars)
Definition: newModels.h:154
VarInitContainerBase(T &&... initialisers)
Definition: newModels.h:109
Base class for all code snippets.
Definition: snippet.h:105
std::vector< VarInit > getInitialisers() const
Gets initialisers as a vector of Values.
Definition: newModels.h:122
VarInit(const InitVarSnippet::Base *snippet, const std::vector< double > &params)
Definition: newModels.h:39
std::vector< std::pair< std::string, std::string > > StringPairVec
Definition: snippet.h:117
Definition: codeGenUtils.h:24
size_t getVarIndex(const std::string &varName) const
Find the index of a named variable.
Definition: newModels.h:145
Definition: snippet.h:36
virtual StringVec getParamNames() const
Gets names of of (independent) model parameters.
Definition: newModels.h:187
m
Definition: build/lib.linux-x86_64-3.6/pygenn/genn_model.py:76
const VarInit & operator[](size_t pos) const
Definition: newModels.h:85