#!/usr/bin/env python3 # vim:tabstop=4 softtabstop=4 shiftwidth=4 textwidth=160 smarttab expandtab colorcolumn=160 import argparse import psycopg2 import aiohttp from aiohttp import web from datetime import datetime, timedelta import dateutil.parser from geopy.distance import distance import json import os import pytz headers = { "Access-Control-Allow-Origin": "*", "Content-Type": "application/json; charset=utf-8", } conn = psycopg2.connect( dbname=os.getenv("GEOLOOKUP_DBNAME", "geo_to_stations"), user=os.getenv("GEOLOOKUP_DBUSER", "geo_to_stations"), password=os.getenv("GEOLOOKUP_DBPASS"), host=os.getenv("GEOLOOKUP_DBHOST", "localhost"), ) conn.autocommit = True conn.set_session(readonly=True) def set_coarse_location(train): now = datetime.now(pytz.utc) train_evas = None stopovers = train["previousStopovers"] # includes train["stop"] -- but with arrival instead of departure for i, stopover in enumerate(stopovers): ts = None if stopover["departure"]: try: stopover["departure"] = dateutil.parser.parse(stopover["departure"]) ts = stopover["departure"] except TypeError: return if stopover["arrival"]: try: stopover["arrival"] = dateutil.parser.parse(stopover["arrival"]) ts = stopover["arrival"] except TypeError: return # start with origin. (planned)arrival is always null in a previousStopovers list except for the last entry # (which is the stop where arrivals were requested) if i > 0 and ts and ts > now: train_evas = ( int(stopovers[i - 1]["stop"]["id"]), int(stopover["stop"]["id"]), ) train_stops = (stopovers[i - 1]["stop"]["name"], stopover["stop"]["name"]) train_coords = ( ( stopovers[i - 1]["stop"]["location"]["latitude"], stopovers[i - 1]["stop"]["location"]["longitude"], ), ( stopover["stop"]["location"]["latitude"], stopover["stop"]["location"]["longitude"], ), ) # XXX known bug: we're saving departure at i-1 and (possibly) departure at i. For a more accurate coarse position estimate later on, # we need to track departure at i-1 and arrival at i. But we don't always have it. train_times = (stopovers[i - 1]["departure"], ts) break if not train_evas: return if not train_times[0]: return train["evas"] = train_evas train["stop_names"] = train_stops train["coords"] = train_coords train["times"] = train_times train["progress_ratio"] = 1 - ( (train["times"][1].timestamp() - now.timestamp()) / (train["times"][1].timestamp() - train["times"][0].timestamp()) ) train["progress_ratio"] = max(0, min(1, train["progress_ratio"])) if train["progress_ratio"] == 0: train["location"] = train["coarse_location"] = train["coords"][0] elif train["progress_ratio"] == 1: train["location"] = train["coarse_location"] = train["coords"][1] else: ratio = train["progress_ratio"] coords = train["coords"] train["coarse_location"] = ( coords[1][0] * ratio + coords[0][0] * (1 - ratio), coords[1][1] * ratio + coords[0][1] * (1 - ratio), ) if train_evas[1] == int(train["stop"]["id"]): # we can compare departure at previous stop with arrival at this stop. this is most accurate for position estimation. train["preferred"] = True else: train["preferred"] = False def calculate_distance(train, latlon): train["distance"] = distance(train["coarse_location"], latlon).km def format_train(train): train_type, line_no = train["line"]["name"].split() train_no = train["line"]["fahrtNr"] return { "line": f"{train_type} {line_no}", "train": f"{train_type} {train_no}", "tripId": train["tripId"], "location": train["coarse_location"], "distance": round(train["distance"], 1), "stops": [ ( train["evas"][0], train["stop_names"][0], train["times"][0].strftime("%H:%M"), ), ( train["evas"][1], train["stop_names"][1], train["times"][1].strftime("%H:%M"), ), ], } async def handle_search(request): try: lat = float(request.query.get("lat")) lon = float(request.query.get("lon")) except TypeError: return web.HTTPBadRequest(text="lat/lon are mandatory") except ValueError: return web.HTTPBadRequest(text="lat/lon must be floating-point numbers") lut_lat = round(lat * 1000) lut_lon = round(lon * 1000) evas = set() with conn.cursor() as cur: cur.execute( "select stations from stations where lat between %s and %s and lon between %s and %s", (lut_lat - 3, lut_lat + 3, lut_lon - 3, lut_lon + 3), ) for eva_list in cur.fetchall(): evas.update(eva_list[0]) if not evas: response = {"evas": list(), "trains": list()} return web.Response(body=json.dumps(response), headers=headers) arrivals = list() trains = list() # deliberately not parallelized to minimize load on transport.rest for eva in evas: async with aiohttp.ClientSession() as session: async with session.get( f"https://v5.db.transport.rest/stops/{eva}/arrivals?results=40&duration=120&stopovers=true&bus=false&subway=false&tram=false" ) as response: content = await response.text() content = json.loads(content) arrivals.append(content) for train_list in arrivals: for train in train_list: is_candidate = False for stop in train["previousStopovers"]: if ( int(stop["stop"]["id"]) in evas and stop["stop"]["id"] != train["stop"]["id"] ): is_candidate = True break if is_candidate: trains.append(train) for train in trains: set_coarse_location(train) trains = list(filter(lambda train: "coarse_location" in train, trains)) for train in trains: calculate_distance(train, (lat, lon)) trains = sorted( trains, key=lambda train: 0 if train["preferred"] else train["distance"] ) # remove duplicates. for now, we keep the preferred version, or the one with the lowest estimated distance. # later on, we'll need to request polylines and perform accurate calculations. seen = set() trains = [ seen.add(train["line"]["fahrtNr"]) or train for train in trains if train["line"]["fahrtNr"] not in seen ] trains = sorted(trains, key=lambda train: train["distance"]) trains = list(map(format_train, trains[:10])) response = {"evas": list(evas), "trains": trains} return web.Response(body=json.dumps(response, ensure_ascii=False), headers=headers) if __name__ == "__main__": parser = argparse.ArgumentParser( description="geolocation to train estimation service" ) parser.add_argument("--port", type=int, metavar="PORT", default=8080) parser.add_argument("--prefix", type=str, metavar="PATH", default="/") args = parser.parse_args() app = web.Application() app.add_routes([web.get(f"{args.prefix}search", handle_search)]) web.run_app(app, host="localhost", port=args.port)