diff options
author | Daniel Friesel <derf@finalrewind.org> | 2021-03-28 18:58:37 +0200 |
---|---|---|
committer | Daniel Friesel <derf@finalrewind.org> | 2021-03-28 18:58:37 +0200 |
commit | 06bd4dc075b1d2901d4503ba1fcf232a238eccc4 (patch) | |
tree | bf3cd14d9a19bed091b3d1cd0e18daa655868b15 |
initial commit
-rwxr-xr-x | bin/lookup-server | 203 |
1 files changed, 203 insertions, 0 deletions
diff --git a/bin/lookup-server b/bin/lookup-server new file mode 100755 index 0000000..13bff59 --- /dev/null +++ b/bin/lookup-server @@ -0,0 +1,203 @@ +#!/usr/bin/env python3 +# vim:tabstop=4 softtabstop=4 shiftwidth=4 textwidth=160 smarttab expandtab colorcolumn=160 + +import argparse +import psycopg2 +import aiohttp +from aiohttp import web +from datetime import datetime, timedelta +import dateutil.parser +from geopy.distance import distance +import json +import os +import pytz + +headers = { + "Access-Control-Allow-Origin": "*", + "Content-Type": "application/json; charset=utf-8", +} + +conn = psycopg2.connect( + dbname=os.getenv("GEOLOOKUP_DBNAME", "geo_to_stations"), + user=os.getenv("GEOLOOKUP_DBUSER", "geo_to_stations"), + password=os.getenv("GEOLOOKUP_DBPASS"), + host=os.getenv("GEOLOOKUP_DBHOST", "localhost"), +) + + +def set_coarse_location(train): + now = datetime.now(pytz.utc) + train_evas = None + stopovers = train["previousStopovers"] + + for i, stopover in enumerate(stopovers): + if stopover["departure"]: + stopover["departure"] = dateutil.parser.parse(stopover["departure"]) + + # start with origin. (planned)arrival is always null in a previousStopovers list + departure = stopover["departure"] + if i > 0 and departure and departure > now: + train_evas = ( + int(stopovers[i - 1]["stop"]["id"]), + int(stopover["stop"]["id"]), + ) + train_stops = (stopovers[i - 1]["stop"]["name"], stopover["stop"]["name"]) + train_coords = ( + ( + stopovers[i - 1]["stop"]["location"]["latitude"], + stopovers[i - 1]["stop"]["location"]["longitude"], + ), + ( + stopover["stop"]["location"]["latitude"], + stopover["stop"]["location"]["longitude"], + ), + ) + train_times = (stopovers[i - 1]["departure"], departure) + break + if not train_evas: + train_evas = (int(train["stop"]["id"]), int(stopovers[-1]["stop"]["id"])) + train_stops = (train["stop"]["name"], stopovers[-1]["stop"]["name"]) + train_coords = ( + ( + stopovers[-1]["stop"]["location"]["latitude"], + stopovers[-1]["stop"]["location"]["longitude"], + ), + ( + train["stop"]["location"]["latitude"], + train["stop"]["location"]["longitude"], + ), + ) + train_times = (stopovers[-1]["departure"], dateutil.parser.parse(train["when"])) + + if not train_times[0]: + return + + train["evas"] = train_evas + train["stop_names"] = train_stops + train["coords"] = train_coords + train["times"] = train_times + + train["progress_ratio"] = 1 - ( + (train["times"][1].timestamp() - now.timestamp()) + / (train["times"][1].timestamp() - train["times"][0].timestamp()) + ) + train["progress_ratio"] = max(0, min(1, train["progress_ratio"])) + + if train["progress_ratio"] == 0: + train["location"] = train["coarse_location"] = train["coords"][0] + elif train["progress_ratio"] == 1: + train["location"] = train["coarse_location"] = train["coords"][1] + else: + ratio = train["progress_ratio"] + coords = train["coords"] + train["coarse_location"] = ( + coords[1][0] * ratio + coords[0][0] * (1 - ratio), + coords[1][1] * ratio + coords[0][1] * (1 - ratio), + ) + + +def calculate_distance(train, latlon): + train["distance"] = distance(train["coarse_location"], latlon).km + + +def format_train(train): + return { + "line": train["line"]["name"], + "no": train["line"]["fahrtNr"], + "tripId": train["tripId"], + "location": train["coarse_location"], + "distance": train["distance"], + "stops": [ + (train["evas"][0], train["stop_names"][0], train["times"][0].isoformat()), + (train["evas"][1], train["stop_names"][1], train["times"][1].isoformat()), + ], + } + + +async def handle_search(request): + try: + lat = float(request.query.get("lat")) + lon = float(request.query.get("lon")) + except TypeError: + return web.HTTPBadRequest(text="lat/lon are mandatory") + except ValueError: + return web.HTTPBadRequest(text="lat/lon must be floating-point numbers") + + lut_lat = round(lat * 1000) + lut_lon = round(lon * 1000) + + evas = set() + + with conn.cursor() as cur: + cur.execute( + "select stations from stations where lat between %s and %s and lon between %s and %s", + (lut_lat - 3, lut_lat + 3, lut_lon - 3, lut_lon + 3), + ) + for eva_list in cur.fetchall(): + evas.update(eva_list[0]) + + if not evas: + response = {"evas": list(), "trains": list()} + return web.Response(body=json.dumps(response), headers=headers) + + arrivals = list() + trains = list() + + # deliberately not parallelized to minimize load on transport.rest + for eva in evas: + async with aiohttp.ClientSession() as session: + async with session.get( + f"https://v5.db.transport.rest/stops/{eva}/arrivals?results=40&duration=120&stopovers=true&bus=false&subway=false&tram=false" + ) as response: + content = await response.text() + content = json.loads(content) + arrivals.append(content) + + for train_list in arrivals: + for train in train_list: + is_candidate = False + for stop in train["previousStopovers"]: + if ( + int(stop["stop"]["id"]) in evas + and stop["stop"]["id"] != train["stop"]["id"] + ): + is_candidate = True + break + if is_candidate: + trains.append(train) + + seen = set() + trains = [ + seen.add(train["line"]["fahrtNr"]) or train + for train in trains + if train["line"]["fahrtNr"] not in seen + ] + + for train in trains: + set_coarse_location(train) + + trains = list(filter(lambda train: "coarse_location" in train, trains)) + + for train in trains: + calculate_distance(train, (lat, lon)) + + trains = sorted(trains, key=lambda train: train["distance"]) + trains = list(map(format_train, trains)) + + response = {"evas": list(evas), "trains": trains} + + return web.Response(body=json.dumps(response, ensure_ascii=False), headers=headers) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="geolocation to train estimation service" + ) + parser.add_argument("--port", type=int, metavar="PORT", default=8080) + parser.add_argument("--prefix", type=str, metavar="PATH", default="/") + args = parser.parse_args() + + app = web.Application() + app.add_routes([web.get(f"{args.prefix}search", handle_search)]) + web.run_app(app, host="localhost", port=args.port) |