PageRenderTime 61ms CodeModel.GetById 18ms RepoModel.GetById 0ms app.codeStats 1ms

/net/server_select.lua

https://github.com/murnko/prosody
Lua | 1016 lines | 859 code | 115 blank | 42 comment | 110 complexity | d896f2dfcf3dd4410ff07c583da5c655 MD5 | raw file
  1. --
  2. -- server.lua by blastbeat of the luadch project
  3. -- Re-used here under the MIT/X Consortium License
  4. --
  5. -- Modifications (C) 2008-2010 Matthew Wild, Waqas Hussain
  6. --
  7. -- // wrapping luadch stuff // --
  8. local use = function( what )
  9. return _G[ what ]
  10. end
  11. local log, table_concat = require ("util.logger").init("socket"), table.concat;
  12. local out_put = function (...) return log("debug", table_concat{...}); end
  13. local out_error = function (...) return log("warn", table_concat{...}); end
  14. ----------------------------------// DECLARATION //--
  15. --// constants //--
  16. local STAT_UNIT = 1 -- byte
  17. --// lua functions //--
  18. local type = use "type"
  19. local pairs = use "pairs"
  20. local ipairs = use "ipairs"
  21. local tonumber = use "tonumber"
  22. local tostring = use "tostring"
  23. --// lua libs //--
  24. local os = use "os"
  25. local table = use "table"
  26. local string = use "string"
  27. local coroutine = use "coroutine"
  28. --// lua lib methods //--
  29. local os_difftime = os.difftime
  30. local math_min = math.min
  31. local math_huge = math.huge
  32. local table_concat = table.concat
  33. local string_sub = string.sub
  34. local coroutine_wrap = coroutine.wrap
  35. local coroutine_yield = coroutine.yield
  36. --// extern libs //--
  37. local has_luasec, luasec = pcall ( require , "ssl" )
  38. local luasocket = use "socket" or require "socket"
  39. local luasocket_gettime = luasocket.gettime
  40. local getaddrinfo = luasocket.dns.getaddrinfo
  41. --// extern lib methods //--
  42. local ssl_wrap = ( has_luasec and luasec.wrap )
  43. local socket_bind = luasocket.bind
  44. local socket_sleep = luasocket.sleep
  45. local socket_select = luasocket.select
  46. --// functions //--
  47. local id
  48. local loop
  49. local stats
  50. local idfalse
  51. local closeall
  52. local addsocket
  53. local addserver
  54. local addtimer
  55. local getserver
  56. local wrapserver
  57. local getsettings
  58. local closesocket
  59. local removesocket
  60. local removeserver
  61. local wrapconnection
  62. local changesettings
  63. --// tables //--
  64. local _server
  65. local _readlist
  66. local _timerlist
  67. local _sendlist
  68. local _socketlist
  69. local _closelist
  70. local _readtimes
  71. local _writetimes
  72. --// simple data types //--
  73. local _
  74. local _readlistlen
  75. local _sendlistlen
  76. local _timerlistlen
  77. local _sendtraffic
  78. local _readtraffic
  79. local _selecttimeout
  80. local _sleeptime
  81. local _tcpbacklog
  82. local _starttime
  83. local _currenttime
  84. local _maxsendlen
  85. local _maxreadlen
  86. local _checkinterval
  87. local _sendtimeout
  88. local _readtimeout
  89. local _timer
  90. local _maxselectlen
  91. local _maxfd
  92. local _maxsslhandshake
  93. ----------------------------------// DEFINITION //--
  94. _server = { } -- key = port, value = table; list of listening servers
  95. _readlist = { } -- array with sockets to read from
  96. _sendlist = { } -- arrary with sockets to write to
  97. _timerlist = { } -- array of timer functions
  98. _socketlist = { } -- key = socket, value = wrapped socket (handlers)
  99. _readtimes = { } -- key = handler, value = timestamp of last data reading
  100. _writetimes = { } -- key = handler, value = timestamp of last data writing/sending
  101. _closelist = { } -- handlers to close
  102. _readlistlen = 0 -- length of readlist
  103. _sendlistlen = 0 -- length of sendlist
  104. _timerlistlen = 0 -- lenght of timerlist
  105. _sendtraffic = 0 -- some stats
  106. _readtraffic = 0
  107. _selecttimeout = 1 -- timeout of socket.select
  108. _sleeptime = 0 -- time to wait at the end of every loop
  109. _tcpbacklog = 128 -- some kind of hint to the OS
  110. _maxsendlen = 51000 * 1024 -- max len of send buffer
  111. _maxreadlen = 25000 * 1024 -- max len of read buffer
  112. _checkinterval = 30 -- interval in secs to check idle clients
  113. _sendtimeout = 60000 -- allowed send idle time in secs
  114. _readtimeout = 6 * 60 * 60 -- allowed read idle time in secs
  115. local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to detemine whether this is Windows
  116. _maxfd = (is_windows and math.huge) or luasocket._SETSIZE or 1024 -- max fd number, limit to 1024 by default to prevent glibc buffer overflow, but not on Windows
  117. _maxselectlen = luasocket._SETSIZE or 1024 -- But this still applies on Windows
  118. _maxsslhandshake = 30 -- max handshake round-trips
  119. ----------------------------------// PRIVATE //--
  120. wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- this function wraps a server -- FIXME Make sure FD < _maxfd
  121. if socket:getfd() >= _maxfd then
  122. out_error("server.lua: Disallowed FD number: "..socket:getfd())
  123. socket:close()
  124. return nil, "fd-too-large"
  125. end
  126. local connections = 0
  127. local dispatch, disconnect = listeners.onconnect, listeners.ondisconnect
  128. local accept = socket.accept
  129. --// public methods of the object //--
  130. local handler = { }
  131. handler.shutdown = function( ) end
  132. handler.ssl = function( )
  133. return sslctx ~= nil
  134. end
  135. handler.sslctx = function( )
  136. return sslctx
  137. end
  138. handler.remove = function( )
  139. connections = connections - 1
  140. if handler then
  141. handler.resume( )
  142. end
  143. end
  144. handler.close = function()
  145. socket:close( )
  146. _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
  147. _readlistlen = removesocket( _readlist, socket, _readlistlen )
  148. _server[ip..":"..serverport] = nil;
  149. _socketlist[ socket ] = nil
  150. handler = nil
  151. socket = nil
  152. --mem_free( )
  153. out_put "server.lua: closed server handler and removed sockets from list"
  154. end
  155. handler.pause = function( hard )
  156. if not handler.paused then
  157. _readlistlen = removesocket( _readlist, socket, _readlistlen )
  158. if hard then
  159. _socketlist[ socket ] = nil
  160. socket:close( )
  161. socket = nil;
  162. end
  163. handler.paused = true;
  164. end
  165. end
  166. handler.resume = function( )
  167. if handler.paused then
  168. if not socket then
  169. socket = socket_bind( ip, serverport, _tcpbacklog );
  170. socket:settimeout( 0 )
  171. end
  172. _readlistlen = addsocket(_readlist, socket, _readlistlen)
  173. _socketlist[ socket ] = handler
  174. handler.paused = false;
  175. end
  176. end
  177. handler.ip = function( )
  178. return ip
  179. end
  180. handler.serverport = function( )
  181. return serverport
  182. end
  183. handler.socket = function( )
  184. return socket
  185. end
  186. handler.readbuffer = function( )
  187. if _readlistlen >= _maxselectlen or _sendlistlen >= _maxselectlen then
  188. handler.pause( )
  189. out_put( "server.lua: refused new client connection: server full" )
  190. return false
  191. end
  192. local client, err = accept( socket ) -- try to accept
  193. if client then
  194. local ip, clientport = client:getpeername( )
  195. local handler, client, err = wrapconnection( handler, listeners, client, ip, serverport, clientport, pattern, sslctx ) -- wrap new client socket
  196. if err then -- error while wrapping ssl socket
  197. return false
  198. end
  199. connections = connections + 1
  200. out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport))
  201. if dispatch and not sslctx then -- SSL connections will notify onconnect when handshake completes
  202. return dispatch( handler );
  203. end
  204. return;
  205. elseif err then -- maybe timeout or something else
  206. out_put( "server.lua: error with new client connection: ", tostring(err) )
  207. return false
  208. end
  209. end
  210. return handler
  211. end
  212. wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx ) -- this function wraps a client to a handler object
  213. if socket:getfd() >= _maxfd then
  214. out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent
  215. socket:close( ) -- Should we send some kind of error here?
  216. if server then
  217. server.pause( )
  218. end
  219. return nil, nil, "fd-too-large"
  220. end
  221. socket:settimeout( 0 )
  222. --// local import of socket methods //--
  223. local send
  224. local receive
  225. local shutdown
  226. --// private closures of the object //--
  227. local ssl
  228. local dispatch = listeners.onincoming
  229. local status = listeners.onstatus
  230. local disconnect = listeners.ondisconnect
  231. local drain = listeners.ondrain
  232. local onreadtimeout = listeners.onreadtimeout;
  233. local bufferqueue = { } -- buffer array
  234. local bufferqueuelen = 0 -- end of buffer array
  235. local toclose
  236. local fatalerror
  237. local needtls
  238. local bufferlen = 0
  239. local noread = false
  240. local nosend = false
  241. local sendtraffic, readtraffic = 0, 0
  242. local maxsendlen = _maxsendlen
  243. local maxreadlen = _maxreadlen
  244. --// public methods of the object //--
  245. local handler = bufferqueue -- saves a table ^_^
  246. handler.dispatch = function( )
  247. return dispatch
  248. end
  249. handler.disconnect = function( )
  250. return disconnect
  251. end
  252. handler.onreadtimeout = onreadtimeout;
  253. handler.setlistener = function( self, listeners )
  254. dispatch = listeners.onincoming
  255. disconnect = listeners.ondisconnect
  256. status = listeners.onstatus
  257. drain = listeners.ondrain
  258. handler.onreadtimeout = listeners.onreadtimeout
  259. end
  260. handler.getstats = function( )
  261. return readtraffic, sendtraffic
  262. end
  263. handler.ssl = function( )
  264. return ssl
  265. end
  266. handler.sslctx = function ( )
  267. return sslctx
  268. end
  269. handler.send = function( _, data, i, j )
  270. return send( socket, data, i, j )
  271. end
  272. handler.receive = function( pattern, prefix )
  273. return receive( socket, pattern, prefix )
  274. end
  275. handler.shutdown = function( pattern )
  276. return shutdown( socket, pattern )
  277. end
  278. handler.setoption = function (self, option, value)
  279. if socket.setoption then
  280. return socket:setoption(option, value);
  281. end
  282. return false, "setoption not implemented";
  283. end
  284. handler.force_close = function ( self, err )
  285. if bufferqueuelen ~= 0 then
  286. out_put("server.lua: discarding unwritten data for ", tostring(ip), ":", tostring(clientport))
  287. bufferqueuelen = 0;
  288. end
  289. return self:close(err);
  290. end
  291. handler.close = function( self, err )
  292. if not handler then return true; end
  293. _readlistlen = removesocket( _readlist, socket, _readlistlen )
  294. _readtimes[ handler ] = nil
  295. if bufferqueuelen ~= 0 then
  296. handler.sendbuffer() -- Try now to send any outstanding data
  297. if bufferqueuelen ~= 0 then -- Still not empty, so we'll try again later
  298. if handler then
  299. handler.write = nil -- ... but no further writing allowed
  300. end
  301. toclose = true
  302. return false
  303. end
  304. end
  305. if socket then
  306. _ = shutdown and shutdown( socket )
  307. socket:close( )
  308. _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
  309. _socketlist[ socket ] = nil
  310. socket = nil
  311. else
  312. out_put "server.lua: socket already closed"
  313. end
  314. if handler then
  315. _writetimes[ handler ] = nil
  316. _closelist[ handler ] = nil
  317. local _handler = handler;
  318. handler = nil
  319. if disconnect then
  320. disconnect(_handler, err or false);
  321. disconnect = nil
  322. end
  323. end
  324. if server then
  325. server.remove( )
  326. end
  327. out_put "server.lua: closed client handler and removed socket from list"
  328. return true
  329. end
  330. handler.ip = function( )
  331. return ip
  332. end
  333. handler.serverport = function( )
  334. return serverport
  335. end
  336. handler.clientport = function( )
  337. return clientport
  338. end
  339. handler.port = handler.clientport -- COMPAT server_event
  340. local write = function( self, data )
  341. bufferlen = bufferlen + #data
  342. if bufferlen > maxsendlen then
  343. _closelist[ handler ] = "send buffer exceeded" -- cannot close the client at the moment, have to wait to the end of the cycle
  344. handler.write = idfalse -- dont write anymore
  345. return false
  346. elseif socket and not _sendlist[ socket ] then
  347. _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
  348. end
  349. bufferqueuelen = bufferqueuelen + 1
  350. bufferqueue[ bufferqueuelen ] = data
  351. if handler then
  352. _writetimes[ handler ] = _writetimes[ handler ] or _currenttime
  353. end
  354. return true
  355. end
  356. handler.write = write
  357. handler.bufferqueue = function( self )
  358. return bufferqueue
  359. end
  360. handler.socket = function( self )
  361. return socket
  362. end
  363. handler.set_mode = function( self, new )
  364. pattern = new or pattern
  365. return pattern
  366. end
  367. handler.set_send = function ( self, newsend )
  368. send = newsend or send
  369. return send
  370. end
  371. handler.bufferlen = function( self, readlen, sendlen )
  372. maxsendlen = sendlen or maxsendlen
  373. maxreadlen = readlen or maxreadlen
  374. return bufferlen, maxreadlen, maxsendlen
  375. end
  376. --TODO: Deprecate
  377. handler.lock_read = function (self, switch)
  378. if switch == true then
  379. local tmp = _readlistlen
  380. _readlistlen = removesocket( _readlist, socket, _readlistlen )
  381. _readtimes[ handler ] = nil
  382. if _readlistlen ~= tmp then
  383. noread = true
  384. end
  385. elseif switch == false then
  386. if noread then
  387. noread = false
  388. _readlistlen = addsocket(_readlist, socket, _readlistlen)
  389. _readtimes[ handler ] = _currenttime
  390. end
  391. end
  392. return noread
  393. end
  394. handler.pause = function (self)
  395. return self:lock_read(true);
  396. end
  397. handler.resume = function (self)
  398. return self:lock_read(false);
  399. end
  400. handler.lock = function( self, switch )
  401. handler.lock_read (switch)
  402. if switch == true then
  403. handler.write = idfalse
  404. local tmp = _sendlistlen
  405. _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
  406. _writetimes[ handler ] = nil
  407. if _sendlistlen ~= tmp then
  408. nosend = true
  409. end
  410. elseif switch == false then
  411. handler.write = write
  412. if nosend then
  413. nosend = false
  414. write( "" )
  415. end
  416. end
  417. return noread, nosend
  418. end
  419. local _readbuffer = function( ) -- this function reads data
  420. local buffer, err, part = receive( socket, pattern ) -- receive buffer with "pattern"
  421. if not err or (err == "wantread" or err == "timeout") then -- received something
  422. local buffer = buffer or part or ""
  423. local len = #buffer
  424. if len > maxreadlen then
  425. handler:close( "receive buffer exceeded" )
  426. return false
  427. end
  428. local count = len * STAT_UNIT
  429. readtraffic = readtraffic + count
  430. _readtraffic = _readtraffic + count
  431. _readtimes[ handler ] = _currenttime
  432. --out_put( "server.lua: read data '", buffer:gsub("[^%w%p ]", "."), "', error: ", err )
  433. return dispatch( handler, buffer, err )
  434. else -- connections was closed or fatal error
  435. out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) )
  436. fatalerror = true
  437. _ = handler and handler:force_close( err )
  438. return false
  439. end
  440. end
  441. local _sendbuffer = function( ) -- this function sends data
  442. local succ, err, byte, buffer, count;
  443. if socket then
  444. buffer = table_concat( bufferqueue, "", 1, bufferqueuelen )
  445. succ, err, byte = send( socket, buffer, 1, bufferlen )
  446. count = ( succ or byte or 0 ) * STAT_UNIT
  447. sendtraffic = sendtraffic + count
  448. _sendtraffic = _sendtraffic + count
  449. for i = bufferqueuelen,1,-1 do
  450. bufferqueue[ i ] = nil
  451. end
  452. --out_put( "server.lua: sended '", buffer, "', bytes: ", tostring(succ), ", error: ", tostring(err), ", part: ", tostring(byte), ", to: ", tostring(ip), ":", tostring(clientport) )
  453. else
  454. succ, err, count = false, "unexpected close", 0;
  455. end
  456. if succ then -- sending succesful
  457. bufferqueuelen = 0
  458. bufferlen = 0
  459. _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) -- delete socket from writelist
  460. _writetimes[ handler ] = nil
  461. if drain then
  462. drain(handler)
  463. end
  464. _ = needtls and handler:starttls(nil)
  465. _ = toclose and handler:force_close( )
  466. return true
  467. elseif byte and ( err == "timeout" or err == "wantwrite" ) then -- want write
  468. buffer = string_sub( buffer, byte + 1, bufferlen ) -- new buffer
  469. bufferqueue[ 1 ] = buffer -- insert new buffer in queue
  470. bufferqueuelen = 1
  471. bufferlen = bufferlen - byte
  472. _writetimes[ handler ] = _currenttime
  473. return true
  474. else -- connection was closed during sending or fatal error
  475. out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) )
  476. fatalerror = true
  477. _ = handler and handler:force_close( err )
  478. return false
  479. end
  480. end
  481. -- Set the sslctx
  482. local handshake;
  483. function handler.set_sslctx(self, new_sslctx)
  484. sslctx = new_sslctx;
  485. local read, wrote
  486. handshake = coroutine_wrap( function( client ) -- create handshake coroutine
  487. local err
  488. for i = 1, _maxsslhandshake do
  489. _sendlistlen = ( wrote and removesocket( _sendlist, client, _sendlistlen ) ) or _sendlistlen
  490. _readlistlen = ( read and removesocket( _readlist, client, _readlistlen ) ) or _readlistlen
  491. read, wrote = nil, nil
  492. _, err = client:dohandshake( )
  493. if not err then
  494. out_put( "server.lua: ssl handshake done" )
  495. handler.readbuffer = _readbuffer -- when handshake is done, replace the handshake function with regular functions
  496. handler.sendbuffer = _sendbuffer
  497. _ = status and status( handler, "ssl-handshake-complete" )
  498. if self.autostart_ssl and listeners.onconnect then
  499. listeners.onconnect(self);
  500. end
  501. _readlistlen = addsocket(_readlist, client, _readlistlen)
  502. return true
  503. else
  504. if err == "wantwrite" then
  505. _sendlistlen = addsocket(_sendlist, client, _sendlistlen)
  506. wrote = true
  507. elseif err == "wantread" then
  508. _readlistlen = addsocket(_readlist, client, _readlistlen)
  509. read = true
  510. else
  511. break;
  512. end
  513. err = nil;
  514. coroutine_yield( ) -- handshake not finished
  515. end
  516. end
  517. out_put( "server.lua: ssl handshake error: ", tostring(err or "handshake too long") )
  518. _ = handler and handler:force_close("ssl handshake failed")
  519. return false, err -- handshake failed
  520. end
  521. )
  522. end
  523. if has_luasec then
  524. handler.starttls = function( self, _sslctx)
  525. if _sslctx then
  526. handler:set_sslctx(_sslctx);
  527. end
  528. if bufferqueuelen > 0 then
  529. out_put "server.lua: we need to do tls, but delaying until send buffer empty"
  530. needtls = true
  531. return
  532. end
  533. out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
  534. local oldsocket, err = socket
  535. socket, err = ssl_wrap( socket, sslctx ) -- wrap socket
  536. if not socket then
  537. out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") )
  538. return nil, err -- fatal error
  539. end
  540. socket:settimeout( 0 )
  541. -- add the new socket to our system
  542. send = socket.send
  543. receive = socket.receive
  544. shutdown = id
  545. _socketlist[ socket ] = handler
  546. _readlistlen = addsocket(_readlist, socket, _readlistlen)
  547. -- remove traces of the old socket
  548. _readlistlen = removesocket( _readlist, oldsocket, _readlistlen )
  549. _sendlistlen = removesocket( _sendlist, oldsocket, _sendlistlen )
  550. _socketlist[ oldsocket ] = nil
  551. handler.starttls = nil
  552. needtls = nil
  553. -- Secure now (if handshake fails connection will close)
  554. ssl = true
  555. handler.readbuffer = handshake
  556. handler.sendbuffer = handshake
  557. return handshake( socket ) -- do handshake
  558. end
  559. end
  560. handler.readbuffer = _readbuffer
  561. handler.sendbuffer = _sendbuffer
  562. send = socket.send
  563. receive = socket.receive
  564. shutdown = ( ssl and id ) or socket.shutdown
  565. _socketlist[ socket ] = handler
  566. _readlistlen = addsocket(_readlist, socket, _readlistlen)
  567. if sslctx and has_luasec then
  568. out_put "server.lua: auto-starting ssl negotiation..."
  569. handler.autostart_ssl = true;
  570. local ok, err = handler:starttls(sslctx);
  571. if ok == false then
  572. return nil, nil, err
  573. end
  574. end
  575. return handler, socket
  576. end
  577. id = function( )
  578. end
  579. idfalse = function( )
  580. return false
  581. end
  582. addsocket = function( list, socket, len )
  583. if not list[ socket ] then
  584. len = len + 1
  585. list[ len ] = socket
  586. list[ socket ] = len
  587. end
  588. return len;
  589. end
  590. removesocket = function( list, socket, len ) -- this function removes sockets from a list ( copied from copas )
  591. local pos = list[ socket ]
  592. if pos then
  593. list[ socket ] = nil
  594. local last = list[ len ]
  595. list[ len ] = nil
  596. if last ~= socket then
  597. list[ last ] = pos
  598. list[ pos ] = last
  599. end
  600. return len - 1
  601. end
  602. return len
  603. end
  604. closesocket = function( socket )
  605. _sendlistlen = removesocket( _sendlist, socket, _sendlistlen )
  606. _readlistlen = removesocket( _readlist, socket, _readlistlen )
  607. _socketlist[ socket ] = nil
  608. socket:close( )
  609. --mem_free( )
  610. end
  611. local function link(sender, receiver, buffersize)
  612. local sender_locked;
  613. local _sendbuffer = receiver.sendbuffer;
  614. function receiver.sendbuffer()
  615. _sendbuffer();
  616. if sender_locked and receiver.bufferlen() < buffersize then
  617. sender:lock_read(false); -- Unlock now
  618. sender_locked = nil;
  619. end
  620. end
  621. local _readbuffer = sender.readbuffer;
  622. function sender.readbuffer()
  623. _readbuffer();
  624. if not sender_locked and receiver.bufferlen() >= buffersize then
  625. sender_locked = true;
  626. sender:lock_read(true);
  627. end
  628. end
  629. end
  630. ----------------------------------// PUBLIC //--
  631. addserver = function( addr, port, listeners, pattern, sslctx ) -- this function provides a way for other scripts to reg a server
  632. addr = addr or "*"
  633. local err
  634. if type( listeners ) ~= "table" then
  635. err = "invalid listener table"
  636. elseif type ( addr ) ~= "string" then
  637. err = "invalid address"
  638. elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then
  639. err = "invalid port"
  640. elseif _server[ addr..":"..port ] then
  641. err = "listeners on '[" .. addr .. "]:" .. port .. "' already exist"
  642. elseif sslctx and not has_luasec then
  643. err = "luasec not found"
  644. end
  645. if err then
  646. out_error( "server.lua, [", addr, "]:", port, ": ", err )
  647. return nil, err
  648. end
  649. local server, err = socket_bind( addr, port, _tcpbacklog )
  650. if err then
  651. out_error( "server.lua, [", addr, "]:", port, ": ", err )
  652. return nil, err
  653. end
  654. local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx ) -- wrap new server socket
  655. if not handler then
  656. server:close( )
  657. return nil, err
  658. end
  659. server:settimeout( 0 )
  660. _readlistlen = addsocket(_readlist, server, _readlistlen)
  661. _server[ addr..":"..port ] = handler
  662. _socketlist[ server ] = handler
  663. out_put( "server.lua: new "..(sslctx and "ssl " or "").."server listener on '[", addr, "]:", port, "'" )
  664. return handler
  665. end
  666. getserver = function ( addr, port )
  667. return _server[ addr..":"..port ];
  668. end
  669. removeserver = function( addr, port )
  670. local handler = _server[ addr..":"..port ]
  671. if not handler then
  672. return nil, "no server found on '[" .. addr .. "]:" .. tostring( port ) .. "'"
  673. end
  674. handler:close( )
  675. _server[ addr..":"..port ] = nil
  676. return true
  677. end
  678. closeall = function( )
  679. for _, handler in pairs( _socketlist ) do
  680. handler:close( )
  681. _socketlist[ _ ] = nil
  682. end
  683. _readlistlen = 0
  684. _sendlistlen = 0
  685. _timerlistlen = 0
  686. _server = { }
  687. _readlist = { }
  688. _sendlist = { }
  689. _timerlist = { }
  690. _socketlist = { }
  691. --mem_free( )
  692. end
  693. getsettings = function( )
  694. return {
  695. select_timeout = _selecttimeout;
  696. select_sleep_time = _sleeptime;
  697. tcp_backlog = _tcpbacklog;
  698. max_send_buffer_size = _maxsendlen;
  699. max_receive_buffer_size = _maxreadlen;
  700. select_idle_check_interval = _checkinterval;
  701. send_timeout = _sendtimeout;
  702. read_timeout = _readtimeout;
  703. max_connections = _maxselectlen;
  704. max_ssl_handshake_roundtrips = _maxsslhandshake;
  705. highest_allowed_fd = _maxfd;
  706. }
  707. end
  708. changesettings = function( new )
  709. if type( new ) ~= "table" then
  710. return nil, "invalid settings table"
  711. end
  712. _selecttimeout = tonumber( new.select_timeout ) or _selecttimeout
  713. _sleeptime = tonumber( new.select_sleep_time ) or _sleeptime
  714. _maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen
  715. _maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen
  716. _checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval
  717. _tcpbacklog = tonumber( new.tcp_backlog ) or _tcpbacklog
  718. _sendtimeout = tonumber( new.send_timeout ) or _sendtimeout
  719. _readtimeout = tonumber( new.read_timeout ) or _readtimeout
  720. _maxselectlen = new.max_connections or _maxselectlen
  721. _maxsslhandshake = new.max_ssl_handshake_roundtrips or _maxsslhandshake
  722. _maxfd = new.highest_allowed_fd or _maxfd
  723. return true
  724. end
  725. addtimer = function( listener )
  726. if type( listener ) ~= "function" then
  727. return nil, "invalid listener function"
  728. end
  729. _timerlistlen = _timerlistlen + 1
  730. _timerlist[ _timerlistlen ] = listener
  731. return true
  732. end
  733. stats = function( )
  734. return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen
  735. end
  736. local quitting;
  737. local function setquitting(quit)
  738. quitting = not not quit;
  739. end
  740. loop = function(once) -- this is the main loop of the program
  741. if quitting then return "quitting"; end
  742. if once then quitting = "once"; end
  743. local next_timer_time = math_huge;
  744. repeat
  745. local read, write, err = socket_select( _readlist, _sendlist, math_min(_selecttimeout, next_timer_time) )
  746. for i, socket in ipairs( write ) do -- send data waiting in writequeues
  747. local handler = _socketlist[ socket ]
  748. if handler then
  749. handler.sendbuffer( )
  750. else
  751. closesocket( socket )
  752. out_put "server.lua: found no handler and closed socket (writelist)" -- this should not happen
  753. end
  754. end
  755. for i, socket in ipairs( read ) do -- receive data
  756. local handler = _socketlist[ socket ]
  757. if handler then
  758. handler.readbuffer( )
  759. else
  760. closesocket( socket )
  761. out_put "server.lua: found no handler and closed socket (readlist)" -- this can happen
  762. end
  763. end
  764. for handler, err in pairs( _closelist ) do
  765. handler.disconnect( )( handler, err )
  766. handler:force_close() -- forced disconnect
  767. _closelist[ handler ] = nil;
  768. end
  769. _currenttime = luasocket_gettime( )
  770. -- Check for socket timeouts
  771. local difftime = os_difftime( _currenttime - _starttime )
  772. if difftime > _checkinterval then
  773. _starttime = _currenttime
  774. for handler, timestamp in pairs( _writetimes ) do
  775. if os_difftime( _currenttime - timestamp ) > _sendtimeout then
  776. handler.disconnect( )( handler, "send timeout" )
  777. handler:force_close() -- forced disconnect
  778. end
  779. end
  780. for handler, timestamp in pairs( _readtimes ) do
  781. if os_difftime( _currenttime - timestamp ) > _readtimeout then
  782. if not(handler.onreadtimeout) or handler:onreadtimeout() ~= true then
  783. handler.disconnect( )( handler, "read timeout" )
  784. handler:close( ) -- forced disconnect?
  785. end
  786. end
  787. end
  788. end
  789. -- Fire timers
  790. if _currenttime - _timer >= math_min(next_timer_time, 1) then
  791. next_timer_time = math_huge;
  792. for i = 1, _timerlistlen do
  793. local t = _timerlist[ i ]( _currenttime ) -- fire timers
  794. if t then next_timer_time = math_min(next_timer_time, t); end
  795. end
  796. _timer = _currenttime
  797. else
  798. next_timer_time = next_timer_time - (_currenttime - _timer);
  799. end
  800. -- wait some time (0 by default)
  801. socket_sleep( _sleeptime )
  802. until quitting;
  803. if once and quitting == "once" then quitting = nil; return; end
  804. return "quitting"
  805. end
  806. local function step()
  807. return loop(true);
  808. end
  809. local function get_backend()
  810. return "select";
  811. end
  812. --// EXPERIMENTAL //--
  813. local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx )
  814. local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx )
  815. if not handler then return nil, err end
  816. _socketlist[ socket ] = handler
  817. if not sslctx then
  818. _sendlistlen = addsocket(_sendlist, socket, _sendlistlen)
  819. if listeners.onconnect then
  820. -- When socket is writeable, call onconnect
  821. local _sendbuffer = handler.sendbuffer;
  822. handler.sendbuffer = function ()
  823. handler.sendbuffer = _sendbuffer;
  824. listeners.onconnect(handler);
  825. return _sendbuffer(); -- Send any queued outgoing data
  826. end
  827. end
  828. end
  829. return handler, socket
  830. end
  831. local addclient = function( address, port, listeners, pattern, sslctx, typ )
  832. local err
  833. if type( listeners ) ~= "table" then
  834. err = "invalid listener table"
  835. elseif type ( address ) ~= "string" then
  836. err = "invalid address"
  837. elseif type( port ) ~= "number" or not ( port >= 0 and port <= 65535 ) then
  838. err = "invalid port"
  839. elseif sslctx and not has_luasec then
  840. err = "luasec not found"
  841. end
  842. if getaddrinfo and not typ then
  843. local addrinfo, err = getaddrinfo(address)
  844. if not addrinfo then return nil, err end
  845. if addrinfo[1] and addrinfo[1].family == "inet6" then
  846. typ = "tcp6"
  847. end
  848. end
  849. local create = luasocket[typ or "tcp"]
  850. if type( create ) ~= "function" then
  851. err = "invalid socket type"
  852. end
  853. if err then
  854. out_error( "server.lua, addclient: ", err )
  855. return nil, err
  856. end
  857. local client, err = create( )
  858. if err then
  859. return nil, err
  860. end
  861. client:settimeout( 0 )
  862. local ok, err = client:connect( address, port )
  863. if ok or err == "timeout" then
  864. return wrapclient( client, address, port, listeners, pattern, sslctx )
  865. else
  866. return nil, err
  867. end
  868. end
  869. --// EXPERIMENTAL //--
  870. ----------------------------------// BEGIN //--
  871. use "setmetatable" ( _socketlist, { __mode = "k" } )
  872. use "setmetatable" ( _readtimes, { __mode = "k" } )
  873. use "setmetatable" ( _writetimes, { __mode = "k" } )
  874. _timer = luasocket_gettime( )
  875. _starttime = luasocket_gettime( )
  876. local function setlogger(new_logger)
  877. local old_logger = log;
  878. if new_logger then
  879. log = new_logger;
  880. end
  881. return old_logger;
  882. end
  883. ----------------------------------// PUBLIC INTERFACE //--
  884. return {
  885. _addtimer = addtimer,
  886. addclient = addclient,
  887. wrapclient = wrapclient,
  888. loop = loop,
  889. link = link,
  890. step = step,
  891. stats = stats,
  892. closeall = closeall,
  893. addserver = addserver,
  894. getserver = getserver,
  895. setlogger = setlogger,
  896. getsettings = getsettings,
  897. setquitting = setquitting,
  898. removeserver = removeserver,
  899. get_backend = get_backend,
  900. changesettings = changesettings,
  901. }