GeNN  3.3.0
GPU enhanced Neuronal Networks (GeNN)
initSparseConnectivitySnippet.h
Go to the documentation of this file.
1 #pragma once
2 
3 // Standard C++ includes
4 #include <functional>
5 #include <vector>
6 
7 // Standard C includes
8 #include <cassert>
9 #include <cmath>
10 
11 // GeNN includes
12 #include "binomial.h"
13 #include "snippet.h"
14 
15 //----------------------------------------------------------------------------
16 // Macros
17 //----------------------------------------------------------------------------
18 #define SET_ROW_BUILD_CODE(CODE) virtual std::string getRowBuildCode() const override{ return CODE; }
19 #define SET_ROW_BUILD_STATE_VARS(...) virtual NameTypeValVec getRowBuildStateVars() const override{ return __VA_ARGS__; }
20 
21 #define SET_CALC_MAX_ROW_LENGTH_FUNC(FUNC) virtual CalcMaxLengthFunc getCalcMaxRowLengthFunc() const override{ return FUNC; }
22 #define SET_CALC_MAX_COL_LENGTH_FUNC(FUNC) virtual CalcMaxLengthFunc getCalcMaxColLengthFunc() const override{ return FUNC; }
23 
24 #define SET_EXTRA_GLOBAL_PARAMS(...) virtual StringPairVec getExtraGlobalParams() const override{ return __VA_ARGS__; }
25 
26 //----------------------------------------------------------------------------
27 // InitSparseConnectivitySnippet::Base
28 //----------------------------------------------------------------------------
31 {
32 class Base : public Snippet::Base
33 {
34 public:
35  //----------------------------------------------------------------------------
36  // Typedefines
37  //----------------------------------------------------------------------------
38  typedef std::function<unsigned int(unsigned int, unsigned int, const std::vector<double> &)> CalcMaxLengthFunc;
39 
40  //----------------------------------------------------------------------------
41  // Declared virtuals
42  //----------------------------------------------------------------------------
43  virtual std::string getRowBuildCode() const{ return ""; }
44  virtual NameTypeValVec getRowBuildStateVars() const{ return {}; }
45 
48 
51 
54  virtual StringPairVec getExtraGlobalParams() const{ return {}; }
55 };
56 
57 //----------------------------------------------------------------------------
58 // Init
59 //----------------------------------------------------------------------------
60 class Init : public Snippet::Init<Base>
61 {
62 public:
63  Init(const Base *snippet, const std::vector<double> &params)
64  : Snippet::Init<Base>(snippet, params)
65  {
66  }
67 };
68 
69 //----------------------------------------------------------------------------
70 // InitSparseConnectivitySnippet::Uninitialised
71 //----------------------------------------------------------------------------
73 class Uninitialised : public Base
74 {
75 public:
77 };
78 
79 //----------------------------------------------------------------------------
80 // InitSparseConnectivitySnippet::OneToOne
81 //----------------------------------------------------------------------------
83 class OneToOne : public Base
84 {
85 public:
87 
89  "$(addSynapse, $(id_pre));\n"
90  "$(endRow);\n");
91 
93  [](unsigned int numPre, unsigned int numPost, const std::vector<double> &)
94  {
95  assert(numPre == numPost);
96  return 1;
97  });
99  [](unsigned int numPre, unsigned int numPost, const std::vector<double> &)
100  {
101  assert(numPre == numPost);
102  return 1;
103  });
104 };
105 
106 //----------------------------------------------------------------------------
107 // InitSparseConnectivitySnippet::FixedProbabilityBase
108 //----------------------------------------------------------------------------
112 {
113 public:
114  virtual std::string getRowBuildCode() const override = 0;
115 
116  SET_ROW_BUILD_STATE_VARS({{"prevJ", {"int", -1}}});
117 
118  SET_PARAM_NAMES({"prob"});
119  SET_DERIVED_PARAMS({{"probLogRecip", [](const std::vector<double> &pars, double){ return 1.0 / log(1.0 - pars[0]); }}});
120 
122  [](unsigned int numPre, unsigned int numPost, const std::vector<double> &pars)
123  {
124  // Calculate suitable quantile for 0.9999 change when drawing numPre times
125  const double quantile = pow(0.9999, 1.0 / (double)numPre);
126 
127  return binomialInverseCDF(quantile, numPost, pars[0]);
128  });
130  [](unsigned int numPre, unsigned int numPost, const std::vector<double> &pars)
131  {
132  // Calculate suitable quantile for 0.9999 change when drawing numPos times
133  const double quantile = pow(0.9999, 1.0 / (double)numPost);
134 
135  return binomialInverseCDF(quantile, numPre, pars[0]);
136  });
137 };
138 
139 //----------------------------------------------------------------------------
140 // InitSparseConnectivitySnippet::FixedProbability
141 //----------------------------------------------------------------------------
144 
154 {
155 public:
157 
159  "const scalar u = $(gennrand_uniform);\n"
160  "prevJ += (1 + (int)(log(u) * $(probLogRecip)));\n"
161  "if(prevJ < $(num_post)) {\n"
162  " $(addSynapse, prevJ);\n"
163  "}\n"
164  "else {\n"
165  " $(endRow);\n"
166  "}\n");
167 };
168 
169 //----------------------------------------------------------------------------
170 // InitSparseConnectivitySnippet::FixedProbabilityNoAutapse
171 //----------------------------------------------------------------------------
176 
186 {
187 public:
189 
191  "int nextJ;\n"
192  "do {\n"
193  " const scalar u = $(gennrand_uniform);\n"
194  " nextJ = prevJ + (1 + (int)(log(u) * $(probLogRecip)));\n"
195  "} while(nextJ == $(id_pre));\n"
196  "prevJ = nextJ;\n"
197  "if(prevJ < $(num_post)) {\n"
198  " $(addSynapse, prevJ);\n"
199  "}\n"
200  "else {\n"
201  " $(endRow);\n"
202  "}\n");
203 };
204 
205 } // namespace InitVarSnippet
virtual StringPairVec getExtraGlobalParams() const
Definition: initSparseConnectivitySnippet.h:54
virtual CalcMaxLengthFunc getCalcMaxColLengthFunc() const
Get function to calculate the maximum column length of this connector based on the parameters and the...
Definition: initSparseConnectivitySnippet.h:50
unsigned int binomialInverseCDF(double cdf, unsigned int n, double p)
Definition: binomial.cc:106
Definition: initSparseConnectivitySnippet.h:60
Definition: snippet.h:139
std::vector< std::pair< std::string, std::pair< std::string, double > > > NameTypeValVec
Definition: snippet.h:118
virtual CalcMaxLengthFunc getCalcMaxRowLengthFunc() const
Get function to calculate the maximum row length of this connector based on the parameters and the si...
Definition: initSparseConnectivitySnippet.h:47
#define SET_ROW_BUILD_STATE_VARS(...)
Definition: initSparseConnectivitySnippet.h:19
#define DECLARE_SNIPPET(TYPE, NUM_PARAMS)
Definition: snippet.h:11
#define SET_CALC_MAX_COL_LENGTH_FUNC(FUNC)
Definition: initSparseConnectivitySnippet.h:22
Definition: initSparseConnectivitySnippet.h:32
virtual NameTypeValVec getRowBuildStateVars() const
Definition: initSparseConnectivitySnippet.h:44
Base class for all sparse connectivity initialisation snippets.
Definition: initSparseConnectivitySnippet.h:30
#define SET_DERIVED_PARAMS(...)
Definition: snippet.h:29
Initialises connectivity to a &#39;one-to-one&#39; diagonal matrix.
Definition: initSparseConnectivitySnippet.h:83
virtual std::string getRowBuildCode() const
Definition: initSparseConnectivitySnippet.h:43
std::function< unsigned int(unsigned int, unsigned int, const std::vector< double > &)> CalcMaxLengthFunc
Definition: initSparseConnectivitySnippet.h:38
#define SET_ROW_BUILD_CODE(CODE)
Definition: initSparseConnectivitySnippet.h:18
Used to mark connectivity as uninitialised - no initialisation code will be run.
Definition: initSparseConnectivitySnippet.h:73
Definition: initSparseConnectivitySnippet.h:153
Base class for all code snippets.
Definition: snippet.h:105
#define SET_PARAM_NAMES(...)
Definition: snippet.h:28
Init(const Base *snippet, const std::vector< double > &params)
Definition: initSparseConnectivitySnippet.h:63
std::vector< std::pair< std::string, std::string > > StringPairVec
Definition: snippet.h:117
Definition: snippet.h:36
#define SET_CALC_MAX_ROW_LENGTH_FUNC(FUNC)
Definition: initSparseConnectivitySnippet.h:21
Definition: initSparseConnectivitySnippet.h:111
Definition: initSparseConnectivitySnippet.h:185