GeNN  3.3.0
GPU enhanced Neuronal Networks (GeNN)
snippet.h
Go to the documentation of this file.
1 #pragma once
2 
3 // Standard C++ includes
4 #include <functional>
5 #include <string>
6 #include <vector>
7 
8 //----------------------------------------------------------------------------
9 // Macros
10 //----------------------------------------------------------------------------
11 #define DECLARE_SNIPPET(TYPE, NUM_PARAMS) \
12 private: \
13  static TYPE *s_Instance; \
14 public: \
15  static const TYPE *getInstance() \
16  { \
17  if(s_Instance == NULL) \
18  { \
19  s_Instance = new TYPE; \
20  } \
21  return s_Instance; \
22  } \
23  typedef Snippet::ValueBase<NUM_PARAMS> ParamValues; \
24 
25 
26 #define IMPLEMENT_SNIPPET(TYPE) TYPE *TYPE::s_Instance = NULL
27 
28 #define SET_PARAM_NAMES(...) virtual StringVec getParamNames() const override{ return __VA_ARGS__; }
29 #define SET_DERIVED_PARAMS(...) virtual DerivedParamVec getDerivedParams() const override{ return __VA_ARGS__; }
30 
31 //----------------------------------------------------------------------------
32 // Snippet::ValueBase
33 //----------------------------------------------------------------------------
36 namespace Snippet
37 {
38 template<size_t NumVars>
39 class ValueBase
40 {
41 public:
42  // **NOTE** other less terrifying forms of constructor won't complain at compile time about
43  // number of parameters e.g. std::array<double, 4> can be initialized with <= 4 elements
44  template<typename... T>
45  ValueBase(T&&... vals) : m_Values(std::vector<double>{{std::forward<const double>(vals)...}})
46  {
47  static_assert(sizeof...(vals) == NumVars, "Wrong number of values");
48  }
49 
50  //----------------------------------------------------------------------------
51  // Public API
52  //----------------------------------------------------------------------------
54  const std::vector<double> &getValues() const
55  {
56  return m_Values;
57  }
58 
59  //----------------------------------------------------------------------------
60  // Operators
61  //----------------------------------------------------------------------------
62  double operator[](size_t pos) const
63  {
64  return m_Values[pos];
65  }
66 
67 private:
68  //----------------------------------------------------------------------------
69  // Members
70  //----------------------------------------------------------------------------
71  std::vector<double> m_Values;
72 };
73 
74 //----------------------------------------------------------------------------
75 // NewModels::ValueBase<0>
76 //----------------------------------------------------------------------------
79 template<>
80 class ValueBase<0>
81 {
82 public:
83  // **NOTE** other less terrifying forms of constructor won't complain at compile time about
84  // number of parameters e.g. std::array<double, 4> can be initialized with <= 4 elements
85  template<typename... T>
86  ValueBase(T&&... vals)
87  {
88  static_assert(sizeof...(vals) == 0, "Wrong number of values");
89  }
90 
91  //----------------------------------------------------------------------------
92  // Public API
93  //----------------------------------------------------------------------------
95  std::vector<double> getValues() const
96  {
97  return {};
98  }
99 };
100 
101 //----------------------------------------------------------------------------
102 // Snippet::Base
103 //----------------------------------------------------------------------------
105 class Base
106 {
107 public:
108  virtual ~Base()
109  {
110  }
111 
112  //----------------------------------------------------------------------------
113  // Typedefines
114  //----------------------------------------------------------------------------
115  typedef std::function<double(const std::vector<double> &, double)> DerivedParamFunc;
116  typedef std::vector<std::string> StringVec;
117  typedef std::vector<std::pair<std::string, std::string>> StringPairVec;
118  typedef std::vector<std::pair<std::string, std::pair<std::string, double>>> NameTypeValVec;
119  typedef std::vector<std::pair<std::string, DerivedParamFunc>> DerivedParamVec;
120 
121  //----------------------------------------------------------------------------
122  // Declared virtuals
123  //----------------------------------------------------------------------------
125  virtual StringVec getParamNames() const{ return {}; }
126 
129  virtual DerivedParamVec getDerivedParams() const{ return {}; }
130 };
131 
132 //----------------------------------------------------------------------------
133 // Snippet::Init
134 //----------------------------------------------------------------------------
138 template<typename SnippetBase>
139 class Init
140 {
141 public:
142  Init(const SnippetBase *snippet, const std::vector<double> &params)
143  : m_Snippet(snippet), m_Params(params)
144  {
145  }
146 
147  //----------------------------------------------------------------------------
148  // Public API
149  //----------------------------------------------------------------------------
150  const SnippetBase *getSnippet() const{ return m_Snippet; }
151  const std::vector<double> &getParams() const{ return m_Params; }
152  const std::vector<double> &getDerivedParams() const{ return m_DerivedParams; }
153 
154  void initDerivedParams(double dt)
155  {
156  auto derivedParams = m_Snippet->getDerivedParams();
157 
158  // Reserve vector to hold derived parameters
159  m_DerivedParams.reserve(derivedParams.size());
160 
161  // Loop through derived parameters
162  for(const auto &d : derivedParams) {
163  m_DerivedParams.push_back(d.second(m_Params, dt));
164  }
165  }
166 
167 private:
168  //----------------------------------------------------------------------------
169  // Members
170  //----------------------------------------------------------------------------
171  const SnippetBase *m_Snippet;
172  std::vector<double> m_Params;
173  std::vector<double> m_DerivedParams;
174 };
175 } // namespace Snippet
std::vector< std::pair< std::string, DerivedParamFunc > > DerivedParamVec
Definition: snippet.h:119
Definition: snippet.h:139
std::vector< std::pair< std::string, std::pair< std::string, double > > > NameTypeValVec
Definition: snippet.h:118
Init(const SnippetBase *snippet, const std::vector< double > &params)
Definition: snippet.h:142
ValueBase(T &&... vals)
Definition: snippet.h:86
std::vector< double > getValues() const
Gets values as a vector of doubles.
Definition: snippet.h:95
std::function< double(const std::vector< double > &, double)> DerivedParamFunc
Definition: snippet.h:115
ValueBase(T &&... vals)
Definition: snippet.h:45
void initDerivedParams(double dt)
Definition: snippet.h:154
virtual ~Base()
Definition: snippet.h:108
Definition: snippet.h:39
double operator[](size_t pos) const
Definition: snippet.h:62
virtual StringVec getParamNames() const
Gets names of of (independent) model parameters.
Definition: snippet.h:125
const SnippetBase * getSnippet() const
Definition: snippet.h:150
const std::vector< double > & getParams() const
Definition: snippet.h:151
std::vector< std::string > StringVec
Definition: snippet.h:116
Base class for all code snippets.
Definition: snippet.h:105
const std::vector< double > & getValues() const
Gets values as a vector of doubles.
Definition: snippet.h:54
std::vector< std::pair< std::string, std::string > > StringPairVec
Definition: snippet.h:117
Definition: snippet.h:36
virtual DerivedParamVec getDerivedParams() const
Definition: snippet.h:129
const std::vector< double > & getDerivedParams() const
Definition: snippet.h:152