在INSERT或SELECT时自动使用函数

有时,应用程序希望在INSERT或SELECT中应用函数。例如,在数据库中投影几何图形时,应用程序可能需要具有经纬度坐标的几何图形。方法来避免总是调整查询。 ST_Transform() ,则可以定义一个 TypeDecorator

 11 from pkg_resources import parse_version
 12 import pytest
 13
 14 import sqlalchemy
 15 from sqlalchemy import create_engine
 16 from sqlalchemy import MetaData
 17 from sqlalchemy import Column
 18 from sqlalchemy import Integer
 19 from sqlalchemy import func
 20 from sqlalchemy.ext.declarative import declarative_base
 21 from sqlalchemy.orm import sessionmaker
 22 from sqlalchemy.types import TypeDecorator
 23
 24 from geoalchemy2.compat import PY3
 25 from geoalchemy2 import Geometry
 26 from geoalchemy2 import shape
 27
 28
 29 engine = create_engine('postgresql://gis:gis@localhost/gis', echo=True)
 30 metadata = MetaData(engine)
 31
 32 Base = declarative_base(metadata=metadata)
 33
 34
 35 class TransformedGeometry(TypeDecorator):
 36     """This class is used to insert a ST_Transform() in each insert or select."""
 37     impl = Geometry
 38
 39     def __init__(self, db_srid, app_srid, **kwargs):
 40         kwargs["srid"] = db_srid
 41         self.impl = self.__class__.impl(**kwargs)
 42         self.app_srid = app_srid
 43         self.db_srid = db_srid
 44
 45     def column_expression(self, col):
 46         """The column_expression() method is overrided to ensure that the
 47         SRID of the resulting WKBElement is correct"""
 48         return getattr(func, self.impl.as_binary)(
 49             func.ST_Transform(col, self.app_srid),
 50             type_=self.__class__.impl(srid=self.app_srid)
 51             # srid could also be -1 so that the SRID is deduced from the
 52             # WKB data
 53         )
 54
 55     def bind_expression(self, bindvalue):
 56         return func.ST_Transform(
 57             self.impl.bind_expression(bindvalue), self.db_srid)
 58
 59
 60 class ThreeDGeometry(TypeDecorator):
 61     """This class is used to insert a ST_Force3D() in each insert."""
 62     impl = Geometry
 63
 64     def bind_expression(self, bindvalue):
 65         return func.ST_Force3D(self.impl.bind_expression(bindvalue))
 66
 67
 68 class Point(Base):
 69     __tablename__ = "point"
 70     id = Column(Integer, primary_key=True)
 71     raw_geom = Column(Geometry(srid=4326, geometry_type="POINT"))
 72     geom = Column(
 73         TransformedGeometry(
 74             db_srid=2154, app_srid=4326, geometry_type="POINT"))
 75     three_d_geom = Column(
 76         ThreeDGeometry(srid=4326, geometry_type="POINTZ", dimension=3))
 77
 78
 79 session = sessionmaker(bind=engine)()
 80
 81
 82 def check_wkb(wkb, x, y):
 83     pt = shape.to_shape(wkb)
 84     assert round(pt.x, 5) == x
 85     assert round(pt.y, 5) == y
 86
 87
 88 class TestTypeDecorator():
 89
 90     def setup(self):
 91         metadata.drop_all(checkfirst=True)
 92         metadata.create_all()
 93
 94     def teardown(self):
 95         session.rollback()
 96         metadata.drop_all()
 97
 98     def _create_one_point(self):
 99         # Create new point instance
100         p = Point()
101         p.raw_geom = "SRID=4326;POINT(5 45)"
102         p.geom = "SRID=4326;POINT(5 45)"
103         p.three_d_geom = "SRID=4326;POINT(5 45)"  # Insert 2D geometry into 3D column
104
105         # Insert point
106         session.add(p)
107         session.flush()
108         session.expire(p)
109
110         return p.id
111
112     def test_transform(self):
113         self._create_one_point()
114
115         # Query the point and check the result
116         pt = session.query(Point).one()
117         assert pt.id == 1
118         assert pt.raw_geom.srid == 4326
119         check_wkb(pt.raw_geom, 5, 45)
120
121         assert pt.geom.srid == 4326
122         check_wkb(pt.geom, 5, 45)
123
124         # Check that the data is correct in DB using raw query
125         q = "SELECT id, ST_AsEWKT(geom) AS geom FROM point;"
126         res_q = session.execute(q).fetchone()
127         assert res_q.id == 1
128         assert res_q.geom == "SRID=2154;POINT(857581.899319668 6435414.7478354)"
129
130         # Compare geom, raw_geom with auto transform and explicit transform
131         pt_trans = session.query(
132             Point,
133             Point.raw_geom,
134             func.ST_Transform(Point.raw_geom, 2154).label("trans")
135         ).one()
136
137         assert pt_trans[0].id == 1
138
139         assert pt_trans[0].geom.srid == 4326
140         check_wkb(pt_trans[0].geom, 5, 45)
141
142         assert pt_trans[0].raw_geom.srid == 4326
143         check_wkb(pt_trans[0].raw_geom, 5, 45)
144
145         assert pt_trans[1].srid == 4326
146         check_wkb(pt_trans[1], 5, 45)
147
148         assert pt_trans[2].srid == 2154
149         check_wkb(pt_trans[2], 857581.89932, 6435414.74784)
150
151     @pytest.mark.skipif(
152         not PY3 and parse_version(str(sqlalchemy.__version__)) < parse_version("1.3"),
153         reason="Need sqlalchemy >= 1.3")
154     def test_force_3d(self):
155         self._create_one_point()
156
157         # Query the point and check the result
158         pt = session.query(Point).one()
159
160         assert pt.id == 1
161         assert pt.three_d_geom.srid == 4326
162         assert pt.three_d_geom.desc.lower() == (
163             '01010000a0e6100000000000000000144000000000008046400000000000000000')

脚本的总运行时间: (0分0.000秒)

Gallery generated by Sphinx-Gallery