summaryrefslogtreecommitdiff
path: root/bin
diff options
context:
space:
mode:
Diffstat (limited to 'bin')
-rwxr-xr-xbin/lookup-server203
1 files changed, 203 insertions, 0 deletions
diff --git a/bin/lookup-server b/bin/lookup-server
new file mode 100755
index 0000000..13bff59
--- /dev/null
+++ b/bin/lookup-server
@@ -0,0 +1,203 @@
+#!/usr/bin/env python3
+# vim:tabstop=4 softtabstop=4 shiftwidth=4 textwidth=160 smarttab expandtab colorcolumn=160
+
+import argparse
+import psycopg2
+import aiohttp
+from aiohttp import web
+from datetime import datetime, timedelta
+import dateutil.parser
+from geopy.distance import distance
+import json
+import os
+import pytz
+
+headers = {
+ "Access-Control-Allow-Origin": "*",
+ "Content-Type": "application/json; charset=utf-8",
+}
+
+conn = psycopg2.connect(
+ dbname=os.getenv("GEOLOOKUP_DBNAME", "geo_to_stations"),
+ user=os.getenv("GEOLOOKUP_DBUSER", "geo_to_stations"),
+ password=os.getenv("GEOLOOKUP_DBPASS"),
+ host=os.getenv("GEOLOOKUP_DBHOST", "localhost"),
+)
+
+
+def set_coarse_location(train):
+ now = datetime.now(pytz.utc)
+ train_evas = None
+ stopovers = train["previousStopovers"]
+
+ for i, stopover in enumerate(stopovers):
+ if stopover["departure"]:
+ stopover["departure"] = dateutil.parser.parse(stopover["departure"])
+
+ # start with origin. (planned)arrival is always null in a previousStopovers list
+ departure = stopover["departure"]
+ if i > 0 and departure and departure > now:
+ train_evas = (
+ int(stopovers[i - 1]["stop"]["id"]),
+ int(stopover["stop"]["id"]),
+ )
+ train_stops = (stopovers[i - 1]["stop"]["name"], stopover["stop"]["name"])
+ train_coords = (
+ (
+ stopovers[i - 1]["stop"]["location"]["latitude"],
+ stopovers[i - 1]["stop"]["location"]["longitude"],
+ ),
+ (
+ stopover["stop"]["location"]["latitude"],
+ stopover["stop"]["location"]["longitude"],
+ ),
+ )
+ train_times = (stopovers[i - 1]["departure"], departure)
+ break
+ if not train_evas:
+ train_evas = (int(train["stop"]["id"]), int(stopovers[-1]["stop"]["id"]))
+ train_stops = (train["stop"]["name"], stopovers[-1]["stop"]["name"])
+ train_coords = (
+ (
+ stopovers[-1]["stop"]["location"]["latitude"],
+ stopovers[-1]["stop"]["location"]["longitude"],
+ ),
+ (
+ train["stop"]["location"]["latitude"],
+ train["stop"]["location"]["longitude"],
+ ),
+ )
+ train_times = (stopovers[-1]["departure"], dateutil.parser.parse(train["when"]))
+
+ if not train_times[0]:
+ return
+
+ train["evas"] = train_evas
+ train["stop_names"] = train_stops
+ train["coords"] = train_coords
+ train["times"] = train_times
+
+ train["progress_ratio"] = 1 - (
+ (train["times"][1].timestamp() - now.timestamp())
+ / (train["times"][1].timestamp() - train["times"][0].timestamp())
+ )
+ train["progress_ratio"] = max(0, min(1, train["progress_ratio"]))
+
+ if train["progress_ratio"] == 0:
+ train["location"] = train["coarse_location"] = train["coords"][0]
+ elif train["progress_ratio"] == 1:
+ train["location"] = train["coarse_location"] = train["coords"][1]
+ else:
+ ratio = train["progress_ratio"]
+ coords = train["coords"]
+ train["coarse_location"] = (
+ coords[1][0] * ratio + coords[0][0] * (1 - ratio),
+ coords[1][1] * ratio + coords[0][1] * (1 - ratio),
+ )
+
+
+def calculate_distance(train, latlon):
+ train["distance"] = distance(train["coarse_location"], latlon).km
+
+
+def format_train(train):
+ return {
+ "line": train["line"]["name"],
+ "no": train["line"]["fahrtNr"],
+ "tripId": train["tripId"],
+ "location": train["coarse_location"],
+ "distance": train["distance"],
+ "stops": [
+ (train["evas"][0], train["stop_names"][0], train["times"][0].isoformat()),
+ (train["evas"][1], train["stop_names"][1], train["times"][1].isoformat()),
+ ],
+ }
+
+
+async def handle_search(request):
+ try:
+ lat = float(request.query.get("lat"))
+ lon = float(request.query.get("lon"))
+ except TypeError:
+ return web.HTTPBadRequest(text="lat/lon are mandatory")
+ except ValueError:
+ return web.HTTPBadRequest(text="lat/lon must be floating-point numbers")
+
+ lut_lat = round(lat * 1000)
+ lut_lon = round(lon * 1000)
+
+ evas = set()
+
+ with conn.cursor() as cur:
+ cur.execute(
+ "select stations from stations where lat between %s and %s and lon between %s and %s",
+ (lut_lat - 3, lut_lat + 3, lut_lon - 3, lut_lon + 3),
+ )
+ for eva_list in cur.fetchall():
+ evas.update(eva_list[0])
+
+ if not evas:
+ response = {"evas": list(), "trains": list()}
+ return web.Response(body=json.dumps(response), headers=headers)
+
+ arrivals = list()
+ trains = list()
+
+ # deliberately not parallelized to minimize load on transport.rest
+ for eva in evas:
+ async with aiohttp.ClientSession() as session:
+ async with session.get(
+ f"https://v5.db.transport.rest/stops/{eva}/arrivals?results=40&duration=120&stopovers=true&bus=false&subway=false&tram=false"
+ ) as response:
+ content = await response.text()
+ content = json.loads(content)
+ arrivals.append(content)
+
+ for train_list in arrivals:
+ for train in train_list:
+ is_candidate = False
+ for stop in train["previousStopovers"]:
+ if (
+ int(stop["stop"]["id"]) in evas
+ and stop["stop"]["id"] != train["stop"]["id"]
+ ):
+ is_candidate = True
+ break
+ if is_candidate:
+ trains.append(train)
+
+ seen = set()
+ trains = [
+ seen.add(train["line"]["fahrtNr"]) or train
+ for train in trains
+ if train["line"]["fahrtNr"] not in seen
+ ]
+
+ for train in trains:
+ set_coarse_location(train)
+
+ trains = list(filter(lambda train: "coarse_location" in train, trains))
+
+ for train in trains:
+ calculate_distance(train, (lat, lon))
+
+ trains = sorted(trains, key=lambda train: train["distance"])
+ trains = list(map(format_train, trains))
+
+ response = {"evas": list(evas), "trains": trains}
+
+ return web.Response(body=json.dumps(response, ensure_ascii=False), headers=headers)
+
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser(
+ description="geolocation to train estimation service"
+ )
+ parser.add_argument("--port", type=int, metavar="PORT", default=8080)
+ parser.add_argument("--prefix", type=str, metavar="PATH", default="/")
+ args = parser.parse_args()
+
+ app = web.Application()
+ app.add_routes([web.get(f"{args.prefix}search", handle_search)])
+ web.run_app(app, host="localhost", port=args.port)