-
Damien DESPRES authoredDamien DESPRES authored
postgis_helper.py 9.38 KiB
from collections import OrderedDict
import json
import logging
from optparse import OptionParser
import re
from sqlalchemy import create_engine
from sqlalchemy.engine import reflection
from sqlalchemy import MetaData
from sqlalchemy import select
from sqlalchemy import Table
# Then load the Geometry type
from geoalchemy2 import Geometry
# --------------------------------------------------------------------
import psycopg2.extensions
DEC2FLOAT = psycopg2.extensions.new_type(
psycopg2.extensions.DECIMAL.values,
'DEC2FLOAT',
lambda value, curs: float(value) if value is not None else None)
psycopg2.extensions.register_type(DEC2FLOAT)
# --------------------------------------------------------------------
TYPE_PRIORITY_ORDER = (
'str',
'float',
'int',
'bool',
'datetime',
'date',
# ...
)
VALID_IP_ADDRESS_REGEX = r'(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])'
VALID_HOSTNAME_ADDRESS = r'(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])'
class Remote(object):
def __init__(self, hostname='localhost', dbname='postgres',
username='postgres', password=None):
valid_regex = ('^({valid_ip}|{valid_host})$').format(
valid_ip=VALID_IP_ADDRESS_REGEX,
valid_host=VALID_HOSTNAME_ADDRESS)
if not re.match(valid_regex, hostname):
raise Exception(
'Malformed hostname {hostname}'.format(hostname=hostname))
connection_string = (
'postgresql+psycopg2://{username}:{password}@{hostname}/{dbname}'
).format(hostname=hostname, dbname=dbname,
username=username, password=password)
self.engine = create_engine(connection_string, echo=False)
try:
with self.engine.connect():
pass
except Exception as e:
raise e
# Then:
self.dbname = dbname
self.metadata = MetaData(bind=self.engine)
self.inspect = reflection.Inspector.from_engine(self.engine)
def get_schema_names(self):
return [name for name in self.inspect.get_schema_names() if name != 'information_schema']
def get_tables(self, schema=None):
for table_name in self.inspect.get_table_names(schema=schema) + self.inspect.get_view_names(schema=schema):
if table_name == 'spatial_ref_sys':
continue
yield self.get_table(table_name, schema=schema)
def get_table(self, name, schema=None):
return Table(name, self.metadata, schema=schema, autoload=True)
def count_entries(self, table):
return self.engine.execute(table.count()).first()[0]
def get_entries(self, table, limit = None):
columns, geom = self.get_columns(table)
fields = [table.c[col.name] for col in columns]
if geom is not None:
if not geom.type.srid == 4326:
the_geom = table.c[geom.name].ST_Transform(4326).ST_AsGeoJSON()
else:
the_geom = table.c[geom.name].ST_AsGeoJSON()
fields.append(the_geom)
selected = select(fields)
if limit is not None:
selected=selected.limit(limit)
for entry in self.engine.execute(selected):
items = entry.items()
properties = dict(items)
geometry = None
try:
# this fails if properties['ST_AsGeoJSON_1'] = None
geometry = json.loads(properties['ST_AsGeoJSON_1'])
except:
pass
try:
del properties['ST_AsGeoJSON_1']
except:
pass
document = {
'type': 'Feature',
'properties': properties
}
if geometry is not None:
document['geometry'] = geometry
yield document
def get_columns(self, table):
common_columns, geometry_columns = [], []
# it is important to preserve the original order, in order to be "loyal" to data producers!
for column in OrderedDict(table.columns).values():
if isinstance(column.type, Geometry):
geometry_columns.append(column)
else:
common_columns.append(column)
if geometry_columns:
return common_columns, geometry_columns[-1]
return common_columns, None
def field_type_detector(self, table):
def evaluate(value):
# TODO -> evaluate in es type
if not value:
return None
if isinstance(value, list):
# is array: ES don't care
for elt in value:
return evaluate(elt)
if isinstance(value, str):
if re.match(r"^\d+?\.\d+?$", value):
# TODO float/double
return 'float'
elif re.match(r"^-?(?!0)\d+$", value):
# TODO short/integer/long
return 'int'
# TODO: date... ip... binary... object... boolean...
else:
return 'str' # =text
else:
return value.__class__.__qualname__
columns, geom = self.get_columns(table)
fields = [table.c[col.name] for col in columns]
selected = select(fields)
detected = {}
for entry in self.engine.execute(selected): # iterates
for col, val in entry.items():
if col not in detected.keys():
detected[col] = {}
key = evaluate(val)
#key = type_utils.detect_type(val)
if key not in detected[col]:
detected[col].update({key: 0})
detected[col][key] += 1 # count (but not used for moment)
data = {}
for col, candidate in detected.items():
candidate = tuple(candidate.items())
try:
first = [t for x in TYPE_PRIORITY_ORDER
for t in candidate if t[0] == x]
print(first)
first = first[0]
except IndexError:
logging.warning("'{col}' is empty!".format(col=col))
data[col] = 'str'
else:
# Get winner (first item in the ordered list)
data[col] = first[0]
logging.warning((
"Mixed type for {col}: {candidate}"
).format(col=col, candidate=str(candidate), chosen=data[col]))
return data
def main(**kwargs):
hostname = kwargs.get('hostname')
dbname = kwargs.get('dbname')
schema_name = kwargs.get('schema')
table_name = kwargs.get('table')
username = kwargs.get('username')
password = kwargs.get('password')
conn = Remote(hostname=hostname, dbname=dbname,
username=username, password=password)
schema_names = conn.get_schema_names()
for schema in schema_names:
if schema_name and not schema_name == schema:
logging.debug(f"not {schema_name}, sckipping {schema}")
continue
for table in conn.get_tables(schema=schema):
if table_name and not table_name == table.name:
continue
count = conn.count_entries(table)
## -----------------------------------------------------
print(count)
# cnt = 0
# for record in conn.get_entries(table):
# #print(json.dumps(record, indent=4, sort_keys=True, default=str))
# #records.append(record)
# cnt += 1
# print(cnt, count)
#
# exit(0)
## -----------------------------------------------------
columns, _ = conn.get_columns(table)
detected = conn.field_type_detector(table)
compared = OrderedDict([
(col.name, OrderedDict([
('py', detected.get(col.name)),
('pg', col.type.__str__())
])) for col in columns])
resultat = OrderedDict([
('schema', schema),
('table', table.name),
('count', count),
('columns', compared),
])
#
print(json.dumps(resultat, indent=4))
#
if __name__ == '__main__':
parser = OptionParser()
parser.add_option('--log', dest='loglevel', help="Niveau de log", default='WARNING')
parser.add_option('--hostname', dest='hostname', default='localhost')
parser.add_option('--dbname', dest='dbname', default='postgres')
parser.add_option('--schema', dest='schema', default=None)
parser.add_option('--table', dest='table', default=None)
parser.add_option('--username', dest='username', default='postgres')
parser.add_option('--password', dest='password')
opts, _ = parser.parse_args()
loglevel = getattr(logging, opts.loglevel.upper(), None)
if not isinstance(loglevel, int):
raise ValueError('Invalid log level: {0}'.format(opts.loglevel))
logging.basicConfig(format='%(levelname)s: %(message)s', level=loglevel)
# logging.getLogger('sqlalchemy.engine').setLevel(loglevel)
main(hostname=opts.hostname, dbname=opts.dbname,
table=opts.table, schema=opts.schema,
username=opts.username, password=opts.password)