summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xbin/populate-lut198
1 files changed, 198 insertions, 0 deletions
diff --git a/bin/populate-lut b/bin/populate-lut
new file mode 100755
index 0000000..42ffa8b
--- /dev/null
+++ b/bin/populate-lut
@@ -0,0 +1,198 @@
+#!/usr/bin/env python3
+# vim:tabstop=4 softtabstop=4 shiftwidth=4 textwidth=160 smarttab expandtab colorcolumn=160
+
+from datetime import datetime, timedelta
+from geopy.distance import distance
+from progress.bar import Bar
+
+import csv
+import json
+import numpy as np
+import os
+import psycopg2
+import requests
+import sys
+
+
+class ProgressBar(Bar):
+ sma_window = 500
+ suffix = "%(percent).0f%% [%(elapsed_td)s/%(eta_td)s]"
+
+
+conn = psycopg2.connect(
+ dbname=os.getenv("GEOLOOKUP_DBNAME", "geo_to_stations"),
+ user=os.getenv("GEOLOOKUP_DBUSER", "geo_to_stations"),
+ password=os.getenv("GEOLOOKUP_DBPASS"),
+ host=os.getenv("GEOLOOKUP_DBHOST", "localhost"),
+)
+
+shape = dict()
+stops_by_latlon = dict()
+
+routes_by_shape_id = dict()
+trips_by_shape_id = dict()
+
+name_to_eva = dict()
+eva_to_name = dict()
+
+try:
+ with open("data/iris-stations.json", "r") as f:
+ for station in json.load(f):
+ name_to_eva[station["name"]] = int(station["eva"])
+ eva_to_name[int(station["eva"])] = station["name"]
+except FileNotFoundError:
+ print(
+ "populate-lut requires a list of IRIS stations. Please run the following commands:"
+ )
+ print()
+ print("mkdir -p data")
+ print(
+ "curl https://git.finalrewind.org/Travel-Status-DE-IRIS/plain/share/stations.json > data/iris-stations.json"
+ )
+ print()
+ sys.exit(1)
+
+try:
+ with open("data/nvbw/trips.txt", "r") as f:
+ pass
+ with open("data/nvbw/shapes.txt", "r") as f:
+ pass
+ with open("data/nvbw/stop_times.txt", "r") as f:
+ pass
+except FileNotFoundError:
+ print("populate-lut requires GTFS shapes of regional transit lines.")
+ print(
+ "At present, the best known resource is <https://www.nvbw.de/open-data/fahrplandaten/fahrplandaten-mit-liniennetz>."
+ )
+ print(
+ "(https://www.nvbw.de/fileadmin/user_upload/service/open_data/fahrplandaten_mit_liniennetz/bwspnv.zip)"
+ )
+ print("Please download and extract it to data/nvbw.")
+ sys.exit(1)
+
+print("Loading trips ...")
+with open("data/nvbw/trips.txt", "r") as f:
+ f.readline()
+ cr = csv.reader(f)
+ for row in cr:
+ route_id, trip_id, service_id, direction_id, block_id, shape_id = row
+ if shape_id not in routes_by_shape_id:
+ routes_by_shape_id[shape_id] = list()
+ routes_by_shape_id[shape_id].append(route_id)
+ if shape_id not in trips_by_shape_id:
+ trips_by_shape_id[shape_id] = list()
+ trips_by_shape_id[shape_id].append(trip_id)
+
+print("Loading shapes ...")
+with open("data/nvbw/shapes.txt", "r") as f:
+ f.readline()
+ cr = csv.reader(f)
+ prev_lat, prev_lon = None, None
+ prev_dist = 0
+ for row in cr:
+ shape_id, _, lat, lon, dist = row
+ if shape_id not in shape:
+ shape[shape_id] = list()
+ prev_dist = 0
+ lat = float(lat)
+ lon = float(lon)
+ dist = float(dist)
+ if dist > prev_dist and dist - prev_dist > 200:
+ # ensure shape entries are no more than 200m apart
+ for i in np.arange(200, dist - prev_dist, 200):
+ ratio = i / (dist - prev_dist)
+ assert 0 <= ratio <= 1
+ rel_lat = (prev_lat * ratio + lat * (1 - ratio)) / 2
+ rel_lon = (prev_lon * ratio + lon * (1 - ratio)) / 2
+ shape[shape_id].append((rel_lat, rel_lon, dist))
+ shape[shape_id].append((lat, lon, dist))
+ prev_dist = dist
+ prev_lat = lat
+ prev_lon = lon
+
+
+def add_stops(lat, lon, stops):
+ lut_lat_center = round(lat * 1000)
+ lut_lon_center = round(lon * 1000)
+
+ evas = list()
+ for stop in stops:
+ try:
+ evas.append(name_to_eva[stop])
+ except KeyError:
+ try:
+ evas.append(name_to_eva[stop.replace(" (", "(")])
+ except KeyError:
+ pass
+
+ for lut_lat in range(lut_lat_center - 0, lut_lat_center + 1):
+ for lut_lon in range(lut_lon_center - 0, lut_lon_center + 1):
+ if (lut_lat, lut_lon) not in stops_by_latlon:
+ stops_by_latlon[(lut_lat, lut_lon)] = set()
+ stops_by_latlon[(lut_lat, lut_lon)].update(evas)
+
+
+print("Loading stop_times ...")
+stops_by_tripid = dict()
+with open("data/nvbw/stop_times.txt", "r") as f:
+ f.readline()
+ cr = csv.reader(f)
+ for row in cr:
+ (
+ trip_id,
+ stop_id,
+ arrival_time,
+ departure_time,
+ stop_seq,
+ stop_headsign,
+ pickup_type,
+ dropoff_type,
+ dist,
+ ) = row
+ if trip_id not in stops_by_tripid:
+ stops_by_tripid[trip_id] = list()
+ stops_by_tripid[trip_id].append((stop_headsign, float(dist)))
+
+num_shapes = len(shape.keys())
+
+for shape_id in ProgressBar("Calculating neighoubrs", max=num_shapes).iter(
+ shape.keys()
+):
+ for trip_id in trips_by_shape_id[shape_id]:
+ stops = stops_by_tripid[trip_id]
+ first_stop = stops[0]
+ last_stop = stops[-1]
+ for lat, lon, shape_dist in shape[shape_id]:
+ assert first_stop[1] <= shape_dist <= last_stop[1]
+ for i, (stop_name, stop_dist) in enumerate(stops):
+ if (
+ stop_dist <= shape_dist
+ and i + 1 < len(stops)
+ and stops[i + 1][1] >= shape_dist
+ ):
+ add_stops(lat, lon, (stop_name, stops[i + 1][0]))
+
+num_latlons = len(stops_by_latlon.keys())
+
+with conn.cursor() as cur:
+ cur.execute("drop table if exists stations")
+ cur.execute(
+ """create table stations (
+ lat integer not null,
+ lon integer not null,
+ stations jsonb not null,
+ primary key (lat, lon)
+ )
+ """
+ )
+
+for (lat, lon), stops in ProgressBar("Inserting coordinates", max=num_latlons).iter(
+ stops_by_latlon.items()
+):
+ with conn.cursor() as cur:
+ cur.execute(
+ """insert into stations (lat, lon, stations) values (%s, %s, %s)""",
+ (lat, lon, json.dumps(list(stops))),
+ )
+
+conn.commit()