GeNN  4.9.0
GPU enhanced Neuronal Networks (GeNN)
initToeplitzConnectivitySnippet.h
Go to the documentation of this file.
1 #pragma once
2 
3 // Standard C++ includes
4 #include <functional>
5 #include <vector>
6 
7 // Standard C includes
8 #include <cassert>
9 #include <cmath>
10 
11 // GeNN includes
12 #include "binomial.h"
13 #include "snippet.h"
14 
15 //----------------------------------------------------------------------------
16 // Macros
17 //----------------------------------------------------------------------------
18 #define SET_DIAGONAL_BUILD_CODE(CODE) virtual std::string getDiagonalBuildCode() const override{ return CODE; }
19 #define SET_DIAGONAL_BUILD_STATE_VARS(...) virtual ParamValVec getDiagonalBuildStateVars() const override{ return __VA_ARGS__; }
20 
21 #define SET_CALC_MAX_ROW_LENGTH_FUNC(FUNC) virtual CalcMaxLengthFunc getCalcMaxRowLengthFunc() const override{ return FUNC; }
22 #define SET_CALC_KERNEL_SIZE_FUNC(...) virtual CalcKernelSizeFunc getCalcKernelSizeFunc() const override{ return __VA_ARGS__; }
23 
24 #define SET_MAX_ROW_LENGTH(MAX_ROW_LENGTH) virtual CalcMaxLengthFunc getCalcMaxRowLengthFunc() const override{ return [](unsigned int, unsigned int, const std::vector<double> &){ return MAX_ROW_LENGTH; }; }
25 
26 //----------------------------------------------------------------------------
27 // InitToeplitzConnectivitySnippet::Base
28 //----------------------------------------------------------------------------
31 {
33 {
34 public:
35  //----------------------------------------------------------------------------
36  // Typedefines
37  //----------------------------------------------------------------------------
38  typedef std::function<unsigned int(unsigned int, unsigned int, const std::vector<double> &)> CalcMaxLengthFunc;
39  typedef std::function<std::vector<unsigned int>(const std::vector<double> &)> CalcKernelSizeFunc;
40 
41  //----------------------------------------------------------------------------
42  // Declared virtuals
43  //----------------------------------------------------------------------------
44  virtual std::string getDiagonalBuildCode() const{ return ""; }
45  virtual ParamValVec getDiagonalBuildStateVars() const { return {}; }
46 
49 
52 
53  //------------------------------------------------------------------------
54  // Public methods
55  //------------------------------------------------------------------------
57  boost::uuids::detail::sha1::digest_type getHashDigest() const;
58 
60  void validate() const;
61 };
62 
63 //----------------------------------------------------------------------------
64 // Init
65 //----------------------------------------------------------------------------
66 class Init : public Snippet::Init<InitToeplitzConnectivitySnippet::Base>
67 {
68 public:
69  Init(const Base *snippet, const std::vector<double> &params)
70  : Snippet::Init<Base>(snippet, params)
71  {
72  }
73 };
74 
75 //----------------------------------------------------------------------------
76 // InitToeplitzConnectivitySnippet::Uninitialised
77 //----------------------------------------------------------------------------
79 class Uninitialised : public Base
80 {
81 public:
83 };
84 
85 //----------------------------------------------------------------------------
86 // InitToeplitzConnectivitySnippet::Conv2D
87 //----------------------------------------------------------------------------
91 class Conv2D : public Base
92 {
93 public:
95 
96  SET_PARAM_NAMES({"conv_kh", "conv_kw",
97  "conv_ih", "conv_iw", "conv_ic",
98  "conv_oh", "conv_ow", "conv_oc"});
99  SET_DERIVED_PARAMS({{"conv_bw", [](const std::vector<double> &pars, double){ return (((int)pars[3] + (int)pars[1] - 1) - (int)pars[6]) / 2; }},
100  {"conv_bh", [](const std::vector<double> &pars, double){ return (((int)pars[2] + (int)pars[0] - 1) - (int)pars[5]) / 2; }}});
101 
102  SET_DIAGONAL_BUILD_STATE_VARS({{"kernRow", "int", "($(id_diag) / (int)$(conv_oc)) / (int)$(conv_kw)"},
103  {"kernCol", "int", "($(id_diag) / (int)$(conv_oc)) % (int)$(conv_kw)"},
104  {"kernOutChan", "int", "$(id_diag) % (int)$(conv_oc)"},
105  {"flipKernRow", "int", "(int)$(conv_kh) - $(kernRow) - 1"},
106  {"flipKernCol", "int", "(int)$(conv_kw) - $(kernCol) - 1"}});
107 
109  "const int preRow = ($(id_pre) / (int)$(conv_ic)) / (int)$(conv_iw);\n"
110  "const int preCol = ($(id_pre) / (int)$(conv_ic)) % (int)$(conv_iw);\n"
111  "const int preChan = $(id_pre) % (int)$(conv_ic);\n"
112  "// If we haven't gone off edge of output\n"
113  "const int postRow = preRow + $(kernRow) - (int)$(conv_bh);\n"
114  "const int postCol = preCol + $(kernCol) - (int)$(conv_bw);\n"
115  "if(postRow >= 0 && postCol >= 0 && postRow < (int)$(conv_oh) && postCol < (int)$(conv_ow)) {\n"
116  " // Calculate postsynaptic index\n"
117  " const int postInd = ((postRow * (int)$(conv_ow) * (int)$(conv_oc)) +\n"
118  " (postCol * (int)$(conv_oc)) +\n"
119  " $(kernOutChan));\n"
120  " $(addSynapse, postInd, $(flipKernRow), $(flipKernCol), preChan, $(kernOutChan));\n"
121  "}\n");
122 
124  [](unsigned int, unsigned int, const std::vector<double> &pars)
125  {
126  const unsigned int convKH = (unsigned int)pars[0];
127  const unsigned int convKW = (unsigned int)pars[1];
128  const unsigned int convOC = (unsigned int)pars[7];
129  return (convKH * convKW * convOC);
130  });
131 
133  [](const std::vector<double> &pars)->std::vector<unsigned int>
134  {
135  const unsigned int convKH = (unsigned int)pars[0];
136  const unsigned int convKW = (unsigned int)pars[1];
137  const unsigned int convIC = (unsigned int)pars[4];
138  const unsigned int convOC = (unsigned int)pars[7];
139  return {convKH, convKW, convIC, convOC};
140  });
141 };
142 
143 //----------------------------------------------------------------------------
144 // InitToeplitzConnectivitySnippet::AvgPoolConv2D
145 //----------------------------------------------------------------------------
149 class AvgPoolConv2D : public Base
150 {
151 public:
153 
154  SET_PARAM_NAMES({"conv_kh", "conv_kw",
155  "pool_kh", "pool_kw",
156  "pool_sh", "pool_sw",
157  "pool_ih", "pool_iw", "pool_ic",
158  "conv_oh", "conv_ow", "conv_oc"});
159  SET_DERIVED_PARAMS({{"conv_bw", [](const std::vector<double> &pars, double){ return (int(ceil((pars[7] - pars[3] + 1.0) / pars[5])) + (int)pars[1] - 1 - (int)pars[10]) / 2; }},
160  {"conv_bh", [](const std::vector<double> &pars, double){ return (int(ceil((pars[6] - pars[2] + 1.0) / pars[4])) + (int)pars[0] - 1 - (int)pars[9]) / 2; }}});
161 
162  SET_DIAGONAL_BUILD_STATE_VARS({{"kernRow", "int", "($(id_diag) / (int)$(conv_oc)) / (int)$(conv_kw)"},
163  {"kernCol", "int", "($(id_diag) / (int)$(conv_oc)) % (int)$(conv_kw)"},
164  {"kernOutChan", "int", "$(id_diag) % (int)$(conv_oc)"},
165  {"flipKernRow", "int", "(int)$(conv_kh) - $(kernRow) - 1"},
166  {"flipKernCol", "int", "(int)$(conv_kw) - $(kernCol) - 1"}});
167 
169  "// Convert spike ID into row, column and channel going INTO pool\n"
170  "const int prePoolInRow = ($(id_pre) / (int)$(pool_ic)) / (int)$(pool_iw);\n"
171  "const int prePoolInCol = ($(id_pre) / (int)$(pool_ic)) % (int)$(pool_iw);\n"
172  "const int preChan = $(id_pre) % (int)$(pool_ic);\n"
173  "// Calculate row and column going OUT of pool\n"
174  "const int poolPreOutRow = prePoolInRow / (int)$(pool_sh);\n"
175  "const int poolStrideRow = poolPreOutRow * (int)$(pool_sh);\n"
176  "const int poolPreOutCol = prePoolInCol / (int)$(pool_sw);\n"
177  "const int poolStrideCol = poolPreOutCol * (int)$(pool_sw);\n"
178  "if(prePoolInRow < (poolStrideRow + (int)$(pool_kh)) && prePoolInCol < (poolStrideCol + (int)$(pool_kw))) {\n"
179  " // If we haven't gone off edge of output\n"
180  " const int postRow = poolPreOutRow + $(kernRow) - (int)$(conv_bh);\n"
181  " const int postCol = poolPreOutCol + $(kernCol) - (int)$(conv_bw);\n"
182  " if(postRow >= 0 && postCol >= 0 && postRow < (int)$(conv_oh) && postCol < (int)$(conv_ow)) {\n"
183  " // Calculate postsynaptic index\n"
184  " const int postInd = ((postRow * (int)$(conv_ow) * (int)$(conv_oc)) +\n"
185  " (postCol * (int)$(conv_oc)) +\n"
186  " $(kernOutChan));\n"
187  " $(addSynapse, postInd, $(flipKernRow), $(flipKernCol), preChan, $(kernOutChan));\n"
188  " }\n"
189  "}\n");
190 
192  [](unsigned int, unsigned int, const std::vector<double> &pars)
193  {
194  const unsigned int convKH = (unsigned int)pars[0];
195  const unsigned int convKW = (unsigned int)pars[1];
196  const unsigned int convOC = (unsigned int)pars[11];
197  return (convKH * convKW * convOC);
198  });
199 
201  [](const std::vector<double> &pars)->std::vector<unsigned int>
202  {
203  const unsigned int convKH = (unsigned int)pars[0];
204  const unsigned int convKW = (unsigned int)pars[1];
205  const unsigned int poolIC = (unsigned int)pars[8];
206  const unsigned int convOC = (unsigned int)pars[11];
207  return {convKH, convKW, poolIC, convOC};
208  });
209 };
210 } // namespace InitToeplitzConnectivitySnippet
#define SET_CALC_MAX_ROW_LENGTH_FUNC(FUNC)
Definition: initToeplitzConnectivitySnippet.h:21
Definition: initToeplitzConnectivitySnippet.h:91
Definition: initToeplitzConnectivitySnippet.h:149
#define GENN_EXPORT
Definition: gennExport.h:13
std::function< std::vector< unsigned int >const std::vector< double > &)> CalcKernelSizeFunc
Definition: initToeplitzConnectivitySnippet.h:39
Definition: snippet.h:250
#define SET_CALC_KERNEL_SIZE_FUNC(...)
Definition: initToeplitzConnectivitySnippet.h:22
Definition: initToeplitzConnectivitySnippet.h:66
#define DECLARE_SNIPPET(TYPE, NUM_PARAMS)
Definition: snippet.h:19
Base class for all toeplitz connectivity initialisation snippets.
Definition: initToeplitzConnectivitySnippet.h:30
virtual CalcKernelSizeFunc getCalcKernelSizeFunc() const
Get function to calculate kernel size required for this conenctor based on its parameters.
Definition: initToeplitzConnectivitySnippet.h:51
#define SET_DIAGONAL_BUILD_CODE(CODE)
Definition: initToeplitzConnectivitySnippet.h:18
#define SET_DERIVED_PARAMS(...)
Definition: snippet.h:37
virtual CalcMaxLengthFunc getCalcMaxRowLengthFunc() const
Get function to calculate the maximum row length of this connector based on the parameters and the si...
Definition: initToeplitzConnectivitySnippet.h:48
Base class for all code snippets.
Definition: snippet.h:120
#define SET_PARAM_NAMES(...)
Definition: snippet.h:36
#define SET_DIAGONAL_BUILD_STATE_VARS(...)
Definition: initToeplitzConnectivitySnippet.h:19
virtual std::string getDiagonalBuildCode() const
Definition: initToeplitzConnectivitySnippet.h:44
std::function< unsigned int(unsigned int, unsigned int, const std::vector< double > &)> CalcMaxLengthFunc
Definition: initToeplitzConnectivitySnippet.h:38
Definition: snippet.h:45
Used to mark connectivity as uninitialised - no initialisation code will be run.
Definition: initToeplitzConnectivitySnippet.h:79
virtual ParamValVec getDiagonalBuildStateVars() const
Definition: initToeplitzConnectivitySnippet.h:45
Definition: initToeplitzConnectivitySnippet.h:32
Init(const Base *snippet, const std::vector< double > &params)
Definition: initToeplitzConnectivitySnippet.h:69
std::vector< ParamVal > ParamValVec
Definition: snippet.h:180