Commit f8e9882c authored by Xavier Thompson's avatar Xavier Thompson

slapformat: WIP

parent 08b37183
......@@ -25,16 +25,18 @@
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
#
##############################################################################
from __future__ import annotations
import configparser
import ipaddress
import logging
import netifaces
import os
import subprocess
from collections import defaultdict
from netifaces import AF_INET, AF_INET6
from typing import List, Union
def do_format(conf):
# load configuration
......@@ -50,7 +52,9 @@ def do_format(conf):
class UsageError(Exception):
pass
def __init__(self, message):
self.message = message
Exception.__init__(self, message)
class Parameters(object):
......@@ -60,6 +64,7 @@ class Parameters(object):
software_root: str
partition_amount: int
class Options(object):
input_definition_file: str = None
output_definition_file: str = None
......@@ -104,8 +109,6 @@ class FormatConfig(Parameters, Options):
CHECK_FILES = ['key_file', 'cert_file', 'input_definition_file']
NORMALIZE_PATHS = ['instance_root', 'software_root']
logger : logging.Logger
def __init__(self, logger):
self.logger = logger
......@@ -116,23 +119,24 @@ class FormatConfig(Parameters, Options):
def parse(self, name, value, t):
if not isinstance(value, str):
if not isinstance(value, t):
self.error("Option %s takes type %s, not %r", name, t.__name__, value)
if type(value).__name__ != t:
self.error("Option %s takes type %s, not %r", name, t, value)
return value
if t in (int,):
if t == 'int':
try:
return t(value)
return int(value)
except ValueError:
self.error("Option %s takes type %s, not %r", name, t.__name__, value)
if t is bool:
self.error("Option %s takes type %s, not %r", name, t, value)
if t == 'bool':
try:
return {'true': True, 'false': False}[value.lower()]
except KeyError:
return bool(('false', 'true').index(value.lower()))
except IndexError:
self.error("Option %r must be 'true' or 'false', not %r", name, value)
return value
def get(self, option):
try:
return gettatr(self, option)
return getattr(self, option)
except AttributeError:
self.error("Parameter %r is not defined", option)
......@@ -163,21 +167,34 @@ class FormatConfig(Parameters, Options):
for option in self.CHECK_FILES:
path = getattr(self, option)
if path is not None and not os.path.exists(path):
if path:
if not os.path.exists(path):
self.error("File %r does not exist or is not readable", path)
setattr(self, option, os.path.abspath(path))
for option in self.NORMALIZE_PATHS:
setattr(self, option, os.path.abspath(getattr(self, option)))
path = getattr(self, option)
if path:
setattr(self, option, os.path.abspath(path))
# XXX Check command line tools + Logs
class Definition(defaultdict):
def __init__(self, definition_file):
super().__init__(lambda: defaultdict(type(None)))
if definition_file:
configp = configparser.ConfigParser(interpolation=None)
configp.read(conf.input_definition_file)
for s in configp.sections():
self[s] = dict(configp.items(s))
class Computer(object):
reference : str
interface : Interface
partitions : List[Partition]
address : Union[ipaddress.IPv4Interface, ipaddress.IPv6Interface]
partitions : list[Partition]
address : ipaddress.IPv4Interface or ipaddress.IPv6Interface
user : User
conf : FormatConfig
......@@ -185,39 +202,18 @@ class Computer(object):
self.conf = conf
self.reference = conf.computer_id
self.interface = Interface(conf)
self.address = self.interface.getComputerIPv6Addr()
self.user = User(conf.software_user, conf.software_root)
definition = None
if conf.input_definition_file:
definition = configparser.ConfigParser(interpolation=None)
definition.read(conf.input_definition_file)
if definition.has_option('computer', 'address')
address = definition.get('computer', 'address')
self.address = ipaddress.ip_interface(address)
if definition.has_option('computer', 'software_user')
user = definition.get('computer', 'software_user')
self.user = User(user, conf.software_root)
definition = Definition(conf.input_definition_file)
computer = definition['computer']
addr = computer['address']
address = ipaddress.ip_interface(addr) if addr else None
self.address = address or self.interface.getComputerIPv6Addr()
username = computer['software_user'] or conf.software_user
self.user = User(username, conf.software_root)
amount = conf.partition_amount
self.partitions = [Partition(i, conf, definition) for i in range(amount)]
self.partitions = [Partition(i, self, definition) for i in range(amount)]
def validate(self):
conf = self.conf
addresses = {4 : [], 6 : []}
networks = {4 : [], 6 : []}
for p in self.partitions:
addresses[4].extend(p.ipv4_list)
addresses[6]extend(p.ipv6_list)
networks[6]append(p.ipv6_range)
ipv4_addresses.sort()
ipv6_addresses.sort()
ipv4_networks.sort()
ipv6_networks.sort()
for network_list in networks.values()
for i, n in enumerate(network_list[:-1])
if n.overlaps(network_list[i + 1]):
self.conv.warning(
"Network configurations overlap"
)
pass
def format(self):
pass
......@@ -248,7 +244,7 @@ class Interface(object):
# XXX allow ipv4_local_network to be None ?
return ipaddress.IPv4Network(cidr, strict=False)
def getPartitionIPv4(self, index):
def getPartitionIPv4Addr(self, index):
return self.ipv4_network[index + 2]
def getIPv6Network(self):
......@@ -284,7 +280,7 @@ class Interface(object):
if prefixlen > 128:
self.conf.error("IPv6 network %s is too small for IPv6 ranges", network)
bits = 128 - network.prefixlen
addr = network[(1 << (bits - 2)) + (i << (128 - prefixlen))]
addr = network[(1 << (bits - 2)) + (index << (128 - prefixlen))]
return ipaddress.IPv6Network((addr, prefixlen))
......@@ -293,54 +289,47 @@ class Partition(object):
index: int
path : str
user : User
ipv4_list: List[ipaddress.IPv4Interface]
ipv6_list: List[ipaddress.IPv6Interface]
ipv4_list: list[ipaddress.IPv4Interface]
ipv6_list: list[ipaddress.IPv6Interface]
ipv6_range: ipaddress.IPv6Network
tap : Tap
tun : Tun
def __init__(self, index, computer, definition=None):
self.from_conf(index, computer)
if definition:
self.from_definition(index, computer, definition)
def from_definition(cls, index, computer, definition):
conf = computer.conf
section = 'partition_%d' % index
options = {}
if definition.has_section('default'):
options.update(definition.items('default'))
if definition.has_section(section):
options.update(definition.items(section))
if 'pathname' in options:
self.reference = options['pathname']
self.path = os.path.join(conf.instance_root, self.reference)
if 'user' in options:
self.user = User(options['user'], self.path)
if 'address' in options:
address_list = [ipaddress.ip_interface(a) for a in options['address']]
for v in (4, 6):
ip_list = [ip for ip in ip_addresses if ip.version == v]
if ip_list:
setattr(self, 'ipv%d_list' % v, ip_list)
# tap = Tap(computer_definition.get(section, 'network_interface'))
# tun = Tun.load(conf, index)
def from_conf(self, index, computer):
i = str(index)
conf = computer.conf
self.reference = '%s%d' % (conf.partition_base_name, index)
interface = computer.interface
section = definition['partition_' + i]
options = defaultdict(type(None), section, **definition['default'])
# Reference, path & user
self.reference = options['pathname'] or conf.partition_base_name + i
self.path = os.path.join(conf.instance_root, self.reference)
self.user = User('%s%d' % (conf.user_base_name, index), self.path)
self.ipv4_list = [computer.interface.getPartitionIPv4(index)]
self.ipv6_list = [computer.interface.getPartitionIPv6(index)]
# XXX Tap & tun
self.user = User(options['user'] or conf.user_base_name + i, self.path)
# IPv4 & IPv6 addresses
ipv4_list = ipv6_list = None
addresses = options['address']
if addresses:
address_list = [ipaddress.ip_interface(a) for a in addresses.split(' ')]
ipv4_list = [ip for ip in ip_addresses if ip.version == 4]
ipv6_list = [ip for ip in ip_addresses if ip.version == 6]
self.ipv4_list = ipv4_list or [interface.getPartitionIPv4Addr(index)]
self.ipv6_list = ipv6_list or [interface.getPartitionIPv6Addr(index)]
# IPv6 range
ipv6_range = options['ipv6_range']
if ipv6_range:
self.ipv6_range = ipaddress.IPv6Network(ipv6_range, strict=False)
elif conf.ipv6_range:
self.ipv6_range = interface.getPartitionIPv6Range(index)
# Tap & Tun
# XXX
def format(self):
pass
def createPath(self):
self.path = os.path.abspath(self.path)
owner = self.user if self.user else User('root')
if not os.path.exists(self.path):
os.mkdir(self.path, 0o750)
owner_pw = pwd.getpwnam(owner.name)
os.mkdir(self.path)
owner_pw = pwd.getpwnam(self.user.name)
os.chown(self.path, owner_pw.pw_uid, owner_pw.pw_gid)
os.chmod(self.path, 0o750)
......@@ -348,7 +337,7 @@ class Partition(object):
class User(object):
name: str
path: str
groups: List[str]
groups: list[str]
SHELL = '/bin/sh'
......@@ -390,7 +379,6 @@ class User(object):
except KeyError:
return False
# Utilities
def callAndRead(argument_list, raise_on_error=True):
......@@ -407,5 +395,70 @@ def callAndRead(argument_list, raise_on_error=True):
return popen.returncode, result
# Tracing
class OS(object):
"""Wrap parts of the 'os' module to provide logging of performed actions."""
_os = os
def __init__(self, conf):
self._dry_run = conf.dry_run
self._logger = conf.logger
add = self._addWrapper
add('chown')
add('chmod')
add('makedirs')
add('mkdir')
def _addWrapper(self, name):
def wrapper(*args, **kw):
arg_list = [repr(x) for x in args] + [
'%s=%r' % (x, y) for x, y in six.iteritems(kw)
]
self._logger.debug('%s(%s)' % (name, ', '.join(arg_list)))
if not self._dry_run:
getattr(self._os, name)(*args, **kw)
setattr(self, name, wrapper)
def __getattr__(self, name):
return getattr(self._os, name)
tracing_monkeypatch_mark = []
def tracing_monkeypatch(conf):
pass
"""Substitute os module and callAndRead function with tracing wrappers."""
# This function is called again if "slapos node boot" failed.
# Don't wrap the logging method again, otherwise the output becomes double.
if tracing_monkeypatch_mark:
return
global os
global callAndRead
original_callAndRead = callAndRead
os = OS(conf)
if conf.dry_run:
def dry_callAndRead(argument_list, raise_on_error=True):
return 0, ''
applied_callAndRead = dry_callAndRead
def fake_getpwnam(user):
class result(object):
pw_uid = 12345
pw_gid = 54321
return result
pwd.getpwnam = fake_getpwnam
else:
applied_callAndRead = original_callAndRead
def logging_callAndRead(argument_list, raise_on_error=True):
conf.logger.debug(' '.join(argument_list))
return applied_callAndRead(argument_list, raise_on_error)
callAndRead = logging_callAndRead
# Put a mark. This function was called once.
tracing_monkeypatch_mark.append(None)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment