"""This module defines specific functions for SQLite dialect."""
import os
from typing import Optional
from sqlalchemy import text
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import func
from sqlalchemy.sql import select
from geoalchemy2 import functions
from geoalchemy2.admin.dialects.common import _check_spatial_type
from geoalchemy2.admin.dialects.common import _format_select_args
from geoalchemy2.admin.dialects.common import _spatial_idx_name
from geoalchemy2.admin.dialects.common import setup_create_drop
from geoalchemy2.types import Geography
from geoalchemy2.types import Geometry
from geoalchemy2.types import _DummyGeometry
from geoalchemy2.utils import authorized_values_in_docstring
[docs]
def load_spatialite_driver(dbapi_conn, *args):
"""Load SpatiaLite extension in SQLite connection.
.. Warning::
The path to the SpatiaLite module should be set in the `SPATIALITE_LIBRARY_PATH`
environment variable.
Args:
dbapi_conn: The DBAPI connection.
"""
if "SPATIALITE_LIBRARY_PATH" not in os.environ:
raise RuntimeError("The SPATIALITE_LIBRARY_PATH environment variable is not set.")
dbapi_conn.enable_load_extension(True)
dbapi_conn.load_extension(os.environ["SPATIALITE_LIBRARY_PATH"])
dbapi_conn.enable_load_extension(False)
_JOURNAL_MODE_VALUES = ["DELETE", "TRUNCATE", "PERSIST", "MEMORY", "WAL", "OFF"]
[docs]
@authorized_values_in_docstring(JOURNAL_MODE_VALUES=_JOURNAL_MODE_VALUES)
def init_spatialite(
dbapi_conn,
*args,
transaction: bool = False,
init_mode: Optional[str] = None,
journal_mode: Optional[str] = None,
):
"""Initialize internal SpatiaLite tables.
Args:
dbapi_conn: The DBAPI connection.
init_mode: Can be `None` to load all EPSG SRIDs, `'WGS84'` to load only the ones related
to WGS84 or `'EMPTY'` to not load any EPSG SRID.
.. Note::
It is possible to load other EPSG SRIDs afterwards using `InsertEpsgSrid(srid)`.
transaction: If set to `True` the whole operation will be handled as a single Transaction
(faster). The default value is `False` (slower, but safer).
journal_mode: Change the journal mode to the given value. This can make the table creation
much faster. The possible values are the following: <JOURNAL_MODE_VALUES>. See
https://www.sqlite.org/pragma.html#pragma_journal_mode for more details.
.. Warning::
Some values, like 'MEMORY' or 'OFF', can lead to corrupted databases if the process
is interrupted during initialization.
.. Note::
The original value is restored after the initialization.
.. Note::
When using this function as a listener it is not possible to pass the `transaction`,
`init_mode` or `journal_mode` arguments directly. To do this you can either create another
function that calls `init_spatialite` (or
:func:`geoalchemy2.admin.dialects.sqlite.load_spatialite` if you also want to load the
SpatiaLite drivers) with an hard-coded `init_mode` or just use a lambda::
>>> sqlalchemy.event.listen(
... engine,
... "connect",
... lambda x, y: init_spatialite(
... x,
... y,
... transaction=True,
... init_mode="EMPTY",
... journal_mode="OFF",
... )
... )
"""
func_args = []
# Check the value of the 'transaction' parameter
if not isinstance(transaction, (bool, int)):
raise ValueError("The 'transaction' argument must be True or False.")
else:
func_args.append(str(transaction))
# Check the value of the 'init_mode' parameter
init_mode_values = ["WGS84", "EMPTY"]
if isinstance(init_mode, str):
init_mode = init_mode.upper()
if init_mode is not None:
if init_mode not in init_mode_values:
raise ValueError("The 'init_mode' argument must be one of {}.".format(init_mode_values))
func_args.append(f"'{init_mode}'")
# Check the value of the 'journal_mode' parameter
if isinstance(journal_mode, str):
journal_mode = journal_mode.upper()
if journal_mode is not None:
if journal_mode not in _JOURNAL_MODE_VALUES:
raise ValueError(
"The 'journal_mode' argument must be one of {}.".format(_JOURNAL_MODE_VALUES)
)
if dbapi_conn.execute("SELECT CheckSpatialMetaData();").fetchone()[0] < 1:
if journal_mode is not None:
current_journal_mode = dbapi_conn.execute("PRAGMA journal_mode").fetchone()[0]
dbapi_conn.execute("PRAGMA journal_mode = {}".format(journal_mode))
dbapi_conn.execute("SELECT InitSpatialMetaData({});".format(", ".join(func_args)))
if journal_mode is not None:
dbapi_conn.execute("PRAGMA journal_mode = {}".format(current_journal_mode))
[docs]
def load_spatialite(*args, **kwargs):
"""Load SpatiaLite extension in SQLite DB and initialize internal tables.
See :func:`geoalchemy2.admin.dialects.sqlite.load_spatialite_driver` and
:func:`geoalchemy2.admin.dialects.sqlite.init_spatialite` functions for details about
arguments.
"""
load_spatialite_driver(*args)
init_spatialite(*args, **kwargs)
def _get_spatialite_attrs(bind, table_name, col_name):
attrs = bind.execute(
text(
"""SELECT * FROM "geometry_columns"
WHERE LOWER(f_table_name) = LOWER(:table_name)
AND LOWER(f_geometry_column) = LOWER(:column_name)
"""
).bindparams(table_name=table_name, column_name=col_name)
).fetchone()
if attrs is None:
# If the column is not registered as a spatial column we ignore it
return None
return attrs[2:]
[docs]
def get_spatialite_version(bind):
"""Get the version of the currently loaded Spatialite extension."""
return bind.execute(text("SELECT spatialite_version();")).fetchone()[0]
[docs]
def _setup_dummy_type(table, gis_cols):
"""Setup dummy type for new Geometry columns so they can be updated later."""
for col in gis_cols:
# Add dummy columns with GEOMETRY type
col._actual_type = col.type
col.type = _DummyGeometry()
table.columns = table.info["_saved_columns"]
[docs]
def get_col_dim(col):
"""Get dimension of the column type."""
if col.type.dimension == 4:
dimension = "XYZM"
elif col.type.dimension == 2:
dimension = "XY"
else:
if col.type.geometry_type.endswith("M"):
dimension = "XYM"
else:
dimension = "XYZ"
return dimension
[docs]
def create_spatial_index(bind, table, col):
"""Create spatial index on the given column."""
stmt = select(*_format_select_args(func.CreateSpatialIndex(table.name, col.name)))
stmt = stmt.execution_options(autocommit=True)
bind.execute(stmt)
[docs]
def disable_spatial_index(bind, table, col):
"""Disable spatial indexes if present."""
stmt = select(*_format_select_args(func.CheckSpatialIndex(table.name, col.name)))
if bind.execute(stmt).fetchone()[0] is not None:
stmt = select(*_format_select_args(func.DisableSpatialIndex(table.name, col.name)))
stmt = stmt.execution_options(autocommit=True)
bind.execute(stmt)
bind.execute(
text(
"DROP TABLE IF EXISTS {};".format(
_spatial_idx_name(
table.name,
col.name,
)
)
)
)
[docs]
def reflect_geometry_column(inspector, table, column_info):
"""Reflect a column of type Geometry with SQLite dialect."""
# Get geometry type, SRID and spatial index from the SpatiaLite metadata
if not isinstance(column_info.get("type"), Geometry):
return
col_attributes = _get_spatialite_attrs(inspector.bind, table.name, column_info["name"])
if col_attributes is not None:
geometry_type, coord_dimension, srid, spatial_index = col_attributes
if isinstance(geometry_type, int):
geometry_type_str = str(geometry_type)
if geometry_type >= 1000:
first_digit = geometry_type_str[0]
has_z = first_digit in ["1", "3"]
has_m = first_digit in ["2", "3"]
else:
has_z = has_m = False
geometry_type = {
"0": "GEOMETRY",
"1": "POINT",
"2": "LINESTRING",
"3": "POLYGON",
"4": "MULTIPOINT",
"5": "MULTILINESTRING",
"6": "MULTIPOLYGON",
"7": "GEOMETRYCOLLECTION",
}[geometry_type_str[-1]]
if has_z:
geometry_type += "Z"
if has_m:
geometry_type += "M"
else:
if "Z" in coord_dimension and "Z" not in geometry_type[-2:]:
geometry_type += "Z"
if "M" in coord_dimension and "M" not in geometry_type[-2:]:
geometry_type += "M"
coord_dimension = {
"XY": 2,
"XYZ": 3,
"XYM": 3,
"XYZM": 4,
}.get(coord_dimension, coord_dimension)
# Set attributes
column_info["type"].geometry_type = geometry_type
column_info["type"].dimension = coord_dimension
column_info["type"].srid = srid
column_info["type"].spatial_index = bool(spatial_index)
# Spatial indexes are not automatically reflected with SQLite dialect
column_info["type"]._spatial_index_reflected = False
[docs]
def before_create(table, bind, **kw):
"""Handle spatial indexes during the before_create event."""
dialect, gis_cols, regular_cols = setup_create_drop(table, bind)
# Remove the spatial indexes from the table metadata because they should not be
# created during the table.create() step since the associated columns do not exist
# at this time.
table.info["_after_create_indexes"] = []
current_indexes = set(table.indexes)
for idx in current_indexes:
for col in table.info["_saved_columns"]:
if (_check_spatial_type(col.type, Geometry, dialect)) and col in idx.columns.values():
table.indexes.remove(idx)
if idx.name != _spatial_idx_name(table.name, col.name) or not getattr(
col.type, "spatial_index", False
):
table.info["_after_create_indexes"].append(idx)
_setup_dummy_type(table, gis_cols)
[docs]
def after_create(table, bind, **kw):
"""Handle spatial indexes during the after_create event."""
dialect = bind.dialect
table.columns = table.info.pop("_saved_columns")
for col in table.columns:
# Add the managed Geometry columns with RecoverGeometryColumn()
if _check_spatial_type(col.type, Geometry, dialect):
col.type = col._actual_type
del col._actual_type
dimension = get_col_dim(col)
args = [table.name, col.name, col.type.srid, col.type.geometry_type, dimension]
stmt = select(*_format_select_args(func.RecoverGeometryColumn(*args)))
stmt = stmt.execution_options(autocommit=True)
bind.execute(stmt)
for col in table.columns:
# Add spatial indexes for the Geometry and Geography columns
# TODO: Check that the Geography type makes sense here
if (
_check_spatial_type(col.type, (Geometry, Geography), dialect)
and col.type.spatial_index is True
):
create_spatial_index(bind, table, col)
for idx in table.info.pop("_after_create_indexes"):
table.indexes.add(idx)
idx.create(bind=bind)
[docs]
def before_drop(table, bind, **kw):
"""Handle spatial indexes during the before_drop event."""
dialect, gis_cols, regular_cols = setup_create_drop(table, bind)
for col in gis_cols:
# Disable spatial indexes if present
disable_spatial_index(bind, table, col)
args = [table.name, col.name]
stmt = select(*_format_select_args(func.DiscardGeometryColumn(*args)))
stmt = stmt.execution_options(autocommit=True)
bind.execute(stmt)
[docs]
def after_drop(table, bind, **kw):
"""Handle spatial indexes during the after_drop event."""
table.columns = table.info.pop("_saved_columns")
# Define compiled versions for functions in SpatiaLite whose names don't have
# the ST_ prefix.
_SQLITE_FUNCTIONS = {
"ST_GeomFromEWKT": "GeomFromEWKT",
"ST_GeomFromEWKB": "GeomFromEWKB",
"ST_AsBinary": "AsBinary",
"ST_AsEWKB": "AsEWKB",
"ST_AsGeoJSON": "AsGeoJSON",
}
def _compiles_sqlite(cls, fn):
def _compile_sqlite(element, compiler, **kw):
return "{}({})".format(fn, compiler.process(element.clauses, **kw))
compiles(getattr(functions, cls), "sqlite")(_compile_sqlite)
[docs]
def register_sqlite_mapping(mapping):
"""Register compilation mappings for the given functions.
Args:
mapping: Should have the following form::
{
"function_name_1": "sqlite_function_name_1",
"function_name_2": "sqlite_function_name_2",
...
}
"""
for cls, fn in mapping.items():
_compiles_sqlite(cls, fn)
register_sqlite_mapping(_SQLITE_FUNCTIONS)