import sqlalchemy as sa from sqlalchemy.dialects import postgresql from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql.expression import ColumnElement, FunctionElement from sqlalchemy.sql.functions import GenericFunction from .functions.orm import quote class array_get(FunctionElement): name = 'array_get' @compiles(array_get) def compile_array_get(element, compiler, **kw): args = list(element.clauses) if len(args) != 2: raise Exception( "Function 'array_get' expects two arguments (%d given)." % len(args) ) if not hasattr(args[1], 'value') or not isinstance(args[1].value, int): raise Exception( "Second argument should be an integer." ) return '(%s)[%s]' % ( compiler.process(args[0]), sa.text(str(args[1].value + 1)) ) class row_to_json(GenericFunction): name = 'row_to_json' type = postgresql.JSON @compiles(row_to_json, 'postgresql') def compile_row_to_json(element, compiler, **kw): return "%s(%s)" % (element.name, compiler.process(element.clauses)) class json_array_length(GenericFunction): name = 'json_array_length' type = sa.Integer @compiles(json_array_length, 'postgresql') def compile_json_array_length(element, compiler, **kw): return "%s(%s)" % (element.name, compiler.process(element.clauses)) class Asterisk(ColumnElement): def __init__(self, selectable): self.selectable = selectable @compiles(Asterisk) def compile_asterisk(element, compiler, **kw): return '%s.*' % quote(compiler.dialect, element.selectable.name)