aboutsummaryrefslogtreecommitdiff
path: root/src/db_handler.py
blob: 83cdad8c48aeee565e25d8255d7f39c08b3e556d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import logging

import psycopg
import yaml

logger = logging.getLogger(__name__) # Set the logger name, to the name of the module

class IllegalInstructionException(Exception):
    pass


with open(os.path.join('configs', 'database.yml'), 'r') as file:
    db_con_params = yaml.safe_load(file.read())

with psycopg.connect(**db_con_params) as conn:
    with conn.cursor() as cur:
        cur.execute("""
        SELECT subclass
        FROM poi
        WHERE subclass NOT LIKE '%;%'
        GROUP BY class, subclass
        HAVING COUNT(*) > 1
        ;
        """)
        categories = cur.fetchall()
        categories = [category[0] for category in categories]
        print(f"Loaded: {len(categories)} categories")

        cur.execute("""
        SELECT country
        FROM poi
        WHERE country is not null GROUP BY country
        HAVING COUNT(*) > 1
        ;""")
        countries = cur.fetchall()
        countries = [country[0] for country in countries]

        print(f"Loaded countries: {countries}")


def get_chains() -> (list[dict]):
    with psycopg.connect(**db_con_params, row_factory=psycopg.rows.dict_row) as conn:
        with conn.cursor() as cur:
            cur.execute("""
            SELECT name, color FROM chain
            """)
            chains = cur.fetchall()
            return chains


def get_all(country: str, category: str) -> (list[dict]):
    if category not in categories:
        raise IllegalInstructionException("Category not found")
    if country not in countries:
        raise IllegalInstructionException("Country not found")
    else:
        with psycopg.connect(**db_con_params, row_factory=psycopg.rows.dict_row) as conn:
            with conn.cursor() as cur:
                cur.execute("""
WITH filtered AS (
SELECT osm_id, name, brand, geom, class, subclass
FROM poi
WHERE subclass = %(subclass)s
AND country = %(country)s
)

SELECT
 filtered.osm_id,
 filtered.name,
 filtered.brand,
 ST_Y(ST_Transform(filtered.geom, 4326)) AS lat,
 ST_X(ST_Transform(filtered.geom, 4326)) AS long,
 ST_AsGeoJSON(ST_Transform(polygon.geom, 4326)) as polygon
 FROM filtered
 JOIN
 (
   SELECT (ST_DUMP(ST_VoronoiPolygons(ST_Collect(geom)))).geom as geom
   FROM filtered
 ) polygon ON ST_Contains(polygon.geom, filtered.geom)
;
            """, {'subclass': category, 'country': country})

                all = cur.fetchall()
                return all