/local/dnsproxy.py
Python | 321 lines | 277 code | 32 blank | 12 comment | 96 complexity | c91d353d5855c99f5aa5df68ed71388f MD5 | raw file
- #!/usr/bin/env python
- # coding:utf-8
-
- __version__ = '1.0'
-
- import sys
- import os
- import sysconfig
-
- sys.path += [os.path.abspath(os.path.join(__file__, '../packages.egg/%s' % x)) for x in ('noarch', sysconfig.get_platform().split('-')[0])]
-
- import gevent
- import gevent.server
- import gevent.timeout
- import gevent.monkey
- gevent.monkey.patch_all(subprocess=True)
-
- import re
- import time
- import logging
- import heapq
- import socket
- import select
- import struct
- import errno
- import thread
- import dnslib
- import Queue
- import pygeoip
-
-
- is_local_addr = re.compile(r'(?i)(?:[0-9a-f:]+0:5efe:)?(?:127(?:\.\d+){3}|10(?:\.\d+){3}|192\.168(?:\.\d+){2}|172\.(?:1[6-9]|2\d|3[01])(?:\.\d+){2})').match
-
-
- def get_dnsserver_list():
- if os.name == 'nt':
- import ctypes, ctypes.wintypes, struct, socket
- DNS_CONFIG_DNS_SERVER_LIST = 6
- buf = ctypes.create_string_buffer(2048)
- ctypes.windll.dnsapi.DnsQueryConfig(DNS_CONFIG_DNS_SERVER_LIST, 0, None, None, ctypes.byref(buf), ctypes.byref(ctypes.wintypes.DWORD(len(buf))))
- ipcount = struct.unpack('I', buf[0:4])[0]
- iplist = [socket.inet_ntoa(buf[i:i+4]) for i in xrange(4, ipcount*4+4, 4)]
- return iplist
- elif os.path.isfile('/etc/resolv.conf'):
- with open('/etc/resolv.conf', 'rb') as fp:
- return re.findall(r'(?m)^nameserver\s+(\S+)', fp.read())
- else:
- logging.warning("get_dnsserver_list failed: unsupport platform '%s-%s'", sys.platform, os.name)
- return []
-
-
- def parse_hostport(host, default_port=80):
- m = re.match(r'(.+)[#](\d+)$', host)
- if m:
- return m.group(1).strip('[]'), int(m.group(2))
- else:
- return host.strip('[]'), default_port
-
-
- class ExpireCache(object):
- """ A dictionary-like object, supporting expire semantics."""
- def __init__(self, max_size=1024):
- self.__maxsize = max_size
- self.__values = {}
- self.__expire_times = {}
- self.__expire_heap = []
-
- def size(self):
- return len(self.__values)
-
- def clear(self):
- self.__values.clear()
- self.__expire_times.clear()
- del self.__expire_heap[:]
-
- def exists(self, key):
- return key in self.__values
-
- def set(self, key, value, expire):
- try:
- et = self.__expire_times[key]
- pos = self.__expire_heap.index((et, key))
- del self.__expire_heap[pos]
- if pos < len(self.__expire_heap):
- heapq._siftup(self.__expire_heap, pos)
- except KeyError:
- pass
- et = int(time.time() + expire)
- self.__expire_times[key] = et
- heapq.heappush(self.__expire_heap, (et, key))
- self.__values[key] = value
- self.cleanup()
-
- def get(self, key):
- et = self.__expire_times[key]
- if et < time.time():
- self.cleanup()
- raise KeyError(key)
- return self.__values[key]
-
- def delete(self, key):
- et = self.__expire_times.pop(key)
- pos = self.__expire_heap.index((et, key))
- del self.__expire_heap[pos]
- if pos < len(self.__expire_heap):
- heapq._siftup(self.__expire_heap, pos)
- del self.__values[key]
-
- def cleanup(self):
- t = int(time.time())
- eh = self.__expire_heap
- ets = self.__expire_times
- v = self.__values
- size = self.__maxsize
- heappop = heapq.heappop
- #Delete expired, ticky
- while eh and eh[0][0] <= t or len(v) > size:
- _, key = heappop(eh)
- del v[key], ets[key]
-
-
- def dnslib_resolve_over_udp(query, dnsservers, timeout, **kwargs):
- """
- http://gfwrev.blogspot.com/2009/11/gfwdns.html
- http://zh.wikipedia.org/wiki/%E5%9F%9F%E5%90%8D%E6%9C%8D%E5%8A%A1%E5%99%A8%E7%BC%93%E5%AD%98%E6%B1%A1%E6%9F%93
- http://support.microsoft.com/kb/241352
- https://gist.github.com/klzgrad/f124065c0616022b65e5
- """
- if not isinstance(query, (basestring, dnslib.DNSRecord)):
- raise TypeError('query argument requires string/DNSRecord')
- blacklist = kwargs.get('blacklist', ())
- turstservers = kwargs.get('turstservers', ())
- dns_v4_servers = [x for x in dnsservers if ':' not in x]
- dns_v6_servers = [x for x in dnsservers if ':' in x]
- sock_v4 = sock_v6 = None
- socks = []
- if dns_v4_servers:
- sock_v4 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
- socks.append(sock_v4)
- if dns_v6_servers:
- sock_v6 = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
- socks.append(sock_v6)
- timeout_at = time.time() + timeout
- try:
- for _ in xrange(4):
- try:
- for dnsserver in dns_v4_servers:
- if isinstance(query, basestring):
- if dnsserver in ('8.8.8.8', '8.8.4.4'):
- query = '.'.join(x[:-1] + x[-1].upper() for x in query.split('.')).title()
- query = dnslib.DNSRecord(q=dnslib.DNSQuestion(query))
- query_data = query.pack()
- if query.q.qtype == 1 and dnsserver in ('8.8.8.8', '8.8.4.4'):
- query_data = query_data[:-5] + '\xc0\x04' + query_data[-4:]
- sock_v4.sendto(query_data, parse_hostport(dnsserver, 53))
- for dnsserver in dns_v6_servers:
- if isinstance(query, basestring):
- query = dnslib.DNSRecord(q=dnslib.DNSQuestion(query, qtype=dnslib.QTYPE.AAAA))
- query_data = query.pack()
- sock_v6.sendto(query_data, parse_hostport(dnsserver, 53))
- while time.time() < timeout_at:
- ins, _, _ = select.select(socks, [], [], 0.1)
- for sock in ins:
- reply_data, reply_address = sock.recvfrom(512)
- reply_server = reply_address[0]
- record = dnslib.DNSRecord.parse(reply_data)
- iplist = [str(x.rdata) for x in record.rr if x.rtype in (1, 28, 255)]
- if any(x in blacklist for x in iplist):
- logging.warning('query=%r dnsservers=%r record bad iplist=%r', query, dnsservers, iplist)
- elif record.header.rcode and not iplist and reply_server in turstservers:
- logging.info('query=%r trust reply_server=%r record rcode=%s', query, reply_server, record.header.rcode)
- return record
- elif iplist:
- logging.debug('query=%r reply_server=%r record iplist=%s', query, reply_server, iplist)
- return record
- else:
- logging.debug('query=%r reply_server=%r record null iplist=%s', query, reply_server, iplist)
- continue
- except socket.error as e:
- logging.warning('handle dns query=%s socket: %r', query, e)
- raise socket.gaierror(11004, 'getaddrinfo %r from %r failed' % (query, dnsservers))
- finally:
- for sock in socks:
- sock.close()
-
-
- def dnslib_resolve_over_tcp(query, dnsservers, timeout, **kwargs):
- """dns query over tcp"""
- if not isinstance(query, (basestring, dnslib.DNSRecord)):
- raise TypeError('query argument requires string/DNSRecord')
- blacklist = kwargs.get('blacklist', ())
- def do_resolve(query, dnsserver, timeout, queobj):
- if isinstance(query, basestring):
- qtype = dnslib.QTYPE.AAAA if ':' in dnsserver else dnslib.QTYPE.A
- query = dnslib.DNSRecord(q=dnslib.DNSQuestion(query, qtype=qtype))
- query_data = query.pack()
- sock_family = socket.AF_INET6 if ':' in dnsserver else socket.AF_INET
- sock = socket.socket(sock_family)
- rfile = None
- try:
- sock.settimeout(timeout or None)
- sock.connect(parse_hostport(dnsserver, 53))
- sock.send(struct.pack('>h', len(query_data)) + query_data)
- rfile = sock.makefile('r', 1024)
- reply_data_length = rfile.read(2)
- if len(reply_data_length) < 2:
- raise socket.gaierror(11004, 'getaddrinfo %r from %r failed' % (query, dnsserver))
- reply_data = rfile.read(struct.unpack('>h', reply_data_length)[0])
- record = dnslib.DNSRecord.parse(reply_data)
- iplist = [str(x.rdata) for x in record.rr if x.rtype in (1, 28, 255)]
- if any(x in blacklist for x in iplist):
- logging.debug('query=%r dnsserver=%r record bad iplist=%r', query, dnsserver, iplist)
- raise socket.gaierror(11004, 'getaddrinfo %r from %r failed' % (query, dnsserver))
- else:
- logging.debug('query=%r dnsserver=%r record iplist=%s', query, dnsserver, iplist)
- queobj.put(record)
- except socket.error as e:
- logging.debug('query=%r dnsserver=%r failed %r', query, dnsserver, e)
- queobj.put(e)
- finally:
- if rfile:
- rfile.close()
- sock.close()
- queobj = Queue.Queue()
- for dnsserver in dnsservers:
- thread.start_new_thread(do_resolve, (query, dnsserver, timeout, queobj))
- for i in range(len(dnsservers)):
- try:
- result = queobj.get(timeout)
- except Queue.Empty:
- raise socket.gaierror(11004, 'getaddrinfo %r from %r failed' % (query, dnsservers))
- if result and not isinstance(result, Exception):
- return result
- elif i == len(dnsservers) - 1:
- logging.warning('dnslib_resolve_over_tcp %r with %s return %r', query, dnsservers, result)
- raise socket.gaierror(11004, 'getaddrinfo %r from %r failed' % (query, dnsservers))
-
-
- class DNSServer(gevent.server.DatagramServer):
- """DNS Proxy based on gevent/dnslib"""
-
- def __init__(self, *args, **kwargs):
- dns_blacklist = kwargs.pop('dns_blacklist')
- dns_servers = kwargs.pop('dns_servers')
- dns_tcpover = kwargs.pop('dns_tcpover', [])
- dns_timeout = kwargs.pop('dns_timeout', 2)
- super(self.__class__, self).__init__(*args, **kwargs)
- self.dns_servers = list(dns_servers)
- self.dns_tcpover = tuple(dns_tcpover)
- self.dns_intranet_servers = [x for x in self.dns_servers if is_local_addr(x)]
- self.dns_blacklist = set(dns_blacklist)
- self.dns_timeout = int(dns_timeout)
- self.dns_cache = ExpireCache(max_size=65536)
- self.dns_trust_servers = set(['8.8.8.8', '8.8.4.4', '2001:4860:4860::8888', '2001:4860:4860::8844'])
- for dirname in ('.', '/usr/share/GeoIP/', '/usr/local/share/GeoIP/'):
- filename = os.path.join(dirname, 'GeoIP.dat')
- if os.path.isfile(filename):
- geoip = pygeoip.GeoIP(filename)
- for dnsserver in self.dns_servers:
- if ':' not in dnsserver and geoip.country_name_by_addr(parse_hostport(dnsserver, 53)[0]) not in ('China',):
- self.dns_trust_servers.add(dnsserver)
- break
-
- def do_read(self):
- try:
- return gevent.server.DatagramServer.do_read(self)
- except socket.error as e:
- if e[0] not in (errno.ECONNABORTED, errno.ECONNRESET, errno.EPIPE):
- raise
-
- def get_reply_record(self, data):
- request = dnslib.DNSRecord.parse(data)
- qname = str(request.q.qname).lower()
- qtype = request.q.qtype
- dnsservers = self.dns_servers
- if qname.endswith('.in-addr.arpa'):
- ipaddr = '.'.join(reversed(qname[:-13].split('.')))
- record = dnslib.DNSRecord(header=dnslib.DNSHeader(id=request.header.id, qr=1,aa=1,ra=1), a=dnslib.RR(qname, rdata=dnslib.A(ipaddr)))
- return record
- if 'USERDNSDOMAIN' in os.environ:
- user_dnsdomain = '.' + os.environ['USERDNSDOMAIN'].lower()
- if qname.endswith(user_dnsdomain):
- qname = qname[:-len(user_dnsdomain)]
- if '.' not in qname:
- if not self.dns_intranet_servers:
- logging.warning('qname=%r is a plain hostname, need intranet dns server!!!', qname)
- return dnslib.DNSRecord(header=dnslib.DNSHeader(id=request.header.id, rcode=3))
- qname += user_dnsdomain
- dnsservers = self.dns_intranet_servers
- try:
- return self.dns_cache.get((qname, qtype))
- except KeyError:
- pass
- try:
- dns_resolve = dnslib_resolve_over_tcp if qname.endswith(self.dns_tcpover) else dnslib_resolve_over_udp
- kwargs = {'blacklist': self.dns_blacklist, 'turstservers': self.dns_trust_servers}
- record = dns_resolve(request, dnsservers, self.dns_timeout, **kwargs)
- ttl = max(x.ttl for x in record.rr) if record.rr else 600
- self.dns_cache.set((qname, qtype), record, ttl * 2)
- return record
- except socket.gaierror as e:
- logging.warning('resolve %r failed: %r', qname, e)
- return dnslib.DNSRecord(header=dnslib.DNSHeader(id=request.header.id, rcode=3))
-
- def handle(self, data, address):
- logging.debug('receive from %r data=%r', address, data)
- record = self.get_reply_record(data)
- return self.sendto(data[:2] + record.pack()[2:], address)
-
-
- def test():
- logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(asctime)s %(message)s', datefmt='[%b %d %H:%M:%S]')
- dns_servers = '8.8.8.8|8.8.4.4|168.95.1.1|168.95.192.1|223.5.5.5|223.6.6.6|114.114.114.114|114.114.115.115'.split('|')
- dns_blacklist = '1.1.1.1|255.255.255.255|74.125.127.102|74.125.155.102|74.125.39.102|74.125.39.113|209.85.229.138|4.36.66.178|8.7.198.45|37.61.54.158|46.82.174.68|59.24.3.173|64.33.88.161|64.33.99.47|64.66.163.251|65.104.202.252|65.160.219.113|66.45.252.237|72.14.205.104|72.14.205.99|78.16.49.15|93.46.8.89|128.121.126.139|159.106.121.75|169.132.13.103|192.67.198.6|202.106.1.2|202.181.7.85|203.161.230.171|203.98.7.65|207.12.88.98|208.56.31.43|209.145.54.50|209.220.30.174|209.36.73.33|211.94.66.147|213.169.251.35|216.221.188.182|216.234.179.13|243.185.187.3|243.185.187.39|23.89.5.60|37.208.111.120|49.2.123.56|54.76.135.1|77.4.7.92|118.5.49.6|188.5.4.96|189.163.17.5|197.4.4.12|249.129.46.48|253.157.14.165|183.207.229.|183.207.232.'.split('|')
- dns_tcpover = ['.youtube.com', '.googlevideo.com']
- logging.info('serving at port 53...')
- DNSServer(('', 53), dns_servers=dns_servers, dns_blacklist=dns_blacklist, dns_tcpover=dns_tcpover).serve_forever()
-
-
- if __name__ == '__main__':
- test()