18 #define SET_DIAGONAL_BUILD_CODE(CODE) virtual std::string getDiagonalBuildCode() const override{ return CODE; } 19 #define SET_DIAGONAL_BUILD_STATE_VARS(...) virtual ParamValVec getDiagonalBuildStateVars() const override{ return __VA_ARGS__; } 21 #define SET_CALC_MAX_ROW_LENGTH_FUNC(FUNC) virtual CalcMaxLengthFunc getCalcMaxRowLengthFunc() const override{ return FUNC; } 22 #define SET_CALC_KERNEL_SIZE_FUNC(...) virtual CalcKernelSizeFunc getCalcKernelSizeFunc() const override{ return __VA_ARGS__; } 24 #define SET_MAX_ROW_LENGTH(MAX_ROW_LENGTH) virtual CalcMaxLengthFunc getCalcMaxRowLengthFunc() const override{ return [](unsigned int, unsigned int, const std::vector<double> &){ return MAX_ROW_LENGTH; }; } 38 typedef std::function<unsigned int(unsigned int, unsigned int, const std::vector<double> &)>
CalcMaxLengthFunc;
39 typedef std::function<std::vector<unsigned int>(
const std::vector<double> &)>
CalcKernelSizeFunc;
57 boost::uuids::detail::sha1::digest_type getHashDigest()
const;
60 void validate()
const;
69 Init(
const Base *snippet,
const std::vector<double> ¶ms)
97 "conv_ih",
"conv_iw",
"conv_ic",
98 "conv_oh",
"conv_ow",
"conv_oc"});
99 SET_DERIVED_PARAMS({{
"conv_bw", [](
const std::vector<double> &pars, double){
return (((
int)pars[3] + (
int)pars[1] - 1) - (
int)pars[6]) / 2; }},
100 {
"conv_bh", [](
const std::vector<double> &pars, double){
return (((
int)pars[2] + (
int)pars[0] - 1) - (int)pars[5]) / 2; }}});
103 {
"kernCol",
"int",
"($(id_diag) / (int)$(conv_oc)) % (int)$(conv_kw)"},
104 {
"kernOutChan",
"int",
"$(id_diag) % (int)$(conv_oc)"},
105 {
"flipKernRow",
"int",
"(int)$(conv_kh) - $(kernRow) - 1"},
106 {
"flipKernCol",
"int",
"(int)$(conv_kw) - $(kernCol) - 1"}});
109 "const int preRow = ($(id_pre) / (int)$(conv_ic)) / (int)$(conv_iw);\n" 110 "const int preCol = ($(id_pre) / (int)$(conv_ic)) % (int)$(conv_iw);\n" 111 "const int preChan = $(id_pre) % (int)$(conv_ic);\n" 112 "// If we haven't gone off edge of output\n" 113 "const int postRow = preRow + $(kernRow) - (int)$(conv_bh);\n" 114 "const int postCol = preCol + $(kernCol) - (int)$(conv_bw);\n" 115 "if(postRow >= 0 && postCol >= 0 && postRow < (int)$(conv_oh) && postCol < (int)$(conv_ow)) {\n" 116 " // Calculate postsynaptic index\n" 117 " const int postInd = ((postRow * (int)$(conv_ow) * (int)$(conv_oc)) +\n" 118 " (postCol * (int)$(conv_oc)) +\n" 119 " $(kernOutChan));\n" 120 " $(addSynapse, postInd, $(flipKernRow), $(flipKernCol), preChan, $(kernOutChan));\n" 124 [](
unsigned int,
unsigned int,
const std::vector<double> &pars)
126 const unsigned int convKH = (
unsigned int)pars[0];
127 const unsigned int convKW = (
unsigned int)pars[1];
128 const unsigned int convOC = (
unsigned int)pars[7];
129 return (convKH * convKW * convOC);
133 [](
const std::vector<double> &pars)->std::vector<unsigned int>
135 const unsigned int convKH = (
unsigned int)pars[0];
136 const unsigned int convKW = (
unsigned int)pars[1];
137 const unsigned int convIC = (
unsigned int)pars[4];
138 const unsigned int convOC = (
unsigned int)pars[7];
139 return {convKH, convKW, convIC, convOC};
155 "pool_kh",
"pool_kw",
156 "pool_sh",
"pool_sw",
157 "pool_ih",
"pool_iw",
"pool_ic",
158 "conv_oh",
"conv_ow",
"conv_oc"});
159 SET_DERIVED_PARAMS({{
"conv_bw", [](
const std::vector<double> &pars, double){
return (
int(ceil((pars[7] - pars[3] + 1.0) / pars[5])) + (
int)pars[1] - 1 - (
int)pars[10]) / 2; }},
160 {
"conv_bh", [](
const std::vector<double> &pars, double){
return (
int(ceil((pars[6] - pars[2] + 1.0) / pars[4])) + (int)pars[0] - 1 - (
int)pars[9]) / 2; }}});
163 {
"kernCol",
"int",
"($(id_diag) / (int)$(conv_oc)) % (int)$(conv_kw)"},
164 {
"kernOutChan",
"int",
"$(id_diag) % (int)$(conv_oc)"},
165 {
"flipKernRow",
"int",
"(int)$(conv_kh) - $(kernRow) - 1"},
166 {
"flipKernCol",
"int",
"(int)$(conv_kw) - $(kernCol) - 1"}});
169 "// Convert spike ID into row, column and channel going INTO pool\n" 170 "const int prePoolInRow = ($(id_pre) / (int)$(pool_ic)) / (int)$(pool_iw);\n" 171 "const int prePoolInCol = ($(id_pre) / (int)$(pool_ic)) % (int)$(pool_iw);\n" 172 "const int preChan = $(id_pre) % (int)$(pool_ic);\n" 173 "// Calculate row and column going OUT of pool\n" 174 "const int poolPreOutRow = prePoolInRow / (int)$(pool_sh);\n" 175 "const int poolStrideRow = poolPreOutRow * (int)$(pool_sh);\n" 176 "const int poolPreOutCol = prePoolInCol / (int)$(pool_sw);\n" 177 "const int poolStrideCol = poolPreOutCol * (int)$(pool_sw);\n" 178 "if(prePoolInRow < (poolStrideRow + (int)$(pool_kh)) && prePoolInCol < (poolStrideCol + (int)$(pool_kw))) {\n" 179 " // If we haven't gone off edge of output\n" 180 " const int postRow = poolPreOutRow + $(kernRow) - (int)$(conv_bh);\n" 181 " const int postCol = poolPreOutCol + $(kernCol) - (int)$(conv_bw);\n" 182 " if(postRow >= 0 && postCol >= 0 && postRow < (int)$(conv_oh) && postCol < (int)$(conv_ow)) {\n" 183 " // Calculate postsynaptic index\n" 184 " const int postInd = ((postRow * (int)$(conv_ow) * (int)$(conv_oc)) +\n" 185 " (postCol * (int)$(conv_oc)) +\n" 186 " $(kernOutChan));\n" 187 " $(addSynapse, postInd, $(flipKernRow), $(flipKernCol), preChan, $(kernOutChan));\n" 192 [](
unsigned int,
unsigned int,
const std::vector<double> &pars)
194 const unsigned int convKH = (
unsigned int)pars[0];
195 const unsigned int convKW = (
unsigned int)pars[1];
196 const unsigned int convOC = (
unsigned int)pars[11];
197 return (convKH * convKW * convOC);
201 [](
const std::vector<double> &pars)->std::vector<unsigned int>
203 const unsigned int convKH = (
unsigned int)pars[0];
204 const unsigned int convKW = (
unsigned int)pars[1];
205 const unsigned int poolIC = (
unsigned int)pars[8];
206 const unsigned int convOC = (
unsigned int)pars[11];
207 return {convKH, convKW, poolIC, convOC};
#define SET_CALC_MAX_ROW_LENGTH_FUNC(FUNC)
Definition: initToeplitzConnectivitySnippet.h:21
Definition: initToeplitzConnectivitySnippet.h:91
Definition: initToeplitzConnectivitySnippet.h:149
#define GENN_EXPORT
Definition: gennExport.h:13
std::function< std::vector< unsigned int >const std::vector< double > &)> CalcKernelSizeFunc
Definition: initToeplitzConnectivitySnippet.h:39
Definition: snippet.h:250
#define SET_CALC_KERNEL_SIZE_FUNC(...)
Definition: initToeplitzConnectivitySnippet.h:22
Definition: initToeplitzConnectivitySnippet.h:66
#define DECLARE_SNIPPET(TYPE, NUM_PARAMS)
Definition: snippet.h:19
Base class for all toeplitz connectivity initialisation snippets.
Definition: initToeplitzConnectivitySnippet.h:30
virtual CalcKernelSizeFunc getCalcKernelSizeFunc() const
Get function to calculate kernel size required for this conenctor based on its parameters.
Definition: initToeplitzConnectivitySnippet.h:51
#define SET_DIAGONAL_BUILD_CODE(CODE)
Definition: initToeplitzConnectivitySnippet.h:18
#define SET_DERIVED_PARAMS(...)
Definition: snippet.h:37
virtual CalcMaxLengthFunc getCalcMaxRowLengthFunc() const
Get function to calculate the maximum row length of this connector based on the parameters and the si...
Definition: initToeplitzConnectivitySnippet.h:48
Base class for all code snippets.
Definition: snippet.h:120
#define SET_PARAM_NAMES(...)
Definition: snippet.h:36
#define SET_DIAGONAL_BUILD_STATE_VARS(...)
Definition: initToeplitzConnectivitySnippet.h:19
virtual std::string getDiagonalBuildCode() const
Definition: initToeplitzConnectivitySnippet.h:44
std::function< unsigned int(unsigned int, unsigned int, const std::vector< double > &)> CalcMaxLengthFunc
Definition: initToeplitzConnectivitySnippet.h:38
Used to mark connectivity as uninitialised - no initialisation code will be run.
Definition: initToeplitzConnectivitySnippet.h:79
virtual ParamValVec getDiagonalBuildStateVars() const
Definition: initToeplitzConnectivitySnippet.h:45
Definition: initToeplitzConnectivitySnippet.h:32
Init(const Base *snippet, const std::vector< double > ¶ms)
Definition: initToeplitzConnectivitySnippet.h:69
std::vector< ParamVal > ParamValVec
Definition: snippet.h:180