# # Copyright (C) 2009-2020 the sqlparse authors and contributors # # # This module is part of python-sqlparse and is released under # the BSD License: https://opensource.org/licenses/BSD-3-Clause from sqlparse import sql from sqlparse import tokens as T from sqlparse.utils import recurse, imt T_NUMERICAL = (T.Number, T.Number.Integer, T.Number.Float) T_STRING = (T.String, T.String.Single, T.String.Symbol) T_NAME = (T.Name, T.Name.Placeholder) def _group_matching(tlist, cls): """Groups Tokens that have beginning and end.""" opens = [] tidx_offset = 0 for idx, token in enumerate(list(tlist)): tidx = idx - tidx_offset if token.is_whitespace: # ~50% of tokens will be whitespace. Will checking early # for them avoid 3 comparisons, but then add 1 more comparison # for the other ~50% of tokens... continue if token.is_group and not isinstance(token, cls): # Check inside previously grouped (i.e. parenthesis) if group # of different type is inside (i.e., case). though ideally should # should check for all open/close tokens at once to avoid recursion _group_matching(token, cls) continue if token.match(*cls.M_OPEN): opens.append(tidx) elif token.match(*cls.M_CLOSE): try: open_idx = opens.pop() except IndexError: # this indicates invalid sql and unbalanced tokens. # instead of break, continue in case other "valid" groups exist continue close_idx = tidx tlist.group_tokens(cls, open_idx, close_idx) tidx_offset += close_idx - open_idx def group_brackets(tlist): _group_matching(tlist, sql.SquareBrackets) def group_parenthesis(tlist): _group_matching(tlist, sql.Parenthesis) def group_case(tlist): _group_matching(tlist, sql.Case) def group_if(tlist): _group_matching(tlist, sql.If) def group_for(tlist): _group_matching(tlist, sql.For) def group_begin(tlist): _group_matching(tlist, sql.Begin) def group_typecasts(tlist): def match(token): return token.match(T.Punctuation, '::') def valid(token): return token is not None def post(tlist, pidx, tidx, nidx): return pidx, nidx valid_prev = valid_next = valid _group(tlist, sql.Identifier, match, valid_prev, valid_next, post) def group_tzcasts(tlist): def match(token): return token.ttype == T.Keyword.TZCast def valid(token): return token is not None def post(tlist, pidx, tidx, nidx): return pidx, nidx _group(tlist, sql.Identifier, match, valid, valid, post) def group_typed_literal(tlist): # definitely not complete, see e.g.: # https://docs.microsoft.com/en-us/sql/odbc/reference/appendixes/interval-literal-syntax # https://docs.microsoft.com/en-us/sql/odbc/reference/appendixes/interval-literals # https://www.postgresql.org/docs/9.1/datatype-datetime.html # https://www.postgresql.org/docs/9.1/functions-datetime.html def match(token): return imt(token, m=sql.TypedLiteral.M_OPEN) def match_to_extend(token): return isinstance(token, sql.TypedLiteral) def valid_prev(token): return token is not None def valid_next(token): return token is not None and token.match(*sql.TypedLiteral.M_CLOSE) def valid_final(token): return token is not None and token.match(*sql.TypedLiteral.M_EXTEND) def post(tlist, pidx, tidx, nidx): return tidx, nidx _group(tlist, sql.TypedLiteral, match, valid_prev, valid_next, post, extend=False) _group(tlist, sql.TypedLiteral, match_to_extend, valid_prev, valid_final, post, extend=True) def group_period(tlist): def match(token): return token.match(T.Punctuation, '.') def valid_prev(token): sqlcls = sql.SquareBrackets, sql.Identifier ttypes = T.Name, T.String.Symbol return imt(token, i=sqlcls, t=ttypes) def valid_next(token): # issue261, allow invalid next token return True def post(tlist, pidx, tidx, nidx): # next_ validation is being performed here. issue261 sqlcls = sql.SquareBrackets, sql.Function ttypes = T.Name, T.String.Symbol, T.Wildcard next_ = tlist[nidx] if nidx is not None else None valid_next = imt(next_, i=sqlcls, t=ttypes) return (pidx, nidx) if valid_next else (pidx, tidx) _group(tlist, sql.Identifier, match, valid_prev, valid_next, post) def group_as(tlist): def match(token): return token.is_keyword and token.normalized == 'AS' def valid_prev(token): return token.normalized == 'NULL' or not token.is_keyword def valid_next(token): ttypes = T.DML, T.DDL, T.CTE return not imt(token, t=ttypes) and token is not None def post(tlist, pidx, tidx, nidx): return pidx, nidx _group(tlist, sql.Identifier, match, valid_prev, valid_next, post) def group_assignment(tlist): def match(token): return token.match(T.Assignment, ':=') def valid(token): return token is not None and token.ttype not in (T.Keyword) def post(tlist, pidx, tidx, nidx): m_semicolon = T.Punctuation, ';' snidx, _ = tlist.token_next_by(m=m_semicolon, idx=nidx) nidx = snidx or nidx return pidx, nidx valid_prev = valid_next = valid _group(tlist, sql.Assignment, match, valid_prev, valid_next, post) def group_comparison(tlist): sqlcls = (sql.Parenthesis, sql.Function, sql.Identifier, sql.Operation, sql.TypedLiteral) ttypes = T_NUMERICAL + T_STRING + T_NAME def match(token): return token.ttype == T.Operator.Comparison def valid(token): if imt(token, t=ttypes, i=sqlcls): return True elif token and token.is_keyword and token.normalized == 'NULL': return True else: return False def post(tlist, pidx, tidx, nidx): return pidx, nidx valid_prev = valid_next = valid _group(tlist, sql.Comparison, match, valid_prev, valid_next, post, extend=False) @recurse(sql.Identifier) def group_identifier(tlist): ttypes = (T.String.Symbol, T.Name) tidx, token = tlist.token_next_by(t=ttypes) while token: tlist.group_tokens(sql.Identifier, tidx, tidx) tidx, token = tlist.token_next_by(t=ttypes, idx=tidx) def group_arrays(tlist): sqlcls = sql.SquareBrackets, sql.Identifier, sql.Function ttypes = T.Name, T.String.Symbol def match(token): return isinstance(token, sql.SquareBrackets) def valid_prev(token): return imt(token, i=sqlcls, t=ttypes) def valid_next(token): return True def post(tlist, pidx, tidx, nidx): return pidx, tidx _group(tlist, sql.Identifier, match, valid_prev, valid_next, post, extend=True, recurse=False) def group_operator(tlist): ttypes = T_NUMERICAL + T_STRING + T_NAME sqlcls = (sql.SquareBrackets, sql.Parenthesis, sql.Function, sql.Identifier, sql.Operation, sql.TypedLiteral) def match(token): return imt(token, t=(T.Operator, T.Wildcard)) def valid(token): return imt(token, i=sqlcls, t=ttypes) \ or (token and token.match( T.Keyword, ('CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP'))) def post(tlist, pidx, tidx, nidx): tlist[tidx].ttype = T.Operator return pidx, nidx valid_prev = valid_next = valid _group(tlist, sql.Operation, match, valid_prev, valid_next, post, extend=False) def group_identifier_list(tlist): m_role = T.Keyword, ('null', 'role') sqlcls = (sql.Function, sql.Case, sql.Identifier, sql.Comparison, sql.IdentifierList, sql.Operation) ttypes = (T_NUMERICAL + T_STRING + T_NAME + (T.Keyword, T.Comment, T.Wildcard)) def match(token): return token.match(T.Punctuation, ',') def valid(token): return imt(token, i=sqlcls, m=m_role, t=ttypes) def post(tlist, pidx, tidx, nidx): return pidx, nidx valid_prev = valid_next = valid _group(tlist, sql.IdentifierList, match, valid_prev, valid_next, post, extend=True) @recurse(sql.Comment) def group_comments(tlist): tidx, token = tlist.token_next_by(t=T.Comment) while token: eidx, end = tlist.token_not_matching( lambda tk: imt(tk, t=T.Comment) or tk.is_whitespace, idx=tidx) if end is not None: eidx, end = tlist.token_prev(eidx, skip_ws=False) tlist.group_tokens(sql.Comment, tidx, eidx) tidx, token = tlist.token_next_by(t=T.Comment, idx=tidx) @recurse(sql.Where) def group_where(tlist): tidx, token = tlist.token_next_by(m=sql.Where.M_OPEN) while token: eidx, end = tlist.token_next_by(m=sql.Where.M_CLOSE, idx=tidx) if end is None: end = tlist._groupable_tokens[-1] else: end = tlist.tokens[eidx - 1] # TODO: convert this to eidx instead of end token. # i think above values are len(tlist) and eidx-1 eidx = tlist.token_index(end) tlist.group_tokens(sql.Where, tidx, eidx) tidx, token = tlist.token_next_by(m=sql.Where.M_OPEN, idx=tidx) @recurse() def group_aliased(tlist): I_ALIAS = (sql.Parenthesis, sql.Function, sql.Case, sql.Identifier, sql.Operation, sql.Comparison) tidx, token = tlist.token_next_by(i=I_ALIAS, t=T.Number) while token: nidx, next_ = tlist.token_next(tidx) if isinstance(next_, sql.Identifier): tlist.group_tokens(sql.Identifier, tidx, nidx, extend=True) tidx, token = tlist.token_next_by(i=I_ALIAS, t=T.Number, idx=tidx) @recurse(sql.Function) def group_functions(tlist): has_create = False has_table = False for tmp_token in tlist.tokens: if tmp_token.value == 'CREATE': has_create = True if tmp_token.value == 'TABLE': has_table = True if has_create and has_table: return tidx, token = tlist.token_next_by(t=T.Name) while token: nidx, next_ = tlist.token_next(tidx) if isinstance(next_, sql.Parenthesis): tlist.group_tokens(sql.Function, tidx, nidx) tidx, token = tlist.token_next_by(t=T.Name, idx=tidx) def group_order(tlist): """Group together Identifier and Asc/Desc token""" tidx, token = tlist.token_next_by(t=T.Keyword.Order) while token: pidx, prev_ = tlist.token_prev(tidx) if imt(prev_, i=sql.Identifier, t=T.Number): tlist.group_tokens(sql.Identifier, pidx, tidx) tidx = pidx tidx, token = tlist.token_next_by(t=T.Keyword.Order, idx=tidx) @recurse() def align_comments(tlist): tidx, token = tlist.token_next_by(i=sql.Comment) while token: pidx, prev_ = tlist.token_prev(tidx) if isinstance(prev_, sql.TokenList): tlist.group_tokens(sql.TokenList, pidx, tidx, extend=True) tidx = pidx tidx, token = tlist.token_next_by(i=sql.Comment, idx=tidx) def group_values(tlist): tidx, token = tlist.token_next_by(m=(T.Keyword, 'VALUES')) start_idx = tidx end_idx = -1 while token: if isinstance(token, sql.Parenthesis): end_idx = tidx tidx, token = tlist.token_next(tidx) if end_idx != -1: tlist.group_tokens(sql.Values, start_idx, end_idx, extend=True) def group(stmt): for func in [ group_comments, # _group_matching group_brackets, group_parenthesis, group_case, group_if, group_for, group_begin, group_functions, group_where, group_period, group_arrays, group_identifier, group_order, group_typecasts, group_tzcasts, group_typed_literal, group_operator, group_comparison, group_as, group_aliased, group_assignment, align_comments, group_identifier_list, group_values, ]: func(stmt) return stmt def _group(tlist, cls, match, valid_prev=lambda t: True, valid_next=lambda t: True, post=None, extend=True, recurse=True ): """Groups together tokens that are joined by a middle token. i.e. x < y""" tidx_offset = 0 pidx, prev_ = None, None for idx, token in enumerate(list(tlist)): tidx = idx - tidx_offset if tidx < 0: # tidx shouldn't get negative continue if token.is_whitespace: continue if recurse and token.is_group and not isinstance(token, cls): _group(token, cls, match, valid_prev, valid_next, post, extend) if match(token): nidx, next_ = tlist.token_next(tidx) if prev_ and valid_prev(prev_) and valid_next(next_): from_idx, to_idx = post(tlist, pidx, tidx, nidx) grp = tlist.group_tokens(cls, from_idx, to_idx, extend=extend) tidx_offset += to_idx - from_idx pidx, prev_ = from_idx, grp continue pidx, prev_ = tidx, token