Skip to content
Snippets Groups Projects
postgis_helper.py 9.16 KiB
Newer Older
  • Learn to ignore specific revisions
  • from collections import OrderedDict
    import json
    import logging
    from optparse import OptionParser
    import re
    from sqlalchemy import create_engine
    from sqlalchemy.engine import reflection
    from sqlalchemy import MetaData
    from sqlalchemy import select
    from sqlalchemy import Table
    
    # Then load the Geometry type
    
    from geoalchemy2 import Geometry
    
    # --------------------------------------------------------------------
    import psycopg2.extensions
    DEC2FLOAT = psycopg2.extensions.new_type(
        psycopg2.extensions.DECIMAL.values,
        'DEC2FLOAT',
        lambda value, curs: float(value) if value is not None else None)
    psycopg2.extensions.register_type(DEC2FLOAT)
    # --------------------------------------------------------------------
    
    
    TYPE_PRIORITY_ORDER = (
        'str',
        'float',
        'int',
        'bool',
        'datetime',
        'date',
        # ...
        )
    
    
    
    VALID_IP_ADDRESS_REGEX = '(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])'
    
    VALID_HOSTNAME_ADDRESS = '(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])'
    
    
    class Remote(object):
    
        def __init__(self, hostname='localhost', dbname='postgres',
                     username='postgres', password=None):
    
    
            valid_regex = ('^({valid_ip}|{valid_host})$').format(
                valid_ip=VALID_IP_ADDRESS_REGEX,
    
                valid_host=VALID_HOSTNAME_ADDRESS)
            if not re.match(valid_regex, hostname):
                raise Exception(
                    'Malformed hostname {hostname}'.format(hostname=hostname))
    
            connection_string = (
                'postgresql+psycopg2://{username}:{password}@{hostname}/{dbname}'
                ).format(hostname=hostname, dbname=dbname,
                         username=username, password=password)
    
            self.engine = create_engine(connection_string, echo=False)
            try:
                with self.engine.connect():
                    pass
            except Exception as e:
                raise e
            # Then:
            self.dbname = dbname
            self.metadata = MetaData(bind=self.engine)
            self.inspect = reflection.Inspector.from_engine(self.engine)
    
        def get_schema_names(self):
    
            return [name for name in self.inspect.get_schema_names() if name != 'information_schema']
    
    
        def get_tables(self, schema=None):
    
    Alessandro Cerioni's avatar
    Alessandro Cerioni committed
            for table_name in self.inspect.get_table_names(schema=schema) + self.inspect.get_view_names(schema=schema):
    
                if table_name == 'spatial_ref_sys':
                    continue
                yield self.get_table(table_name, schema=schema)
    
        def get_table(self, name, schema=None):
            return Table(name, self.metadata, schema=schema, autoload=True)
    
        def count_entries(self, table):
            return self.engine.execute(table.count()).first()[0]
    
        def get_entries(self, table):
            columns, geom = self.get_columns(table)
    
            fields = [table.c[col.name] for col in columns]
    
    
            if geom is not None:
                if not geom.type.srid == 4326:
    
                    the_geom = table.c[geom.name].ST_Transform(4326).ST_AsGeoJSON()
    
                    the_geom = table.c[geom.name].ST_AsGeoJSON()
    
    
            selected = select(fields)
            for entry in self.engine.execute(selected):
                items = entry.items()
                properties = dict(items)
    
                #print('************************here')
    
    
                    # this fails if properties['ST_AsGeoJSON_1'] = None
    
                    geometry = json.loads(properties['ST_AsGeoJSON_1'])
                except:
                    pass
    
    
                try:
                    del properties['ST_AsGeoJSON_1']
                except:
                    pass
    
                document = {
                    'type': 'Feature',
    
                yield document
    
        def get_columns(self, table):
            common_columns, geometry_columns = [], []
            for column in dict(table.columns).values():
                if isinstance(column.type, Geometry):
                    geometry_columns.append(column)
                else:
                    common_columns.append(column)
            if geometry_columns:
                return common_columns, geometry_columns[-1]
            return common_columns, None
    
        def field_type_detector(self, table):
    
            def evaluate(value):
                # TODO -> evaluate in es type
                if not value:
                    return None
                if isinstance(value, list):
                    # is array: ES don't care
                    for elt in value:
                        return evaluate(elt)
                if isinstance(value, str):
                    if re.match("^\d+?\.\d+?$", value):
                        # TODO float/double
                        return 'float'
                    elif re.match("^-?(?!0)\d+$", value):
                        # TODO short/integer/long
                        return 'int'
                    # TODO: date... ip... binary... object... boolean...
                    else:
                        return 'str'  # =text
                else:
                    return value.__class__.__qualname__
    
            columns, geom = self.get_columns(table)
            fields = [table.c[col.name] for col in columns]
            selected = select(fields)
            detected = {}
            for entry in self.engine.execute(selected):  # iterates
                for col, val in entry.items():
                    if col not in detected.keys():
                        detected[col] = {}
                    key = evaluate(val)
    
                    #key = type_utils.detect_type(val)
    
                    if key not in detected[col]:
                        detected[col].update({key: 0})
                    detected[col][key] += 1  # count (but not used for moment)
    
            data = {}
            for col, candidate in detected.items():
                candidate = tuple(candidate.items())
                try:
                    first = [t for x in TYPE_PRIORITY_ORDER
                             for t in candidate if t[0] == x]
                    print(first)
                    first = first[0]
                except IndexError:
                    logging.warning("'{col}' is empty!".format(col=col))
                    data[col] = 'str'
                else:
                    # Get winner (first item in the ordered list)
                    data[col] = first[0]
                    logging.warning((
                        "Mixed type for {col}: {candidate}"
                        ).format(col=col, candidate=str(candidate), chosen=data[col]))
            return data
    
    
    def main(**kwargs):
    
        hostname = kwargs.get('hostname')
        dbname = kwargs.get('dbname')
        schema_name = kwargs.get('schema')
        table_name = kwargs.get('table')
        username = kwargs.get('username')
        password = kwargs.get('password')
    
        conn = Remote(hostname=hostname, dbname=dbname,
                      username=username, password=password)
    
        schema_names = conn.get_schema_names()
        for schema in schema_names:
            if schema_name and not schema_name == schema:
                continue
            for table in conn.get_tables(schema=schema):
                if table_name and not table_name == table.name:
                    continue
                count = conn.count_entries(table)
                ## -----------------------------------------------------
                print(count)
    
                # cnt = 0
                # for record in conn.get_entries(table):
                #     #print(json.dumps(record, indent=4, sort_keys=True, default=str))
                #     #records.append(record)
                #     cnt += 1
                #     print(cnt, count)
                #
                # exit(0)
    
    
                ## -----------------------------------------------------
                columns, _ = conn.get_columns(table)
                detected = conn.field_type_detector(table)
                compared = OrderedDict([
                    (col.name, OrderedDict([
                        ('py', detected.get(col.name)),
                        ('pg', col.type.__str__())
                        ])) for col in columns])
    
                resultat = OrderedDict([
                    ('schema', schema),
                    ('table', table.name),
                    ('count', count),
                    ('columns', compared),
                    ])
    
                #
                print(json.dumps(resultat, indent=4))
                #
    
    
    if __name__ == '__main__':
    
        parser = OptionParser()
        parser.add_option('--log', dest='loglevel', help="Niveau de log", default='WARNING')
        parser.add_option('--hostname', dest='hostname', default='localhost')
        parser.add_option('--dbname', dest='dbname', default='postgres')
        parser.add_option('--schema', dest='schema', default=None)
        parser.add_option('--table', dest='table', default=None)
        parser.add_option('--username', dest='username', default='postgres')
        parser.add_option('--password', dest='password')
    
        opts, _ = parser.parse_args()
    
        loglevel = getattr(logging, opts.loglevel.upper(), None)
        if not isinstance(loglevel, int):
            raise ValueError('Invalid log level: {0}'.format(opts.loglevel))
    
        logging.basicConfig(format='%(levelname)s: %(message)s', level=loglevel)
        # logging.getLogger('sqlalchemy.engine').setLevel(loglevel)
    
        main(hostname=opts.hostname, dbname=opts.dbname,
             table=opts.table, schema=opts.schema,
             username=opts.username, password=opts.password)