PageRenderTime 24ms CodeModel.GetById 23ms RepoModel.GetById 1ms app.codeStats 0ms

/src/Platform/Windows/System/Dispatcher.cpp

https://github.com/simplecoin/simplecoin
C++ | 419 lines | 363 code | 53 blank | 3 comment | 126 complexity | 552554682acf856f61074667da9eff31 MD5 | raw file
  1. // Copyright (c) 2011-2016 The Cryptonote developers
  2. // Distributed under the MIT/X11 software license, see the accompanying
  3. // file COPYING or http://www.opensource.org/licenses/mit-license.php.
  4. #include "Dispatcher.h"
  5. #include <cassert>
  6. #include <string>
  7. #ifndef WIN32_LEAN_AND_MEAN
  8. #define WIN32_LEAN_AND_MEAN
  9. #endif
  10. #ifndef NOMINMAX
  11. #define NOMINMAX
  12. #endif
  13. #include <winsock2.h>
  14. #include "ErrorMessage.h"
  15. namespace System {
  16. namespace {
  17. struct DispatcherContext : public OVERLAPPED {
  18. NativeContext* context;
  19. };
  20. const size_t STACK_SIZE = 16384;
  21. const size_t RESERVE_STACK_SIZE = 2097152;
  22. }
  23. Dispatcher::Dispatcher() {
  24. static_assert(sizeof(CRITICAL_SECTION) == sizeof(Dispatcher::criticalSection), "CRITICAL_SECTION size doesn't fit sizeof(Dispatcher::criticalSection)");
  25. BOOL result = InitializeCriticalSectionAndSpinCount(reinterpret_cast<LPCRITICAL_SECTION>(criticalSection), 4000);
  26. assert(result != FALSE);
  27. std::string message;
  28. if (ConvertThreadToFiberEx(NULL, 0) == NULL) {
  29. message = "ConvertThreadToFiberEx failed, " + lastErrorMessage();
  30. } else {
  31. completionPort = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0);
  32. if (completionPort == NULL) {
  33. message = "CreateIoCompletionPort failed, " + lastErrorMessage();
  34. } else {
  35. WSADATA wsaData;
  36. int wsaResult = WSAStartup(0x0202, &wsaData);
  37. if (wsaResult != 0) {
  38. message = "WSAStartup failed, " + errorMessage(wsaResult);
  39. } else {
  40. remoteNotificationSent = false;
  41. reinterpret_cast<LPOVERLAPPED>(remoteSpawnOverlapped)->hEvent = NULL;
  42. threadId = GetCurrentThreadId();
  43. mainContext.fiber = GetCurrentFiber();
  44. mainContext.interrupted = false;
  45. mainContext.group = &contextGroup;
  46. mainContext.groupPrev = nullptr;
  47. mainContext.groupNext = nullptr;
  48. contextGroup.firstContext = nullptr;
  49. contextGroup.lastContext = nullptr;
  50. contextGroup.firstWaiter = nullptr;
  51. contextGroup.lastWaiter = nullptr;
  52. currentContext = &mainContext;
  53. firstResumingContext = nullptr;
  54. firstReusableContext = nullptr;
  55. runningContextCount = 0;
  56. return;
  57. }
  58. BOOL result2 = CloseHandle(completionPort);
  59. assert(result2 == TRUE);
  60. }
  61. BOOL result2 = ConvertFiberToThread();
  62. assert(result == TRUE);
  63. }
  64. DeleteCriticalSection(reinterpret_cast<LPCRITICAL_SECTION>(criticalSection));
  65. throw std::runtime_error("Dispatcher::Dispatcher, " + message);
  66. }
  67. Dispatcher::~Dispatcher() {
  68. assert(GetCurrentThreadId() == threadId);
  69. for (NativeContext* context = contextGroup.firstContext; context != nullptr; context = context->groupNext) {
  70. interrupt(context);
  71. }
  72. yield();
  73. assert(timers.empty());
  74. assert(contextGroup.firstContext == nullptr);
  75. assert(contextGroup.firstWaiter == nullptr);
  76. assert(firstResumingContext == nullptr);
  77. assert(runningContextCount == 0);
  78. while (firstReusableContext != nullptr) {
  79. void* fiber = firstReusableContext->fiber;
  80. firstReusableContext = firstReusableContext->next;
  81. DeleteFiber(fiber);
  82. }
  83. int wsaResult = WSACleanup();
  84. assert(wsaResult == 0);
  85. BOOL result = CloseHandle(completionPort);
  86. assert(result == TRUE);
  87. result = ConvertFiberToThread();
  88. assert(result == TRUE);
  89. DeleteCriticalSection(reinterpret_cast<LPCRITICAL_SECTION>(criticalSection));
  90. }
  91. void Dispatcher::clear() {
  92. assert(GetCurrentThreadId() == threadId);
  93. while (firstReusableContext != nullptr) {
  94. void* fiber = firstReusableContext->fiber;
  95. firstReusableContext = firstReusableContext->next;
  96. DeleteFiber(fiber);
  97. }
  98. }
  99. void Dispatcher::dispatch() {
  100. assert(GetCurrentThreadId() == threadId);
  101. NativeContext* context;
  102. for (;;) {
  103. if (firstResumingContext != nullptr) {
  104. context = firstResumingContext;
  105. firstResumingContext = context->next;
  106. break;
  107. }
  108. LARGE_INTEGER frequency;
  109. LARGE_INTEGER ticks;
  110. QueryPerformanceCounter(&ticks);
  111. QueryPerformanceFrequency(&frequency);
  112. uint64_t currentTime = ticks.QuadPart / (frequency.QuadPart / 1000);
  113. auto timerContextPair = timers.begin();
  114. auto end = timers.end();
  115. while (timerContextPair != end && timerContextPair->first <= currentTime) {
  116. pushContext(timerContextPair->second);
  117. timerContextPair = timers.erase(timerContextPair);
  118. }
  119. if (firstResumingContext != nullptr) {
  120. context = firstResumingContext;
  121. firstResumingContext = context->next;
  122. break;
  123. }
  124. DWORD timeout = timers.empty() ? INFINITE : static_cast<DWORD>(std::min(timers.begin()->first - currentTime, static_cast<uint64_t>(INFINITE - 1)));
  125. OVERLAPPED_ENTRY entry;
  126. ULONG actual = 0;
  127. if (GetQueuedCompletionStatusEx(completionPort, &entry, 1, &actual, timeout, TRUE) == TRUE) {
  128. if (entry.lpOverlapped == reinterpret_cast<LPOVERLAPPED>(remoteSpawnOverlapped)) {
  129. EnterCriticalSection(reinterpret_cast<LPCRITICAL_SECTION>(criticalSection));
  130. assert(remoteNotificationSent);
  131. assert(!remoteSpawningProcedures.empty());
  132. do {
  133. spawn(std::move(remoteSpawningProcedures.front()));
  134. remoteSpawningProcedures.pop();
  135. } while (!remoteSpawningProcedures.empty());
  136. remoteNotificationSent = false;
  137. LeaveCriticalSection(reinterpret_cast<LPCRITICAL_SECTION>(criticalSection));
  138. continue;
  139. }
  140. context = reinterpret_cast<DispatcherContext*>(entry.lpOverlapped)->context;
  141. break;
  142. }
  143. DWORD lastError = GetLastError();
  144. if (lastError == WAIT_TIMEOUT) {
  145. continue;
  146. }
  147. if (lastError != WAIT_IO_COMPLETION) {
  148. throw std::runtime_error("Dispatcher::dispatch, GetQueuedCompletionStatusEx failed, " + errorMessage(lastError));
  149. }
  150. }
  151. if (context != currentContext) {
  152. currentContext = context;
  153. SwitchToFiber(context->fiber);
  154. }
  155. }
  156. NativeContext* Dispatcher::getCurrentContext() const {
  157. assert(GetCurrentThreadId() == threadId);
  158. return currentContext;
  159. }
  160. void Dispatcher::interrupt() {
  161. interrupt(currentContext);
  162. }
  163. void Dispatcher::interrupt(NativeContext* context) {
  164. assert(GetCurrentThreadId() == threadId);
  165. assert(context != nullptr);
  166. if (!context->interrupted) {
  167. if (context->interruptProcedure != nullptr) {
  168. context->interruptProcedure();
  169. context->interruptProcedure = nullptr;
  170. } else {
  171. context->interrupted = true;
  172. }
  173. }
  174. }
  175. bool Dispatcher::interrupted() {
  176. if (currentContext->interrupted) {
  177. currentContext->interrupted = false;
  178. return true;
  179. }
  180. return false;
  181. }
  182. void Dispatcher::pushContext(NativeContext* context) {
  183. assert(GetCurrentThreadId() == threadId);
  184. assert(context != nullptr);
  185. context->next = nullptr;
  186. if (firstResumingContext != nullptr) {
  187. assert(lastResumingContext->next == nullptr);
  188. lastResumingContext->next = context;
  189. } else {
  190. firstResumingContext = context;
  191. }
  192. lastResumingContext = context;
  193. }
  194. void Dispatcher::remoteSpawn(std::function<void()>&& procedure) {
  195. EnterCriticalSection(reinterpret_cast<LPCRITICAL_SECTION>(criticalSection));
  196. remoteSpawningProcedures.push(std::move(procedure));
  197. if (!remoteNotificationSent) {
  198. remoteNotificationSent = true;
  199. if (PostQueuedCompletionStatus(completionPort, 0, 0, reinterpret_cast<LPOVERLAPPED>(remoteSpawnOverlapped)) == NULL) {
  200. LeaveCriticalSection(reinterpret_cast<LPCRITICAL_SECTION>(criticalSection));
  201. throw std::runtime_error("Dispatcher::remoteSpawn, PostQueuedCompletionStatus failed, " + lastErrorMessage());
  202. };
  203. }
  204. LeaveCriticalSection(reinterpret_cast<LPCRITICAL_SECTION>(criticalSection));
  205. }
  206. void Dispatcher::spawn(std::function<void()>&& procedure) {
  207. assert(GetCurrentThreadId() == threadId);
  208. NativeContext* context = &getReusableContext();
  209. if (contextGroup.firstContext != nullptr) {
  210. context->groupPrev = contextGroup.lastContext;
  211. assert(contextGroup.lastContext->groupNext == nullptr);
  212. contextGroup.lastContext->groupNext = context;
  213. } else {
  214. context->groupPrev = nullptr;
  215. contextGroup.firstContext = context;
  216. contextGroup.firstWaiter = nullptr;
  217. }
  218. context->interrupted = false;
  219. context->group = &contextGroup;
  220. context->groupNext = nullptr;
  221. context->procedure = std::move(procedure);
  222. contextGroup.lastContext = context;
  223. pushContext(context);
  224. }
  225. void Dispatcher::yield() {
  226. assert(GetCurrentThreadId() == threadId);
  227. for (;;) {
  228. LARGE_INTEGER frequency;
  229. LARGE_INTEGER ticks;
  230. QueryPerformanceCounter(&ticks);
  231. QueryPerformanceFrequency(&frequency);
  232. uint64_t currentTime = ticks.QuadPart / (frequency.QuadPart / 1000);
  233. auto timerContextPair = timers.begin();
  234. auto end = timers.end();
  235. while (timerContextPair != end && timerContextPair->first <= currentTime) {
  236. timerContextPair->second->interruptProcedure = nullptr;
  237. pushContext(timerContextPair->second);
  238. timerContextPair = timers.erase(timerContextPair);
  239. }
  240. OVERLAPPED_ENTRY entries[16];
  241. ULONG actual = 0;
  242. if (GetQueuedCompletionStatusEx(completionPort, entries, 16, &actual, 0, TRUE) == TRUE) {
  243. assert(actual > 0);
  244. for (ULONG i = 0; i < actual; ++i) {
  245. if (entries[i].lpOverlapped == reinterpret_cast<LPOVERLAPPED>(remoteSpawnOverlapped)) {
  246. EnterCriticalSection(reinterpret_cast<LPCRITICAL_SECTION>(criticalSection));
  247. assert(remoteNotificationSent);
  248. assert(!remoteSpawningProcedures.empty());
  249. do {
  250. spawn(std::move(remoteSpawningProcedures.front()));
  251. remoteSpawningProcedures.pop();
  252. } while (!remoteSpawningProcedures.empty());
  253. remoteNotificationSent = false;
  254. LeaveCriticalSection(reinterpret_cast<LPCRITICAL_SECTION>(criticalSection));
  255. continue;
  256. }
  257. NativeContext* context = reinterpret_cast<DispatcherContext*>(entries[i].lpOverlapped)->context;
  258. context->interruptProcedure = nullptr;
  259. pushContext(context);
  260. }
  261. } else {
  262. DWORD lastError = GetLastError();
  263. if (lastError == WAIT_TIMEOUT) {
  264. break;
  265. } else if (lastError != WAIT_IO_COMPLETION) {
  266. throw std::runtime_error("Dispatcher::yield, GetQueuedCompletionStatusEx failed, " + errorMessage(lastError));
  267. }
  268. }
  269. }
  270. if (firstResumingContext != nullptr) {
  271. pushContext(currentContext);
  272. dispatch();
  273. }
  274. }
  275. void Dispatcher::addTimer(uint64_t time, NativeContext* context) {
  276. assert(GetCurrentThreadId() == threadId);
  277. timers.insert(std::make_pair(time, context));
  278. }
  279. void* Dispatcher::getCompletionPort() const {
  280. return completionPort;
  281. }
  282. NativeContext& Dispatcher::getReusableContext() {
  283. if (firstReusableContext == nullptr) {
  284. void* fiber = CreateFiberEx(STACK_SIZE, RESERVE_STACK_SIZE, 0, contextProcedureStatic, this);
  285. if (fiber == NULL) {
  286. throw std::runtime_error("Dispatcher::getReusableContext, CreateFiberEx failed, " + lastErrorMessage());
  287. }
  288. SwitchToFiber(fiber);
  289. assert(firstReusableContext != nullptr);
  290. firstReusableContext->fiber = fiber;
  291. }
  292. NativeContext* context = firstReusableContext;
  293. firstReusableContext = context->next;
  294. return *context;
  295. }
  296. void Dispatcher::pushReusableContext(NativeContext& context) {
  297. context.next = firstReusableContext;
  298. firstReusableContext = &context;
  299. --runningContextCount;
  300. }
  301. void Dispatcher::interruptTimer(uint64_t time, NativeContext* context) {
  302. assert(GetCurrentThreadId() == threadId);
  303. auto range = timers.equal_range(time);
  304. for (auto it = range.first; ; ++it) {
  305. assert(it != range.second);
  306. if (it->second == context) {
  307. pushContext(context);
  308. timers.erase(it);
  309. break;
  310. }
  311. }
  312. }
  313. void Dispatcher::contextProcedure() {
  314. assert(GetCurrentThreadId() == threadId);
  315. assert(firstReusableContext == nullptr);
  316. NativeContext context;
  317. context.interrupted = false;
  318. context.next = nullptr;
  319. firstReusableContext = &context;
  320. SwitchToFiber(currentContext->fiber);
  321. for (;;) {
  322. ++runningContextCount;
  323. try {
  324. context.procedure();
  325. } catch (std::exception&) {
  326. }
  327. if (context.group != nullptr) {
  328. if (context.groupPrev != nullptr) {
  329. assert(context.groupPrev->groupNext == &context);
  330. context.groupPrev->groupNext = context.groupNext;
  331. if (context.groupNext != nullptr) {
  332. assert(context.groupNext->groupPrev == &context);
  333. context.groupNext->groupPrev = context.groupPrev;
  334. } else {
  335. assert(context.group->lastContext == &context);
  336. context.group->lastContext = context.groupPrev;
  337. }
  338. } else {
  339. assert(context.group->firstContext == &context);
  340. context.group->firstContext = context.groupNext;
  341. if (context.groupNext != nullptr) {
  342. assert(context.groupNext->groupPrev == &context);
  343. context.groupNext->groupPrev = nullptr;
  344. } else {
  345. assert(context.group->lastContext == &context);
  346. if (context.group->firstWaiter != nullptr) {
  347. if (firstResumingContext != nullptr) {
  348. assert(lastResumingContext->next == nullptr);
  349. lastResumingContext->next = context.group->firstWaiter;
  350. } else {
  351. firstResumingContext = context.group->firstWaiter;
  352. }
  353. lastResumingContext = context.group->lastWaiter;
  354. context.group->firstWaiter = nullptr;
  355. }
  356. }
  357. }
  358. pushReusableContext(context);
  359. }
  360. dispatch();
  361. }
  362. }
  363. void __stdcall Dispatcher::contextProcedureStatic(void* context) {
  364. static_cast<Dispatcher*>(context)->contextProcedure();
  365. }
  366. }