Commit f1af4683 authored by Julien Muchembled's avatar Julien Muchembled Committed by Xavier Thompson

[fix] Rewrite 'urlretrieve' helper to fix various download-related issues

- Py3: stop using legacy API of urllib.request and
       fix download of http(s) URLs containing user:passwd@
- Py2: avoid OOM when downloading huge files

This is implemented as a method in case we want to make it configurable
via [buildout].
parent 0a6d8c19
......@@ -49,3 +49,16 @@ class UserError(Exception):
def __str__(self):
return " ".join(map(str, self.args))
# Used for Python 2-3 compatibility
if str is bytes: # BBB Py2
bytes2str = str2bytes = lambda s: s
def unicode2str(s):
return s.encode('utf-8')
else:
def bytes2str(s):
return s.decode()
def str2bytes(s):
return s.encode()
def unicode2str(s):
return s
......@@ -20,35 +20,17 @@ except ImportError:
try:
# Python 3
from urllib.request import urlretrieve
from urllib.parse import urlparse
from urllib.request import Request, urlopen
from urllib.parse import urlparse, urlunparse
except ImportError:
# Python 2
import base64
from urlparse import urlparse
from urlparse import urlunparse
import urllib2
def urlretrieve(url, tmp_path):
"""Work around Python issue 24599 including basic auth support
"""
scheme, netloc, path, params, query, frag = urlparse(url)
auth, host = urllib2.splituser(netloc)
if auth:
url = urlunparse((scheme, host, path, params, query, frag))
req = urllib2.Request(url)
base64string = base64.encodestring(auth)[:-1]
basic = "Basic " + base64string
req.add_header("Authorization", basic)
else:
req = urllib2.Request(url)
url_obj = urllib2.urlopen(req)
with open(tmp_path, 'wb') as fp:
fp.write(url_obj.read())
return tmp_path, url_obj.info()
from urllib2 import Request, urlopen
from zc.buildout.easy_install import realpath
from base64 import b64encode
from contextlib import closing
import logging
import os
import os.path
......@@ -56,6 +38,7 @@ import re
import shutil
import tempfile
import zc.buildout
from . import bytes2str, str2bytes
from .rmtree import rmtree
......@@ -216,7 +199,7 @@ class Download(object):
if not path:
handle, tmp_path = tempfile.mkstemp(prefix='buildout-')
os.close(handle)
tmp_path, headers = urlretrieve(url, tmp_path)
tmp_path, headers = self.urlretrieve(url, tmp_path)
if not check_md5sum(tmp_path, md5sum):
raise ChecksumError(
'MD5 checksum mismatch downloading %r' % url)
......@@ -257,6 +240,22 @@ class Download(object):
url_host, url_port = parsed[-2:]
return '%s:%s' % (url_host, url_port)
def urlretrieve(self, url, tmp_path):
parsed_url = urlparse(url)
req = url
if parsed_url.scheme in ('http', 'https'):
auth_host = parsed_url.netloc.rsplit('@', 1)
if len(auth_host) > 1:
auth = auth_host[0]
url = parsed_url._replace(netloc=auth_host[1]).geturl()
req = Request(url)
req.add_header("Authorization",
"Basic " + bytes2str(b64encode(str2bytes(auth))))
with closing(urlopen(req)) as src:
with open(tmp_path, 'wb') as dst:
shutil.copyfileobj(src, dst)
return tmp_path, src.info()
def check_md5sum(path, md5sum):
"""Tell whether the MD5 checksum of the file at path matches.
......
......@@ -23,6 +23,7 @@ except ImportError:
from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler
from urllib2 import urlopen
import base64
import errno
import logging
import multiprocessing
......@@ -415,6 +416,23 @@ class Handler(BaseHTTPRequestHandler):
self.__server.__log = False
return k()
if self.path.startswith('/private/'):
auth = self.headers.get('Authorization')
if auth and auth.startswith('Basic ') and \
self.path[9:].encode() == base64.b64decode(
self.headers.get('Authorization')[6:]):
return k()
# But not returning 401+WWW-Authenticate, we check that the client
# skips auth challenge, which is not free (in terms of performance)
# and useless for what we support.
self.send_response(403, 'Forbidden')
out = '<html><body>Forbidden</body></html>'.encode()
self.send_header('Content-Length', str(len(out)))
self.send_header('Content-Type', 'text/html')
self.end_headers()
self.wfile.write(out)
return
path = os.path.abspath(os.path.join(self.tree, *self.path.split('/')))
if not (
((path == self.tree) or path.startswith(self.tree+os.path.sep))
......
......@@ -126,6 +126,19 @@ This is a foo text.
>>> remove(path)
HTTP basic authentication:
>>> download = Download()
>>> user_url = server_url.replace('/localhost:', '/%s@localhost:') + 'private/'
>>> path, is_temp = download(user_url % 'foo:' + 'foo:')
>>> is_temp; remove(path)
True
>>> path, is_temp = download(user_url % 'foo:bar' + 'foo:bar')
>>> is_temp; remove(path)
True
>>> download(user_url % 'bar:' + 'foo:')
Traceback (most recent call last):
UserError: Error downloading ...: HTTP Error 403: Forbidden
Downloading using the download cache
------------------------------------
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment