# -*- coding: utf-8 -*- # This code is part of Ansible, but is an independent component. # This particular file snippet, and this file snippet only, is BSD licensed. # Modules you write using this snippet, which is embedded dynamically by Ansible # still belong to the author of the module, and may assign their own license # to the complete work. # # Copyright (c) 2014, Toshio Kuratomi # # Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) from __future__ import (absolute_import, division, print_function) __metaclass__ = type import re # Input patterns for is_input_dangerous function: # # 1. '"' in string and '--' in string or # "'" in string and '--' in string PATTERN_1 = re.compile(r'(\'|\").*--') # 2. union \ intersect \ except + select PATTERN_2 = re.compile(r'(UNION|INTERSECT|EXCEPT).*SELECT', re.IGNORECASE) # 3. ';' and any KEY_WORDS PATTERN_3 = re.compile(r';.*(SELECT|UPDATE|INSERT|DELETE|DROP|TRUNCATE|ALTER)', re.IGNORECASE) class SQLParseError(Exception): pass class UnclosedQuoteError(SQLParseError): pass # maps a type of identifier to the maximum number of dot levels that are # allowed to specify that identifier. For example, a database column can be # specified by up to 4 levels: database.schema.table.column _PG_IDENTIFIER_TO_DOT_LEVEL = dict( database=1, schema=2, table=3, column=4, role=1, tablespace=1, sequence=3, publication=1, ) _MYSQL_IDENTIFIER_TO_DOT_LEVEL = dict(database=1, table=2, column=3, role=1, vars=1) def _find_end_quote(identifier, quote_char): accumulate = 0 while True: try: quote = identifier.index(quote_char) except ValueError: raise UnclosedQuoteError accumulate = accumulate + quote try: next_char = identifier[quote + 1] except IndexError: return accumulate if next_char == quote_char: try: identifier = identifier[quote + 2:] accumulate = accumulate + 2 except IndexError: raise UnclosedQuoteError else: return accumulate def _identifier_parse(identifier, quote_char): if not identifier: raise SQLParseError('Identifier name unspecified or unquoted trailing dot') already_quoted = False if identifier.startswith(quote_char): already_quoted = True try: end_quote = _find_end_quote(identifier[1:], quote_char=quote_char) + 1 except UnclosedQuoteError: already_quoted = False else: if end_quote < len(identifier) - 1: if identifier[end_quote + 1] == '.': dot = end_quote + 1 first_identifier = identifier[:dot] next_identifier = identifier[dot + 1:] further_identifiers = _identifier_parse(next_identifier, quote_char) further_identifiers.insert(0, first_identifier) else: raise SQLParseError('User escaped identifiers must escape extra quotes') else: further_identifiers = [identifier] if not already_quoted: try: dot = identifier.index('.') except ValueError: identifier = identifier.replace(quote_char, quote_char * 2) identifier = ''.join((quote_char, identifier, quote_char)) further_identifiers = [identifier] else: if dot == 0 or dot >= len(identifier) - 1: identifier = identifier.replace(quote_char, quote_char * 2) identifier = ''.join((quote_char, identifier, quote_char)) further_identifiers = [identifier] else: first_identifier = identifier[:dot] next_identifier = identifier[dot + 1:] further_identifiers = _identifier_parse(next_identifier, quote_char) first_identifier = first_identifier.replace(quote_char, quote_char * 2) first_identifier = ''.join((quote_char, first_identifier, quote_char)) further_identifiers.insert(0, first_identifier) return further_identifiers def pg_quote_identifier(identifier, id_type): identifier_fragments = _identifier_parse(identifier, quote_char='"') if len(identifier_fragments) > _PG_IDENTIFIER_TO_DOT_LEVEL[id_type]: raise SQLParseError('PostgreSQL does not support %s with more than %i dots' % (id_type, _PG_IDENTIFIER_TO_DOT_LEVEL[id_type])) return '.'.join(identifier_fragments) def mysql_quote_identifier(identifier, id_type): identifier_fragments = _identifier_parse(identifier, quote_char='`') if (len(identifier_fragments) - 1) > _MYSQL_IDENTIFIER_TO_DOT_LEVEL[id_type]: raise SQLParseError('MySQL does not support %s with more than %i dots' % (id_type, _MYSQL_IDENTIFIER_TO_DOT_LEVEL[id_type])) special_cased_fragments = [] for fragment in identifier_fragments: if fragment == '`*`': special_cased_fragments.append('*') else: special_cased_fragments.append(fragment) return '.'.join(special_cased_fragments) def is_input_dangerous(string): """Check if the passed string is potentially dangerous. Can be used to prevent SQL injections. Note: use this function only when you can't use psycopg2's cursor.execute method parametrized (typically with DDL queries). """ if not string: return False for pattern in (PATTERN_1, PATTERN_2, PATTERN_3): if re.search(pattern, string): return True return False def check_input(module, *args): """Wrapper for is_input_dangerous function.""" needs_to_check = args dangerous_elements = [] for elem in needs_to_check: if isinstance(elem, str): if is_input_dangerous(elem): dangerous_elements.append(elem) elif isinstance(elem, list): for e in elem: if is_input_dangerous(e): dangerous_elements.append(e) elif elem is None or isinstance(elem, bool): pass else: elem = str(elem) if is_input_dangerous(elem): dangerous_elements.append(elem) if dangerous_elements: module.fail_json(msg="Passed input '%s' is " "potentially dangerous" % ', '.join(dangerous_elements))