PageRenderTime 43ms CodeModel.GetById 17ms RepoModel.GetById 0ms app.codeStats 0ms

/lualib/socket.lua

https://github.com/382411196/skynet
Lua | 314 lines | 265 code | 31 blank | 18 comment | 46 complexity | 7e7d6207790f86d59961fc9ad715e220 MD5 | raw file
  1. local driver = require "socketdriver"
  2. local skynet = require "skynet"
  3. local assert = assert
  4. local socket = {} -- api
  5. local buffer_pool = {} -- store all message buffer object
  6. local socket_pool = setmetatable( -- store all socket object
  7. {},
  8. { __gc = function(p)
  9. for id,v in pairs(p) do
  10. driver.close(id)
  11. -- don't need clear v.buffer, because buffer pool will be free at the end
  12. p[id] = nil
  13. end
  14. end
  15. }
  16. )
  17. local socket_message = {}
  18. local function wakeup(s)
  19. local co = s.co
  20. if co then
  21. s.co = nil
  22. skynet.wakeup(co)
  23. end
  24. end
  25. local function suspend(s)
  26. assert(not s.co)
  27. s.co = coroutine.running()
  28. skynet.wait()
  29. -- wakeup closing corouting every time suspend,
  30. -- because socket.close() will wait last socket buffer operation before clear the buffer.
  31. if s.closing then
  32. skynet.wakeup(s.closing)
  33. end
  34. end
  35. -- read skynet_socket.h for these macro
  36. -- SKYNET_SOCKET_TYPE_DATA = 1
  37. socket_message[1] = function(id, size, data)
  38. local s = socket_pool[id]
  39. if s == nil then
  40. skynet.error("socket: drop package from " .. id)
  41. driver.drop(data, size)
  42. return
  43. end
  44. local sz = driver.push(s.buffer, buffer_pool, data, size)
  45. local rr = s.read_required
  46. local rrt = type(rr)
  47. if rrt == "number" then
  48. -- read size
  49. if sz >= rr then
  50. s.read_required = nil
  51. wakeup(s)
  52. end
  53. elseif rrt == "string" then
  54. -- read line
  55. if driver.readline(s.buffer,nil,rr) then
  56. s.read_required = nil
  57. wakeup(s)
  58. end
  59. end
  60. end
  61. -- SKYNET_SOCKET_TYPE_CONNECT = 2
  62. socket_message[2] = function(id, _ , addr)
  63. local s = socket_pool[id]
  64. if s == nil then
  65. return
  66. end
  67. -- log remote addr
  68. s.connected = true
  69. wakeup(s)
  70. end
  71. -- SKYNET_SOCKET_TYPE_CLOSE = 3
  72. socket_message[3] = function(id)
  73. local s = socket_pool[id]
  74. if s == nil then
  75. return
  76. end
  77. s.connected = false
  78. wakeup(s)
  79. end
  80. -- SKYNET_SOCKET_TYPE_ACCEPT = 4
  81. socket_message[4] = function(id, newid, addr)
  82. local s = socket_pool[id]
  83. if s == nil then
  84. driver.close(newid)
  85. return
  86. end
  87. s.callback(newid, addr)
  88. end
  89. -- SKYNET_SOCKET_TYPE_ERROR = 5
  90. socket_message[5] = function(id)
  91. local s = socket_pool[id]
  92. if s == nil then
  93. skynet.error("socket: error on unknown", id)
  94. return
  95. end
  96. if s.connected then
  97. skynet.error("socket: error on", id)
  98. end
  99. s.connected = false
  100. wakeup(s)
  101. end
  102. skynet.register_protocol {
  103. name = "socket",
  104. id = skynet.PTYPE_SOCKET, -- PTYPE_SOCKET = 6
  105. unpack = driver.unpack,
  106. dispatch = function (_, _, t, n1, n2, data)
  107. socket_message[t](n1,n2,data)
  108. end
  109. }
  110. local function connect(id, func)
  111. local newbuffer
  112. if func == nil then
  113. newbuffer = driver.buffer()
  114. end
  115. local s = {
  116. id = id,
  117. buffer = newbuffer,
  118. connected = false,
  119. read_require = false,
  120. co = false,
  121. callback = func,
  122. }
  123. socket_pool[id] = s
  124. suspend(s)
  125. if s.connected then
  126. return id
  127. end
  128. end
  129. function socket.open(addr, port)
  130. local id = driver.connect(addr,port)
  131. return connect(id)
  132. end
  133. function socket.bind(os_fd)
  134. local id = driver.bind(os_fd)
  135. return connect(id)
  136. end
  137. function socket.stdin()
  138. return socket.bind(0)
  139. end
  140. function socket.start(id, func)
  141. driver.start(id)
  142. return connect(id, func)
  143. end
  144. function socket.shutdown(id)
  145. local s = socket_pool[id]
  146. if s then
  147. if s.buffer then
  148. driver.clear(s.buffer,buffer_pool)
  149. end
  150. if s.connected then
  151. driver.close(id)
  152. end
  153. end
  154. end
  155. function socket.close(id)
  156. local s = socket_pool[id]
  157. if s == nil then
  158. return
  159. end
  160. if s.connected then
  161. driver.close(s.id)
  162. -- notice: call socket.close in __gc should be carefully,
  163. -- because skynet.wait never return in __gc, so driver.clear may not be called
  164. if s.co then
  165. -- reading this socket on another coroutine, so don't shutdown (clear the buffer) immediatel
  166. -- wait reading coroutine read the buffer.
  167. assert(not s.closing)
  168. s.closing = coroutine.running()
  169. skynet.wait()
  170. else
  171. suspend(s)
  172. end
  173. s.connected = false
  174. end
  175. socket.shutdown(id)
  176. assert(s.lock_set == nil or next(s.lock_set) == nil)
  177. socket_pool[id] = nil
  178. end
  179. function socket.read(id, sz)
  180. local s = socket_pool[id]
  181. assert(s)
  182. local ret = driver.pop(s.buffer, buffer_pool, sz)
  183. if ret then
  184. return ret
  185. end
  186. if not s.connected then
  187. return false, driver.readall(s.buffer, buffer_pool)
  188. end
  189. assert(not s.read_required)
  190. s.read_required = sz
  191. suspend(s)
  192. ret = driver.pop(s.buffer, buffer_pool, sz)
  193. if ret then
  194. return ret
  195. else
  196. return false, driver.readall(s.buffer, buffer_pool)
  197. end
  198. end
  199. function socket.readall(id)
  200. local s = socket_pool[id]
  201. assert(s)
  202. if not s.connected then
  203. local r = driver.readall(s.buffer, buffer_pool)
  204. return r ~= "" and r
  205. end
  206. assert(not s.read_required)
  207. s.read_required = true
  208. suspend(s)
  209. assert(s.connected == false)
  210. return driver.readall(s.buffer, buffer_pool)
  211. end
  212. function socket.readline(id, sep)
  213. sep = sep or "\n"
  214. local s = socket_pool[id]
  215. assert(s)
  216. local ret = driver.readline(s.buffer, buffer_pool, sep)
  217. if ret then
  218. return ret
  219. end
  220. if not s.connected then
  221. return false, driver.readall(s.buffer, buffer_pool)
  222. end
  223. assert(not s.read_required)
  224. s.read_required = sep
  225. suspend(s)
  226. if s.connected then
  227. return driver.readline(s.buffer, buffer_pool, sep)
  228. else
  229. return false, driver.readall(s.buffer, buffer_pool)
  230. end
  231. end
  232. function socket.block(id)
  233. local s = socket_pool[id]
  234. if not s or not s.connected then
  235. return false
  236. end
  237. assert(not s.read_required)
  238. s.read_required = 0
  239. suspend(s)
  240. return s.connected
  241. end
  242. socket.write = assert(driver.send)
  243. socket.lwrite = assert(driver.lsend)
  244. function socket.invalid(id)
  245. return socket_pool[id] == nil
  246. end
  247. socket.listen = assert(driver.listen)
  248. function socket.lock(id)
  249. local s = socket_pool[id]
  250. assert(s)
  251. local lock_set = s.lock
  252. if not lock_set then
  253. lock_set = {}
  254. s.lock = lock_set
  255. end
  256. if #lock_set == 0 then
  257. lock_set[1] = true
  258. else
  259. local co = coroutine.running()
  260. table.insert(lock_set, co)
  261. skynet.wait()
  262. end
  263. end
  264. function socket.unlock(id)
  265. local s = socket_pool[id]
  266. assert(s)
  267. local lock_set = assert(s.lock)
  268. table.remove(lock_set,1)
  269. local co = lock_set[1]
  270. if co then
  271. skynet.wakeup(co)
  272. end
  273. end
  274. -- abandon use to forward socket id to other service
  275. -- you must call socket.start(id) later in other service
  276. function socket.abandon(id)
  277. local s = socket_pool[id]
  278. if s and s.buffer then
  279. driver.clear(s.buffer,buffer_pool)
  280. end
  281. socket_pool[id] = nil
  282. end
  283. return socket