aboutsummaryrefslogtreecommitdiff
path: root/src/db_handler.py
blob: 335e9e4296af5ae6cd641bfeb5d850d8ff30d8fd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import logging

import psycopg
import yaml

logger = logging.getLogger(__name__) # Set the logger name, to the name of the module

class IllegalInstructionException(Exception):
    pass


with open(os.path.join('configs', 'database.yml'), 'r') as file:
    db_con_params = yaml.safe_load(file.read())

with psycopg.connect(**db_con_params) as conn:
    with conn.cursor() as cur:

        cur.execute("""
SELECT class
FROM poi
GROUP BY class
HAVING COUNT(*) > 1
ORDER BY class
;
        """)
        classes = cur.fetchall()
        categories = [category[0] for category in classes]

        cur.execute("""
SELECT class, subclass
FROM poi
WHERE subclass NOT LIKE '%;%'
GROUP BY class, subclass
HAVING COUNT(*) > 1
ORDER BY class, subclass
;
        """)
        sub_cat = cur.fetchall()
        categories += [(f"{category[0]}:{category[1]}") for category in sub_cat]
        print(f"Loaded: {len(categories)} categories")

        cur.execute("""
        SELECT country
        FROM poi
        WHERE country is not null GROUP BY country
        HAVING COUNT(*) > 1
        ;""")
        countries = cur.fetchall()
        countries = [country[0] for country in countries]

        print(f"Loaded countries: {countries}")


def get_chains() -> (list[dict]):
    with psycopg.connect(**db_con_params, row_factory=psycopg.rows.dict_row) as conn:
        with conn.cursor() as cur:
            cur.execute("""
            SELECT name, color FROM chain
            """)
            chains = cur.fetchall()
            return chains


def get_all(country: str, category: str) -> (list[dict]):
    if category not in categories:
        raise IllegalInstructionException("Category not found")
    if country not in countries:
        raise IllegalInstructionException("Country not found")

    query = """
WITH filtered AS (
SELECT osm_id, name, brand, geom, class, subclass
FROM poi
WHERE class = %(class)s
"""

    params = {'country': country}

    if ':' in category:
        query += " AND subclass = %(subclass)s "
        class_, sub_class = category.split(':')
        params['class'] = class_
        params['subclass'] = sub_class
    else:
        params['class'] = category

    query += """AND country = %(country)s
)

SELECT
 filtered.osm_id,
 filtered.name,
 filtered.brand,
 ST_Y(ST_Transform(filtered.geom, 4326)) AS lat,
 ST_X(ST_Transform(filtered.geom, 4326)) AS long,
 ST_AsGeoJSON(ST_Transform(polygon.geom, 4326)) as polygon
 FROM filtered
 JOIN
 (
   SELECT (ST_DUMP(ST_VoronoiPolygons(ST_Collect(geom)))).geom as geom
   FROM filtered
 ) polygon ON ST_Contains(polygon.geom, filtered.geom)
;
    """

    with psycopg.connect(**db_con_params, row_factory=psycopg.rows.dict_row) as conn:
        with conn.cursor() as cur:
            cur.execute(query, params)

            all = cur.fetchall()
            return all