PageRenderTime 48ms CodeModel.GetById 13ms app.highlight 29ms RepoModel.GetById 1ms app.codeStats 0ms

/lib/ansible/inventory/__init__.py

https://github.com/ajanthanm/ansible
Python | 641 lines | 599 code | 13 blank | 29 comment | 22 complexity | c01298e38c85af9acd73414a6857cded MD5 | raw file
  1# (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
  2#
  3# This file is part of Ansible
  4#
  5# Ansible is free software: you can redistribute it and/or modify
  6# it under the terms of the GNU General Public License as published by
  7# the Free Software Foundation, either version 3 of the License, or
  8# (at your option) any later version.
  9#
 10# Ansible is distributed in the hope that it will be useful,
 11# but WITHOUT ANY WARRANTY; without even the implied warranty of
 12# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 13# GNU General Public License for more details.
 14#
 15# You should have received a copy of the GNU General Public License
 16# along with Ansible.  If not, see <http://www.gnu.org/licenses/>.
 17
 18#############################################
 19import fnmatch
 20import os
 21import sys
 22import re
 23import subprocess
 24
 25import ansible.constants as C
 26from ansible.inventory.ini import InventoryParser
 27from ansible.inventory.script import InventoryScript
 28from ansible.inventory.dir import InventoryDirectory
 29from ansible.inventory.group import Group
 30from ansible.inventory.host import Host
 31from ansible import errors
 32from ansible import utils
 33
 34class Inventory(object):
 35    """
 36    Host inventory for ansible.
 37    """
 38
 39    __slots__ = [ 'host_list', 'groups', '_restriction', '_also_restriction', '_subset', 
 40                  'parser', '_vars_per_host', '_vars_per_group', '_hosts_cache', '_groups_list',
 41                  '_pattern_cache', '_vault_password', '_vars_plugins', '_playbook_basedir']
 42
 43    def __init__(self, host_list=C.DEFAULT_HOST_LIST, vault_password=None):
 44
 45        # the host file file, or script path, or list of hosts
 46        # if a list, inventory data will NOT be loaded
 47        self.host_list = host_list
 48        self._vault_password=vault_password
 49
 50        # caching to avoid repeated calculations, particularly with
 51        # external inventory scripts.
 52
 53        self._vars_per_host  = {}
 54        self._vars_per_group = {}
 55        self._hosts_cache    = {}
 56        self._groups_list    = {} 
 57        self._pattern_cache  = {}
 58
 59        # to be set by calling set_playbook_basedir by playbook code
 60        self._playbook_basedir = None
 61
 62        # the inventory object holds a list of groups
 63        self.groups = []
 64
 65        # a list of host(names) to contain current inquiries to
 66        self._restriction = None
 67        self._also_restriction = None
 68        self._subset = None
 69
 70        if isinstance(host_list, basestring):
 71            if "," in host_list:
 72                host_list = host_list.split(",")
 73                host_list = [ h for h in host_list if h and h.strip() ]
 74
 75        if host_list is None:
 76            self.parser = None
 77        elif isinstance(host_list, list):
 78            self.parser = None
 79            all = Group('all')
 80            self.groups = [ all ]
 81            ipv6_re = re.compile('\[([a-f:A-F0-9]*[%[0-z]+]?)\](?::(\d+))?')
 82            for x in host_list:
 83                m = ipv6_re.match(x)
 84                if m:
 85                    all.add_host(Host(m.groups()[0], m.groups()[1]))
 86                else:
 87                    if ":" in x:
 88                        tokens = x.rsplit(":", 1)
 89                        # if there is ':' in the address, then this is a ipv6
 90                        if ':' in tokens[0]:
 91                            all.add_host(Host(x))
 92                        else:
 93                            all.add_host(Host(tokens[0], tokens[1]))
 94                    else:
 95                        all.add_host(Host(x))
 96        elif os.path.exists(host_list):
 97            if os.path.isdir(host_list):
 98                # Ensure basedir is inside the directory
 99                self.host_list = os.path.join(self.host_list, "")
