Commit 244a3b61 authored by Michael Tremer's avatar Michael Tremer

Add download script to automatically update the database

Signed-off-by: default avatarMichael Tremer <michael.tremer@ipfire.org>
parent bbdb2e0a
......@@ -12,6 +12,7 @@ Makefile.in
/configure
/libtool
/stamp-h1
/src/python/location-downloader
/src/python/location-query
/test.db
/testdata.db
......@@ -210,12 +210,15 @@ uninstall-perl:
$(DESTDIR)/$(prefix)/man/man3/Location.3pm
bin_SCRIPTS = \
src/python/location-downloader \
src/python/location-query
EXTRA_DIST += \
src/python/location-downloader.in \
src/python/location-query.in
CLEANFILES += \
src/python/location-downloader \
src/python/location-query
# ------------------------------------------------------------------------------
......
#!/usr/bin/python3
###############################################################################
# #
# libloc - A library to determine the location of someone on the Internet #
# #
# Copyright (C) 2019 IPFire Development Team <info@ipfire.org> #
# #
# This library is free software; you can redistribute it and/or #
# modify it under the terms of the GNU Lesser General Public #
# License as published by the Free Software Foundation; either #
# version 2.1 of the License, or (at your option) any later version. #
# #
# This library is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU #
# Lesser General Public License for more details. #
# #
###############################################################################
import argparse
import gettext
import lzma
import os
import random
import shutil
import sys
import tempfile
import time
import urllib.error
import urllib.parse
import urllib.request
# Load our location module
import location
import logging
logging.basicConfig(level=logging.INFO)
DATABASE_FILENAME = "test.db.xz"
MIRRORS = (
"https://location.ipfire.org/databases/",
"https://people.ipfire.org/~ms/location/",
)
# i18n
def _(singular, plural=None, n=None):
if plural:
return gettext.dngettext("libloc", singular, plural, n)
return gettext.dgettext("libloc", singular)
class NotModifiedError(Exception):
"""
Raised when the file has not been modified on the server
"""
pass
class Downloader(object):
def __init__(self, mirrors):
self.mirrors = list(mirrors)
# Randomize mirrors
random.shuffle(self.mirrors)
# Get proxies from environment
self.proxies = self._get_proxies()
def _get_proxies(self):
proxies = {}
for protocol in ("https", "http"):
proxy = os.environ.get("%s_proxy" % protocol, None)
if proxy:
proxies[protocol] = proxy
return proxies
def _make_request(self, url, baseurl=None, headers={}):
if baseurl:
url = urllib.parse.urljoin(baseurl, url)
req = urllib.request.Request(url, method="GET")
# Update headers
headers.update({
"User-Agent" : "location-downloader/%s" % location.__version__,
})
# Set headers
for header in headers:
req.add_header(header, headers[header])
# Set proxies
for protocol in self.proxies:
req.set_proxy(self.proxies[protocol], protocol)
return req
def _send_request(self, req, **kwargs):
# Log request headers
logging.debug("HTTP %s Request to %s" % (req.method, req.host))
logging.debug(" URL: %s" % req.full_url)
logging.debug(" Headers:")
for k, v in req.header_items():
logging.debug(" %s: %s" % (k, v))
try:
res = urllib.request.urlopen(req, **kwargs)
except urllib.error.HTTPError as e:
# Log response headers
logging.debug("HTTP Response: %s" % e.code)
logging.debug(" Headers:")
for header in e.headers:
logging.debug(" %s: %s" % (header, e.headers[header]))
# Handle 304
if e.code == 304:
raise NotModifiedError() from e
# Raise all other errors
raise e
# Log response headers
logging.debug("HTTP Response: %s" % res.code)
logging.debug(" Headers:")
for k, v in res.getheaders():
logging.debug(" %s: %s" % (k, v))
return res
def download(self, url, mtime=None, **kwargs):
headers = {}
if mtime:
headers["If-Modified-Since"] = time.strftime(
"%a, %d %b %Y %H:%M:%S GMT", time.gmtime(mtime),
)
t = tempfile.NamedTemporaryFile(delete=False)
with t:
# Try all mirrors
for mirror in self.mirrors:
# Prepare HTTP request
req = self._make_request(url, baseurl=mirror, headers=headers)
try:
with self._send_request(req) as res:
decompressor = lzma.LZMADecompressor()
# Read all data
while True:
buf = res.read(1024)
if not buf:
break
# Decompress data
buf = decompressor.decompress(buf)
if buf:
t.write(buf)
# Write all data to disk
t.flush()
# Nothing to do when the database on the server is up to date
except NotModifiedError:
logging.info("Local database is up to date")
return
# Catch decompression errors
except lzma.LZMAError as e:
logging.warning("Could not decompress downloaded file: %s" % e)
continue
# XXX what do we catch here?
except urllib.error.HTTPError as e:
if e.code == 404:
continue
# Truncate the target file and drop downloaded content
try:
t.truncate()
except OSError:
pass
raise e
# Return temporary file
return t
raise FileNotFoundError(url)
class CLI(object):
def __init__(self):
self.downloader = Downloader(mirrors=MIRRORS)
def parse_cli(self):
parser = argparse.ArgumentParser(
description=_("Location Downloader Command Line Interface"),
)
subparsers = parser.add_subparsers()
# Global configuration flags
parser.add_argument("--debug", action="store_true",
help=_("Enable debug output"))
# version
parser.add_argument("--version", action="version",
version="%%(prog)s %s" % location.__version__)
# database
parser.add_argument("--database", "-d",
default="@databasedir@/database.db", help=_("Path to database"),
)
# Update
update = subparsers.add_parser("update", help=_("Update database"))
update.set_defaults(func=self.handle_update)
args = parser.parse_args()
# Enable debug logging
if args.debug:
logging.basicConfig(level=logging.DEBUG)
# Print usage if no action was given
if not "func" in args:
parser.print_usage()
sys.exit(2)
return args
def run(self):
# Parse command line arguments
args = self.parse_cli()
# Call function
ret = args.func(args)
# Return with exit code
if ret:
sys.exit(ret)
# Otherwise just exit
sys.exit(0)
def handle_update(self, ns):
mtime = None
# Open database
try:
db = location.Database(ns.database)
# Get mtime of the old file
mtime = os.path.getmtime(ns.database)
except FileNotFoundError as e:
db = None
# Try downloading a new database
try:
t = self.downloader.download(DATABASE_FILENAME, mtime=mtime)
# If no file could be downloaded, log a message
except FileNotFoundError as e:
logging.error("Could not download a new database")
return 1
# If we have not received a new file, there is nothing to do
if not t:
return 0
# Save old database creation time
created_at = db.created_at if db else 0
# Try opening the downloaded file
try:
db = location.Database(t.name)
except Exception as e:
raise e
# Check if the downloaded file is newer
if db.created_at <= created_at:
logging.warning("Downloaded database is older than the current version")
return 1
logging.info("Downloaded new database from %s" % (time.strftime(
"%a, %d %b %Y %H:%M:%S GMT", time.gmtime(db.created_at),
)))
# Write temporary file to destination
shutil.copyfile(t.name, ns.database)
# Remove temporary file
os.unlink(t.name)
def main():
# Run the command line interface
c = CLI()
c.run()
main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment