注解
单击 here 要下载完整的示例代码,请执行以下操作
在INSERT或SELECT时自动使用函数¶
有时,应用程序希望在INSERT或SELECT中应用函数。例如,在数据库中投影几何图形时,应用程序可能需要具有经纬度坐标的几何图形。方法来避免总是调整查询。 ST_Transform()
,则可以定义一个 TypeDecorator
11 from pkg_resources import parse_version
12 import pytest
13
14 import sqlalchemy
15 from sqlalchemy import create_engine
16 from sqlalchemy import MetaData
17 from sqlalchemy import Column
18 from sqlalchemy import Integer
19 from sqlalchemy import func
20 from sqlalchemy.ext.declarative import declarative_base
21 from sqlalchemy.orm import sessionmaker
22 from sqlalchemy.types import TypeDecorator
23
24 from geoalchemy2.compat import PY3
25 from geoalchemy2 import Geometry
26 from geoalchemy2 import shape
27
28
29 engine = create_engine('postgresql://gis:gis@localhost/gis', echo=True)
30 metadata = MetaData(engine)
31
32 Base = declarative_base(metadata=metadata)
33
34
35 class TransformedGeometry(TypeDecorator):
36 """This class is used to insert a ST_Transform() in each insert or select."""
37 impl = Geometry
38
39 def __init__(self, db_srid, app_srid, **kwargs):
40 kwargs["srid"] = db_srid
41 self.impl = self.__class__.impl(**kwargs)
42 self.app_srid = app_srid
43 self.db_srid = db_srid
44
45 def column_expression(self, col):
46 """The column_expression() method is overrided to ensure that the
47 SRID of the resulting WKBElement is correct"""
48 return getattr(func, self.impl.as_binary)(
49 func.ST_Transform(col, self.app_srid),
50 type_=self.__class__.impl(srid=self.app_srid)
51 # srid could also be -1 so that the SRID is deduced from the
52 # WKB data
53 )
54
55 def bind_expression(self, bindvalue):
56 return func.ST_Transform(
57 self.impl.bind_expression(bindvalue), self.db_srid)
58
59
60 class ThreeDGeometry(TypeDecorator):
61 """This class is used to insert a ST_Force3D() in each insert."""
62 impl = Geometry
63
64 def bind_expression(self, bindvalue):
65 return func.ST_Force3D(self.impl.bind_expression(bindvalue))
66
67
68 class Point(Base):
69 __tablename__ = "point"
70 id = Column(Integer, primary_key=True)
71 raw_geom = Column(Geometry(srid=4326, geometry_type="POINT"))
72 geom = Column(
73 TransformedGeometry(
74 db_srid=2154, app_srid=4326, geometry_type="POINT"))
75 three_d_geom = Column(
76 ThreeDGeometry(srid=4326, geometry_type="POINTZ", dimension=3))
77
78
79 session = sessionmaker(bind=engine)()
80
81
82 def check_wkb(wkb, x, y):
83 pt = shape.to_shape(wkb)
84 assert round(pt.x, 5) == x
85 assert round(pt.y, 5) == y
86
87
88 class TestTypeDecorator():
89
90 def setup(self):
91 metadata.drop_all(checkfirst=True)
92 metadata.create_all()
93
94 def teardown(self):
95 session.rollback()
96 metadata.drop_all()
97
98 def _create_one_point(self):
99 # Create new point instance
100 p = Point()
101 p.raw_geom = "SRID=4326;POINT(5 45)"
102 p.geom = "SRID=4326;POINT(5 45)"
103 p.three_d_geom = "SRID=4326;POINT(5 45)" # Insert 2D geometry into 3D column
104
105 # Insert point
106 session.add(p)
107 session.flush()
108 session.expire(p)
109
110 return p.id
111
112 def test_transform(self):
113 self._create_one_point()
114
115 # Query the point and check the result
116 pt = session.query(Point).one()
117 assert pt.id == 1
118 assert pt.raw_geom.srid == 4326
119 check_wkb(pt.raw_geom, 5, 45)
120
121 assert pt.geom.srid == 4326
122 check_wkb(pt.geom, 5, 45)
123
124 # Check that the data is correct in DB using raw query
125 q = "SELECT id, ST_AsEWKT(geom) AS geom FROM point;"
126 res_q = session.execute(q).fetchone()
127 assert res_q.id == 1
128 assert res_q.geom == "SRID=2154;POINT(857581.899319668 6435414.7478354)"
129
130 # Compare geom, raw_geom with auto transform and explicit transform
131 pt_trans = session.query(
132 Point,
133 Point.raw_geom,
134 func.ST_Transform(Point.raw_geom, 2154).label("trans")
135 ).one()
136
137 assert pt_trans[0].id == 1
138
139 assert pt_trans[0].geom.srid == 4326
140 check_wkb(pt_trans[0].geom, 5, 45)
141
142 assert pt_trans[0].raw_geom.srid == 4326
143 check_wkb(pt_trans[0].raw_geom, 5, 45)
144
145 assert pt_trans[1].srid == 4326
146 check_wkb(pt_trans[1], 5, 45)
147
148 assert pt_trans[2].srid == 2154
149 check_wkb(pt_trans[2], 857581.89932, 6435414.74784)
150
151 @pytest.mark.skipif(
152 not PY3 and parse_version(str(sqlalchemy.__version__)) < parse_version("1.3"),
153 reason="Need sqlalchemy >= 1.3")
154 def test_force_3d(self):
155 self._create_one_point()
156
157 # Query the point and check the result
158 pt = session.query(Point).one()
159
160 assert pt.id == 1
161 assert pt.three_d_geom.srid == 4326
162 assert pt.three_d_geom.desc.lower() == (
163 '01010000a0e6100000000000000000144000000000008046400000000000000000')
脚本的总运行时间: (0分0.000秒)