100                self.parser = InventoryDirectory(filename=host_list)
101                self.groups = self.parser.groups.values()
102            else:
103                # check to see if the specified file starts with a
104                # shebang (#!/), so if an error is raised by the parser
105                # class we can show a more apropos error
106                shebang_present = False
107                try:
108                    inv_file = open(host_list)
109                    first_line = inv_file.readlines()[0]
110                    inv_file.close()
111                    if first_line.startswith('#!'):
112                        shebang_present = True
113                except:
114                    pass
115
116                if utils.is_executable(host_list):
117                    try:
118                        self.parser = InventoryScript(filename=host_list)
119                        self.groups = self.parser.groups.values()
120                    except:
121                        if not shebang_present:
122                            raise errors.AnsibleError("The file %s is marked as executable, but failed to execute correctly. " % host_list + \
123                                                      "If this is not supposed to be an executable script, correct this with `chmod -x %s`." % host_list)
124                        else:
125                            raise
126                else:
127                    try:
128                        self.parser = InventoryParser(filename=host_list)
129                        self.groups = self.parser.groups.values()
130                    except:
131                        if shebang_present:
132                            raise errors.AnsibleError("The file %s looks like it should be an executable inventory script, but is not marked executable. " % host_list + \
133                                                      "Perhaps you want to correct this with `chmod +x %s`?" % host_list)
134                        else:
135                            raise
136
137            utils.plugins.vars_loader.add_directory(self.basedir(), with_subdir=True)
138        else:
139            raise errors.AnsibleError("Unable to find an inventory file, specify one with -i ?")
140
141        self._vars_plugins = [ x for x in utils.plugins.vars_loader.all(self) ]
142
143        # get group vars from group_vars/ files and vars plugins
144        for group in self.groups:
145            group.vars = utils.combine_vars(group.vars, self.get_group_variables(group.name, self._vault_password))
146
147        # get host vars from host_vars/ files and vars plugins
148        for host in self.get_hosts():
149            host.vars = utils.combine_vars(host.vars, self.get_variables(host.name, self._vault_password))
150
151
152    def _match(self, str, pattern_str):
153        if pattern_str.startswith('~'):
154            return re.search(pattern_str[1:], str)
155        else:
156            return fnmatch.fnmatch(str, pattern_str)
157
158    def _match_list(self, items, item_attr, pattern_str):
159        results = []
160        if not pattern_str.startswith('~'):
161            pattern = re.compile(fnmatch.translate(pattern_str))
162        else:
163            pattern = re.compile(pattern_str[1:])
164        for item in items:
165            if pattern.search(getattr(item, item_attr)):
166                results.append(item)
167        return results
168
169    def get_hosts(self, pattern="all"):
170        """ 
171        find all host names matching a pattern string, taking into account any inventory restrictions or
172        applied subsets.
173        """
174
175        # process patterns
176        if isinstance(pattern, list):
177            pattern = ';'.join(pattern)
178        patterns = pattern.replace(";",":").split(":")
179        hosts = self._get_hosts(patterns)
180
181        # exclude hosts not in a subset, if defined
182        if self._subset:
183            subset = self._get_hosts(self._subset)
184            hosts = [ h for h in hosts if h in subset ]
185
186        # exclude hosts mentioned in any restriction (ex: failed hosts)
187        if self._restriction is not None:
188            hosts = [ h for h in hosts if h.name in self._restriction ]
189        if self._also_restriction is not None:
190            hosts = [ h for h in hosts if h.name in self._also_restriction ]
191
192        return hosts
193
194    def _get_hosts(self, patterns):
195        """
196        finds hosts that match a list of patterns. Handles negative
197        matches as well as intersection matches.
198        """
199
200        # Host specifiers should be sorted to ensure consistent behavior
201        pattern_regular = []
202        pattern_intersection = []
203        pattern_exclude = []
204        for p in patterns:
205            if p.startswith("!"):
206                pattern_exclude.append(p)
207            elif p.startswith("&"):
208                pattern_intersection.append(p)
209            elif p:
210                pattern_regular.append(p)
211
212        # if no regular pattern was given, hence only exclude and/or intersection
213        # make that magically work
214        if pattern_regular == []:
215            pattern_regular = ['all']
216
217        # when applying the host selectors, run those without the "&" or "!"
218        # first, then the &s, then the !s.
219        patterns = pattern_regular + pattern_intersection + pattern_exclude
220
221        hosts = []
222
223        for p in patterns:
224            # avoid resolving a pattern that is a plain host
225            if p in self._hosts_cache:
226                hosts.append(self.get_host(p))
227            else:
228                that = self.__get_hosts(p)
229                if p.startswith("!"):
230                    hosts = [ h for h in hosts if h not in that ]
231                elif p.startswith("&"):
232                    hosts = [ h for h in hosts if h in that ]
233                else:
234                    to_append = [ h for h in that if h.name not in [ y.name for y in hosts ] ]
235                    hosts.extend(to_append)
236        return hosts
237
238    def __get_hosts(self, pattern):
239        """ 
240        finds hosts that postively match a particular pattern.  Does not
241        take into account negative matches.
242        """
243
244        if pattern in self._pattern_cache:
245            return self._pattern_cache[pattern]
246
247        (name, enumeration_details) = self._enumeration_info(pattern)
248        hpat = self._hosts_in_unenumerated_pattern(name)
249        result = self._apply_ranges(pattern, hpat)
250        self._pattern_cache[pattern] = result
251        return result
252
253    def _enumeration_info(self, pattern):
254        """
255        returns (pattern, limits) taking a regular pattern and finding out
256        which parts of it correspond to start/stop offsets.  limits is
257        a tuple of (start, stop) or None
258        """
259
260        # Do not parse regexes for enumeration info
261        if pattern.startswith('~'):
262            return (pattern, None)
263
264        # The regex used to match on the range, which can be [x] or [x-y].
265        pattern_re = re.compile("^(.*)\[([-]?[0-9]+)(?:(?:-)([0-9]+))?\](.*)$")
266        m = pattern_re.match(pattern)
267        if m:
268            (target, first, last, rest) = m.groups()
269            first = int(first)
270            if last:
271                if first < 0:
272                    raise errors.AnsibleError("invalid range: negative indices cannot be used as the first item in a range")
273                last = int(last)
274            else:
275                last = first
276            return (target, (first, last))
277        else:
278            return (pattern, None)
279
280    def _apply_ranges(self, pat, hosts):
281        """
282        given a pattern like foo, that matches hosts, return all of hosts
283        given a pattern like foo[0:5], where foo matches hosts, return the first 6 hosts
284        """ 
285
286        # If there are no hosts to select from, just return the
287        # empty set. This prevents trying to do selections on an empty set.
288        # issue#6258
289        if not hosts:
290            return hosts
291
292        (loose_pattern, limits) = self._enumeration_info(pat)
293        if not limits:
294            return hosts
295
296        (left, right) = limits
297
298        if left == '':
299            left = 0
300        if right == '':
301            right = 0
302        left=int(left)
303        right=int(right)
304        try:
305            if left != right:
306                return hosts[left:right]
307            else:
308                return [ hosts[left] ]
309        except IndexError:
310            raise errors.AnsibleError("no hosts matching the pattern '%s' were found" % pat)
311
312    def _create_implicit_localhost(self, pattern):
313        new_host = Host(pattern)
314        new_host.set_variable("ansible_python_interpreter", sys.executable)
315        new_host.set_variable("ansible_connection", "local")
316        ungrouped = self.get_group("ungrouped")
317        if ungrouped is None:
318            self.add_group(Group('ungrouped'))
319            ungrouped = self.get_group('ungrouped')
320        ungrouped.add_host(new_host)
321        return new_host
322
323    def _hosts_in_unenumerated_pattern(self, pattern):
324        """ Get all host names matching the pattern """
325
326        results = []
327        hosts = []
328        hostnames = set()
329
330        # ignore any negative checks here, this is handled elsewhere
331        pattern = pattern.replace("!","").replace("&", "")
332
333        def __append_host_to_results(host):
334            if host not in results and host.name not in hostnames:
335                hostnames.add(host.name)
336                results.append(host)
337
338        groups = self.get_groups()
339        for group in groups:
340            if pattern == 'all':
341                for host in group.get_hosts():
342                    __append_host_to_results(host)
343            else:
344                if self._match(group.name, pattern):
345                    for host in group.get_hosts():
346                        __append_host_to_results(host)
347                else:
348                    matching_hosts = self._match_list(group.get_hosts(), 'name', pattern)
349                    for host in matching_hosts:
350                        __append_host_to_results(host)
351
352        if pattern in ["localhost", "127.0.0.1"] and len(results) == 0:
353            new_host = self._create_implicit_localhost(pattern)
354            results.append(new_host)
355        return results
356
357    def clear_pattern_cache(self):
358        ''' called exclusively by the add_host plugin to allow patterns to be recalculated '''
359        self._pattern_cache = {}
360
361    def groups_for_host(self, host):
362        if host in self._hosts_cache:
363            return self._hosts_cache[host].get_groups()
364        else:
365            return []
366
367    def groups_list(self):
368        if not self._groups_list:
369            groups = {}
370            for g in self.groups:
371                groups[g.name] = [h.name for h in g.get_hosts()]
372                ancestors = g.get_ancestors()
373                for a in ancestors:
374                    if a.name not in groups:
375                        groups[a.name] = [h.name for h in a.get_hosts()]
376            self._groups_list = groups
377        return self._groups_list
378
379    def get_groups(self):
380        return self.groups
381
382    def get_host(self, hostname):
383        if hostname not in self._hosts_cache:
384            self._hosts_cache[hostname] = self._get_host(hostname)
385        return self._hosts_cache[hostname]
386
387    def _get_host(self, hostname):
388        if hostname in ['localhost','127.0.0.1']:
389            for host in self.get_group('all').get_hosts():
390                if host.name in ['localhost', '127.0.0.1']:
391                    return host
392            return self._create_implicit_localhost(hostname)
393        else:
394            for group in self.groups:
395                for host in group.get_hosts():
396                    if hostname == host.name:
397                        return host
398        return None
399
400    def get_group(self, groupname):
401        for group in self.groups:
402            if group.name == groupname:
403                return group
404        return None
405
406    def get_group_variables(self, groupname, update_cached=False, vault_password=None):
407        if groupname not in self._vars_per_group or update_cached:
408            self._vars_per_group[groupname] = self._get_group_variables(groupname, vault_password=vault_password)
409        return self._vars_per_group[groupname]
410
411    def _get_group_variables(self, groupname, vault_password=None):
412
413        group = self.get_group(groupname)
414        if group is None:
415            raise Exception("group not found: %s" % groupname)
416
417        vars = {}
418
419        # plugin.get_group_vars retrieves just vars for specific group
420        vars_results = [ plugin.get_group_vars(group, vault_password=vault_password) for plugin in self._vars_plugins if hasattr(plugin, 'get_group_vars')]
421        for updated in vars_results:
422            if updated is not None:
423                vars = utils.combine_vars(vars, updated)
424
425        # get group variables set by Inventory Parsers
426        vars = utils.combine_vars(vars, group.get_variables())
427
428        # Read group_vars/ files
429        vars = utils.combine_vars(vars, self.get_group_vars(group))
430
431        return vars
432
433    def get_variables(self, hostname, update_cached=False, vault_password=None):
434        if hostname not in self._vars_per_host or update_cached:
435            self._vars_per_host[hostname] = self._get_variables(hostname, vault_password=vault_password)
436        return self._vars_per_host[hostname]
437
438    def _get_variables(self, hostname, vault_password=None):
439
440        host = self.get_host(hostname)
441        if host is None:
442            raise errors.AnsibleError("host not found: %s" % hostname)
443
444        vars = {}
445
446        # plugin.run retrieves all vars (also from groups) for host
447        vars_results = [ plugin.run(host, vault_password=vault_password) for plugin in self._vars_plugins if hasattr(plugin, 'run')]
448        for updated in vars_results:
449            if updated is not None:
450                vars = utils.combine_vars(vars, updated)
451
452        # plugin.get_host_vars retrieves just vars for specific host
453        vars_results = [ plugin.get_host_vars(host, vault_password=vault_password) for plugin in self._vars_plugins if hasattr(plugin, 'get_host_vars')]
454        for updated in vars_results:
455            if updated is not None:
456                vars = utils.combine_vars(vars, updated)
457
458        # get host variables set by Inventory Parsers
459        vars = utils.combine_vars(vars, host.get_variables())
460
461        # still need to check InventoryParser per host vars
462        # which actually means InventoryScript per host,
463        # which is not performant
464        if self.parser is not None:
465            vars = utils.combine_vars(vars, self.parser.get_host_variables(host))
466
467        # Read host_vars/ files
468        vars = utils.combine_vars(vars, self.get_host_vars(host))
469
470        return vars
471
472    def add_group(self, group):
473        if group.name not in self.groups_list():
474            self.groups.append(group)
475            self._groups_list = None  # invalidate internal cache 
476        else:
477            raise errors.AnsibleError("group already in inventory: %s" % group.name)
478
479    def list_hosts(self, pattern="all"):
480
481        """ return a list of hostnames for a pattern """
482
483        result = [ h.name for h in self.get_hosts(pattern) ]
484        if len(result) == 0 and pattern in ["localhost", "127.0.0.1"]:
485            result = [pattern]
486        return result
487
488    def list_groups(self):
489        return sorted([ g.name for g in self.groups ], key=lambda x: x)
490
491    # TODO: remove this function
492    def get_restriction(self):
493        return self._restriction
494
495    def restrict_to(self, restriction):
496        """ 
497        Restrict list operations to the hosts given in restriction.  This is used
498        to exclude failed hosts in main playbook code, don't use this for other
499        reasons.
500        """
501        if not isinstance(restriction, list):
502            restriction = [ restriction ]
503        self._restriction = restriction
504
505    def also_restrict_to(self, restriction):
506        """
507        Works like restict_to but offers an additional restriction.  Playbooks use this
508        to implement serial behavior.
509        """
510        if not isinstance(restriction, list):
511            restriction = [ restriction ]
512        self._also_restriction = restriction
513    
514    def subset(self, subset_pattern):
515        """ 
516        Limits inventory results to a subset of inventory that matches a given
517        pattern, such as to select a given geographic of numeric slice amongst
518        a previous 'hosts' selection that only select roles, or vice versa.  
519        Corresponds to --limit parameter to ansible-playbook
520        """        
521        if subset_pattern is None:
522            self._subset = None
523        else:
524            subset_pattern = subset_pattern.replace(',',':')
525            subset_pattern = subset_pattern.replace(";",":").split(":")
526            results = []
527            # allow Unix style @filename data
528            for x in subset_pattern:
529                if x.startswith("@"):
530                    fd = open(x[1:])
531                    results.extend(fd.read().split("\n"))
532                    fd.close()
533                else:
534                    results.append(x)
535            self._subset = results
536
537    def lift_restriction(self):
538        """ Do not restrict list operations """
539        self._restriction = None
540    
541    def lift_also_restriction(self):
542        """ Clears the also restriction """
543        self._also_restriction = None
544
545    def is_file(self):
546        """ did inventory come from a file? """
547        if not isinstance(self.host_list, basestring):
548            return False
549        return os.path.exists(self.host_list)
550
551    def basedir(self):
552        """ if inventory came from a file, what's the directory? """
553        if not self.is_file():
554            return None
555        dname = os.path.dirname(self.host_list)
556        if dname is None or dname == '' or dname == '.':
557            cwd = os.getcwd()
558            return os.path.abspath(cwd) 
559        return os.path.abspath(dname)
560
561    def src(self):
562        """ if inventory came from a file, what's the directory and file name? """
563        if not self.is_file():
564            return None
565        return self.host_list
566
567    def playbook_basedir(self):
568        """ returns the directory of the current playbook """
569        return self._playbook_basedir
570
571    def set_playbook_basedir(self, dir):
572        """
573        sets the base directory of the playbook so inventory can use it as a
574        basedir for host_ and group_vars, and other things.
575        """
576        # Only update things if dir is a different playbook basedir
577        if dir != self._playbook_basedir:
578            self._playbook_basedir = dir
579            # get group vars from group_vars/ files
580            for group in self.groups:
581                group.vars = utils.combine_vars(group.vars, self.get_group_vars(group, new_pb_basedir=True))
582            # get host vars from host_vars/ files
583            for host in self.get_hosts():
584                host.vars = utils.combine_vars(host.vars, self.get_host_vars(host, new_pb_basedir=True))
585
586    def get_host_vars(self, host, new_pb_basedir=False):
587        """ Read host_vars/ files """
588        return self._get_hostgroup_vars(host=host, group=None, new_pb_basedir=False)
589
590    def get_group_vars(self, group, new_pb_basedir=False):
591        """ Read group_vars/ files """
592        return self._get_hostgroup_vars(host=None, group=group, new_pb_basedir=False)
593
594    def _get_hostgroup_vars(self, host=None, group=None, new_pb_basedir=False):
595        """
596        Loads variables from group_vars/<groupname> and host_vars/<hostname> in directories parallel
597        to the inventory base directory or in the same directory as the playbook.  Variables in the playbook
598        dir will win over the inventory dir if files are in both.
599        """
600
601        results = {}
602        scan_pass = 0
603        _basedir = self.basedir()
604
605        # look in both the inventory base directory and the playbook base directory
606        # unless we do an update for a new playbook base dir
607        if not new_pb_basedir:
608            basedirs = [_basedir, self._playbook_basedir]
609        else:
610            basedirs = [self._playbook_basedir]
611
612        for basedir in basedirs:
613
614            # this can happen from particular API usages, particularly if not run
615            # from /usr/bin/ansible-playbook
616            if basedir is None:
617                continue
618
619            scan_pass = scan_pass + 1
620
621            # it's not an eror if the directory does not exist, keep moving
622            if not os.path.exists(basedir):
623                continue
624
625            # save work of second scan if the directories are the same
626            if _basedir == self._playbook_basedir and scan_pass != 1:
627                continue
628
629            if group and host is None:
630                # load vars in dir/group_vars/name_of_group
631                base_path = os.path.join(basedir, "group_vars/%s" % group.name)
632                results = utils.load_vars(base_path, results, vault_password=self._vault_password)
633
634            elif host and group is None:
635                # same for hostvars in dir/host_vars/name_of_host
636                base_path = os.path.join(basedir, "host_vars/%s" % host.name)
637                results = utils.load_vars(base_path, results, vault_password=self._vault_password)
638
639        # all done, results is a dictionary of variables for this particular host.
640        return results
641