10 #include <unordered_map> 30 template<
typename scalar =
float>
35 : m_Library(nullptr), m_AllocateMem(nullptr), m_AllocateRecordingBuffers(nullptr),
36 m_FreeMem(nullptr), m_Initialize(nullptr), m_InitializeSparse(nullptr),
37 m_StepTime(nullptr), m_PullRecordingBuffersFromDevice(nullptr),
38 m_NCCLGenerateUniqueID(nullptr), m_NCCLGetUniqueID(nullptr),
39 m_NCCLInitCommunicator(nullptr), m_NCCLUniqueIDBytes(nullptr)
44 bool includeModelNameInDLL =
false)
46 if(!
open(pathToModel, modelName, includeModelNameInDLL)) {
47 throw std::runtime_error(
"Unable to open library");
60 bool open(
const std::string &pathToModel,
const std::string &modelName,
61 bool includeModelNameInDLL =
false)
64 const std::string runnerName = includeModelNameInDLL ? (
"runner_" + modelName) :
"runner";
66 const std::string libraryName = pathToModel +
"\\" + runnerName +
"_Debug.dll";
68 const std::string libraryName = pathToModel +
"\\" + runnerName +
"_Release.dll";
70 m_Library = LoadLibrary(libraryName.c_str());
72 const std::string libraryName = pathToModel + modelName +
"_CODE/librunner.so";
73 m_Library = dlopen(libraryName.c_str(), RTLD_NOW);
77 if(m_Library !=
nullptr) {
78 m_AllocateMem = (VoidFunction)
getSymbol(
"allocateMem");
79 m_AllocateRecordingBuffers = (EGPFunction)
getSymbol(
"allocateRecordingBuffers",
true);
80 m_FreeMem = (VoidFunction)
getSymbol(
"freeMem");
81 m_GetFreeDeviceMemBytes = (GetFreeMemFunction)
getSymbol(
"getFreeDeviceMemBytes");
83 m_Initialize = (VoidFunction)
getSymbol(
"initialize");
84 m_InitializeSparse = (VoidFunction)
getSymbol(
"initializeSparse");
86 m_StepTime = (VoidFunction)
getSymbol(
"stepTime");
87 m_PullRecordingBuffersFromDevice = (VoidFunction)
getSymbol(
"pullRecordingBuffersFromDevice",
true);
90 m_Timestep = (
unsigned long long*)
getSymbol(
"iT");
92 m_NCCLGenerateUniqueID = (VoidFunction)
getSymbol(
"ncclGenerateUniqueID",
true);
93 m_NCCLGetUniqueID = (UCharPtrFunction)
getSymbol(
"ncclGetUniqueID",
true);
94 m_NCCLInitCommunicator = (NCCLInitCommunicatorFunction)
getSymbol(
"ncclInitCommunicator",
true);
95 m_NCCLUniqueIDBytes = (
unsigned int*)
getSymbol(
"ncclUniqueIDBytes",
true);
100 std::cerr <<
"Unable to load library - error:" << std::to_string(GetLastError()) << std::endl;;
102 std::cerr <<
"Unable to load library - error:" << dlerror() << std::endl;
113 FreeLibrary(m_Library);
121 m_AllocateMem =
nullptr;
122 m_AllocateRecordingBuffers =
nullptr;
124 m_GetFreeDeviceMemBytes =
nullptr;
125 m_Initialize =
nullptr;
126 m_InitializeSparse =
nullptr;
127 m_StepTime =
nullptr;
128 m_PullRecordingBuffersFromDevice =
nullptr;
129 m_NCCLGenerateUniqueID =
nullptr;
130 m_NCCLGetUniqueID =
nullptr;
131 m_NCCLInitCommunicator =
nullptr;
132 m_NCCLUniqueIDBytes =
nullptr;
134 m_Timestep =
nullptr;
137 m_PopulationVars.clear();
138 m_PopulationEPGs.clear();
139 m_CustomUpdates.clear();
145 const auto funcs = getEGPFunctions(egpName + popName);
146 if(std::get<0>(funcs) ==
nullptr) {
147 throw std::runtime_error(
"You cannot allocate EGP '" + egpName +
"' in population '" + popName +
"'");
151 std::get<0>(funcs)(count);
154 void allocateExtraGlobalParam(
const std::string &popName,
const std::string &varName,
const std::string &egpName,
unsigned int count)
157 const auto funcs = getEGPFunctions(egpName + varName + popName);
158 if(std::get<0>(funcs) ==
nullptr) {
159 throw std::runtime_error(
"You cannot allocate EGP '" + egpName +
"' for initializing '" + varName +
"'in population '" + popName +
"'");
163 std::get<0>(funcs)(count);
169 const auto funcs = getEGPFunctions(egpName + popName);
170 if(std::get<1>(funcs) ==
nullptr) {
171 throw std::runtime_error(
"You cannot free EGP '" + egpName +
"' in population '" + popName +
"'");
175 std::get<1>(funcs)();
181 const auto funcs = getEGPFunctions(egpName + varName + popName);
182 if(std::get<1>(funcs) ==
nullptr) {
183 throw std::runtime_error(
"You cannot free EGP '" + egpName +
"' for initializing '" + varName +
"'in population '" + popName +
"'");
187 std::get<1>(funcs)();
193 const auto pushPull = getPopPushPullFunction(popName +
"State");
194 if(pushPull.second ==
nullptr) {
195 throw std::runtime_error(
"You cannot pull state from population '" + popName +
"'");
205 const auto pushPull = getPopPushPullFunction(popName +
"Spikes");
206 if(pushPull.second ==
nullptr) {
207 throw std::runtime_error(
"You cannot pull spikes from population '" + popName +
"'");
217 const auto pushPull = getPopPushPullFunction(popName +
"SpikeEvents");
218 if(pushPull.second ==
nullptr) {
219 throw std::runtime_error(
"You cannot pull spike events from population '" + popName +
"'");
229 const auto pushPull = getPopPushPullFunction(popName +
"CurrentSpikes");
230 if(pushPull.second ==
nullptr) {
231 throw std::runtime_error(
"You cannot pull current spikes from population '" + popName +
"'");
241 const auto pushPull = getPopPushPullFunction(popName +
"CurrentSpikeEvents");
242 if(pushPull.second ==
nullptr) {
243 throw std::runtime_error(
"You cannot pull current spike events from population '" + popName +
"'");
253 const auto pushPull = getPopPushPullFunction(popName +
"Connectivity");
254 if(pushPull.second ==
nullptr) {
255 throw std::runtime_error(
"You cannot pull connectivity from population '" + popName +
"'");
265 const auto pushPull = getPopPushPullFunction(varName + popName);
266 if(pushPull.second ==
nullptr) {
267 throw std::runtime_error(
"You cannot pull var '" + varName +
"' from population '" + popName +
"'");
277 const auto funcs = getEGPFunctions(egpName + popName);
278 if(std::get<3>(funcs) ==
nullptr) {
279 throw std::runtime_error(
"You cannot pull EGP '" + egpName +
"' from population '" + popName +
"'");
283 std::get<3>(funcs)(count);
286 void pullExtraGlobalParam(
const std::string &popName,
const std::string &varName,
const std::string &egpName,
unsigned int count)
289 const auto funcs = getEGPFunctions(egpName + varName + popName);
290 if(std::get<3>(funcs) ==
nullptr) {
291 throw std::runtime_error(
"You cannot pull EGP '" + egpName +
"' for initializing '" + varName +
"'in population '" + popName +
"'");
295 std::get<3>(funcs)(count);
301 const auto pushPull = getPopPushPullFunction(popName +
"State");
302 if(pushPull.first ==
nullptr) {
303 throw std::runtime_error(
"You cannot push state to population '" + popName +
"'");
307 pushPull.first(uninitialisedOnly);
313 const auto pushPull = getPopPushPullFunction(popName +
"Spikes");
314 if(pushPull.first ==
nullptr) {
315 throw std::runtime_error(
"You cannot push spikes to population '" + popName +
"'");
319 pushPull.first(uninitialisedOnly);
325 const auto pushPull = getPopPushPullFunction(popName +
"SpikeEvents");
326 if(pushPull.first ==
nullptr) {
327 throw std::runtime_error(
"You cannot push spike events to population '" + popName +
"'");
331 pushPull.first(uninitialisedOnly);
337 const auto pushPull = getPopPushPullFunction(popName +
"CurrentSpikes");
338 if(pushPull.first ==
nullptr) {
339 throw std::runtime_error(
"You cannot push current spikes to population '" + popName +
"'");
343 pushPull.first(uninitialisedOnly);
349 const auto pushPull = getPopPushPullFunction(popName +
"CurrentSpikeEvents");
350 if(pushPull.first ==
nullptr) {
351 throw std::runtime_error(
"You cannot push current spike events to population '" + popName +
"'");
355 pushPull.first(uninitialisedOnly);
361 const auto pushPull = getPopPushPullFunction(popName +
"Connectivity");
362 if(pushPull.first ==
nullptr) {
363 throw std::runtime_error(
"You cannot push connectivity to population '" + popName +
"'");
367 pushPull.first(uninitialisedOnly);
370 void pushVarToDevice(
const std::string &popName,
const std::string &varName,
bool uninitialisedOnly =
false)
373 const auto pushPull = getPopPushPullFunction(varName + popName);
374 if(pushPull.first ==
nullptr) {
375 throw std::runtime_error(
"You cannot push var '" + varName +
"' to population '" + popName +
"'");
379 pushPull.first(uninitialisedOnly);
385 const auto funcs = getEGPFunctions(egpName + popName);
386 if(std::get<2>(funcs) ==
nullptr) {
387 throw std::runtime_error(
"You cannot push EGP '" + egpName +
"' to population '" + popName +
"'");
391 std::get<2>(funcs)(count);
394 void pushExtraGlobalParam(
const std::string &popName,
const std::string &varName,
const std::string &egpName,
unsigned int count)
397 const auto funcs = getEGPFunctions(egpName + varName + popName);
398 if(std::get<2>(funcs) ==
nullptr) {
399 throw std::runtime_error(
"You cannot push EGP '" + egpName +
"' for initializing '" + varName +
"'in population '" + popName +
"'");
403 std::get<2>(funcs)(count);
406 template<
typename Writer,
typename... WriterArgs>
421 return *(
static_cast<T**
>(
getSymbol(varName)));
428 return (static_cast<T*>(
getSymbol(varName)));
438 if(m_AllocateRecordingBuffers ==
nullptr) {
439 throw std::runtime_error(
"Cannot allocate recording buffers - model may not have recording enabled");
441 m_AllocateRecordingBuffers(timesteps);
451 return m_GetFreeDeviceMemBytes();
456 if(m_NCCLGenerateUniqueID ==
nullptr) {
457 throw std::runtime_error(
"Cannot generate NCCL unique ID - model may not have been built with NCCL support");
459 m_NCCLGenerateUniqueID();
464 if(m_NCCLGetUniqueID ==
nullptr) {
465 throw std::runtime_error(
"Cannot get NCCL unique ID - model may not have been built with NCCL support");
467 return m_NCCLGetUniqueID();
472 if(m_NCCLUniqueIDBytes ==
nullptr) {
473 throw std::runtime_error(
"Cannot get NCCL unique ID bytes - model may not have been built with NCCL support");
476 return *m_NCCLUniqueIDBytes;
481 if(m_NCCLInitCommunicator ==
nullptr) {
482 throw std::runtime_error(
"Cannot initialise NCCL communicator - model may not have been built with NCCL support");
484 m_NCCLInitCommunicator(rank, numRanks);
496 m_InitializeSparse();
506 auto c = m_CustomUpdates.find(name);
507 if(c != m_CustomUpdates.cend()) {
511 auto customUpdateFn = (VoidFunction)
getSymbol(
"update" + name);
512 m_CustomUpdates.emplace(name, customUpdateFn);
519 if(m_PullRecordingBuffersFromDevice ==
nullptr) {
520 throw std::runtime_error(
"Cannot pull recording buffers from device - model may not have recording enabled");
522 m_PullRecordingBuffersFromDevice();
554 void *
getSymbol(
const std::string &symbolName,
bool allowMissing =
false,
void *defaultSymbol =
nullptr)
const 557 void *symbol = GetProcAddress(m_Library, symbolName.c_str());
559 void *symbol = dlsym(m_Library, symbolName.c_str());
563 if(symbol ==
nullptr) {
566 throw std::runtime_error(
"Cannot find symbol '" + symbolName +
"'");
570 return defaultSymbol;
583 typedef void (*VoidFunction)(void);
584 typedef unsigned char* (*UCharPtrFunction)(void);
585 typedef void (*PushFunction)(bool);
586 typedef void (*PullFunction)(void);
587 typedef void (*EGPFunction)(
unsigned int);
588 typedef size_t (*GetFreeMemFunction)(void);
589 typedef void (*NCCLInitCommunicatorFunction)(int, int);
591 typedef std::pair<PushFunction, PullFunction> PushPullFunc;
592 typedef std::tuple<EGPFunction, VoidFunction, EGPFunction, EGPFunction> EGPFunc;
597 PushPullFunc getPopPushPullFunction(
const std::string &description)
600 const auto popVar = m_PopulationVars.find(description);
601 if(popVar != m_PopulationVars.end()) {
602 return popVar->second;
606 auto pushFunc = (PushFunction)
getSymbol(
"push" + description +
"ToDevice",
true);
607 auto pullFunc = (PullFunction)
getSymbol(
"pull" + description +
"FromDevice",
true);
610 auto newPopVar = m_PopulationVars.emplace(std::piecewise_construct,
611 std::forward_as_tuple(description),
612 std::forward_as_tuple(pushFunc, pullFunc));
615 return newPopVar.first->second;
619 EGPFunc getEGPFunctions(
const std::string &description)
622 const auto popEGP = m_PopulationEPGs.find(description);
623 if(popEGP != m_PopulationEPGs.end()) {
624 return popEGP->second;
628 auto allocateFunc = (EGPFunction)
getSymbol(
"allocate" + description,
true);
629 auto freeFunc = (VoidFunction)
getSymbol(
"free" + description,
true);
630 auto pushFunc = (EGPFunction)
getSymbol(
"push" + description +
"ToDevice",
true);
631 auto pullFunc = (EGPFunction)
getSymbol(
"pull" + description +
"FromDevice",
true);
634 auto newPopEGP = m_PopulationEPGs.emplace(std::piecewise_construct,
635 std::forward_as_tuple(description),
636 std::forward_as_tuple(allocateFunc, freeFunc,
637 pushFunc, pullFunc));
640 return newPopEGP.first->second;
653 VoidFunction m_AllocateMem;
654 EGPFunction m_AllocateRecordingBuffers;
655 VoidFunction m_FreeMem;
656 GetFreeMemFunction m_GetFreeDeviceMemBytes;
657 VoidFunction m_Initialize;
658 VoidFunction m_InitializeSparse;
659 VoidFunction m_StepTime;
661 PullFunction m_PullRecordingBuffersFromDevice;
663 VoidFunction m_NCCLGenerateUniqueID;
664 UCharPtrFunction m_NCCLGetUniqueID;
665 NCCLInitCommunicatorFunction m_NCCLInitCommunicator;
666 const unsigned int *m_NCCLUniqueIDBytes;
668 std::unordered_map<std::string, PushPullFunc> m_PopulationVars;
669 std::unordered_map<std::string, EGPFunc> m_PopulationEPGs;
670 std::unordered_map<std::string, VoidFunction> m_CustomUpdates;
672 unsigned long long *m_Timestep;
SharedLibraryModel()
Definition: sharedLibraryModel.h:34
double getNeuronUpdateTime() const
Definition: sharedLibraryModel.h:545
void * getSymbol(const std::string &symbolName, bool allowMissing=false, void *defaultSymbol=nullptr) const
Definition: sharedLibraryModel.h:554
void close()
Definition: sharedLibraryModel.h:108
void pushExtraGlobalParam(const std::string &popName, const std::string &varName, const std::string &egpName, unsigned int count)
Definition: sharedLibraryModel.h:394
unsigned char * ncclGetUniqueID()
Definition: sharedLibraryModel.h:462
void ncclInitCommunicator(int rank, int numRanks)
Definition: sharedLibraryModel.h:479
void setTimestep(unsigned long long iT)
Definition: sharedLibraryModel.h:540
void pushStateToDevice(const std::string &popName, bool uninitialisedOnly=false)
Definition: sharedLibraryModel.h:298
double getInitSparseTime() const
Definition: sharedLibraryModel.h:550
void allocateExtraGlobalParam(const std::string &popName, const std::string &egpName, unsigned int count)
Definition: sharedLibraryModel.h:142
void pullExtraGlobalParam(const std::string &popName, const std::string &varName, const std::string &egpName, unsigned int count)
Definition: sharedLibraryModel.h:286
void allocateRecordingBuffers(unsigned int timesteps)
Definition: sharedLibraryModel.h:436
SharedLibraryModel(const std::string &pathToModel, const std::string &modelName, bool includeModelNameInDLL=false)
Definition: sharedLibraryModel.h:43
double getSynapseDynamicsTime() const
Definition: sharedLibraryModel.h:549
void pushSpikeEventsToDevice(const std::string &popName, bool uninitialisedOnly=false)
Definition: sharedLibraryModel.h:322
Class to read spikes from neuron groups.
Definition: spikeRecorder.h:150
void freeMem()
Definition: sharedLibraryModel.h:444
void freeExtraGlobalParam(const std::string &popName, const std::string &varName, const std::string &egpName)
Definition: sharedLibraryModel.h:178
size_t getFreeDeviceMemBytes()
Definition: sharedLibraryModel.h:449
double getPresynapticUpdateTime() const
Definition: sharedLibraryModel.h:547
void pullCurrentSpikesFromDevice(const std::string &popName)
Definition: sharedLibraryModel.h:226
T * getArray(const std::string &varName)
Definition: sharedLibraryModel.h:419
double getCustomUpdateTransposeTime(const std::string &name) const
Definition: sharedLibraryModel.h:552
void pullCurrentSpikesEventsFromDevice(const std::string &popName)
Definition: sharedLibraryModel.h:238
double getPostsynapticUpdateTime() const
Definition: sharedLibraryModel.h:548
unsigned long long getTimestep() const
Definition: sharedLibraryModel.h:530
void pushCurrentSpikeEventsToDevice(const std::string &popName, bool uninitialisedOnly=false)
Definition: sharedLibraryModel.h:346
void initializeSparse()
Definition: sharedLibraryModel.h:494
bool open(const std::string &pathToModel, const std::string &modelName, bool includeModelNameInDLL=false)
Definition: sharedLibraryModel.h:60
void initialize()
Definition: sharedLibraryModel.h:489
Definition: sharedLibraryModel.h:31
void allocateMem()
Definition: sharedLibraryModel.h:431
void pullSpikeEventsFromDevice(const std::string &popName)
Definition: sharedLibraryModel.h:214
void pullRecordingBuffersFromDevice()
Definition: sharedLibraryModel.h:517
void pullConnectivityFromDevice(const std::string &popName)
Definition: sharedLibraryModel.h:250
void pushSpikesToDevice(const std::string &popName, bool uninitialisedOnly=false)
Definition: sharedLibraryModel.h:310
virtual ~SharedLibraryModel()
Definition: sharedLibraryModel.h:51
void customUpdate(const std::string &name)
Definition: sharedLibraryModel.h:504
void pullVarFromDevice(const std::string &popName, const std::string &varName)
Definition: sharedLibraryModel.h:262
void pullExtraGlobalParam(const std::string &popName, const std::string &egpName, unsigned int count)
Definition: sharedLibraryModel.h:274
void freeExtraGlobalParam(const std::string &popName, const std::string &egpName)
Definition: sharedLibraryModel.h:166
void pushCurrentSpikesToDevice(const std::string &popName, bool uninitialisedOnly=false)
Definition: sharedLibraryModel.h:334
void stepTime()
Definition: sharedLibraryModel.h:499
T * getScalar(const std::string &varName)
Definition: sharedLibraryModel.h:426
void ncclGenerateUniqueID()
Definition: sharedLibraryModel.h:454
void pushVarToDevice(const std::string &popName, const std::string &varName, bool uninitialisedOnly=false)
Definition: sharedLibraryModel.h:370
scalar getTime() const
Definition: sharedLibraryModel.h:525
void allocateExtraGlobalParam(const std::string &popName, const std::string &varName, const std::string &egpName, unsigned int count)
Definition: sharedLibraryModel.h:154
void setTime(scalar t)
Definition: sharedLibraryModel.h:535
void pullSpikesFromDevice(const std::string &popName)
Definition: sharedLibraryModel.h:202
void pullStateFromDevice(const std::string &popName)
Definition: sharedLibraryModel.h:190
SpikeRecorder< Writer > getSpikeRecorder(const std::string &popName, WriterArgs &&... writerArgs)
Definition: sharedLibraryModel.h:407
double getInitTime() const
Definition: sharedLibraryModel.h:546
void pushConnectivityToDevice(const std::string &popName, bool uninitialisedOnly=false)
Definition: sharedLibraryModel.h:358
double getCustomUpdateTime(const std::string &name) const
Definition: sharedLibraryModel.h:551
unsigned int ncclGetUniqueIDBytes() const
Definition: sharedLibraryModel.h:470
float scalar
Definition: modelProperty.h:27
void pushExtraGlobalParam(const std::string &popName, const std::string &egpName, unsigned int count)
Definition: sharedLibraryModel.h:382