PageRenderTime 402ms CodeModel.GetById 24ms RepoModel.GetById 0ms app.codeStats 1ms

/local/dnsproxy.py

https://gitlab.com/0072016/Google-5
Python | 321 lines | 277 code | 32 blank | 12 comment | 96 complexity | c91d353d5855c99f5aa5df68ed71388f MD5 | raw file
  1. #!/usr/bin/env python
  2. # coding:utf-8
  3. __version__ = '1.0'
  4. import sys
  5. import os
  6. import sysconfig
  7. sys.path += [os.path.abspath(os.path.join(__file__, '../packages.egg/%s' % x)) for x in ('noarch', sysconfig.get_platform().split('-')[0])]
  8. import gevent
  9. import gevent.server
  10. import gevent.timeout
  11. import gevent.monkey
  12. gevent.monkey.patch_all(subprocess=True)
  13. import re
  14. import time
  15. import logging
  16. import heapq
  17. import socket
  18. import select
  19. import struct
  20. import errno
  21. import thread
  22. import dnslib
  23. import Queue
  24. import pygeoip
  25. is_local_addr = re.compile(r'(?i)(?:[0-9a-f:]+0:5efe:)?(?:127(?:\.\d+){3}|10(?:\.\d+){3}|192\.168(?:\.\d+){2}|172\.(?:1[6-9]|2\d|3[01])(?:\.\d+){2})').match
  26. def get_dnsserver_list():
  27. if os.name == 'nt':
  28. import ctypes, ctypes.wintypes, struct, socket
  29. DNS_CONFIG_DNS_SERVER_LIST = 6
  30. buf = ctypes.create_string_buffer(2048)
  31. ctypes.windll.dnsapi.DnsQueryConfig(DNS_CONFIG_DNS_SERVER_LIST, 0, None, None, ctypes.byref(buf), ctypes.byref(ctypes.wintypes.DWORD(len(buf))))
  32. ipcount = struct.unpack('I', buf[0:4])[0]
  33. iplist = [socket.inet_ntoa(buf[i:i+4]) for i in xrange(4, ipcount*4+4, 4)]
  34. return iplist
  35. elif os.path.isfile('/etc/resolv.conf'):
  36. with open('/etc/resolv.conf', 'rb') as fp:
  37. return re.findall(r'(?m)^nameserver\s+(\S+)', fp.read())
  38. else:
  39. logging.warning("get_dnsserver_list failed: unsupport platform '%s-%s'", sys.platform, os.name)
  40. return []
  41. def parse_hostport(host, default_port=80):
  42. m = re.match(r'(.+)[#](\d+)$', host)
  43. if m:
  44. return m.group(1).strip('[]'), int(m.group(2))
  45. else:
  46. return host.strip('[]'), default_port
  47. class ExpireCache(object):
  48. """ A dictionary-like object, supporting expire semantics."""
  49. def __init__(self, max_size=1024):
  50. self.__maxsize = max_size
  51. self.__values = {}
  52. self.__expire_times = {}
  53. self.__expire_heap = []
  54. def size(self):
  55. return len(self.__values)
  56. def clear(self):
  57. self.__values.clear()
  58. self.__expire_times.clear()
  59. del self.__expire_heap[:]
  60. def exists(self, key):
  61. return key in self.__values
  62. def set(self, key, value, expire):
  63. try:
  64. et = self.__expire_times[key]
  65. pos = self.__expire_heap.index((et, key))
  66. del self.__expire_heap[pos]
  67. if pos < len(self.__expire_heap):
  68. heapq._siftup(self.__expire_heap, pos)
  69. except KeyError:
  70. pass
  71. et = int(time.time() + expire)
  72. self.__expire_times[key] = et
  73. heapq.heappush(self.__expire_heap, (et, key))
  74. self.__values[key] = value
  75. self.cleanup()
  76. def get(self, key):
  77. et = self.__expire_times[key]
  78. if et < time.time():
  79. self.cleanup()
  80. raise KeyError(key)
  81. return self.__values[key]
  82. def delete(self, key):
  83. et = self.__expire_times.pop(key)
  84. pos = self.__expire_heap.index((et, key))
  85. del self.__expire_heap[pos]
  86. if pos < len(self.__expire_heap):
  87. heapq._siftup(self.__expire_heap, pos)
  88. del self.__values[key]
  89. def cleanup(self):
  90. t = int(time.time())
  91. eh = self.__expire_heap
  92. ets = self.__expire_times
  93. v = self.__values
  94. size = self.__maxsize
  95. heappop = heapq.heappop
  96. #Delete expired, ticky
  97. while eh and eh[0][0] <= t or len(v) > size:
  98. _, key = heappop(eh)
  99. del v[key], ets[key]
  100. def dnslib_resolve_over_udp(query, dnsservers, timeout, **kwargs):
  101. """
  102. http://gfwrev.blogspot.com/2009/11/gfwdns.html
  103. http://zh.wikipedia.org/wiki/%E5%9F%9F%E5%90%8D%E6%9C%8D%E5%8A%A1%E5%99%A8%E7%BC%93%E5%AD%98%E6%B1%A1%E6%9F%93
  104. http://support.microsoft.com/kb/241352
  105. https://gist.github.com/klzgrad/f124065c0616022b65e5
  106. """
  107. if not isinstance(query, (basestring, dnslib.DNSRecord)):
  108. raise TypeError('query argument requires string/DNSRecord')
  109. blacklist = kwargs.get('blacklist', ())
  110. turstservers = kwargs.get('turstservers', ())
  111. dns_v4_servers = [x for x in dnsservers if ':' not in x]
  112. dns_v6_servers = [x for x in dnsservers if ':' in x]
  113. sock_v4 = sock_v6 = None
  114. socks = []
  115. if dns_v4_servers:
  116. sock_v4 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  117. socks.append(sock_v4)
  118. if dns_v6_servers:
  119. sock_v6 = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
  120. socks.append(sock_v6)
  121. timeout_at = time.time() + timeout
  122. try:
  123. for _ in xrange(4):
  124. try:
  125. for dnsserver in dns_v4_servers:
  126. if isinstance(query, basestring):
  127. if dnsserver in ('8.8.8.8', '8.8.4.4'):
  128. query = '.'.join(x[:-1] + x[-1].upper() for x in query.split('.')).title()
  129. query = dnslib.DNSRecord(q=dnslib.DNSQuestion(query))
  130. query_data = query.pack()
  131. if query.q.qtype == 1 and dnsserver in ('8.8.8.8', '8.8.4.4'):
  132. query_data = query_data[:-5] + '\xc0\x04' + query_data[-4:]
  133. sock_v4.sendto(query_data, parse_hostport(dnsserver, 53))
  134. for dnsserver in dns_v6_servers:
  135. if isinstance(query, basestring):
  136. query = dnslib.DNSRecord(q=dnslib.DNSQuestion(query, qtype=dnslib.QTYPE.AAAA))
  137. query_data = query.pack()
  138. sock_v6.sendto(query_data, parse_hostport(dnsserver, 53))
  139. while time.time() < timeout_at:
  140. ins, _, _ = select.select(socks, [], [], 0.1)
  141. for sock in ins:
  142. reply_data, reply_address = sock.recvfrom(512)
  143. reply_server = reply_address[0]
  144. record = dnslib.DNSRecord.parse(reply_data)
  145. iplist = [str(x.rdata) for x in record.rr if x.rtype in (1, 28, 255)]
  146. if any(x in blacklist for x in iplist):
  147. logging.warning('query=%r dnsservers=%r record bad iplist=%r', query, dnsservers, iplist)
  148. elif record.header.rcode and not iplist and reply_server in turstservers:
  149. logging.info('query=%r trust reply_server=%r record rcode=%s', query, reply_server, record.header.rcode)
  150. return record
  151. elif iplist:
  152. logging.debug('query=%r reply_server=%r record iplist=%s', query, reply_server, iplist)
  153. return record
  154. else:
  155. logging.debug('query=%r reply_server=%r record null iplist=%s', query, reply_server, iplist)
  156. continue
  157. except socket.error as e:
  158. logging.warning('handle dns query=%s socket: %r', query, e)
  159. raise socket.gaierror(11004, 'getaddrinfo %r from %r failed' % (query, dnsservers))
  160. finally:
  161. for sock in socks:
  162. sock.close()
  163. def dnslib_resolve_over_tcp(query, dnsservers, timeout, **kwargs):
  164. """dns query over tcp"""
  165. if not isinstance(query, (basestring, dnslib.DNSRecord)):
  166. raise TypeError('query argument requires string/DNSRecord')
  167. blacklist = kwargs.get('blacklist', ())
  168. def do_resolve(query, dnsserver, timeout, queobj):
  169. if isinstance(query, basestring):
  170. qtype = dnslib.QTYPE.AAAA if ':' in dnsserver else dnslib.QTYPE.A
  171. query = dnslib.DNSRecord(q=dnslib.DNSQuestion(query, qtype=qtype))
  172. query_data = query.pack()
  173. sock_family = socket.AF_INET6 if ':' in dnsserver else socket.AF_INET
  174. sock = socket.socket(sock_family)
  175. rfile = None
  176. try:
  177. sock.settimeout(timeout or None)
  178. sock.connect(parse_hostport(dnsserver, 53))
  179. sock.send(struct.pack('>h', len(query_data)) + query_data)
  180. rfile = sock.makefile('r', 1024)
  181. reply_data_length = rfile.read(2)
  182. if len(reply_data_length) < 2:
  183. raise socket.gaierror(11004, 'getaddrinfo %r from %r failed' % (query, dnsserver))
  184. reply_data = rfile.read(struct.unpack('>h', reply_data_length)[0])
  185. record = dnslib.DNSRecord.parse(reply_data)
  186. iplist = [str(x.rdata) for x in record.rr if x.rtype in (1, 28, 255)]
  187. if any(x in blacklist for x in iplist):
  188. logging.debug('query=%r dnsserver=%r record bad iplist=%r', query, dnsserver, iplist)
  189. raise socket.gaierror(11004, 'getaddrinfo %r from %r failed' % (query, dnsserver))
  190. else:
  191. logging.debug('query=%r dnsserver=%r record iplist=%s', query, dnsserver, iplist)
  192. queobj.put(record)
  193. except socket.error as e:
  194. logging.debug('query=%r dnsserver=%r failed %r', query, dnsserver, e)
  195. queobj.put(e)
  196. finally:
  197. if rfile:
  198. rfile.close()
  199. sock.close()
  200. queobj = Queue.Queue()
  201. for dnsserver in dnsservers:
  202. thread.start_new_thread(do_resolve, (query, dnsserver, timeout, queobj))
  203. for i in range(len(dnsservers)):
  204. try:
  205. result = queobj.get(timeout)
  206. except Queue.Empty:
  207. raise socket.gaierror(11004, 'getaddrinfo %r from %r failed' % (query, dnsservers))
  208. if result and not isinstance(result, Exception):
  209. return result
  210. elif i == len(dnsservers) - 1:
  211. logging.warning('dnslib_resolve_over_tcp %r with %s return %r', query, dnsservers, result)
  212. raise socket.gaierror(11004, 'getaddrinfo %r from %r failed' % (query, dnsservers))
  213. class DNSServer(gevent.server.DatagramServer):
  214. """DNS Proxy based on gevent/dnslib"""
  215. def __init__(self, *args, **kwargs):
  216. dns_blacklist = kwargs.pop('dns_blacklist')
  217. dns_servers = kwargs.pop('dns_servers')
  218. dns_tcpover = kwargs.pop('dns_tcpover', [])
  219. dns_timeout = kwargs.pop('dns_timeout', 2)
  220. super(self.__class__, self).__init__(*args, **kwargs)
  221. self.dns_servers = list(dns_servers)
  222. self.dns_tcpover = tuple(dns_tcpover)
  223. self.dns_intranet_servers = [x for x in self.dns_servers if is_local_addr(x)]
  224. self.dns_blacklist = set(dns_blacklist)
  225. self.dns_timeout = int(dns_timeout)
  226. self.dns_cache = ExpireCache(max_size=65536)
  227. self.dns_trust_servers = set(['8.8.8.8', '8.8.4.4', '2001:4860:4860::8888', '2001:4860:4860::8844'])
  228. for dirname in ('.', '/usr/share/GeoIP/', '/usr/local/share/GeoIP/'):
  229. filename = os.path.join(dirname, 'GeoIP.dat')
  230. if os.path.isfile(filename):
  231. geoip = pygeoip.GeoIP(filename)
  232. for dnsserver in self.dns_servers:
  233. if ':' not in dnsserver and geoip.country_name_by_addr(parse_hostport(dnsserver, 53)[0]) not in ('China',):
  234. self.dns_trust_servers.add(dnsserver)
  235. break
  236. def do_read(self):
  237. try:
  238. return gevent.server.DatagramServer.do_read(self)
  239. except socket.error as e:
  240. if e[0] not in (errno.ECONNABORTED, errno.ECONNRESET, errno.EPIPE):
  241. raise
  242. def get_reply_record(self, data):
  243. request = dnslib.DNSRecord.parse(data)
  244. qname = str(request.q.qname).lower()
  245. qtype = request.q.qtype
  246. dnsservers = self.dns_servers
  247. if qname.endswith('.in-addr.arpa'):
  248. ipaddr = '.'.join(reversed(qname[:-13].split('.')))
  249. record = dnslib.DNSRecord(header=dnslib.DNSHeader(id=request.header.id, qr=1,aa=1,ra=1), a=dnslib.RR(qname, rdata=dnslib.A(ipaddr)))
  250. return record
  251. if 'USERDNSDOMAIN' in os.environ:
  252. user_dnsdomain = '.' + os.environ['USERDNSDOMAIN'].lower()
  253. if qname.endswith(user_dnsdomain):
  254. qname = qname[:-len(user_dnsdomain)]
  255. if '.' not in qname:
  256. if not self.dns_intranet_servers:
  257. logging.warning('qname=%r is a plain hostname, need intranet dns server!!!', qname)
  258. return dnslib.DNSRecord(header=dnslib.DNSHeader(id=request.header.id, rcode=3))
  259. qname += user_dnsdomain
  260. dnsservers = self.dns_intranet_servers
  261. try:
  262. return self.dns_cache.get((qname, qtype))
  263. except KeyError:
  264. pass
  265. try:
  266. dns_resolve = dnslib_resolve_over_tcp if qname.endswith(self.dns_tcpover) else dnslib_resolve_over_udp
  267. kwargs = {'blacklist': self.dns_blacklist, 'turstservers': self.dns_trust_servers}
  268. record = dns_resolve(request, dnsservers, self.dns_timeout, **kwargs)
  269. ttl = max(x.ttl for x in record.rr) if record.rr else 600
  270. self.dns_cache.set((qname, qtype), record, ttl * 2)
  271. return record
  272. except socket.gaierror as e:
  273. logging.warning('resolve %r failed: %r', qname, e)
  274. return dnslib.DNSRecord(header=dnslib.DNSHeader(id=request.header.id, rcode=3))
  275. def handle(self, data, address):
  276. logging.debug('receive from %r data=%r', address, data)
  277. record = self.get_reply_record(data)
  278. return self.sendto(data[:2] + record.pack()[2:], address)
  279. def test():
  280. logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(asctime)s %(message)s', datefmt='[%b %d %H:%M:%S]')
  281. dns_servers = '8.8.8.8|8.8.4.4|168.95.1.1|168.95.192.1|223.5.5.5|223.6.6.6|114.114.114.114|114.114.115.115'.split('|')
  282. dns_blacklist = '1.1.1.1|255.255.255.255|74.125.127.102|74.125.155.102|74.125.39.102|74.125.39.113|209.85.229.138|4.36.66.178|8.7.198.45|37.61.54.158|46.82.174.68|59.24.3.173|64.33.88.161|64.33.99.47|64.66.163.251|65.104.202.252|65.160.219.113|66.45.252.237|72.14.205.104|72.14.205.99|78.16.49.15|93.46.8.89|128.121.126.139|159.106.121.75|169.132.13.103|192.67.198.6|202.106.1.2|202.181.7.85|203.161.230.171|203.98.7.65|207.12.88.98|208.56.31.43|209.145.54.50|209.220.30.174|209.36.73.33|211.94.66.147|213.169.251.35|216.221.188.182|216.234.179.13|243.185.187.3|243.185.187.39|23.89.5.60|37.208.111.120|49.2.123.56|54.76.135.1|77.4.7.92|118.5.49.6|188.5.4.96|189.163.17.5|197.4.4.12|249.129.46.48|253.157.14.165|183.207.229.|183.207.232.'.split('|')
  283. dns_tcpover = ['.youtube.com', '.googlevideo.com']
  284. logging.info('serving at port 53...')
  285. DNSServer(('', 53), dns_servers=dns_servers, dns_blacklist=dns_blacklist, dns_tcpover=dns_tcpover).serve_forever()
  286. if __name__ == '__main__':
  287. test()