PageRenderTime 132ms CodeModel.GetById 96ms app.highlight 29ms RepoModel.GetById 1ms app.codeStats 0ms

/bin/apiary_mysql_logger.py

https://bitbucket.org/lindenlab/apiary/
Python | 787 lines | 499 code | 112 blank | 176 comment | 73 complexity | f85d7a6461556e3ab13f039633f5e122 MD5 | raw file
  1#!/usr/bin/python
  2#
  3# $LicenseInfo:firstyear=2010&license=mit$
  4# 
  5# Copyright (c) 2010, Linden Research, Inc.
  6# 
  7# Permission is hereby granted, free of charge, to any person obtaining a copy
  8# of this software and associated documentation files (the "Software"), to deal
  9# in the Software without restriction, including without limitation the rights
 10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 11# copies of the Software, and to permit persons to whom the Software is
 12# furnished to do so, subject to the following conditions:
 13# 
 14# The above copyright notice and this permission notice shall be included in
 15# all copies or substantial portions of the Software.
 16# 
 17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 23# THE SOFTWARE.
 24# $/LicenseInfo$
 25#
 26
 27"""
 28Log all queries hitting a particular mysql database
 29"""
 30
 31try:
 32    import psyco
 33    psyco.full()
 34except:
 35    pass
 36
 37import array
 38import curses
 39import curses.wrapper
 40import getopt
 41import os.path
 42import re
 43import socket
 44import struct
 45import sys
 46import time
 47import math
 48
 49LOG_ROTATION_INTERVAL=3600
 50MAX_LOGS = 36
 51MIN_BIN=-15
 52MAX_BIN=10
 53ip_table = {}
 54host_type_cache = {}
 55
 56sim_re = re.compile(".*sim\d+.*")
 57web_re = re.compile("int\.web\d+.*")
 58iweb_re = re.compile("int\.iweb\d+.*")
 59webds_re = re.compile(".*web-ds\d+.*")
 60login_re = re.compile(".*login\d+.*")
 61data_re = re.compile(".*data\..*")
 62xmlrpc_re = re.compile("(?:int\.omgiwanna.*)|(?:int\.pony.*)")
 63ip_re = re.compile("\d+\.\d+\.\d+\.\d+")
 64ll_re = re.compile("(.*)\.lindenlab\.com")
 65
 66#
 67# Utility stuff for query cleaner
 68#
 69
 70hex_wildcard = r"[0-9a-fA-F]"
 71word = hex_wildcard + r"{4,4}-"
 72long_word = hex_wildcard + r"{8,8}-"
 73very_long_word = hex_wildcard + r"{12,12}"
 74UUID_REGEX_STRING = long_word + word + word + word + very_long_word
 75uuid_re = re.compile("[\"\']"+UUID_REGEX_STRING+"[\"\']")
 76hex_re = re.compile("[\"\'][\da-f]+[\"\']")
 77num_re = re.compile("[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?")
 78
 79# Quoted string re from: http://blog.stevenlevithan.com/archives/match-quoted-string
 80string_re = re.compile(r'([\"\'])(?:(?=(\\?))\2.)*?\1')
 81
 82values_re = re.compile('VALUES\s+\(.*\)', re.IGNORECASE)
 83in_re = re.compile('IN\s+\(.*\)', re.IGNORECASE)
 84
 85prepare_re = re.compile('PREPARE.*', re.IGNORECASE)
 86deallocate_re = re.compile('DEALLOCATE\s+PREPARE.*', re.IGNORECASE)
 87execute_re = re.compile('EXECUTE.*', re.IGNORECASE)
 88mdb_re = re.compile('MDB2_STATEMENT\S+')
 89 
 90
 91def llquery_from_llsd(query_llsd):
 92    # Hack, fill in arbitary data for info that isn't serialized
 93    query = LLQuery(None, None, query_llsd['query'], 0.0)
 94    query.mData['host_clean'] = query_llsd['host_clean']
 95    query.mData['query_clean'] = query_llsd['query_clean']
 96
 97    # Hack, keeps correctOutliers from trashing the data
 98    #query.mNumQueries = query_llsd['num_queries']
 99    #query.mTotalTime = query_llsd['total_time']
100    try:
101        query.mNumQueriesCorrected = query_llsd['num_queries_corrected']
102        query.mTotalTimeCorrected = query_llsd['total_time_corrected']
103    except:
104        # Hack for old output which didn't generate this data
105        query.mNumQueriesCorrected = query_llsd['num_queries']
106        query.mTotalTimeCorrected = query_llsd['total_time']
107        
108    return query
109
110
111# MySQL protocol sniffer, using tcpdump, ncap packet parsing and mysql internals
112# http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol
113class LLQueryStream:
114    "Process a raw tcpdump stream (in raw libpcap format)"
115    def __init__(self, in_file):
116        self.mInFile = in_file
117        self.mStartTime = time.time()
118
119        #
120        # A list of all outstanding "connections", and what they're doing.
121        # This is necessary in order to get script timing and other information.
122        #
123        self.mConnStatus = {}
124
125        #
126        # Parse/skip past the libpcap global header
127        #
128        
129        #guint32 magic_number;   /* magic number */
130        #guint16 version_major;  /* major version number */
131        #guint16 version_minor;  /* minor version number */
132        #gint32  thiszone;       /* GMT to local correction */
133        #guint32 sigfigs;        /* accuracy of timestamps */
134        #guint32 snaplen;        /* max length of captured packets, in octets */
135        #guint32 network;        /* data link type */
136
137        # Skip past the libpcap global header
138        format = 'IHHiIII'
139        size = struct.calcsize(format)
140        header_bin = self.mInFile.read(size)
141        res = struct.unpack(format, header_bin)
142
143    def closeConnection(self, ip_port):
144        if ip_port in self.mConnStatus:
145            del self.mConnStatus[ip_port]
146
147
148    def getNextEvent(self):
149        # Get the next event out of the packet stream
150
151        td_format = 'IIII'
152        ip_format = '!BBHHHBBHII'
153        tcp_format = '!HHIIBBHHH'
154        while 1:
155            #
156            # Parse out an individual packet from the tcpdump stream
157            #
158            # Match the packet header
159
160            # Pull a record (packet) off of the wire
161
162            # Packet header
163            # guint32 ts_sec;         /* timestamp seconds */
164            # guint32 ts_usec;        /* timestamp microseconds */
165            # guint32 incl_len;       /* number of octets of packet saved in file */
166            # guint32 orig_len;       /* actual length of packet */
167            ph_bin = self.mInFile.read(16)
168            res = struct.unpack(td_format, ph_bin)
169            ts_sec = res[0]
170            ts_usec = res[1]
171            pkt_time = ts_sec + (ts_usec/1000000.0)
172            incl_len = res[2]
173            orig_len = res[3]
174
175            # Packet data (incl_len bytes)
176            raw_data = self.mInFile.read(incl_len)
177
178            # Parse out the MAC header
179            # Don't bother, we don't care - 14 byte header
180            mac_offset = 14
181
182            # Parse out the IP header (min 20 bytes)
183            # 4 bits - version
184            # 4 bits - header length in 32 bit words
185            # 1 byte - type of service
186            # 2 bytes - total length
187            # 2 bytes - fragment identification
188            # 3 bits - flags
189            # 13 bits - fragment offset
190            # 1 byte - TTL
191            # 1 byte - Protocol (should be 6)
192            # 2 bytes - header checksum
193            # 4 bytes - source IP
194            # 4 bytes - dest IP
195            
196            ip_header = struct.unpack(ip_format, raw_data[mac_offset:mac_offset + 20])
197
198            # Assume all packets are TCP
199            #if ip_header[6] != 6:
200            #    print "Not TCP!"
201            #    continue
202            
203            src_ip_bin = ip_header[8]
204            src_ip = lookup_ip_string(src_ip_bin)
205            #src_ip = "%d.%d.%d.%d" % ((src_ip_bin & 0xff000000L) >> 24,
206            #                          (src_ip_bin & 0x00ff0000L) >> 16,
207            #                          (src_ip_bin & 0x0000ff00L) >> 8,
208            #                          src_ip_bin & 0x000000ffL)
209            dst_ip_bin = ip_header[9]
210            dst_ip = lookup_ip_string(dst_ip_bin)
211            #dst_ip = "%d.%d.%d.%d" % ((dst_ip_bin & 0xff000000L) >> 24,
212            #                          (dst_ip_bin & 0x00ff0000L) >> 16,
213            #                          (dst_ip_bin & 0x0000ff00L) >> 8,
214            #                          dst_ip_bin & 0x000000ffL)
215            
216            ip_size = (ip_header[0] & 0x0f) * 4
217            
218
219            # Parse out the TCP packet header
220            # 2 bytes - src_prt
221            # 2 bytes - dst_port
222            # 4 bytes - sequence number
223            # 4 bytes - ack number
224            # 4 bits - data offset (size in 32 bit words of header
225            # 6 bits - reserved
226            # 6 bits - control bits
227            # 2 bytes - window
228            # 2 bytes - checksum
229            # 2 bytes - urgent pointer
230
231            tcp_offset = mac_offset + ip_size
232            tcp_header = struct.unpack(tcp_format, raw_data[tcp_offset:tcp_offset+20])
233            tcp_size = ((tcp_header[4] & 0xf0) >> 4) * 4
234
235            src_port = tcp_header[0]
236            dst_port = tcp_header[1]
237
238            # 3 bytes - packet length
239            # 1 byte - packet number
240            # 1 byte - command
241            # <n bytes> - args
242            pkt_offset = tcp_offset + tcp_size
243
244            if len(raw_data) == pkt_offset:
245                continue
246
247            # Clearly not a mysql packet if it's less than 5 bytes of data
248            if len(raw_data) - pkt_offset < 5:
249                continue
250
251            src_ip_port = "%s:%d" % (src_ip, src_port)
252            dst_ip_port = "%s:%d" % (dst_ip, dst_port)
253
254            if src_port == 3306:
255                #
256                # We are processing traffic from mysql server -> client
257                # This primarily is used to time how long it takes for use
258                # to start receiving data to the client from the server.
259                #
260                mysql_arr = array.array('B', raw_data[pkt_offset])
261                result_type = ord(raw_data[pkt_offset])
262
263                # Track the connection if we don't know about it yet.
264                if not dst_ip_port in self.mConnStatus:
265                    self.mConnStatus[dst_ip_port] = LLConnStatus(dst_ip_port, pkt_time)
266                conn = self.mConnStatus[dst_ip_port]
267
268                # Update the status of this connection, including query times on
269                # connections
270                if conn.updateResponse(pkt_time, result_type):
271                    # Event: Initial query response
272                    return "QueryResponse", conn.mLastQuery
273                continue
274            if dst_port == 3306:
275                #
276                # Processing a packet from the client to the server
277                #
278
279                # Pull out packet length from the header
280                mysql_arr = array.array('B', raw_data[pkt_offset:pkt_offset+5])
281                pkt_len = mysql_arr[0] + (long(mysql_arr[1]) << 8) + (long(mysql_arr[2]) << 16)
282
283                pkt_number = mysql_arr[3]
284
285                # Find the connection associated with this packet
286                if not src_ip_port in self.mConnStatus:
287                    self.mConnStatus[src_ip_port] = LLConnStatus(src_ip_port, pkt_time)
288                conn = self.mConnStatus[src_ip_port]
289
290                #if conn.mLastMysqlPacketNumber != (pkt_number - 1):
291                #    print "Prev:", conn.mLastMysqlPacketNumber, "Cur:", pkt_number
292                conn.mLastMysqlPacketNumber = pkt_number
293                
294                cmd = mysql_arr[4]
295                # If we're not a command, do stuff
296                if cmd > 0x1c:
297                    # Unfortunately, we can't trivially tell the difference between
298                    # various non-command packets
299                    # Assume that these are all AuthResponses for now.
300
301                    conn.updateNonCommand(pkt_time, raw_data[pkt_offset:])
302                    if "QuerySent" == conn.mCurState:
303                        return ("QueryStart", conn.mLastQuery)
304                    continue
305
306                query = None
307
308                if cmd == 1:
309                    # Event: Quitting a connection
310                    conn.quit(src_ip, src_port, pkt_time)
311                    # This connection is closing, get rid of it
312                    self.closeConnection(src_ip_port)
313                    return ("Quit", conn.mLastQuery)
314                elif cmd == 3:
315                    # Event: Starting a query
316                    conn.queryStart(src_ip, src_port, pkt_time, raw_data, pkt_len, pkt_offset + 5)
317
318                    # Only return an QueryStart if we have the whole query
319                    if "QuerySent" == conn.mCurState:
320                        return ("QueryStart", conn.mLastQuery)
321                else:
322                    pass
323
324class LLQuery:
325    fromLLSDStats = staticmethod(llquery_from_llsd)
326    def __init__(self, host, port, query, start_time):
327        # Store information which will be serialized for metadata in a map
328        self.mData = {}
329        self.mData['host'] = host
330        self.mData['port'] = port
331        self.mData['query'] = query
332
333        # Metadata
334        self.mData['host_clean'] = None
335        self.mData['query_clean'] = None
336        self.mData['tables'] = []
337
338        # Stats information
339        self.mNumQueries = 0
340        self.mTotalTime = 0.0
341        self.mOutQueries = 0
342        self.mTotalTimeCorrected = 0.0
343        self.mNumQueriesCorrected = 0
344        self.mBins = {} # Bins for histogram
345
346        # This stuff doesn't usually get serialized
347        self.mQueryLen = len(query)
348        self.mStartTime = start_time
349        self.mResponseTime = start_time
350
351    def __hash__(self):
352        return (self.mData['host_clean'] + ":" + self.mData['query_clean']).__hash__()
353
354    def __eq__(self, other):
355        # Note, this matches on clean, not strictly correct
356        if ((self.mData['query_clean'] == other.mData['query_clean']) and
357            (self.mData['host_clean'] == other.mData['host_clean'])):
358            return True
359        return False
360
361    def getKey(self):
362        # The string key is just the clean host and query, concatenated
363        return self.mData['host_clean'] + ":" + self.mData['query_clean']
364        
365    def clean(self):
366        "Generate the clean query so it can be used for statistics"
367        if not self.mData['host_clean']:
368            self.mData['host_clean'] = host_type(self.mData['host'])
369            self.mData['query_clean'] = clean_query(self.mData['query'], 0)
370
371    def getAvgTimeCorrected(self):
372        return self.mTotalTimeCorrected/self.mNumQueriesCorrected
373
374    def queryStart(self):
375        "When collecting query stats, use this when the query is receieved"
376        self.mNumQueries += 1
377        self.mOutQueries += 1
378
379    def queryResponse(self, elapsed):
380        "When collecting stats, use this when the response is received"
381        self.mTotalTime += elapsed
382        self.mOutQueries -=1
383        bin = MIN_BIN
384        if elapsed:
385            bin = int(math.log(elapsed,2))
386        bin = max(MIN_BIN, bin)
387        bin = min(MAX_BIN, bin)
388        if bin not in self.mBins:
389            self.mBins[bin] = LLQueryStatBin(bin)
390        self.mBins[bin].accumulate(elapsed)
391
392    def correctOutliers(self):
393        "Find outliers bins and calculate corrected results"
394        # Outliers are 3 orders of magnitude less than the total count
395        if not self.mNumQueries:
396            # FIXME: This is a hack because we don't save this information in the query count dump
397            return
398        min_queries = self.mNumQueries/1000
399        self.mTotalTimeCorrected = 0.0
400        self.mNumQueriesCorrected = 0
401        for i in self.mBins.keys():
402            if self.mBins[i].mNumQueries < min_queries:
403                # Outlier, flag as such.
404                self.mBins[i].mOutlier = True
405            else:
406                self.mTotalTimeCorrected += self.mBins[i].mTotalTime
407                self.mNumQueriesCorrected += self.mBins[i].mNumQueries
408        if self.mNumQueriesCorrected == 0:
409            #HACK: Deal with divide by zero
410            self.mNumQueriesCorrected = 1
411
412    sReadRE = re.compile("(SELECT.*)|(USE.*)", re.IGNORECASE)
413    sSelectWhereRE = re.compile("\(?\s*?SELECT.+?FROM\s+\(?(.*?)\)?\s+WHERE.*", re.IGNORECASE)
414    sSelectRE = re.compile("\(?\s*?SELECT.+?FROM\s+(.+)(?:\s+LIMIT.*|.*)", re.IGNORECASE)
415    sUpdateRE = re.compile("UPDATE\s+(.+?)\s+SET.*", re.IGNORECASE)
416    sReplaceRE = re.compile("REPLACE INTO\s+(.+?)(?:\s*\(|\s+SET).*", re.IGNORECASE)
417    sInsertRE = re.compile("INSERT.+?INTO\s+(.+?)(?:\s*\(|\s+SET).*", re.IGNORECASE)
418    sDeleteRE = re.compile("DELETE.+?FROM\s+(.+?)\s+WHERE.*", re.IGNORECASE)
419    def analyze(self):
420        "Does some query analysis on the query"
421        if 'type' in self.mData:
422            # Already analyzed
423            return
424        query = self.mData['query_clean']
425        if LLQuery.sReadRE.match(query):
426            self.mData['type'] = 'read'
427        else:
428            self.mData['type'] = 'write'
429
430        self.mData['tables'] = get_query_tables(query)
431
432    def dumpLine(self, elapsed, query_len = 0):
433        bin_str = ''
434        for i in range(MIN_BIN,MAX_BIN+1):
435            if i in self.mBins:
436                if self.mBins[i].mOutlier:
437                    bin_str += '*'
438                else:
439                    bin_str += str(int(math.log10(self.mBins[i].mNumQueries)))
440            else:
441                bin_str += '.'
442        if not query_len:
443            query_len = 4096
444        num_queries = self.mNumQueriesCorrected
445        if not num_queries:
446            num_queries = 1
447        return ("%s\t%5d\t%6.2f\t%6.2f\t%1.4f\t%s\t" % (bin_str, num_queries,
448                                                       num_queries/elapsed, self.mTotalTimeCorrected,
449                                                       self.mTotalTimeCorrected/num_queries, self.mData['host_clean'])) \
450                                                       + self.mData['query_clean'][0:query_len]
451
452    def as_map(self):
453        "Make an LLSD map version of data that can be used for merging"
454        self.analyze()
455        self.mData['num_queries'] = self.mNumQueries
456        self.mData['total_time'] = self.mTotalTime
457        self.mData['num_queries_corrected'] = self.mNumQueriesCorrected
458        self.mData['total_time_corrected'] = self.mTotalTimeCorrected
459        return self.mData
460
461
462class LLConnStatus:
463    "Keeps track of the status of a connection talking to mysql"
464    def __init__(self, ip_port, start_time):
465        self.mLastMysqlPacketNumber = 0
466        self.mNumPackets = 0
467        self.mIPPort = ip_port
468        self.mStartTime = start_time
469        self.mLastUpdate = start_time
470        self.mCurState = ""
471        self.mLastQuery = None
472        self.mNumQueries = 0
473
474    def quit(self, src_ip, src_port, pkt_time):
475        query = LLQuery(src_ip, src_port, "Quit", pkt_time)
476        query.clean()
477        self.mLastUpdate = pkt_time
478        self.mLastQuery = query
479        self.mNumPackets += 1
480
481    def queryStart(self, src_ip, src_port, pkt_time, raw, pkt_len, offset):
482        query_len = pkt_len - 1
483        query = LLQuery(src_ip, src_port, raw[offset:offset + (pkt_len - 1)], pkt_time)
484        self.mLastUpdate = pkt_time
485        # Packet length includes the command, offset into raw doesn't
486        if query_len > (len(raw) - offset):
487            query.mQueryLen = query_len
488            self.mCurState = "SendingQuery"
489        else:
490            self.mCurState = "QuerySent"
491            query.clean()
492        self.mNumQueries += 1
493        self.mLastQuery = query
494        self.mNumPackets += 1
495
496    def queryStartProcessed(self, src_ip, src_port, pkt_time, query_str):
497        query = LLQuery(src_ip, src_port, query_str, pkt_time)
498        query.clean()
499        self.mLastUpdate = pkt_time
500        self.mCurState = "QuerySent"
501        self.mNumQueries += 1
502        self.mLastQuery = query
503        self.mNumPackets += 1
504
505    def updateNonCommand(self, pkt_time, raw):
506        # Clean up an existing query if you get a non-command.
507        self.mNumPackets += 1
508        self.mLastUpdate = pkt_time
509        if self.mLastQuery:
510            if self.mCurState == "SendingQuery":
511                # We're continuing a query
512                # We won't generate a new clean version, because it'll $!@# up all the sorting.
513                self.mLastQuery.mData['query'] += raw
514                if len(self.mLastQuery.mData['query']) == self.mLastQuery.mQueryLen:
515                    self.mCurState = "QuerySent"
516                    self.mLastQuery.clean()
517                return
518            else:
519                #
520                # A non-command that's continuing a query. Not sure why this is happening,
521                # but clear the last query to avoid generating inadvertent long query results.
522                #
523                self.mLastQuery = None
524        # Default to setting state to "NonCommand"
525        self.mCurState = "NonCommand"
526
527    def updateResponse(self, pkt_time, result_type):
528        # If we've got a query running, accumulate the elapsed time
529        start_query_response = False
530        if self.mCurState == "QuerySent":
531            lq = self.mLastQuery
532            if lq:
533                if lq.mStartTime == 0.0:
534                    lq.mStartTime = pkt_time
535                lq.mResponseTime = pkt_time
536                start_query_response = True
537
538        self.mLastUpdate = pkt_time
539        if result_type == 0:
540            self.mCurState = "Result:RecvOK"
541        elif result_type == 0xff:
542            self.mCurState = "Result:Error"
543        elif result_type == 0xfe:
544            self.mCurState = "Result:EOF"
545        elif result_type == 0x01:
546            self.mCurState = "Result:Header"
547        else:
548            self.mCurState = "Result:Data"
549        return start_query_response
550
551    def dump(self):
552        if self.mLastQuery:
553            print "%s: NumQ: %d State:%s\n\tLast: %s" % (self.mIPPort, self.mNumQueries, self.mCurState,
554                                                         self.mLastQuery.mData['query_clean'][0:40])
555        else:
556            print "%s: NumQ: %d State:%s\n\tLast: None" % (self.mIPPort, self.mNumQueries, self.mCurState)
557    
558def clean_query(query, num_words):
559    "Generalizes a query by removing all unique information"
560    # Generalize the query, remove all unique information
561
562
563    # Strip carriage returns
564    query = query.replace("\n", " ")
565
566    # Screw it, if it's a prepared statement or an execute, generalize the statement name
567    if prepare_re.match(query):
568        query = mdb_re.sub('*statement*', query)
569        return query
570    if execute_re.match(query):
571        query = mdb_re.sub('*statement*', query)
572    if deallocate_re.match(query):
573        query = "DEALLOCATE PREPARE"
574        return query
575
576    # Replace all "unique" information - strings, uuids, numbers
577    query = uuid_re.sub("*uuid*", query)
578    query = hex_re.sub("*hex*", query)
579    try:
580        query = string_re.sub("*string*", query)
581    except:
582        pass
583    query = num_re.sub("*num*", query)
584
585    # Get rid of all "VALUES ()" data.
586    query = values_re.sub("VALUES (*values*)", query)
587    # Get rid of all "IN ()" data.
588    query = in_re.sub("IN (*values*)", query)
589    # After we do the cleanup, then we get rid of extra whitespace
590    words = query.split(None)
591    query = " ".join(words)    
592    return query
593
594
595def host_type(host):
596    "Returns the genericized linden host type from an IP address or hostname"
597    if host in host_type_cache:
598        return host_type_cache[host]
599
600    named_host = host
601    if ip_re.match(host):
602        # Look up the hostname
603        try:
604            named_host = socket.gethostbyaddr(host)[0]
605        except:
606            pass
607
608    # Figure out generic host type
609    host_type = named_host
610    if sim_re.match(named_host):
611        host_type = "sim"
612    elif login_re.match(named_host):
613        host_type = "login"
614    elif web_re.match(named_host):
615        host_type = "web"
616    elif iweb_re.match(named_host):
617        host_type = "iweb"
618    elif webds_re.match(named_host):
619        host_type = "web-ds"
620    elif data_re.match(named_host):
621        host_type = "data"
622    elif xmlrpc_re.match(named_host):
623        host_type = "xmlrpc"
624    m = ll_re.match(host_type)
625    if m:
626        host_type = m.group(1)
627    host_type_cache[host] = host_type
628    return host_type
629
630
631def start_dump(host, port):
632    # Start up tcpdump pushing data into netcat on the sql server
633    interface = "eth0"
634    
635    # Start up tcpdump pushing data into netcat on the sql server
636    SRC_DUMP_CMD = "ssh root@%s '/usr/sbin/tcpdump -n -s 0 -w - -i %s dst port 3306 or src port 3306 | nc %s %d'" \
637                   % (host, interface, socket.getfqdn(), port)
638    os.popen2(SRC_DUMP_CMD, "r")
639
640def lookup_ip_string(ip_bin):
641    if not ip_bin in ip_table:
642        ip_table[ip_bin] = "%d.%d.%d.%d" % ((ip_bin & 0xff000000L) >> 24,
643                                            (ip_bin & 0x00ff0000L) >> 16,
644                                            (ip_bin & 0x0000ff00L) >> 8,
645                                            ip_bin & 0x000000ffL)
646    return ip_table[ip_bin]
647
648
649def remote_mysql_stream(host):
650    # Create a server socket, then have tcpdump dump stuff to it.
651    serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
652
653    bound = False
654    port = 9999
655    while not bound:
656        try:
657            serversocket.bind((socket.gethostname(), port))
658            bound = True
659        except:
660            print port, " already bound, trying again"
661            port += 1
662    print "Bound port %d" % port
663    serversocket.listen(1)
664
665    # Fork off the dumper, start the server on the main connection
666    pid = os.fork()
667    if not pid:
668        # Child process which gets data from the database
669        time.sleep(1.0)
670        print "Starting dump!"
671        start_dump(host, port)
672        print "Exiting dump!"
673        sys.exit(0)
674
675    print "Starting server"
676    (clientsocket, address) = serversocket.accept()
677    print "Accepted connection", address
678
679    # Start listening to the data stream
680    return clientsocket.makefile("rb")
681
682
683def rotate_logs(log_path, query_log_file):
684    # Fork to do the actual rotation/compression
685    print "Rotating query logs"
686    if query_log_file:
687        query_log_file.close()
688    need_gzip = False
689
690    if os.path.exists(log_path+"/query.log"):
691        os.rename(log_path+"/query.log", log_path+"/query.log.tmp")
692        need_gzip = True
693    
694    query_log_file = open("%s/query.log" % log_path, "w")
695
696    pid = os.fork()
697    if pid:
698        return query_log_file
699
700    # Child process actually does the log rotation
701    # Delete the oldest
702    log_filename = log_path+"/query.log.%d.gz" % (MAX_LOGS)
703    if os.path.exists(log_filename):
704        os.remove(log_filename)
705
706    for i in range(0, MAX_LOGS):
707        # Count down from the max and rename
708        n = MAX_LOGS - i
709        log_filename = log_path+"/query.log.%d.gz" % n
710        if os.path.exists(log_filename):
711            os.rename(log_path + ("/query.log.%d.gz" % n), log_path + ("/query.log.%d.gz" % (n+1)))
712
713    if need_gzip:
714        # Compress the "first" log (query.log.tmp)
715        os.rename(log_path + "/query.log.tmp", log_path + "/query.log.1")
716        os.system('gzip -f %s' % (log_path + "/query.log.1"))
717    print "Done rotating logs!"
718    sys.exit(0)
719
720
721def watch_host(query_stream, host):
722    "Watches query traffic for a particular host.  Returns the overall query counts when exited by breaking"
723
724    # Make output path
725    output_path = "./%s" % host
726    os.system("mkdir -p %s" % output_path)
727    query_log_file = rotate_logs(output_path, None)
728
729    last_log_time = time.time()
730
731    done = False
732    count = 0
733    try:
734        while not done:
735            (event_type, query) = query_stream.getNextEvent()
736
737            # Use the start time to determine which hour bin to put the query into
738            start_time = query.mStartTime
739            start_hour = time.localtime(start_time)[3]
740            
741            if event_type == "QueryStart":
742                query_log_file.write("%f\t%s:%d\t%s\tQueryStart\n" % (query.mStartTime, query.mData['host'], query.mData['port'], query.mData['host_clean']))
743                query_log_file.write("%s\n" % (query.mData['query']))
744                query_log_file.write("**************************************\n")
745                count += 1
746            elif (event_type == "QueryResponse"):
747                query_log_file.write("%f\t%s:%d\t%s\tQueryResponse\n" % (query.mResponseTime, query.mData['host'], query.mData['port'], query.mData['host_clean']))
748                query_log_file.write("%s\n" % (query.mData['query']))
749                query_log_file.write("**************************************\n")
750            elif event_type == "Quit":
751                # Quit is an "instantaneous" query, both start and response
752                query_log_file.write("%f\t%s:%d\t%s\tQuit\n" % (query.mStartTime, query.mData['host'], query.mData['port'], query.mData['host_clean']))
753                query_log_file.write("%s\n" % (query.mData['query']))
754                query_log_file.write("**************************************\n")
755                continue
756            if not (count % 1000):
757                try:
758                    os.waitpid(-1, os.WNOHANG)
759                except OSError:
760                    pass
761                if (time.time() - last_log_time) > LOG_ROTATION_INTERVAL:
762                    last_log_time = time.time()
763                    query_log_file = rotate_logs(output_path, query_log_file)
764
765            
766    except KeyboardInterrupt:
767        pass
768    query_log_file.close()
769
770
771if __name__ == "__main__":
772    opts, args = getopt.getopt(sys.argv[1:], "", ["host="])
773
774    host = None
775    for o, a in opts:
776        if o in ("--host"):
777            host = a
778    if not host:
779        print "Specify a host using --host="
780        sys.exit(1)
781
782    # Start up the stream from the target host and create a file
783    # that we can hand to LLQueryStream
784    query_stream_file = remote_mysql_stream(host)
785    query_stream = LLQueryStream(query_stream_file)
786
787    watch_host(query_stream, host)