asyncmy.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. # mysql/asyncmy.py
  2. # Copyright (C) 2005-2022 the SQLAlchemy authors and contributors <see AUTHORS
  3. # file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: https://www.opensource.org/licenses/mit-license.php
  7. r"""
  8. .. dialect:: mysql+asyncmy
  9. :name: asyncmy
  10. :dbapi: asyncmy
  11. :connectstring: mysql+asyncmy://user:password@host:port/dbname[?key=value&key=value...]
  12. :url: https://github.com/long2ice/asyncmy
  13. .. note:: The asyncmy dialect as of September, 2021 was added to provide
  14. MySQL/MariaDB asyncio compatibility given that the :ref:`aiomysql` database
  15. driver has become unmaintained, however asyncmy is itself very new.
  16. Using a special asyncio mediation layer, the asyncmy dialect is usable
  17. as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
  18. extension package.
  19. This dialect should normally be used only with the
  20. :func:`_asyncio.create_async_engine` engine creation function::
  21. from sqlalchemy.ext.asyncio import create_async_engine
  22. engine = create_async_engine("mysql+asyncmy://user:pass@hostname/dbname?charset=utf8mb4")
  23. """ # noqa
  24. from .pymysql import MySQLDialect_pymysql
  25. from ... import pool
  26. from ... import util
  27. from ...engine import AdaptedConnection
  28. from ...util.concurrency import asynccontextmanager
  29. from ...util.concurrency import asyncio
  30. from ...util.concurrency import await_fallback
  31. from ...util.concurrency import await_only
  32. class AsyncAdapt_asyncmy_cursor:
  33. server_side = False
  34. __slots__ = (
  35. "_adapt_connection",
  36. "_connection",
  37. "await_",
  38. "_cursor",
  39. "_rows",
  40. )
  41. def __init__(self, adapt_connection):
  42. self._adapt_connection = adapt_connection
  43. self._connection = adapt_connection._connection
  44. self.await_ = adapt_connection.await_
  45. cursor = self._connection.cursor()
  46. self._cursor = self.await_(cursor.__aenter__())
  47. self._rows = []
  48. @property
  49. def description(self):
  50. return self._cursor.description
  51. @property
  52. def rowcount(self):
  53. return self._cursor.rowcount
  54. @property
  55. def arraysize(self):
  56. return self._cursor.arraysize
  57. @arraysize.setter
  58. def arraysize(self, value):
  59. self._cursor.arraysize = value
  60. @property
  61. def lastrowid(self):
  62. return self._cursor.lastrowid
  63. def close(self):
  64. # note we aren't actually closing the cursor here,
  65. # we are just letting GC do it. to allow this to be async
  66. # we would need the Result to change how it does "Safe close cursor".
  67. # MySQL "cursors" don't actually have state to be "closed" besides
  68. # exhausting rows, which we already have done for sync cursor.
  69. # another option would be to emulate aiosqlite dialect and assign
  70. # cursor only if we are doing server side cursor operation.
  71. self._rows[:] = []
  72. def execute(self, operation, parameters=None):
  73. return self.await_(self._execute_async(operation, parameters))
  74. def executemany(self, operation, seq_of_parameters):
  75. return self.await_(
  76. self._executemany_async(operation, seq_of_parameters)
  77. )
  78. async def _execute_async(self, operation, parameters):
  79. async with self._adapt_connection._mutex_and_adapt_errors():
  80. if parameters is None:
  81. result = await self._cursor.execute(operation)
  82. else:
  83. result = await self._cursor.execute(operation, parameters)
  84. if not self.server_side:
  85. # asyncmy has a "fake" async result, so we have to pull it out
  86. # of that here since our default result is not async.
  87. # we could just as easily grab "_rows" here and be done with it
  88. # but this is safer.
  89. self._rows = list(await self._cursor.fetchall())
  90. return result
  91. async def _executemany_async(self, operation, seq_of_parameters):
  92. async with self._adapt_connection._mutex_and_adapt_errors():
  93. return await self._cursor.executemany(operation, seq_of_parameters)
  94. def setinputsizes(self, *inputsizes):
  95. pass
  96. def __iter__(self):
  97. while self._rows:
  98. yield self._rows.pop(0)
  99. def fetchone(self):
  100. if self._rows:
  101. return self._rows.pop(0)
  102. else:
  103. return None
  104. def fetchmany(self, size=None):
  105. if size is None:
  106. size = self.arraysize
  107. retval = self._rows[0:size]
  108. self._rows[:] = self._rows[size:]
  109. return retval
  110. def fetchall(self):
  111. retval = self._rows[:]
  112. self._rows[:] = []
  113. return retval
  114. class AsyncAdapt_asyncmy_ss_cursor(AsyncAdapt_asyncmy_cursor):
  115. __slots__ = ()
  116. server_side = True
  117. def __init__(self, adapt_connection):
  118. self._adapt_connection = adapt_connection
  119. self._connection = adapt_connection._connection
  120. self.await_ = adapt_connection.await_
  121. cursor = self._connection.cursor(
  122. adapt_connection.dbapi.asyncmy.cursors.SSCursor
  123. )
  124. self._cursor = self.await_(cursor.__aenter__())
  125. def close(self):
  126. if self._cursor is not None:
  127. self.await_(self._cursor.close())
  128. self._cursor = None
  129. def fetchone(self):
  130. return self.await_(self._cursor.fetchone())
  131. def fetchmany(self, size=None):
  132. return self.await_(self._cursor.fetchmany(size=size))
  133. def fetchall(self):
  134. return self.await_(self._cursor.fetchall())
  135. class AsyncAdapt_asyncmy_connection(AdaptedConnection):
  136. await_ = staticmethod(await_only)
  137. __slots__ = ("dbapi", "_connection", "_execute_mutex")
  138. def __init__(self, dbapi, connection):
  139. self.dbapi = dbapi
  140. self._connection = connection
  141. self._execute_mutex = asyncio.Lock()
  142. @asynccontextmanager
  143. async def _mutex_and_adapt_errors(self):
  144. async with self._execute_mutex:
  145. try:
  146. yield
  147. except AttributeError:
  148. raise self.dbapi.InternalError(
  149. "network operation failed due to asyncmy attribute error"
  150. )
  151. def ping(self, reconnect):
  152. assert not reconnect
  153. return self.await_(self._do_ping())
  154. async def _do_ping(self):
  155. async with self._mutex_and_adapt_errors():
  156. return await self._connection.ping(False)
  157. def character_set_name(self):
  158. return self._connection.character_set_name()
  159. def autocommit(self, value):
  160. self.await_(self._connection.autocommit(value))
  161. def cursor(self, server_side=False):
  162. if server_side:
  163. return AsyncAdapt_asyncmy_ss_cursor(self)
  164. else:
  165. return AsyncAdapt_asyncmy_cursor(self)
  166. def rollback(self):
  167. self.await_(self._connection.rollback())
  168. def commit(self):
  169. self.await_(self._connection.commit())
  170. def close(self):
  171. # it's not awaitable.
  172. self._connection.close()
  173. class AsyncAdaptFallback_asyncmy_connection(AsyncAdapt_asyncmy_connection):
  174. __slots__ = ()
  175. await_ = staticmethod(await_fallback)
  176. def _Binary(x):
  177. """Return x as a binary type."""
  178. return bytes(x)
  179. class AsyncAdapt_asyncmy_dbapi:
  180. def __init__(self, asyncmy):
  181. self.asyncmy = asyncmy
  182. self.paramstyle = "format"
  183. self._init_dbapi_attributes()
  184. def _init_dbapi_attributes(self):
  185. for name in (
  186. "Warning",
  187. "Error",
  188. "InterfaceError",
  189. "DataError",
  190. "DatabaseError",
  191. "OperationalError",
  192. "InterfaceError",
  193. "IntegrityError",
  194. "ProgrammingError",
  195. "InternalError",
  196. "NotSupportedError",
  197. ):
  198. setattr(self, name, getattr(self.asyncmy.errors, name))
  199. STRING = util.symbol("STRING")
  200. NUMBER = util.symbol("NUMBER")
  201. BINARY = util.symbol("BINARY")
  202. DATETIME = util.symbol("DATETIME")
  203. TIMESTAMP = util.symbol("TIMESTAMP")
  204. Binary = staticmethod(_Binary)
  205. def connect(self, *arg, **kw):
  206. async_fallback = kw.pop("async_fallback", False)
  207. if util.asbool(async_fallback):
  208. return AsyncAdaptFallback_asyncmy_connection(
  209. self,
  210. await_fallback(self.asyncmy.connect(*arg, **kw)),
  211. )
  212. else:
  213. return AsyncAdapt_asyncmy_connection(
  214. self,
  215. await_only(self.asyncmy.connect(*arg, **kw)),
  216. )
  217. class MySQLDialect_asyncmy(MySQLDialect_pymysql):
  218. driver = "asyncmy"
  219. supports_statement_cache = True
  220. supports_server_side_cursors = True
  221. _sscursor = AsyncAdapt_asyncmy_ss_cursor
  222. is_async = True
  223. @classmethod
  224. def dbapi(cls):
  225. return AsyncAdapt_asyncmy_dbapi(__import__("asyncmy"))
  226. @classmethod
  227. def get_pool_class(cls, url):
  228. async_fallback = url.query.get("async_fallback", False)
  229. if util.asbool(async_fallback):
  230. return pool.FallbackAsyncAdaptedQueuePool
  231. else:
  232. return pool.AsyncAdaptedQueuePool
  233. def create_connect_args(self, url):
  234. return super(MySQLDialect_asyncmy, self).create_connect_args(
  235. url, _translate_args=dict(username="user", database="db")
  236. )
  237. def is_disconnect(self, e, connection, cursor):
  238. if super(MySQLDialect_asyncmy, self).is_disconnect(
  239. e, connection, cursor
  240. ):
  241. return True
  242. else:
  243. str_e = str(e).lower()
  244. return (
  245. "not connected" in str_e or "network operation failed" in str_e
  246. )
  247. def _found_rows_client_flag(self):
  248. from asyncmy.constants import CLIENT
  249. return CLIENT.FOUND_ROWS
  250. def get_driver_connection(self, connection):
  251. return connection._connection
  252. dialect = MySQLDialect_asyncmy