123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454 |
- #
- # Copyright (C) 2009-2020 the sqlparse authors and contributors
- # <see AUTHORS file>
- #
- # This module is part of python-sqlparse and is released under
- # the BSD License: https://opensource.org/licenses/BSD-3-Clause
- from sqlparse import sql
- from sqlparse import tokens as T
- from sqlparse.utils import recurse, imt
- T_NUMERICAL = (T.Number, T.Number.Integer, T.Number.Float)
- T_STRING = (T.String, T.String.Single, T.String.Symbol)
- T_NAME = (T.Name, T.Name.Placeholder)
- def _group_matching(tlist, cls):
- """Groups Tokens that have beginning and end."""
- opens = []
- tidx_offset = 0
- for idx, token in enumerate(list(tlist)):
- tidx = idx - tidx_offset
- if token.is_whitespace:
- # ~50% of tokens will be whitespace. Will checking early
- # for them avoid 3 comparisons, but then add 1 more comparison
- # for the other ~50% of tokens...
- continue
- if token.is_group and not isinstance(token, cls):
- # Check inside previously grouped (i.e. parenthesis) if group
- # of different type is inside (i.e., case). though ideally should
- # should check for all open/close tokens at once to avoid recursion
- _group_matching(token, cls)
- continue
- if token.match(*cls.M_OPEN):
- opens.append(tidx)
- elif token.match(*cls.M_CLOSE):
- try:
- open_idx = opens.pop()
- except IndexError:
- # this indicates invalid sql and unbalanced tokens.
- # instead of break, continue in case other "valid" groups exist
- continue
- close_idx = tidx
- tlist.group_tokens(cls, open_idx, close_idx)
- tidx_offset += close_idx - open_idx
- def group_brackets(tlist):
- _group_matching(tlist, sql.SquareBrackets)
- def group_parenthesis(tlist):
- _group_matching(tlist, sql.Parenthesis)
- def group_case(tlist):
- _group_matching(tlist, sql.Case)
- def group_if(tlist):
- _group_matching(tlist, sql.If)
- def group_for(tlist):
- _group_matching(tlist, sql.For)
- def group_begin(tlist):
- _group_matching(tlist, sql.Begin)
- def group_typecasts(tlist):
- def match(token):
- return token.match(T.Punctuation, '::')
- def valid(token):
- return token is not None
- def post(tlist, pidx, tidx, nidx):
- return pidx, nidx
- valid_prev = valid_next = valid
- _group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
- def group_tzcasts(tlist):
- def match(token):
- return token.ttype == T.Keyword.TZCast
- def valid(token):
- return token is not None
- def post(tlist, pidx, tidx, nidx):
- return pidx, nidx
- _group(tlist, sql.Identifier, match, valid, valid, post)
- def group_typed_literal(tlist):
- # definitely not complete, see e.g.:
- # https://docs.microsoft.com/en-us/sql/odbc/reference/appendixes/interval-literal-syntax
- # https://docs.microsoft.com/en-us/sql/odbc/reference/appendixes/interval-literals
- # https://www.postgresql.org/docs/9.1/datatype-datetime.html
- # https://www.postgresql.org/docs/9.1/functions-datetime.html
- def match(token):
- return imt(token, m=sql.TypedLiteral.M_OPEN)
- def match_to_extend(token):
- return isinstance(token, sql.TypedLiteral)
- def valid_prev(token):
- return token is not None
- def valid_next(token):
- return token is not None and token.match(*sql.TypedLiteral.M_CLOSE)
- def valid_final(token):
- return token is not None and token.match(*sql.TypedLiteral.M_EXTEND)
- def post(tlist, pidx, tidx, nidx):
- return tidx, nidx
- _group(tlist, sql.TypedLiteral, match, valid_prev, valid_next,
- post, extend=False)
- _group(tlist, sql.TypedLiteral, match_to_extend, valid_prev, valid_final,
- post, extend=True)
- def group_period(tlist):
- def match(token):
- return token.match(T.Punctuation, '.')
- def valid_prev(token):
- sqlcls = sql.SquareBrackets, sql.Identifier
- ttypes = T.Name, T.String.Symbol
- return imt(token, i=sqlcls, t=ttypes)
- def valid_next(token):
- # issue261, allow invalid next token
- return True
- def post(tlist, pidx, tidx, nidx):
- # next_ validation is being performed here. issue261
- sqlcls = sql.SquareBrackets, sql.Function
- ttypes = T.Name, T.String.Symbol, T.Wildcard
- next_ = tlist[nidx] if nidx is not None else None
- valid_next = imt(next_, i=sqlcls, t=ttypes)
- return (pidx, nidx) if valid_next else (pidx, tidx)
- _group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
- def group_as(tlist):
- def match(token):
- return token.is_keyword and token.normalized == 'AS'
- def valid_prev(token):
- return token.normalized == 'NULL' or not token.is_keyword
- def valid_next(token):
- ttypes = T.DML, T.DDL, T.CTE
- return not imt(token, t=ttypes) and token is not None
- def post(tlist, pidx, tidx, nidx):
- return pidx, nidx
- _group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
- def group_assignment(tlist):
- def match(token):
- return token.match(T.Assignment, ':=')
- def valid(token):
- return token is not None and token.ttype not in (T.Keyword)
- def post(tlist, pidx, tidx, nidx):
- m_semicolon = T.Punctuation, ';'
- snidx, _ = tlist.token_next_by(m=m_semicolon, idx=nidx)
- nidx = snidx or nidx
- return pidx, nidx
- valid_prev = valid_next = valid
- _group(tlist, sql.Assignment, match, valid_prev, valid_next, post)
- def group_comparison(tlist):
- sqlcls = (sql.Parenthesis, sql.Function, sql.Identifier,
- sql.Operation, sql.TypedLiteral)
- ttypes = T_NUMERICAL + T_STRING + T_NAME
- def match(token):
- return token.ttype == T.Operator.Comparison
- def valid(token):
- if imt(token, t=ttypes, i=sqlcls):
- return True
- elif token and token.is_keyword and token.normalized == 'NULL':
- return True
- else:
- return False
- def post(tlist, pidx, tidx, nidx):
- return pidx, nidx
- valid_prev = valid_next = valid
- _group(tlist, sql.Comparison, match,
- valid_prev, valid_next, post, extend=False)
- @recurse(sql.Identifier)
- def group_identifier(tlist):
- ttypes = (T.String.Symbol, T.Name)
- tidx, token = tlist.token_next_by(t=ttypes)
- while token:
- tlist.group_tokens(sql.Identifier, tidx, tidx)
- tidx, token = tlist.token_next_by(t=ttypes, idx=tidx)
- def group_arrays(tlist):
- sqlcls = sql.SquareBrackets, sql.Identifier, sql.Function
- ttypes = T.Name, T.String.Symbol
- def match(token):
- return isinstance(token, sql.SquareBrackets)
- def valid_prev(token):
- return imt(token, i=sqlcls, t=ttypes)
- def valid_next(token):
- return True
- def post(tlist, pidx, tidx, nidx):
- return pidx, tidx
- _group(tlist, sql.Identifier, match,
- valid_prev, valid_next, post, extend=True, recurse=False)
- def group_operator(tlist):
- ttypes = T_NUMERICAL + T_STRING + T_NAME
- sqlcls = (sql.SquareBrackets, sql.Parenthesis, sql.Function,
- sql.Identifier, sql.Operation, sql.TypedLiteral)
- def match(token):
- return imt(token, t=(T.Operator, T.Wildcard))
- def valid(token):
- return imt(token, i=sqlcls, t=ttypes) \
- or (token and token.match(
- T.Keyword,
- ('CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP')))
- def post(tlist, pidx, tidx, nidx):
- tlist[tidx].ttype = T.Operator
- return pidx, nidx
- valid_prev = valid_next = valid
- _group(tlist, sql.Operation, match,
- valid_prev, valid_next, post, extend=False)
- def group_identifier_list(tlist):
- m_role = T.Keyword, ('null', 'role')
- sqlcls = (sql.Function, sql.Case, sql.Identifier, sql.Comparison,
- sql.IdentifierList, sql.Operation)
- ttypes = (T_NUMERICAL + T_STRING + T_NAME
- + (T.Keyword, T.Comment, T.Wildcard))
- def match(token):
- return token.match(T.Punctuation, ',')
- def valid(token):
- return imt(token, i=sqlcls, m=m_role, t=ttypes)
- def post(tlist, pidx, tidx, nidx):
- return pidx, nidx
- valid_prev = valid_next = valid
- _group(tlist, sql.IdentifierList, match,
- valid_prev, valid_next, post, extend=True)
- @recurse(sql.Comment)
- def group_comments(tlist):
- tidx, token = tlist.token_next_by(t=T.Comment)
- while token:
- eidx, end = tlist.token_not_matching(
- lambda tk: imt(tk, t=T.Comment) or tk.is_whitespace, idx=tidx)
- if end is not None:
- eidx, end = tlist.token_prev(eidx, skip_ws=False)
- tlist.group_tokens(sql.Comment, tidx, eidx)
- tidx, token = tlist.token_next_by(t=T.Comment, idx=tidx)
- @recurse(sql.Where)
- def group_where(tlist):
- tidx, token = tlist.token_next_by(m=sql.Where.M_OPEN)
- while token:
- eidx, end = tlist.token_next_by(m=sql.Where.M_CLOSE, idx=tidx)
- if end is None:
- end = tlist._groupable_tokens[-1]
- else:
- end = tlist.tokens[eidx - 1]
- # TODO: convert this to eidx instead of end token.
- # i think above values are len(tlist) and eidx-1
- eidx = tlist.token_index(end)
- tlist.group_tokens(sql.Where, tidx, eidx)
- tidx, token = tlist.token_next_by(m=sql.Where.M_OPEN, idx=tidx)
- @recurse()
- def group_aliased(tlist):
- I_ALIAS = (sql.Parenthesis, sql.Function, sql.Case, sql.Identifier,
- sql.Operation, sql.Comparison)
- tidx, token = tlist.token_next_by(i=I_ALIAS, t=T.Number)
- while token:
- nidx, next_ = tlist.token_next(tidx)
- if isinstance(next_, sql.Identifier):
- tlist.group_tokens(sql.Identifier, tidx, nidx, extend=True)
- tidx, token = tlist.token_next_by(i=I_ALIAS, t=T.Number, idx=tidx)
- @recurse(sql.Function)
- def group_functions(tlist):
- has_create = False
- has_table = False
- for tmp_token in tlist.tokens:
- if tmp_token.value == 'CREATE':
- has_create = True
- if tmp_token.value == 'TABLE':
- has_table = True
- if has_create and has_table:
- return
- tidx, token = tlist.token_next_by(t=T.Name)
- while token:
- nidx, next_ = tlist.token_next(tidx)
- if isinstance(next_, sql.Parenthesis):
- tlist.group_tokens(sql.Function, tidx, nidx)
- tidx, token = tlist.token_next_by(t=T.Name, idx=tidx)
- def group_order(tlist):
- """Group together Identifier and Asc/Desc token"""
- tidx, token = tlist.token_next_by(t=T.Keyword.Order)
- while token:
- pidx, prev_ = tlist.token_prev(tidx)
- if imt(prev_, i=sql.Identifier, t=T.Number):
- tlist.group_tokens(sql.Identifier, pidx, tidx)
- tidx = pidx
- tidx, token = tlist.token_next_by(t=T.Keyword.Order, idx=tidx)
- @recurse()
- def align_comments(tlist):
- tidx, token = tlist.token_next_by(i=sql.Comment)
- while token:
- pidx, prev_ = tlist.token_prev(tidx)
- if isinstance(prev_, sql.TokenList):
- tlist.group_tokens(sql.TokenList, pidx, tidx, extend=True)
- tidx = pidx
- tidx, token = tlist.token_next_by(i=sql.Comment, idx=tidx)
- def group_values(tlist):
- tidx, token = tlist.token_next_by(m=(T.Keyword, 'VALUES'))
- start_idx = tidx
- end_idx = -1
- while token:
- if isinstance(token, sql.Parenthesis):
- end_idx = tidx
- tidx, token = tlist.token_next(tidx)
- if end_idx != -1:
- tlist.group_tokens(sql.Values, start_idx, end_idx, extend=True)
- def group(stmt):
- for func in [
- group_comments,
- # _group_matching
- group_brackets,
- group_parenthesis,
- group_case,
- group_if,
- group_for,
- group_begin,
- group_functions,
- group_where,
- group_period,
- group_arrays,
- group_identifier,
- group_order,
- group_typecasts,
- group_tzcasts,
- group_typed_literal,
- group_operator,
- group_comparison,
- group_as,
- group_aliased,
- group_assignment,
- align_comments,
- group_identifier_list,
- group_values,
- ]:
- func(stmt)
- return stmt
- def _group(tlist, cls, match,
- valid_prev=lambda t: True,
- valid_next=lambda t: True,
- post=None,
- extend=True,
- recurse=True
- ):
- """Groups together tokens that are joined by a middle token. i.e. x < y"""
- tidx_offset = 0
- pidx, prev_ = None, None
- for idx, token in enumerate(list(tlist)):
- tidx = idx - tidx_offset
- if tidx < 0: # tidx shouldn't get negative
- continue
- if token.is_whitespace:
- continue
- if recurse and token.is_group and not isinstance(token, cls):
- _group(token, cls, match, valid_prev, valid_next, post, extend)
- if match(token):
- nidx, next_ = tlist.token_next(tidx)
- if prev_ and valid_prev(prev_) and valid_next(next_):
- from_idx, to_idx = post(tlist, pidx, tidx, nidx)
- grp = tlist.group_tokens(cls, from_idx, to_idx, extend=extend)
- tidx_offset += to_idx - from_idx
- pidx, prev_ = from_idx, grp
- continue
- pidx, prev_ = tidx, token
|