1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407 |
- # orm/persistence.py
- # Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
- # <see AUTHORS file>
- #
- # This module is part of SQLAlchemy and is released under
- # the MIT License: https://www.opensource.org/licenses/mit-license.php
- """private module containing functions used to emit INSERT, UPDATE
- and DELETE statements on behalf of a :class:`_orm.Mapper` and its descending
- mappers.
- The functions here are called only by the unit of work functions
- in unitofwork.py.
- """
- from itertools import chain
- from itertools import groupby
- import operator
- from . import attributes
- from . import evaluator
- from . import exc as orm_exc
- from . import loading
- from . import sync
- from .base import NO_VALUE
- from .base import state_str
- from .. import exc as sa_exc
- from .. import future
- from .. import sql
- from .. import util
- from ..engine import result as _result
- from ..sql import coercions
- from ..sql import expression
- from ..sql import operators
- from ..sql import roles
- from ..sql import select
- from ..sql import sqltypes
- from ..sql.base import _entity_namespace_key
- from ..sql.base import CompileState
- from ..sql.base import Options
- from ..sql.dml import DeleteDMLState
- from ..sql.dml import UpdateDMLState
- from ..sql.elements import BooleanClauseList
- from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
- def _bulk_insert(
- mapper,
- mappings,
- session_transaction,
- isstates,
- return_defaults,
- render_nulls,
- ):
- base_mapper = mapper.base_mapper
- if session_transaction.session.connection_callable:
- raise NotImplementedError(
- "connection_callable / per-instance sharding "
- "not supported in bulk_insert()"
- )
- if isstates:
- if return_defaults:
- states = [(state, state.dict) for state in mappings]
- mappings = [dict_ for (state, dict_) in states]
- else:
- mappings = [state.dict for state in mappings]
- else:
- mappings = list(mappings)
- connection = session_transaction.connection(base_mapper)
- for table, super_mapper in base_mapper._sorted_tables.items():
- if not mapper.isa(super_mapper):
- continue
- records = (
- (
- None,
- state_dict,
- params,
- mapper,
- connection,
- value_params,
- has_all_pks,
- has_all_defaults,
- )
- for (
- state,
- state_dict,
- params,
- mp,
- conn,
- value_params,
- has_all_pks,
- has_all_defaults,
- ) in _collect_insert_commands(
- table,
- ((None, mapping, mapper, connection) for mapping in mappings),
- bulk=True,
- return_defaults=return_defaults,
- render_nulls=render_nulls,
- )
- )
- _emit_insert_statements(
- base_mapper,
- None,
- super_mapper,
- table,
- records,
- bookkeeping=return_defaults,
- )
- if return_defaults and isstates:
- identity_cls = mapper._identity_class
- identity_props = [p.key for p in mapper._identity_key_props]
- for state, dict_ in states:
- state.key = (
- identity_cls,
- tuple([dict_[key] for key in identity_props]),
- )
- def _bulk_update(
- mapper, mappings, session_transaction, isstates, update_changed_only
- ):
- base_mapper = mapper.base_mapper
- search_keys = mapper._primary_key_propkeys
- if mapper._version_id_prop:
- search_keys = {mapper._version_id_prop.key}.union(search_keys)
- def _changed_dict(mapper, state):
- return dict(
- (k, v)
- for k, v in state.dict.items()
- if k in state.committed_state or k in search_keys
- )
- if isstates:
- if update_changed_only:
- mappings = [_changed_dict(mapper, state) for state in mappings]
- else:
- mappings = [state.dict for state in mappings]
- else:
- mappings = list(mappings)
- if session_transaction.session.connection_callable:
- raise NotImplementedError(
- "connection_callable / per-instance sharding "
- "not supported in bulk_update()"
- )
- connection = session_transaction.connection(base_mapper)
- for table, super_mapper in base_mapper._sorted_tables.items():
- if not mapper.isa(super_mapper):
- continue
- records = _collect_update_commands(
- None,
- table,
- (
- (
- None,
- mapping,
- mapper,
- connection,
- (
- mapping[mapper._version_id_prop.key]
- if mapper._version_id_prop
- else None
- ),
- )
- for mapping in mappings
- ),
- bulk=True,
- )
- _emit_update_statements(
- base_mapper,
- None,
- super_mapper,
- table,
- records,
- bookkeeping=False,
- )
- def save_obj(base_mapper, states, uowtransaction, single=False):
- """Issue ``INSERT`` and/or ``UPDATE`` statements for a list
- of objects.
- This is called within the context of a UOWTransaction during a
- flush operation, given a list of states to be flushed. The
- base mapper in an inheritance hierarchy handles the inserts/
- updates for all descendant mappers.
- """
- # if batch=false, call _save_obj separately for each object
- if not single and not base_mapper.batch:
- for state in _sort_states(base_mapper, states):
- save_obj(base_mapper, [state], uowtransaction, single=True)
- return
- states_to_update = []
- states_to_insert = []
- for (
- state,
- dict_,
- mapper,
- connection,
- has_identity,
- row_switch,
- update_version_id,
- ) in _organize_states_for_save(base_mapper, states, uowtransaction):
- if has_identity or row_switch:
- states_to_update.append(
- (state, dict_, mapper, connection, update_version_id)
- )
- else:
- states_to_insert.append((state, dict_, mapper, connection))
- for table, mapper in base_mapper._sorted_tables.items():
- if table not in mapper._pks_by_table:
- continue
- insert = _collect_insert_commands(table, states_to_insert)
- update = _collect_update_commands(
- uowtransaction, table, states_to_update
- )
- _emit_update_statements(
- base_mapper,
- uowtransaction,
- mapper,
- table,
- update,
- )
- _emit_insert_statements(
- base_mapper,
- uowtransaction,
- mapper,
- table,
- insert,
- )
- _finalize_insert_update_commands(
- base_mapper,
- uowtransaction,
- chain(
- (
- (state, state_dict, mapper, connection, False)
- for (state, state_dict, mapper, connection) in states_to_insert
- ),
- (
- (state, state_dict, mapper, connection, True)
- for (
- state,
- state_dict,
- mapper,
- connection,
- update_version_id,
- ) in states_to_update
- ),
- ),
- )
- def post_update(base_mapper, states, uowtransaction, post_update_cols):
- """Issue UPDATE statements on behalf of a relationship() which
- specifies post_update.
- """
- states_to_update = list(
- _organize_states_for_post_update(base_mapper, states, uowtransaction)
- )
- for table, mapper in base_mapper._sorted_tables.items():
- if table not in mapper._pks_by_table:
- continue
- update = (
- (
- state,
- state_dict,
- sub_mapper,
- connection,
- mapper._get_committed_state_attr_by_column(
- state, state_dict, mapper.version_id_col
- )
- if mapper.version_id_col is not None
- else None,
- )
- for state, state_dict, sub_mapper, connection in states_to_update
- if table in sub_mapper._pks_by_table
- )
- update = _collect_post_update_commands(
- base_mapper, uowtransaction, table, update, post_update_cols
- )
- _emit_post_update_statements(
- base_mapper,
- uowtransaction,
- mapper,
- table,
- update,
- )
- def delete_obj(base_mapper, states, uowtransaction):
- """Issue ``DELETE`` statements for a list of objects.
- This is called within the context of a UOWTransaction during a
- flush operation.
- """
- states_to_delete = list(
- _organize_states_for_delete(base_mapper, states, uowtransaction)
- )
- table_to_mapper = base_mapper._sorted_tables
- for table in reversed(list(table_to_mapper.keys())):
- mapper = table_to_mapper[table]
- if table not in mapper._pks_by_table:
- continue
- elif mapper.inherits and mapper.passive_deletes:
- continue
- delete = _collect_delete_commands(
- base_mapper, uowtransaction, table, states_to_delete
- )
- _emit_delete_statements(
- base_mapper,
- uowtransaction,
- mapper,
- table,
- delete,
- )
- for (
- state,
- state_dict,
- mapper,
- connection,
- update_version_id,
- ) in states_to_delete:
- mapper.dispatch.after_delete(mapper, connection, state)
- def _organize_states_for_save(base_mapper, states, uowtransaction):
- """Make an initial pass across a set of states for INSERT or
- UPDATE.
- This includes splitting out into distinct lists for
- each, calling before_insert/before_update, obtaining
- key information for each state including its dictionary,
- mapper, the connection to use for the execution per state,
- and the identity flag.
- """
- for state, dict_, mapper, connection in _connections_for_states(
- base_mapper, uowtransaction, states
- ):
- has_identity = bool(state.key)
- instance_key = state.key or mapper._identity_key_from_state(state)
- row_switch = update_version_id = None
- # call before_XXX extensions
- if not has_identity:
- mapper.dispatch.before_insert(mapper, connection, state)
- else:
- mapper.dispatch.before_update(mapper, connection, state)
- if mapper._validate_polymorphic_identity:
- mapper._validate_polymorphic_identity(mapper, state, dict_)
- # detect if we have a "pending" instance (i.e. has
- # no instance_key attached to it), and another instance
- # with the same identity key already exists as persistent.
- # convert to an UPDATE if so.
- if (
- not has_identity
- and instance_key in uowtransaction.session.identity_map
- ):
- instance = uowtransaction.session.identity_map[instance_key]
- existing = attributes.instance_state(instance)
- if not uowtransaction.was_already_deleted(existing):
- if not uowtransaction.is_deleted(existing):
- util.warn(
- "New instance %s with identity key %s conflicts "
- "with persistent instance %s"
- % (state_str(state), instance_key, state_str(existing))
- )
- else:
- base_mapper._log_debug(
- "detected row switch for identity %s. "
- "will update %s, remove %s from "
- "transaction",
- instance_key,
- state_str(state),
- state_str(existing),
- )
- # remove the "delete" flag from the existing element
- uowtransaction.remove_state_actions(existing)
- row_switch = existing
- if (has_identity or row_switch) and mapper.version_id_col is not None:
- update_version_id = mapper._get_committed_state_attr_by_column(
- row_switch if row_switch else state,
- row_switch.dict if row_switch else dict_,
- mapper.version_id_col,
- )
- yield (
- state,
- dict_,
- mapper,
- connection,
- has_identity,
- row_switch,
- update_version_id,
- )
- def _organize_states_for_post_update(base_mapper, states, uowtransaction):
- """Make an initial pass across a set of states for UPDATE
- corresponding to post_update.
- This includes obtaining key information for each state
- including its dictionary, mapper, the connection to use for
- the execution per state.
- """
- return _connections_for_states(base_mapper, uowtransaction, states)
- def _organize_states_for_delete(base_mapper, states, uowtransaction):
- """Make an initial pass across a set of states for DELETE.
- This includes calling out before_delete and obtaining
- key information for each state including its dictionary,
- mapper, the connection to use for the execution per state.
- """
- for state, dict_, mapper, connection in _connections_for_states(
- base_mapper, uowtransaction, states
- ):
- mapper.dispatch.before_delete(mapper, connection, state)
- if mapper.version_id_col is not None:
- update_version_id = mapper._get_committed_state_attr_by_column(
- state, dict_, mapper.version_id_col
- )
- else:
- update_version_id = None
- yield (state, dict_, mapper, connection, update_version_id)
- def _collect_insert_commands(
- table,
- states_to_insert,
- bulk=False,
- return_defaults=False,
- render_nulls=False,
- ):
- """Identify sets of values to use in INSERT statements for a
- list of states.
- """
- for state, state_dict, mapper, connection in states_to_insert:
- if table not in mapper._pks_by_table:
- continue
- params = {}
- value_params = {}
- propkey_to_col = mapper._propkey_to_col[table]
- eval_none = mapper._insert_cols_evaluating_none[table]
- for propkey in set(propkey_to_col).intersection(state_dict):
- value = state_dict[propkey]
- col = propkey_to_col[propkey]
- if value is None and col not in eval_none and not render_nulls:
- continue
- elif not bulk and (
- hasattr(value, "__clause_element__")
- or isinstance(value, sql.ClauseElement)
- ):
- value_params[col] = (
- value.__clause_element__()
- if hasattr(value, "__clause_element__")
- else value
- )
- else:
- params[col.key] = value
- if not bulk:
- # for all the columns that have no default and we don't have
- # a value and where "None" is not a special value, add
- # explicit None to the INSERT. This is a legacy behavior
- # which might be worth removing, as it should not be necessary
- # and also produces confusion, given that "missing" and None
- # now have distinct meanings
- for colkey in (
- mapper._insert_cols_as_none[table]
- .difference(params)
- .difference([c.key for c in value_params])
- ):
- params[colkey] = None
- if not bulk or return_defaults:
- # params are in terms of Column key objects, so
- # compare to pk_keys_by_table
- has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
- if mapper.base_mapper.eager_defaults:
- has_all_defaults = mapper._server_default_cols[table].issubset(
- params
- )
- else:
- has_all_defaults = True
- else:
- has_all_defaults = has_all_pks = True
- if (
- mapper.version_id_generator is not False
- and mapper.version_id_col is not None
- and mapper.version_id_col in mapper._cols_by_table[table]
- ):
- params[mapper.version_id_col.key] = mapper.version_id_generator(
- None
- )
- yield (
- state,
- state_dict,
- params,
- mapper,
- connection,
- value_params,
- has_all_pks,
- has_all_defaults,
- )
- def _collect_update_commands(
- uowtransaction, table, states_to_update, bulk=False
- ):
- """Identify sets of values to use in UPDATE statements for a
- list of states.
- This function works intricately with the history system
- to determine exactly what values should be updated
- as well as how the row should be matched within an UPDATE
- statement. Includes some tricky scenarios where the primary
- key of an object might have been changed.
- """
- for (
- state,
- state_dict,
- mapper,
- connection,
- update_version_id,
- ) in states_to_update:
- if table not in mapper._pks_by_table:
- continue
- pks = mapper._pks_by_table[table]
- value_params = {}
- propkey_to_col = mapper._propkey_to_col[table]
- if bulk:
- # keys here are mapped attribute keys, so
- # look at mapper attribute keys for pk
- params = dict(
- (propkey_to_col[propkey].key, state_dict[propkey])
- for propkey in set(propkey_to_col)
- .intersection(state_dict)
- .difference(mapper._pk_attr_keys_by_table[table])
- )
- has_all_defaults = True
- else:
- params = {}
- for propkey in set(propkey_to_col).intersection(
- state.committed_state
- ):
- value = state_dict[propkey]
- col = propkey_to_col[propkey]
- if hasattr(value, "__clause_element__") or isinstance(
- value, sql.ClauseElement
- ):
- value_params[col] = (
- value.__clause_element__()
- if hasattr(value, "__clause_element__")
- else value
- )
- # guard against values that generate non-__nonzero__
- # objects for __eq__()
- elif (
- state.manager[propkey].impl.is_equal(
- value, state.committed_state[propkey]
- )
- is not True
- ):
- params[col.key] = value
- if mapper.base_mapper.eager_defaults:
- has_all_defaults = (
- mapper._server_onupdate_default_cols[table]
- ).issubset(params)
- else:
- has_all_defaults = True
- if (
- update_version_id is not None
- and mapper.version_id_col in mapper._cols_by_table[table]
- ):
- if not bulk and not (params or value_params):
- # HACK: check for history in other tables, in case the
- # history is only in a different table than the one
- # where the version_id_col is. This logic was lost
- # from 0.9 -> 1.0.0 and restored in 1.0.6.
- for prop in mapper._columntoproperty.values():
- history = state.manager[prop.key].impl.get_history(
- state, state_dict, attributes.PASSIVE_NO_INITIALIZE
- )
- if history.added:
- break
- else:
- # no net change, break
- continue
- col = mapper.version_id_col
- no_params = not params and not value_params
- params[col._label] = update_version_id
- if (
- bulk or col.key not in params
- ) and mapper.version_id_generator is not False:
- val = mapper.version_id_generator(update_version_id)
- params[col.key] = val
- elif mapper.version_id_generator is False and no_params:
- # no version id generator, no values set on the table,
- # and version id wasn't manually incremented.
- # set version id to itself so we get an UPDATE
- # statement
- params[col.key] = update_version_id
- elif not (params or value_params):
- continue
- has_all_pks = True
- expect_pk_cascaded = False
- if bulk:
- # keys here are mapped attribute keys, so
- # look at mapper attribute keys for pk
- pk_params = dict(
- (propkey_to_col[propkey]._label, state_dict.get(propkey))
- for propkey in set(propkey_to_col).intersection(
- mapper._pk_attr_keys_by_table[table]
- )
- )
- else:
- pk_params = {}
- for col in pks:
- propkey = mapper._columntoproperty[col].key
- history = state.manager[propkey].impl.get_history(
- state, state_dict, attributes.PASSIVE_OFF
- )
- if history.added:
- if (
- not history.deleted
- or ("pk_cascaded", state, col)
- in uowtransaction.attributes
- ):
- expect_pk_cascaded = True
- pk_params[col._label] = history.added[0]
- params.pop(col.key, None)
- else:
- # else, use the old value to locate the row
- pk_params[col._label] = history.deleted[0]
- if col in value_params:
- has_all_pks = False
- else:
- pk_params[col._label] = history.unchanged[0]
- if pk_params[col._label] is None:
- raise orm_exc.FlushError(
- "Can't update table %s using NULL for primary "
- "key value on column %s" % (table, col)
- )
- if params or value_params:
- params.update(pk_params)
- yield (
- state,
- state_dict,
- params,
- mapper,
- connection,
- value_params,
- has_all_defaults,
- has_all_pks,
- )
- elif expect_pk_cascaded:
- # no UPDATE occurs on this table, but we expect that CASCADE rules
- # have changed the primary key of the row; propagate this event to
- # other columns that expect to have been modified. this normally
- # occurs after the UPDATE is emitted however we invoke it here
- # explicitly in the absence of our invoking an UPDATE
- for m, equated_pairs in mapper._table_to_equated[table]:
- sync.populate(
- state,
- m,
- state,
- m,
- equated_pairs,
- uowtransaction,
- mapper.passive_updates,
- )
- def _collect_post_update_commands(
- base_mapper, uowtransaction, table, states_to_update, post_update_cols
- ):
- """Identify sets of values to use in UPDATE statements for a
- list of states within a post_update operation.
- """
- for (
- state,
- state_dict,
- mapper,
- connection,
- update_version_id,
- ) in states_to_update:
- # assert table in mapper._pks_by_table
- pks = mapper._pks_by_table[table]
- params = {}
- hasdata = False
- for col in mapper._cols_by_table[table]:
- if col in pks:
- params[col._label] = mapper._get_state_attr_by_column(
- state, state_dict, col, passive=attributes.PASSIVE_OFF
- )
- elif col in post_update_cols or col.onupdate is not None:
- prop = mapper._columntoproperty[col]
- history = state.manager[prop.key].impl.get_history(
- state, state_dict, attributes.PASSIVE_NO_INITIALIZE
- )
- if history.added:
- value = history.added[0]
- params[col.key] = value
- hasdata = True
- if hasdata:
- if (
- update_version_id is not None
- and mapper.version_id_col in mapper._cols_by_table[table]
- ):
- col = mapper.version_id_col
- params[col._label] = update_version_id
- if (
- bool(state.key)
- and col.key not in params
- and mapper.version_id_generator is not False
- ):
- val = mapper.version_id_generator(update_version_id)
- params[col.key] = val
- yield state, state_dict, mapper, connection, params
- def _collect_delete_commands(
- base_mapper, uowtransaction, table, states_to_delete
- ):
- """Identify values to use in DELETE statements for a list of
- states to be deleted."""
- for (
- state,
- state_dict,
- mapper,
- connection,
- update_version_id,
- ) in states_to_delete:
- if table not in mapper._pks_by_table:
- continue
- params = {}
- for col in mapper._pks_by_table[table]:
- params[
- col.key
- ] = value = mapper._get_committed_state_attr_by_column(
- state, state_dict, col
- )
- if value is None:
- raise orm_exc.FlushError(
- "Can't delete from table %s "
- "using NULL for primary "
- "key value on column %s" % (table, col)
- )
- if (
- update_version_id is not None
- and mapper.version_id_col in mapper._cols_by_table[table]
- ):
- params[mapper.version_id_col.key] = update_version_id
- yield params, connection
- def _emit_update_statements(
- base_mapper,
- uowtransaction,
- mapper,
- table,
- update,
- bookkeeping=True,
- ):
- """Emit UPDATE statements corresponding to value lists collected
- by _collect_update_commands()."""
- needs_version_id = (
- mapper.version_id_col is not None
- and mapper.version_id_col in mapper._cols_by_table[table]
- )
- execution_options = {"compiled_cache": base_mapper._compiled_cache}
- def update_stmt():
- clauses = BooleanClauseList._construct_raw(operators.and_)
- for col in mapper._pks_by_table[table]:
- clauses.clauses.append(
- col == sql.bindparam(col._label, type_=col.type)
- )
- if needs_version_id:
- clauses.clauses.append(
- mapper.version_id_col
- == sql.bindparam(
- mapper.version_id_col._label,
- type_=mapper.version_id_col.type,
- )
- )
- stmt = table.update().where(clauses)
- return stmt
- cached_stmt = base_mapper._memo(("update", table), update_stmt)
- for (
- (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks),
- records,
- ) in groupby(
- update,
- lambda rec: (
- rec[4], # connection
- set(rec[2]), # set of parameter keys
- bool(rec[5]), # whether or not we have "value" parameters
- rec[6], # has_all_defaults
- rec[7], # has all pks
- ),
- ):
- rows = 0
- records = list(records)
- statement = cached_stmt
- return_defaults = False
- if not has_all_pks:
- statement = statement.return_defaults()
- return_defaults = True
- elif (
- bookkeeping
- and not has_all_defaults
- and mapper.base_mapper.eager_defaults
- ):
- statement = statement.return_defaults()
- return_defaults = True
- elif mapper.version_id_col is not None:
- statement = statement.return_defaults(mapper.version_id_col)
- return_defaults = True
- assert_singlerow = (
- connection.dialect.supports_sane_rowcount
- if not return_defaults
- else connection.dialect.supports_sane_rowcount_returning
- )
- assert_multirow = (
- assert_singlerow
- and connection.dialect.supports_sane_multi_rowcount
- )
- allow_multirow = has_all_defaults and not needs_version_id
- if hasvalue:
- for (
- state,
- state_dict,
- params,
- mapper,
- connection,
- value_params,
- has_all_defaults,
- has_all_pks,
- ) in records:
- c = connection._execute_20(
- statement.values(value_params),
- params,
- execution_options=execution_options,
- )
- if bookkeeping:
- _postfetch(
- mapper,
- uowtransaction,
- table,
- state,
- state_dict,
- c,
- c.context.compiled_parameters[0],
- value_params,
- True,
- c.returned_defaults,
- )
- rows += c.rowcount
- check_rowcount = assert_singlerow
- else:
- if not allow_multirow:
- check_rowcount = assert_singlerow
- for (
- state,
- state_dict,
- params,
- mapper,
- connection,
- value_params,
- has_all_defaults,
- has_all_pks,
- ) in records:
- c = connection._execute_20(
- statement, params, execution_options=execution_options
- )
- # TODO: why with bookkeeping=False?
- if bookkeeping:
- _postfetch(
- mapper,
- uowtransaction,
- table,
- state,
- state_dict,
- c,
- c.context.compiled_parameters[0],
- value_params,
- True,
- c.returned_defaults,
- )
- rows += c.rowcount
- else:
- multiparams = [rec[2] for rec in records]
- check_rowcount = assert_multirow or (
- assert_singlerow and len(multiparams) == 1
- )
- c = connection._execute_20(
- statement, multiparams, execution_options=execution_options
- )
- rows += c.rowcount
- for (
- state,
- state_dict,
- params,
- mapper,
- connection,
- value_params,
- has_all_defaults,
- has_all_pks,
- ) in records:
- if bookkeeping:
- _postfetch(
- mapper,
- uowtransaction,
- table,
- state,
- state_dict,
- c,
- c.context.compiled_parameters[0],
- value_params,
- True,
- c.returned_defaults
- if not c.context.executemany
- else None,
- )
- if check_rowcount:
- if rows != len(records):
- raise orm_exc.StaleDataError(
- "UPDATE statement on table '%s' expected to "
- "update %d row(s); %d were matched."
- % (table.description, len(records), rows)
- )
- elif needs_version_id:
- util.warn(
- "Dialect %s does not support updated rowcount "
- "- versioning cannot be verified."
- % c.dialect.dialect_description
- )
- def _emit_insert_statements(
- base_mapper,
- uowtransaction,
- mapper,
- table,
- insert,
- bookkeeping=True,
- ):
- """Emit INSERT statements corresponding to value lists collected
- by _collect_insert_commands()."""
- cached_stmt = base_mapper._memo(("insert", table), table.insert)
- execution_options = {"compiled_cache": base_mapper._compiled_cache}
- for (
- (connection, pkeys, hasvalue, has_all_pks, has_all_defaults),
- records,
- ) in groupby(
- insert,
- lambda rec: (
- rec[4], # connection
- set(rec[2]), # parameter keys
- bool(rec[5]), # whether we have "value" parameters
- rec[6],
- rec[7],
- ),
- ):
- statement = cached_stmt
- if (
- not bookkeeping
- or (
- has_all_defaults
- or not base_mapper.eager_defaults
- or not connection.dialect.implicit_returning
- )
- and has_all_pks
- and not hasvalue
- ):
- # the "we don't need newly generated values back" section.
- # here we have all the PKs, all the defaults or we don't want
- # to fetch them, or the dialect doesn't support RETURNING at all
- # so we have to post-fetch / use lastrowid anyway.
- records = list(records)
- multiparams = [rec[2] for rec in records]
- c = connection._execute_20(
- statement, multiparams, execution_options=execution_options
- )
- if bookkeeping:
- for (
- (
- state,
- state_dict,
- params,
- mapper_rec,
- conn,
- value_params,
- has_all_pks,
- has_all_defaults,
- ),
- last_inserted_params,
- ) in zip(records, c.context.compiled_parameters):
- if state:
- _postfetch(
- mapper_rec,
- uowtransaction,
- table,
- state,
- state_dict,
- c,
- last_inserted_params,
- value_params,
- False,
- c.returned_defaults
- if not c.context.executemany
- else None,
- )
- else:
- _postfetch_bulk_save(mapper_rec, state_dict, table)
- else:
- # here, we need defaults and/or pk values back.
- records = list(records)
- if (
- not hasvalue
- and connection.dialect.insert_executemany_returning
- and len(records) > 1
- ):
- do_executemany = True
- else:
- do_executemany = False
- if not has_all_defaults and base_mapper.eager_defaults:
- statement = statement.return_defaults()
- elif mapper.version_id_col is not None:
- statement = statement.return_defaults(mapper.version_id_col)
- elif do_executemany:
- statement = statement.return_defaults(*table.primary_key)
- if do_executemany:
- multiparams = [rec[2] for rec in records]
- c = connection._execute_20(
- statement, multiparams, execution_options=execution_options
- )
- if bookkeeping:
- for (
- (
- state,
- state_dict,
- params,
- mapper_rec,
- conn,
- value_params,
- has_all_pks,
- has_all_defaults,
- ),
- last_inserted_params,
- inserted_primary_key,
- returned_defaults,
- ) in util.zip_longest(
- records,
- c.context.compiled_parameters,
- c.inserted_primary_key_rows,
- c.returned_defaults_rows or (),
- ):
- for pk, col in zip(
- inserted_primary_key,
- mapper._pks_by_table[table],
- ):
- prop = mapper_rec._columntoproperty[col]
- if state_dict.get(prop.key) is None:
- state_dict[prop.key] = pk
- if state:
- _postfetch(
- mapper_rec,
- uowtransaction,
- table,
- state,
- state_dict,
- c,
- last_inserted_params,
- value_params,
- False,
- returned_defaults,
- )
- else:
- _postfetch_bulk_save(mapper_rec, state_dict, table)
- else:
- for (
- state,
- state_dict,
- params,
- mapper_rec,
- connection,
- value_params,
- has_all_pks,
- has_all_defaults,
- ) in records:
- if value_params:
- result = connection._execute_20(
- statement.values(value_params),
- params,
- execution_options=execution_options,
- )
- else:
- result = connection._execute_20(
- statement,
- params,
- execution_options=execution_options,
- )
- primary_key = result.inserted_primary_key
- for pk, col in zip(
- primary_key, mapper._pks_by_table[table]
- ):
- prop = mapper_rec._columntoproperty[col]
- if (
- col in value_params
- or state_dict.get(prop.key) is None
- ):
- state_dict[prop.key] = pk
- if bookkeeping:
- if state:
- _postfetch(
- mapper_rec,
- uowtransaction,
- table,
- state,
- state_dict,
- result,
- result.context.compiled_parameters[0],
- value_params,
- False,
- result.returned_defaults
- if not result.context.executemany
- else None,
- )
- else:
- _postfetch_bulk_save(mapper_rec, state_dict, table)
- def _emit_post_update_statements(
- base_mapper, uowtransaction, mapper, table, update
- ):
- """Emit UPDATE statements corresponding to value lists collected
- by _collect_post_update_commands()."""
- execution_options = {"compiled_cache": base_mapper._compiled_cache}
- needs_version_id = (
- mapper.version_id_col is not None
- and mapper.version_id_col in mapper._cols_by_table[table]
- )
- def update_stmt():
- clauses = BooleanClauseList._construct_raw(operators.and_)
- for col in mapper._pks_by_table[table]:
- clauses.clauses.append(
- col == sql.bindparam(col._label, type_=col.type)
- )
- if needs_version_id:
- clauses.clauses.append(
- mapper.version_id_col
- == sql.bindparam(
- mapper.version_id_col._label,
- type_=mapper.version_id_col.type,
- )
- )
- stmt = table.update().where(clauses)
- if mapper.version_id_col is not None:
- stmt = stmt.return_defaults(mapper.version_id_col)
- return stmt
- statement = base_mapper._memo(("post_update", table), update_stmt)
- # execute each UPDATE in the order according to the original
- # list of states to guarantee row access order, but
- # also group them into common (connection, cols) sets
- # to support executemany().
- for key, records in groupby(
- update,
- lambda rec: (rec[3], set(rec[4])), # connection # parameter keys
- ):
- rows = 0
- records = list(records)
- connection = key[0]
- assert_singlerow = (
- connection.dialect.supports_sane_rowcount
- if mapper.version_id_col is None
- else connection.dialect.supports_sane_rowcount_returning
- )
- assert_multirow = (
- assert_singlerow
- and connection.dialect.supports_sane_multi_rowcount
- )
- allow_multirow = not needs_version_id or assert_multirow
- if not allow_multirow:
- check_rowcount = assert_singlerow
- for state, state_dict, mapper_rec, connection, params in records:
- c = connection._execute_20(
- statement, params, execution_options=execution_options
- )
- _postfetch_post_update(
- mapper_rec,
- uowtransaction,
- table,
- state,
- state_dict,
- c,
- c.context.compiled_parameters[0],
- )
- rows += c.rowcount
- else:
- multiparams = [
- params
- for state, state_dict, mapper_rec, conn, params in records
- ]
- check_rowcount = assert_multirow or (
- assert_singlerow and len(multiparams) == 1
- )
- c = connection._execute_20(
- statement, multiparams, execution_options=execution_options
- )
- rows += c.rowcount
- for state, state_dict, mapper_rec, connection, params in records:
- _postfetch_post_update(
- mapper_rec,
- uowtransaction,
- table,
- state,
- state_dict,
- c,
- c.context.compiled_parameters[0],
- )
- if check_rowcount:
- if rows != len(records):
- raise orm_exc.StaleDataError(
- "UPDATE statement on table '%s' expected to "
- "update %d row(s); %d were matched."
- % (table.description, len(records), rows)
- )
- elif needs_version_id:
- util.warn(
- "Dialect %s does not support updated rowcount "
- "- versioning cannot be verified."
- % c.dialect.dialect_description
- )
- def _emit_delete_statements(
- base_mapper, uowtransaction, mapper, table, delete
- ):
- """Emit DELETE statements corresponding to value lists collected
- by _collect_delete_commands()."""
- need_version_id = (
- mapper.version_id_col is not None
- and mapper.version_id_col in mapper._cols_by_table[table]
- )
- def delete_stmt():
- clauses = BooleanClauseList._construct_raw(operators.and_)
- for col in mapper._pks_by_table[table]:
- clauses.clauses.append(
- col == sql.bindparam(col.key, type_=col.type)
- )
- if need_version_id:
- clauses.clauses.append(
- mapper.version_id_col
- == sql.bindparam(
- mapper.version_id_col.key, type_=mapper.version_id_col.type
- )
- )
- return table.delete().where(clauses)
- statement = base_mapper._memo(("delete", table), delete_stmt)
- for connection, recs in groupby(delete, lambda rec: rec[1]): # connection
- del_objects = [params for params, connection in recs]
- execution_options = {"compiled_cache": base_mapper._compiled_cache}
- expected = len(del_objects)
- rows_matched = -1
- only_warn = False
- if (
- need_version_id
- and not connection.dialect.supports_sane_multi_rowcount
- ):
- if connection.dialect.supports_sane_rowcount:
- rows_matched = 0
- # execute deletes individually so that versioned
- # rows can be verified
- for params in del_objects:
- c = connection._execute_20(
- statement, params, execution_options=execution_options
- )
- rows_matched += c.rowcount
- else:
- util.warn(
- "Dialect %s does not support deleted rowcount "
- "- versioning cannot be verified."
- % connection.dialect.dialect_description
- )
- connection._execute_20(
- statement, del_objects, execution_options=execution_options
- )
- else:
- c = connection._execute_20(
- statement, del_objects, execution_options=execution_options
- )
- if not need_version_id:
- only_warn = True
- rows_matched = c.rowcount
- if (
- base_mapper.confirm_deleted_rows
- and rows_matched > -1
- and expected != rows_matched
- and (
- connection.dialect.supports_sane_multi_rowcount
- or len(del_objects) == 1
- )
- ):
- # TODO: why does this "only warn" if versioning is turned off,
- # whereas the UPDATE raises?
- if only_warn:
- util.warn(
- "DELETE statement on table '%s' expected to "
- "delete %d row(s); %d were matched. Please set "
- "confirm_deleted_rows=False within the mapper "
- "configuration to prevent this warning."
- % (table.description, expected, rows_matched)
- )
- else:
- raise orm_exc.StaleDataError(
- "DELETE statement on table '%s' expected to "
- "delete %d row(s); %d were matched. Please set "
- "confirm_deleted_rows=False within the mapper "
- "configuration to prevent this warning."
- % (table.description, expected, rows_matched)
- )
- def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
- """finalize state on states that have been inserted or updated,
- including calling after_insert/after_update events.
- """
- for state, state_dict, mapper, connection, has_identity in states:
- if mapper._readonly_props:
- readonly = state.unmodified_intersection(
- [
- p.key
- for p in mapper._readonly_props
- if (
- p.expire_on_flush
- and (not p.deferred or p.key in state.dict)
- )
- or (
- not p.expire_on_flush
- and not p.deferred
- and p.key not in state.dict
- )
- ]
- )
- if readonly:
- state._expire_attributes(state.dict, readonly)
- # if eager_defaults option is enabled, load
- # all expired cols. Else if we have a version_id_col, make sure
- # it isn't expired.
- toload_now = []
- if base_mapper.eager_defaults:
- toload_now.extend(
- state._unloaded_non_object.intersection(
- mapper._server_default_plus_onupdate_propkeys
- )
- )
- if (
- mapper.version_id_col is not None
- and mapper.version_id_generator is False
- ):
- if mapper._version_id_prop.key in state.unloaded:
- toload_now.extend([mapper._version_id_prop.key])
- if toload_now:
- state.key = base_mapper._identity_key_from_state(state)
- stmt = future.select(mapper).set_label_style(
- LABEL_STYLE_TABLENAME_PLUS_COL
- )
- loading.load_on_ident(
- uowtransaction.session,
- stmt,
- state.key,
- refresh_state=state,
- only_load_props=toload_now,
- )
- # call after_XXX extensions
- if not has_identity:
- mapper.dispatch.after_insert(mapper, connection, state)
- else:
- mapper.dispatch.after_update(mapper, connection, state)
- if (
- mapper.version_id_generator is False
- and mapper.version_id_col is not None
- ):
- if state_dict[mapper._version_id_prop.key] is None:
- raise orm_exc.FlushError(
- "Instance does not contain a non-NULL version value"
- )
- def _postfetch_post_update(
- mapper, uowtransaction, table, state, dict_, result, params
- ):
- if uowtransaction.is_deleted(state):
- return
- prefetch_cols = result.context.compiled.prefetch
- postfetch_cols = result.context.compiled.postfetch
- if (
- mapper.version_id_col is not None
- and mapper.version_id_col in mapper._cols_by_table[table]
- ):
- prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
- refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
- if refresh_flush:
- load_evt_attrs = []
- for c in prefetch_cols:
- if c.key in params and c in mapper._columntoproperty:
- dict_[mapper._columntoproperty[c].key] = params[c.key]
- if refresh_flush:
- load_evt_attrs.append(mapper._columntoproperty[c].key)
- if refresh_flush and load_evt_attrs:
- mapper.class_manager.dispatch.refresh_flush(
- state, uowtransaction, load_evt_attrs
- )
- if postfetch_cols:
- state._expire_attributes(
- state.dict,
- [
- mapper._columntoproperty[c].key
- for c in postfetch_cols
- if c in mapper._columntoproperty
- ],
- )
- def _postfetch(
- mapper,
- uowtransaction,
- table,
- state,
- dict_,
- result,
- params,
- value_params,
- isupdate,
- returned_defaults,
- ):
- """Expire attributes in need of newly persisted database state,
- after an INSERT or UPDATE statement has proceeded for that
- state."""
- prefetch_cols = result.context.compiled.prefetch
- postfetch_cols = result.context.compiled.postfetch
- returning_cols = result.context.compiled.returning
- if (
- mapper.version_id_col is not None
- and mapper.version_id_col in mapper._cols_by_table[table]
- ):
- prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
- refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
- if refresh_flush:
- load_evt_attrs = []
- if returning_cols:
- row = returned_defaults
- if row is not None:
- for row_value, col in zip(row, returning_cols):
- # pk cols returned from insert are handled
- # distinctly, don't step on the values here
- if col.primary_key and result.context.isinsert:
- continue
- # note that columns can be in the "return defaults" that are
- # not mapped to this mapper, typically because they are
- # "excluded", which can be specified directly or also occurs
- # when using declarative w/ single table inheritance
- prop = mapper._columntoproperty.get(col)
- if prop:
- dict_[prop.key] = row_value
- if refresh_flush:
- load_evt_attrs.append(prop.key)
- for c in prefetch_cols:
- if c.key in params and c in mapper._columntoproperty:
- dict_[mapper._columntoproperty[c].key] = params[c.key]
- if refresh_flush:
- load_evt_attrs.append(mapper._columntoproperty[c].key)
- if refresh_flush and load_evt_attrs:
- mapper.class_manager.dispatch.refresh_flush(
- state, uowtransaction, load_evt_attrs
- )
- if isupdate and value_params:
- # explicitly suit the use case specified by
- # [ticket:3801], PK SQL expressions for UPDATE on non-RETURNING
- # database which are set to themselves in order to do a version bump.
- postfetch_cols.extend(
- [
- col
- for col in value_params
- if col.primary_key and col not in returning_cols
- ]
- )
- if postfetch_cols:
- state._expire_attributes(
- state.dict,
- [
- mapper._columntoproperty[c].key
- for c in postfetch_cols
- if c in mapper._columntoproperty
- ],
- )
- # synchronize newly inserted ids from one table to the next
- # TODO: this still goes a little too often. would be nice to
- # have definitive list of "columns that changed" here
- for m, equated_pairs in mapper._table_to_equated[table]:
- sync.populate(
- state,
- m,
- state,
- m,
- equated_pairs,
- uowtransaction,
- mapper.passive_updates,
- )
- def _postfetch_bulk_save(mapper, dict_, table):
- for m, equated_pairs in mapper._table_to_equated[table]:
- sync.bulk_populate_inherit_keys(dict_, m, equated_pairs)
- def _connections_for_states(base_mapper, uowtransaction, states):
- """Return an iterator of (state, state.dict, mapper, connection).
- The states are sorted according to _sort_states, then paired
- with the connection they should be using for the given
- unit of work transaction.
- """
- # if session has a connection callable,
- # organize individual states with the connection
- # to use for update
- if uowtransaction.session.connection_callable:
- connection_callable = uowtransaction.session.connection_callable
- else:
- connection = uowtransaction.transaction.connection(base_mapper)
- connection_callable = None
- for state in _sort_states(base_mapper, states):
- if connection_callable:
- connection = connection_callable(base_mapper, state.obj())
- mapper = state.manager.mapper
- yield state, state.dict, mapper, connection
- def _sort_states(mapper, states):
- pending = set(states)
- persistent = set(s for s in pending if s.key is not None)
- pending.difference_update(persistent)
- try:
- persistent_sorted = sorted(
- persistent, key=mapper._persistent_sortkey_fn
- )
- except TypeError as err:
- util.raise_(
- sa_exc.InvalidRequestError(
- "Could not sort objects by primary key; primary key "
- "values must be sortable in Python (was: %s)" % err
- ),
- replace_context=err,
- )
- return (
- sorted(pending, key=operator.attrgetter("insert_order"))
- + persistent_sorted
- )
- _EMPTY_DICT = util.immutabledict()
- class BulkUDCompileState(CompileState):
- class default_update_options(Options):
- _synchronize_session = "evaluate"
- _autoflush = True
- _subject_mapper = None
- _resolved_values = _EMPTY_DICT
- _resolved_keys_as_propnames = _EMPTY_DICT
- _value_evaluators = _EMPTY_DICT
- _matched_objects = None
- _matched_rows = None
- _refresh_identity_token = None
- @classmethod
- def orm_pre_session_exec(
- cls,
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- is_reentrant_invoke,
- ):
- if is_reentrant_invoke:
- return statement, execution_options
- (
- update_options,
- execution_options,
- ) = BulkUDCompileState.default_update_options.from_execution_options(
- "_sa_orm_update_options",
- {"synchronize_session"},
- execution_options,
- statement._execution_options,
- )
- sync = update_options._synchronize_session
- if sync is not None:
- if sync not in ("evaluate", "fetch", False):
- raise sa_exc.ArgumentError(
- "Valid strategies for session synchronization "
- "are 'evaluate', 'fetch', False"
- )
- bind_arguments["clause"] = statement
- try:
- plugin_subject = statement._propagate_attrs["plugin_subject"]
- except KeyError:
- assert False, "statement had 'orm' plugin but no plugin_subject"
- else:
- bind_arguments["mapper"] = plugin_subject.mapper
- update_options += {"_subject_mapper": plugin_subject.mapper}
- if update_options._autoflush:
- session._autoflush()
- statement = statement._annotate(
- {"synchronize_session": update_options._synchronize_session}
- )
- # this stage of the execution is called before the do_orm_execute event
- # hook. meaning for an extension like horizontal sharding, this step
- # happens before the extension splits out into multiple backends and
- # runs only once. if we do pre_sync_fetch, we execute a SELECT
- # statement, which the horizontal sharding extension splits amongst the
- # shards and combines the results together.
- if update_options._synchronize_session == "evaluate":
- update_options = cls._do_pre_synchronize_evaluate(
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- update_options,
- )
- elif update_options._synchronize_session == "fetch":
- update_options = cls._do_pre_synchronize_fetch(
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- update_options,
- )
- return (
- statement,
- util.immutabledict(execution_options).union(
- {"_sa_orm_update_options": update_options}
- ),
- )
- @classmethod
- def orm_setup_cursor_result(
- cls,
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- result,
- ):
- # this stage of the execution is called after the
- # do_orm_execute event hook. meaning for an extension like
- # horizontal sharding, this step happens *within* the horizontal
- # sharding event handler which calls session.execute() re-entrantly
- # and will occur for each backend individually.
- # the sharding extension then returns its own merged result from the
- # individual ones we return here.
- update_options = execution_options["_sa_orm_update_options"]
- if update_options._synchronize_session == "evaluate":
- cls._do_post_synchronize_evaluate(session, result, update_options)
- elif update_options._synchronize_session == "fetch":
- cls._do_post_synchronize_fetch(session, result, update_options)
- return result
- @classmethod
- def _adjust_for_extra_criteria(cls, global_attributes, ext_info):
- """Apply extra criteria filtering.
- For all distinct single-table-inheritance mappers represented in the
- table being updated or deleted, produce additional WHERE criteria such
- that only the appropriate subtypes are selected from the total results.
- Additionally, add WHERE criteria originating from LoaderCriteriaOptions
- collected from the statement.
- """
- return_crit = ()
- adapter = ext_info._adapter if ext_info.is_aliased_class else None
- if (
- "additional_entity_criteria",
- ext_info.mapper,
- ) in global_attributes:
- return_crit += tuple(
- ae._resolve_where_criteria(ext_info)
- for ae in global_attributes[
- ("additional_entity_criteria", ext_info.mapper)
- ]
- if ae.include_aliases or ae.entity is ext_info
- )
- if ext_info.mapper._single_table_criterion is not None:
- return_crit += (ext_info.mapper._single_table_criterion,)
- if adapter:
- return_crit = tuple(adapter.traverse(crit) for crit in return_crit)
- return return_crit
- @classmethod
- def _do_pre_synchronize_evaluate(
- cls,
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- update_options,
- ):
- mapper = update_options._subject_mapper
- target_cls = mapper.class_
- value_evaluators = resolved_keys_as_propnames = _EMPTY_DICT
- try:
- evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
- crit = ()
- if statement._where_criteria:
- crit += statement._where_criteria
- global_attributes = {}
- for opt in statement._with_options:
- if opt._is_criteria_option:
- opt.get_global_criteria(global_attributes)
- if global_attributes:
- crit += cls._adjust_for_extra_criteria(
- global_attributes, mapper
- )
- if crit:
- eval_condition = evaluator_compiler.process(*crit)
- else:
- def eval_condition(obj):
- return True
- except evaluator.UnevaluatableError as err:
- util.raise_(
- sa_exc.InvalidRequestError(
- 'Could not evaluate current criteria in Python: "%s". '
- "Specify 'fetch' or False for the "
- "synchronize_session execution option." % err
- ),
- from_=err,
- )
- if statement.__visit_name__ == "lambda_element":
- # ._resolved is called on every LambdaElement in order to
- # generate the cache key, so this access does not add
- # additional expense
- effective_statement = statement._resolved
- else:
- effective_statement = statement
- if effective_statement.__visit_name__ == "update":
- resolved_values = cls._get_resolved_values(
- mapper, effective_statement
- )
- value_evaluators = {}
- resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
- mapper, resolved_values
- )
- for key, value in resolved_keys_as_propnames:
- try:
- _evaluator = evaluator_compiler.process(
- coercions.expect(roles.ExpressionElementRole, value)
- )
- except evaluator.UnevaluatableError:
- pass
- else:
- value_evaluators[key] = _evaluator
- # TODO: detect when the where clause is a trivial primary key match.
- matched_objects = [
- state.obj()
- for state in session.identity_map.all_states()
- if state.mapper.isa(mapper)
- and not state.expired
- and eval_condition(state.obj())
- and (
- update_options._refresh_identity_token is None
- # TODO: coverage for the case where horizontal sharding
- # invokes an update() or delete() given an explicit identity
- # token up front
- or state.identity_token
- == update_options._refresh_identity_token
- )
- ]
- return update_options + {
- "_matched_objects": matched_objects,
- "_value_evaluators": value_evaluators,
- "_resolved_keys_as_propnames": resolved_keys_as_propnames,
- }
- @classmethod
- def _get_resolved_values(cls, mapper, statement):
- if statement._multi_values:
- return []
- elif statement._ordered_values:
- return list(statement._ordered_values)
- elif statement._values:
- return list(statement._values.items())
- else:
- return []
- @classmethod
- def _resolved_keys_as_propnames(cls, mapper, resolved_values):
- values = []
- for k, v in resolved_values:
- if isinstance(k, attributes.QueryableAttribute):
- values.append((k.key, v))
- continue
- elif hasattr(k, "__clause_element__"):
- k = k.__clause_element__()
- if mapper and isinstance(k, expression.ColumnElement):
- try:
- attr = mapper._columntoproperty[k]
- except orm_exc.UnmappedColumnError:
- pass
- else:
- values.append((attr.key, v))
- else:
- raise sa_exc.InvalidRequestError(
- "Invalid expression type: %r" % k
- )
- return values
- @classmethod
- def _do_pre_synchronize_fetch(
- cls,
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- update_options,
- ):
- mapper = update_options._subject_mapper
- select_stmt = (
- select(*(mapper.primary_key + (mapper.select_identity_token,)))
- .select_from(mapper)
- .options(*statement._with_options)
- )
- select_stmt._where_criteria = statement._where_criteria
- def skip_for_full_returning(orm_context):
- bind = orm_context.session.get_bind(**orm_context.bind_arguments)
- if bind.dialect.full_returning:
- return _result.null_result()
- else:
- return None
- result = session.execute(
- select_stmt,
- params,
- execution_options,
- bind_arguments,
- _add_event=skip_for_full_returning,
- )
- matched_rows = result.fetchall()
- value_evaluators = _EMPTY_DICT
- if statement.__visit_name__ == "lambda_element":
- # ._resolved is called on every LambdaElement in order to
- # generate the cache key, so this access does not add
- # additional expense
- effective_statement = statement._resolved
- else:
- effective_statement = statement
- if effective_statement.__visit_name__ == "update":
- target_cls = mapper.class_
- evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
- resolved_values = cls._get_resolved_values(
- mapper, effective_statement
- )
- resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
- mapper, resolved_values
- )
- resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
- mapper, resolved_values
- )
- value_evaluators = {}
- for key, value in resolved_keys_as_propnames:
- try:
- _evaluator = evaluator_compiler.process(
- coercions.expect(roles.ExpressionElementRole, value)
- )
- except evaluator.UnevaluatableError:
- pass
- else:
- value_evaluators[key] = _evaluator
- else:
- resolved_keys_as_propnames = _EMPTY_DICT
- return update_options + {
- "_value_evaluators": value_evaluators,
- "_matched_rows": matched_rows,
- "_resolved_keys_as_propnames": resolved_keys_as_propnames,
- }
- @CompileState.plugin_for("orm", "update")
- class BulkORMUpdate(UpdateDMLState, BulkUDCompileState):
- @classmethod
- def create_for_statement(cls, statement, compiler, **kw):
- self = cls.__new__(cls)
- ext_info = statement.table._annotations["parententity"]
- self.mapper = mapper = ext_info.mapper
- self.extra_criteria_entities = {}
- self._resolved_values = cls._get_resolved_values(mapper, statement)
- extra_criteria_attributes = {}
- for opt in statement._with_options:
- if opt._is_criteria_option:
- opt.get_global_criteria(extra_criteria_attributes)
- if not statement._preserve_parameter_order and statement._values:
- self._resolved_values = dict(self._resolved_values)
- new_stmt = sql.Update.__new__(sql.Update)
- new_stmt.__dict__.update(statement.__dict__)
- new_stmt.table = mapper.local_table
- # note if the statement has _multi_values, these
- # are passed through to the new statement, which will then raise
- # InvalidRequestError because UPDATE doesn't support multi_values
- # right now.
- if statement._ordered_values:
- new_stmt._ordered_values = self._resolved_values
- elif statement._values:
- new_stmt._values = self._resolved_values
- new_crit = cls._adjust_for_extra_criteria(
- extra_criteria_attributes, mapper
- )
- if new_crit:
- new_stmt = new_stmt.where(*new_crit)
- # if we are against a lambda statement we might not be the
- # topmost object that received per-execute annotations
- if (
- compiler._annotations.get("synchronize_session", None) == "fetch"
- and compiler.dialect.full_returning
- ):
- if new_stmt._returning:
- raise sa_exc.InvalidRequestError(
- "Can't use synchronize_session='fetch' "
- "with explicit returning()"
- )
- new_stmt = new_stmt.returning(*mapper.primary_key)
- UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
- return self
- @classmethod
- def _get_crud_kv_pairs(cls, statement, kv_iterator):
- plugin_subject = statement._propagate_attrs["plugin_subject"]
- core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
- if not plugin_subject or not plugin_subject.mapper:
- return core_get_crud_kv_pairs(statement, kv_iterator)
- mapper = plugin_subject.mapper
- values = []
- for k, v in kv_iterator:
- k = coercions.expect(roles.DMLColumnRole, k)
- if isinstance(k, util.string_types):
- desc = _entity_namespace_key(mapper, k, default=NO_VALUE)
- if desc is NO_VALUE:
- values.append(
- (
- k,
- coercions.expect(
- roles.ExpressionElementRole,
- v,
- type_=sqltypes.NullType(),
- is_crud=True,
- ),
- )
- )
- else:
- values.extend(
- core_get_crud_kv_pairs(
- statement, desc._bulk_update_tuples(v)
- )
- )
- elif "entity_namespace" in k._annotations:
- k_anno = k._annotations
- attr = _entity_namespace_key(
- k_anno["entity_namespace"], k_anno["proxy_key"]
- )
- values.extend(
- core_get_crud_kv_pairs(
- statement, attr._bulk_update_tuples(v)
- )
- )
- else:
- values.append(
- (
- k,
- coercions.expect(
- roles.ExpressionElementRole,
- v,
- type_=sqltypes.NullType(),
- is_crud=True,
- ),
- )
- )
- return values
- @classmethod
- def _do_post_synchronize_evaluate(cls, session, result, update_options):
- states = set()
- evaluated_keys = list(update_options._value_evaluators.keys())
- values = update_options._resolved_keys_as_propnames
- attrib = set(k for k, v in values)
- for obj in update_options._matched_objects:
- state, dict_ = (
- attributes.instance_state(obj),
- attributes.instance_dict(obj),
- )
- # the evaluated states were gathered across all identity tokens.
- # however the post_sync events are called per identity token,
- # so filter.
- if (
- update_options._refresh_identity_token is not None
- and state.identity_token
- != update_options._refresh_identity_token
- ):
- continue
- # only evaluate unmodified attributes
- to_evaluate = state.unmodified.intersection(evaluated_keys)
- for key in to_evaluate:
- if key in dict_:
- dict_[key] = update_options._value_evaluators[key](obj)
- state.manager.dispatch.refresh(state, None, to_evaluate)
- state._commit(dict_, list(to_evaluate))
- to_expire = attrib.intersection(dict_).difference(to_evaluate)
- if to_expire:
- state._expire_attributes(dict_, to_expire)
- states.add(state)
- session._register_altered(states)
- @classmethod
- def _do_post_synchronize_fetch(cls, session, result, update_options):
- target_mapper = update_options._subject_mapper
- states = set()
- evaluated_keys = list(update_options._value_evaluators.keys())
- if result.returns_rows:
- matched_rows = [
- tuple(row) + (update_options._refresh_identity_token,)
- for row in result.all()
- ]
- else:
- matched_rows = update_options._matched_rows
- objs = [
- session.identity_map[identity_key]
- for identity_key in [
- target_mapper.identity_key_from_primary_key(
- list(primary_key),
- identity_token=identity_token,
- )
- for primary_key, identity_token in [
- (row[0:-1], row[-1]) for row in matched_rows
- ]
- if update_options._refresh_identity_token is None
- or identity_token == update_options._refresh_identity_token
- ]
- if identity_key in session.identity_map
- ]
- values = update_options._resolved_keys_as_propnames
- attrib = set(k for k, v in values)
- for obj in objs:
- state, dict_ = (
- attributes.instance_state(obj),
- attributes.instance_dict(obj),
- )
- to_evaluate = state.unmodified.intersection(evaluated_keys)
- for key in to_evaluate:
- if key in dict_:
- dict_[key] = update_options._value_evaluators[key](obj)
- state.manager.dispatch.refresh(state, None, to_evaluate)
- state._commit(dict_, list(to_evaluate))
- to_expire = attrib.intersection(dict_).difference(to_evaluate)
- if to_expire:
- state._expire_attributes(dict_, to_expire)
- states.add(state)
- session._register_altered(states)
- @CompileState.plugin_for("orm", "delete")
- class BulkORMDelete(DeleteDMLState, BulkUDCompileState):
- @classmethod
- def create_for_statement(cls, statement, compiler, **kw):
- self = cls.__new__(cls)
- ext_info = statement.table._annotations["parententity"]
- self.mapper = mapper = ext_info.mapper
- self.extra_criteria_entities = {}
- extra_criteria_attributes = {}
- for opt in statement._with_options:
- if opt._is_criteria_option:
- opt.get_global_criteria(extra_criteria_attributes)
- new_crit = cls._adjust_for_extra_criteria(
- extra_criteria_attributes, mapper
- )
- if new_crit:
- statement = statement.where(*new_crit)
- if (
- mapper
- and compiler._annotations.get("synchronize_session", None)
- == "fetch"
- and compiler.dialect.full_returning
- ):
- statement = statement.returning(*mapper.primary_key)
- DeleteDMLState.__init__(self, statement, compiler, **kw)
- return self
- @classmethod
- def _do_post_synchronize_evaluate(cls, session, result, update_options):
- session._remove_newly_deleted(
- [
- attributes.instance_state(obj)
- for obj in update_options._matched_objects
- ]
- )
- @classmethod
- def _do_post_synchronize_fetch(cls, session, result, update_options):
- target_mapper = update_options._subject_mapper
- if result.returns_rows:
- matched_rows = [
- tuple(row) + (update_options._refresh_identity_token,)
- for row in result.all()
- ]
- else:
- matched_rows = update_options._matched_rows
- for row in matched_rows:
- primary_key = row[0:-1]
- identity_token = row[-1]
- # TODO: inline this and call remove_newly_deleted
- # once
- identity_key = target_mapper.identity_key_from_primary_key(
- list(primary_key),
- identity_token=identity_token,
- )
- if identity_key in session.identity_map:
- session._remove_newly_deleted(
- [
- attributes.instance_state(
- session.identity_map[identity_key]
- )
- ]
- )
|