Commit 40ef4a8f authored by Pedro Oliveira's avatar Pedro Oliveira Committed by GitHub

Merge pull request #1 from pedrofran12/PIM-DM

Pim dm implementation
parents bcefc317 dec96dcb
from time import time
try:
from threading import _Timer as Timer
except ImportError:
from threading import Timer
class RemainingTimer(Timer):
def __init__(self, interval, function):
super().__init__(interval, function)
self.start_time = time()
def time_remaining(self):
delta_time = time() - self.start_time
return self.interval - delta_time
'''
def test():
print("ola")
x = RemainingTimer(10, test)
x.start()
from time import sleep
for i in range(0, 10):
print(x.time_remaining())
sleep(1)
'''
"""Generic linux daemon base class for python 3.x."""
import sys, os, time, atexit, signal
class Daemon:
"""A generic Daemon class.
Usage: subclass the Daemon class and override the run() method."""
def __init__(self, pidfile): self.pidfile = pidfile
def daemonize(self):
"""Deamonize class. UNIX double fork mechanism."""
try:
pid = os.fork()
if pid > 0:
# exit first parent
sys.exit(0)
except OSError as err:
sys.stderr.write('fork #1 failed: {0}\n'.format(err))
sys.exit(1)
# decouple from parent environment
#os.chdir('/')
#os.setsid()
#os.umask(0)
# do second fork
try:
pid = os.fork()
if pid > 0:
# exit from second parent
sys.exit(0)
except OSError as err:
sys.stderr.write('fork #2 failed: {0}\n'.format(err))
sys.exit(1)
# redirect standard file descriptors
sys.stdout.flush()
sys.stderr.flush()
si = open(os.devnull, 'r')
so = open('stdout', 'a+')
se = open('stderror', 'a+')
os.dup2(si.fileno(), sys.stdin.fileno())
os.dup2(so.fileno(), sys.stdout.fileno())
os.dup2(se.fileno(), sys.stderr.fileno())
# write pidfile
atexit.register(self.delpid)
pid = str(os.getpid())
with open(self.pidfile, 'w+') as f:
f.write(pid + '\n')
def delpid(self):
os.remove(self.pidfile)
def start(self):
"""Start the Daemon."""
# Check for a pidfile to see if the Daemon already runs
if self.is_running():
message = "pidfile {0} already exist. " + \
"Daemon already running?\n"
sys.stderr.write(message.format(self.pidfile))
sys.exit(1)
# Start the Daemon
self.daemonize()
self.run()
def stop(self):
"""Stop the Daemon."""
# Get the pid from the pidfile
try:
with open(self.pidfile, 'r') as pf:
pid = int(pf.read().strip())
except IOError:
pid = None
if not pid:
message = "pidfile {0} does not exist. " + \
"Daemon not running?\n"
sys.stderr.write(message.format(self.pidfile))
return # not an error in a restart
# Try killing the Daemon process
try:
while 1:
#os.killpg(os.getpgid(pid), signal.SIGTERM)
os.kill(pid, signal.SIGTERM)
time.sleep(0.1)
except OSError as err:
e = str(err.args)
if e.find("No such process") > 0:
if os.path.exists(self.pidfile):
os.remove(self.pidfile)
else:
print(str(err.args))
sys.exit(1)
def restart(self):
"""Restart the Daemon."""
self.stop()
self.start()
def run(self):
"""You should override this method when you subclass Daemon.
It will be called after the process has been daemonized by
start() or restart()."""
def is_running(self):
try:
with open(self.pidfile, 'r') as pf:
pid = int(pf.read().strip())
except IOError:
return False
""" Check For the existence of a unix pid. """
try:
os.kill(pid, 0)
return True
except:
return False
import socket
from abc import ABCMeta, abstractmethod
import threading
import random
import netifaces
from Packet.ReceivedPacket import ReceivedPacket
import Main
import traceback
from RWLock.RWLock import RWLockWrite
class Interface(metaclass=ABCMeta):
MCAST_GRP = '224.0.0.13'
def __init__(self, interface_name, recv_socket, send_socket, vif_index):
self.interface_name = interface_name
# virtual interface index for the multicast routing table
self.vif_index = vif_index
# set receive socket and send socket
self._send_socket = send_socket
self._recv_socket = recv_socket
self.interface_enabled = False
def _enable(self):
self.interface_enabled = True
# run receive method in background
receive_thread = threading.Thread(target=self.receive)
receive_thread.daemon = True
receive_thread.start()
def receive(self):
while self.interface_enabled:
try:
(raw_bytes, _) = self._recv_socket.recvfrom(256 * 1024)
if raw_bytes:
self._receive(raw_bytes)
except Exception:
traceback.print_exc()
continue
@abstractmethod
def _receive(self, raw_bytes):
raise NotImplementedError
def send(self, data: bytes, group_ip: str):
if self.interface_enabled and data:
self._send_socket.sendto(data, (group_ip, 0))
def remove(self):
self.interface_enabled = False
try:
self._recv_socket.shutdown(socket.SHUT_RDWR)
except Exception:
pass
self._recv_socket.close()
self._send_socket.close()
def is_enabled(self):
return self.interface_enabled
@abstractmethod
def get_ip(self):
raise NotImplementedError
\ No newline at end of file
import socket
import struct
from ipaddress import IPv4Address
from ctypes import create_string_buffer, addressof
import netifaces
from Packet.ReceivedPacket import ReceivedPacket
from Interface import Interface
from utils import Version_1_Membership_Report, Version_2_Membership_Report, Leave_Group, Membership_Query
if not hasattr(socket, 'SO_BINDTODEVICE'):
socket.SO_BINDTODEVICE = 25
class InterfaceIGMP(Interface):
ETH_P_IP = 0x0800 # Internet Protocol packet
SO_ATTACH_FILTER = 26
FILTER_IGMP = [
struct.pack('HBBI', 0x28, 0, 0, 0x0000000c),
struct.pack('HBBI', 0x15, 0, 3, 0x00000800),
struct.pack('HBBI', 0x30, 0, 0, 0x00000017),
struct.pack('HBBI', 0x15, 0, 1, 0x00000002),
struct.pack('HBBI', 0x6, 0, 0, 0x00040000),
struct.pack('HBBI', 0x6, 0, 0, 0x00000000),
]
def __init__(self, interface_name: str, vif_index: int):
# SEND SOCKET
snd_s = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_IGMP)
# bind to interface
snd_s.setsockopt(socket.SOL_SOCKET, socket.SO_BINDTODEVICE, str(interface_name + "\0").encode('utf-8'))
# RECEIVE SOCKET
rcv_s = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, socket.htons(InterfaceIGMP.ETH_P_IP))
# receive only IGMP packets by setting a BPF filter
bpf_filter = b''.join(InterfaceIGMP.FILTER_IGMP)
b = create_string_buffer(bpf_filter)
mem_addr_of_filters = addressof(b)
fprog = struct.pack('HL', len(InterfaceIGMP.FILTER_IGMP), mem_addr_of_filters)
rcv_s.setsockopt(socket.SOL_SOCKET, InterfaceIGMP.SO_ATTACH_FILTER, fprog)
# bind to interface
rcv_s.bind((interface_name, 0x0800))
super().__init__(interface_name=interface_name, recv_socket=rcv_s, send_socket=snd_s, vif_index=vif_index)
self.interface_enabled = True
from igmp.RouterState import RouterState
self.interface_state = RouterState(self)
super()._enable()
def get_ip(self):
return netifaces.ifaddresses(self.interface_name)[netifaces.AF_INET][0]['addr']
@property
def ip_interface(self):
return self.get_ip()
def send(self, data: bytes, address: str="224.0.0.1"):
super().send(data, address)
def _receive(self, raw_bytes):
if raw_bytes:
raw_bytes = raw_bytes[14:]
packet = ReceivedPacket(raw_bytes, self)
ip_src = packet.ip_header.ip_src
if not (ip_src == "0.0.0.0" or IPv4Address(ip_src).is_multicast):
self.PKT_FUNCTIONS.get(packet.payload.get_igmp_type(), InterfaceIGMP.receive_unknown_type)(self, packet)
###########################################
# Recv packets
###########################################
def receive_version_1_membership_report(self, packet):
ip_dst = packet.ip_header.ip_dst
igmp_group = packet.payload.group_address
if ip_dst == igmp_group and IPv4Address(igmp_group).is_multicast:
self.interface_state.receive_v1_membership_report(packet)
def receive_version_2_membership_report(self, packet):
ip_dst = packet.ip_header.ip_dst
igmp_group = packet.payload.group_address
if ip_dst == igmp_group and IPv4Address(igmp_group).is_multicast:
self.interface_state.receive_v2_membership_report(packet)
def receive_leave_group(self, packet):
ip_dst = packet.ip_header.ip_dst
igmp_group = packet.payload.group_address
if ip_dst == "224.0.0.2" and IPv4Address(igmp_group).is_multicast:
self.interface_state.receive_leave_group(packet)
def receive_membership_query(self, packet):
ip_dst = packet.ip_header.ip_dst
igmp_group = packet.payload.group_address
if ip_dst == igmp_group or (ip_dst == "224.0.0.1" and igmp_group == "0.0.0.0"):
self.interface_state.receive_query(packet)
def receive_unknown_type(self, packet):
return
PKT_FUNCTIONS = {
Version_1_Membership_Report: receive_version_1_membership_report,
Version_2_Membership_Report: receive_version_2_membership_report,
Leave_Group: receive_leave_group,
Membership_Query: receive_membership_query,
}
##################
def remove(self):
super().remove()
self.interface_state.remove()
import random
from Interface import Interface
from Packet.ReceivedPacket import ReceivedPacket
import Main
import traceback
from RWLock.RWLock import RWLockWrite
from Packet.PacketPimHelloOptions import *
from Packet.PacketPimHello import PacketPimHello
from Packet.PacketPimHeader import PacketPimHeader
from Packet.Packet import Packet
from utils import HELLO_HOLD_TIME_TIMEOUT
from threading import Timer
from tree.globals import REFRESH_INTERVAL
import socket
import netifaces
import logging
class InterfacePim(Interface):
MCAST_GRP = '224.0.0.13'
PROPAGATION_DELAY = 0.5
OVERRIDE_INTERNAL = 2.5
HELLO_PERIOD = 30
TRIGGERED_HELLO_PERIOD = 5
LOGGER = logging.getLogger('pim.Interface')
def __init__(self, interface_name: str, vif_index:int, state_refresh_capable:bool=False):
# generation id
self.generation_id = random.getrandbits(32)
# When PIM is enabled on an interface or when a router first starts, the Hello Timer (HT)
# MUST be set to random value between 0 and Triggered_Hello_Delay
self.hello_timer = None
# state refresh capable
self._state_refresh_capable = state_refresh_capable
self._neighbors_state_refresh_capable = False
# todo: lan delay enabled
self._lan_delay_enabled = False
# todo: propagation delay
self._propagation_delay = self.PROPAGATION_DELAY
# todo: override interval
self._override_interval = self.OVERRIDE_INTERNAL
# pim neighbors
self._had_neighbors = False
self.neighbors = {}
self.neighbors_lock = RWLockWrite()
self.interface_logger = logging.LoggerAdapter(InterfacePim.LOGGER, {'vif': vif_index, 'interfacename': interface_name})
# SOCKET
ip_interface = netifaces.ifaddresses(interface_name)[netifaces.AF_INET][0]['addr']
self.ip_interface = ip_interface
s = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_PIM)
# allow other sockets to bind this port too
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# explicitly join the multicast group on the interface specified
#s.setsockopt(socket.SOL_IP, socket.IP_ADD_MEMBERSHIP, socket.inet_aton(Interface.MCAST_GRP) + socket.inet_aton(ip_interface))
s.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP,
socket.inet_aton(Interface.MCAST_GRP) + socket.inet_aton(ip_interface))
s.setsockopt(socket.SOL_SOCKET, 25, str(interface_name + '\0').encode('utf-8'))
# set socket output interface
s.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_IF, socket.inet_aton(ip_interface))
# set socket TTL to 1
s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 1)
s.setsockopt(socket.IPPROTO_IP, socket.IP_TTL, 1)
# don't receive outgoing packets
s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, 0)
super().__init__(interface_name, s, s, vif_index)
super()._enable()
self.force_send_hello()
def get_ip(self):
return self.ip_interface
def _receive(self, raw_bytes):
if raw_bytes:
packet = ReceivedPacket(raw_bytes, self)
self.PKT_FUNCTIONS[packet.payload.get_pim_type()](self, packet)
def send(self, data: bytes, group_ip: str=MCAST_GRP):
super().send(data=data, group_ip=group_ip)
#Random interval for initial Hello message on bootup or triggered Hello message to a rebooting neighbor
def force_send_hello(self):
if self.hello_timer is not None:
self.hello_timer.cancel()
hello_timer_time = random.uniform(0, self.TRIGGERED_HELLO_PERIOD)
self.hello_timer = Timer(hello_timer_time, self.send_hello)
self.hello_timer.start()
def send_hello(self):
self.interface_logger.debug('Send Hello message')
self.hello_timer.cancel()
pim_payload = PacketPimHello()
pim_payload.add_option(PacketPimHelloHoldtime(holdtime=3.5 * self.HELLO_PERIOD))
pim_payload.add_option(PacketPimHelloGenerationID(self.generation_id))
# TODO implementar LANPRUNEDELAY e OVERRIDE_INTERVAL por interface e nas maquinas de estados ler valor de interface e nao do globals.py
#pim_payload.add_option(PacketPimHelloLANPruneDelay(lan_prune_delay=self._propagation_delay, override_interval=self._override_interval))
if self._state_refresh_capable:
pim_payload.add_option(PacketPimHelloStateRefreshCapable(REFRESH_INTERVAL))
ph = PacketPimHeader(pim_payload)
packet = Packet(payload=ph)
self.send(packet.bytes())
# reschedule hello_timer
self.hello_timer = Timer(self.HELLO_PERIOD, self.send_hello)
self.hello_timer.start()
def remove(self):
self.hello_timer.cancel()
self.hello_timer = None
# send pim_hello timeout message
pim_payload = PacketPimHello()
pim_payload.add_option(PacketPimHelloHoldtime(holdtime=HELLO_HOLD_TIME_TIMEOUT))
pim_payload.add_option(PacketPimHelloGenerationID(self.generation_id))
ph = PacketPimHeader(pim_payload)
packet = Packet(payload=ph)
self.send(packet.bytes())
Main.kernel.interface_change_number_of_neighbors()
super().remove()
def check_number_of_neighbors(self):
has_neighbors = len(self.neighbors) > 0
if has_neighbors != self._had_neighbors:
self._had_neighbors = has_neighbors
Main.kernel.interface_change_number_of_neighbors()
def new_or_reset_neighbor(self, neighbor_ip):
Main.kernel.new_or_reset_neighbor(self.vif_index, neighbor_ip)
'''
def add_neighbor(self, ip, random_number, hello_hold_time):
with self.neighbors_lock.genWlock():
if ip not in self.neighbors:
print("ADD NEIGHBOR")
from Neighbor import Neighbor
self.neighbors[ip] = Neighbor(self, ip, random_number, hello_hold_time)
self.force_send_hello()
self.check_number_of_neighbors()
'''
def get_neighbors(self):
with self.neighbors_lock.genRlock():
return self.neighbors.values()
def get_neighbor(self, ip):
with self.neighbors_lock.genRlock():
return self.neighbors.get(ip)
def remove_neighbor(self, ip):
with self.neighbors_lock.genWlock():
del self.neighbors[ip]
self.interface_logger.debug("Remove neighbor: " + ip)
self.check_number_of_neighbors()
def is_state_refresh_enabled(self):
return self._state_refresh_capable
# check if Interface is StateRefreshCapable
def is_state_refresh_capable(self):
with self.neighbors_lock.genWlock():
if len(self.neighbors) == 0:
return False
state_refresh_capable = True
for neighbor in list(self.neighbors.values()):
state_refresh_capable &= neighbor.state_refresh_capable
return state_refresh_capable
'''
def change_interface(self):
# check if ip change was already applied to interface
old_ip_address = self.ip_interface
new_ip_interface = netifaces.ifaddresses(self.interface_name)[netifaces.AF_INET][0]['addr']
if old_ip_address == new_ip_interface:
return
self._send_socket.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_IF, socket.inet_aton(new_ip_interface))
self._recv_socket.setsockopt(socket.IPPROTO_IP, socket.IP_DROP_MEMBERSHIP,
socket.inet_aton(Interface.MCAST_GRP) + socket.inet_aton(old_ip_address))
self._recv_socket.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP,
socket.inet_aton(Interface.MCAST_GRP) + socket.inet_aton(new_ip_interface))
self.ip_interface = new_ip_interface
'''
###########################################
# Recv packets
###########################################
def receive_hello(self, packet):
ip = packet.ip_header.ip_src
print("ip = ", ip)
options = packet.payload.payload.get_options()
if (1 in options) and (20 in options):
hello_hold_time = options[1].holdtime
generation_id = options[20].generation_id
else:
raise Exception
state_refresh_capable = (21 in options)
with self.neighbors_lock.genWlock():
if ip not in self.neighbors:
if hello_hold_time == 0:
return
print("ADD NEIGHBOR")
from Neighbor import Neighbor
self.neighbors[ip] = Neighbor(self, ip, generation_id, hello_hold_time, state_refresh_capable)
self.force_send_hello()
self.check_number_of_neighbors()
self.new_or_reset_neighbor(ip)
return
else:
neighbor = self.neighbors[ip]
neighbor.receive_hello(generation_id, hello_hold_time, state_refresh_capable)
def receive_assert(self, packet):
pkt_assert = packet.payload.payload # type: PacketPimAssert
source = pkt_assert.source_address
group = pkt_assert.multicast_group_address
source_group = (source, group)
try:
Main.kernel.get_routing_entry(source_group).recv_assert_msg(self.vif_index, packet)
except:
traceback.print_exc()
def receive_join_prune(self, packet):
pkt_join_prune = packet.payload.payload # type: PacketPimJoinPrune
join_prune_groups = pkt_join_prune.groups
for group in join_prune_groups:
multicast_group = group.multicast_group
joined_src_addresses = group.joined_src_addresses
pruned_src_addresses = group.pruned_src_addresses
for source_address in joined_src_addresses:
source_group = (source_address, multicast_group)
try:
Main.kernel.get_routing_entry(source_group).recv_join_msg(self.vif_index, packet)
except:
traceback.print_exc()
continue
for source_address in pruned_src_addresses:
source_group = (source_address, multicast_group)
try:
Main.kernel.get_routing_entry(source_group).recv_prune_msg(self.vif_index, packet)
except:
traceback.print_exc()
continue
def receive_graft(self, packet):
pkt_join_prune = packet.payload.payload # type: PacketPimGraft
join_prune_groups = pkt_join_prune.groups
for group in join_prune_groups:
multicast_group = group.multicast_group
joined_src_addresses = group.joined_src_addresses
for source_address in joined_src_addresses:
source_group = (source_address, multicast_group)
try:
Main.kernel.get_routing_entry(source_group).recv_graft_msg(self.vif_index, packet)
except:
traceback.print_exc()
continue
def receive_graft_ack(self, packet):
pkt_join_prune = packet.payload.payload # type: PacketPimGraftAck
join_prune_groups = pkt_join_prune.groups
for group in join_prune_groups:
multicast_group = group.multicast_group
joined_src_addresses = group.joined_src_addresses
for source_address in joined_src_addresses:
source_group = (source_address, multicast_group)
try:
Main.kernel.get_routing_entry(source_group).recv_graft_ack_msg(self.vif_index, packet)
except:
traceback.print_exc()
continue
def receive_state_refresh(self, packet):
if not self.is_state_refresh_enabled():
return
pkt_state_refresh = packet.payload.payload # type: PacketPimStateRefresh
source = pkt_state_refresh.source_address
group = pkt_state_refresh.multicast_group_adress
source_group = (source, group)
try:
Main.kernel.get_routing_entry(source_group).recv_state_refresh_msg(self.vif_index, packet)
except:
traceback.print_exc()
PKT_FUNCTIONS = {
0: receive_hello,
3: receive_join_prune,
5: receive_assert,
6: receive_graft,
7: receive_graft_ack,
9: receive_state_refresh,
}
import socket
import struct
from threading import Lock, Thread
import traceback
import ipaddress
from RWLock.RWLock import RWLockWrite
import Main
import UnicastRouting
from InterfacePIM import InterfacePim
from InterfaceIGMP import InterfaceIGMP
from tree.KernelEntry import KernelEntry
class Kernel:
# MRT
MRT_BASE = 200
MRT_INIT = (MRT_BASE) # /* Activate the kernel mroute code */
MRT_DONE = (MRT_BASE + 1) # /* Shutdown the kernel mroute */
MRT_ADD_VIF = (MRT_BASE + 2) # /* Add a virtual interface */
MRT_DEL_VIF = (MRT_BASE + 3) # /* Delete a virtual interface */
MRT_ADD_MFC = (MRT_BASE + 4) # /* Add a multicast forwarding entry */
MRT_DEL_MFC = (MRT_BASE + 5) # /* Delete a multicast forwarding entry */
MRT_VERSION = (MRT_BASE + 6) # /* Get the kernel multicast version */
MRT_ASSERT = (MRT_BASE + 7) # /* Activate PIM assert mode */
MRT_PIM = (MRT_BASE + 8) # /* enable PIM code */
MRT_TABLE = (MRT_BASE + 9) # /* Specify mroute table ID */
#MRT_ADD_MFC_PROXY = (MRT_BASE + 10) # /* Add a (*,*|G) mfc entry */
#MRT_DEL_MFC_PROXY = (MRT_BASE + 11) # /* Del a (*,*|G) mfc entry */
#MRT_MAX = (MRT_BASE + 11)
# Max Number of Virtual Interfaces
MAXVIFS = 32
# SIGNAL MSG TYPE
IGMPMSG_NOCACHE = 1
IGMPMSG_WRONGVIF = 2
IGMPMSG_WHOLEPKT = 3 # NOT USED ON PIM-DM
# Interface flags
VIFF_TUNNEL = 0x1 # IPIP tunnel
VIFF_SRCRT = 0x2 # NI
VIFF_REGISTER = 0x4 # register vif
VIFF_USE_IFINDEX = 0x8 # use vifc_lcl_ifindex instead of vifc_lcl_addr to find an interface
def __init__(self):
# Kernel is running
self.running = True
# KEY : interface_ip, VALUE : vif_index
self.vif_dic = {}
self.vif_index_to_name_dic = {}
self.vif_name_to_index_dic = {}
# KEY : source_ip, VALUE : {group_ip: KernelEntry}
self.routing = {}
s = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_IGMP)
# MRT INIT
s.setsockopt(socket.IPPROTO_IP, Kernel.MRT_INIT, 1)
# MRT PIM
s.setsockopt(socket.IPPROTO_IP, Kernel.MRT_PIM, 0)
s.setsockopt(socket.IPPROTO_IP, Kernel.MRT_ASSERT, 1)
self.socket = s
self.rwlock = RWLockWrite()
self.interface_lock = Lock()
# Create register interface
# todo useless in PIM-DM... useful in PIM-SM
#self.create_virtual_interface("0.0.0.0", "pimreg", index=0, flags=Kernel.VIFF_REGISTER)
self.pim_interface = {} # name: interface_pim
self.igmp_interface = {} # name: interface_igmp
# logs
self.interface_logger = Main.logger.getChild('KernelInterface')
self.tree_logger = Main.logger.getChild('KernelTree')
# receive signals from kernel with a background thread
handler_thread = Thread(target=self.handler)
handler_thread.daemon = True
handler_thread.start()
'''
Structure to create/remove virtual interfaces
struct vifctl {
vifi_t vifc_vifi; /* Index of VIF */
unsigned char vifc_flags; /* VIFF_ flags */
unsigned char vifc_threshold; /* ttl limit */
unsigned int vifc_rate_limit; /* Rate limiter values (NI) */
union {
struct in_addr vifc_lcl_addr; /* Local interface address */
int vifc_lcl_ifindex; /* Local interface index */
};
struct in_addr vifc_rmt_addr; /* IPIP tunnel addr */
};
'''
def create_virtual_interface(self, ip_interface: str or bytes, interface_name: str, index, flags=0x0):
if type(ip_interface) is str:
ip_interface = socket.inet_aton(ip_interface)
struct_mrt_add_vif = struct.pack("HBBI 4s 4s", index, flags, 1, 0, ip_interface,
socket.inet_aton("0.0.0.0"))
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_ADD_VIF, struct_mrt_add_vif)
self.vif_dic[socket.inet_ntoa(ip_interface)] = index
self.vif_index_to_name_dic[index] = interface_name
self.vif_name_to_index_dic[interface_name] = index
with self.rwlock.genWlock():
for source_dict in list(self.routing.values()):
for kernel_entry in list(source_dict.values()):
kernel_entry.new_interface(index)
self.interface_logger.debug('Create virtual interface: %s -> %d', interface_name, index)
return index
def create_pim_interface(self, interface_name: str, state_refresh_capable:bool):
with self.interface_lock:
pim_interface = self.pim_interface.get(interface_name)
igmp_interface = self.igmp_interface.get(interface_name)
vif_already_exists = pim_interface or igmp_interface
if pim_interface:
# already exists
return
elif igmp_interface:
index = igmp_interface.vif_index
else:
index = list(range(0, self.MAXVIFS) - self.vif_index_to_name_dic.keys())[0]
ip_interface = None
if interface_name not in self.pim_interface:
pim_interface = InterfacePim(interface_name, index, state_refresh_capable)
self.pim_interface[interface_name] = pim_interface
ip_interface = pim_interface.ip_interface
if not vif_already_exists:
self.create_virtual_interface(ip_interface=ip_interface, interface_name=interface_name, index=index)
def create_igmp_interface(self, interface_name: str):
with self.interface_lock:
pim_interface = self.pim_interface.get(interface_name)
igmp_interface = self.igmp_interface.get(interface_name)
vif_already_exists = pim_interface or igmp_interface
if igmp_interface:
# already exists
return
elif pim_interface:
index = pim_interface.vif_index
else:
index = list(range(0, self.MAXVIFS) - self.vif_index_to_name_dic.keys())[0]
ip_interface = None
if interface_name not in self.igmp_interface:
igmp_interface = InterfaceIGMP(interface_name, index)
self.igmp_interface[interface_name] = igmp_interface
ip_interface = igmp_interface.ip_interface
if not vif_already_exists:
self.create_virtual_interface(ip_interface=ip_interface, interface_name=interface_name, index=index)
def remove_interface(self, interface_name, igmp:bool=False, pim:bool=False):
with self.interface_lock:
ip_interface = None
pim_interface = self.pim_interface.get(interface_name)
igmp_interface = self.igmp_interface.get(interface_name)
if (igmp and not igmp_interface) or (pim and not pim_interface) or (not igmp and not pim):
return
if pim:
pim_interface = self.pim_interface.pop(interface_name)
ip_interface = pim_interface.ip_interface
pim_interface.remove()
elif igmp:
igmp_interface = self.igmp_interface.pop(interface_name)
ip_interface = igmp_interface.ip_interface
igmp_interface.remove()
if (not self.igmp_interface.get(interface_name) and not self.pim_interface.get(interface_name)):
self.remove_virtual_interface(ip_interface)
def remove_virtual_interface(self, ip_interface):
#with self.interface_lock:
index = self.vif_dic[ip_interface]
struct_vifctl = struct.pack("HBBI 4s 4s", index, 0, 0, 0, socket.inet_aton("0.0.0.0"), socket.inet_aton("0.0.0.0"))
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_DEL_VIF, struct_vifctl)
del self.vif_dic[ip_interface]
del self.vif_name_to_index_dic[self.vif_index_to_name_dic[index]]
interface_name = self.vif_index_to_name_dic.pop(index)
# alterar MFC's para colocar a 0 esta interface
with self.rwlock.genWlock():
for source_dict in list(self.routing.values()):
for kernel_entry in list(source_dict.values()):
kernel_entry.remove_interface(index)
self.interface_logger.debug('Remove virtual interface: %s -> %d', interface_name, index)
'''
/* Cache manipulation structures for mrouted and PIMd */
struct mfcctl {
struct in_addr mfcc_origin; /* Origin of mcast */
struct in_addr mfcc_mcastgrp; /* Group in question */
vifi_t mfcc_parent; /* Where it arrived */
unsigned char mfcc_ttls[MAXVIFS]; /* Where it is going */
unsigned int mfcc_pkt_cnt; /* pkt count for src-grp */
unsigned int mfcc_byte_cnt;
unsigned int mfcc_wrong_if;
int mfcc_expire;
};
'''
def set_multicast_route(self, kernel_entry: KernelEntry):
source_ip = socket.inet_aton(kernel_entry.source_ip)
group_ip = socket.inet_aton(kernel_entry.group_ip)
outbound_interfaces = kernel_entry.get_outbound_interfaces_indexes()
if len(outbound_interfaces) != Kernel.MAXVIFS:
raise Exception
#outbound_interfaces_and_other_parameters = list(kernel_entry.outbound_interfaces) + [0]*4
outbound_interfaces_and_other_parameters = outbound_interfaces + [0]*4
#outbound_interfaces, 0, 0, 0, 0 <- only works with python>=3.5
#struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces, 0, 0, 0, 0)
struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, kernel_entry.inbound_interface_index, *outbound_interfaces_and_other_parameters)
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_ADD_MFC, struct_mfcctl)
def set_flood_multicast_route(self, source_ip, group_ip, inbound_interface_index):
source_ip = socket.inet_aton(source_ip)
group_ip = socket.inet_aton(group_ip)
outbound_interfaces = [1]*self.MAXVIFS
outbound_interfaces[inbound_interface_index] = 0
#outbound_interfaces_and_other_parameters = list(kernel_entry.outbound_interfaces) + [0]*4
outbound_interfaces_and_other_parameters = outbound_interfaces + [0]*3 + [20]
#outbound_interfaces, 0, 0, 0, 0 <- only works with python>=3.5
#struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces, 0, 0, 0, 0)
struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, inbound_interface_index, *outbound_interfaces_and_other_parameters)
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_ADD_MFC, struct_mfcctl)
def remove_multicast_route(self, kernel_entry: KernelEntry):
source_ip = socket.inet_aton(kernel_entry.source_ip)
group_ip = socket.inet_aton(kernel_entry.group_ip)
outbound_interfaces_and_other_parameters = [0] + [0]*Kernel.MAXVIFS + [0]*4
struct_mfcctl = struct.pack("4s 4s H " + "B"*Kernel.MAXVIFS + " IIIi", source_ip, group_ip, *outbound_interfaces_and_other_parameters)
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_DEL_MFC, struct_mfcctl)
self.routing[kernel_entry.source_ip].pop(kernel_entry.group_ip)
if len(self.routing[kernel_entry.source_ip]) == 0:
self.routing.pop(kernel_entry.source_ip)
def exit(self):
self.running = False
# MRT DONE
self.socket.setsockopt(socket.IPPROTO_IP, Kernel.MRT_DONE, 1)
self.socket.close()
'''
/* This is the format the mroute daemon expects to see IGMP control
* data. Magically happens to be like an IP packet as per the original
*/
struct igmpmsg {
__u32 unused1,unused2;
unsigned char im_msgtype; /* What is this */
unsigned char im_mbz; /* Must be zero */
unsigned char im_vif; /* Interface (this ought to be a vifi_t!) */
unsigned char unused3;
struct in_addr im_src,im_dst;
};
'''
def handler(self):
while self.running:
try:
msg = self.socket.recv(20)
(_, _, im_msgtype, im_mbz, im_vif, _, im_src, im_dst) = struct.unpack("II B B B B 4s 4s", msg[:20])
print((im_msgtype, im_mbz, socket.inet_ntoa(im_src), socket.inet_ntoa(im_dst)))
if im_mbz != 0:
continue
print(im_msgtype)
print(im_mbz)
print(im_vif)
print(socket.inet_ntoa(im_src))
print(socket.inet_ntoa(im_dst))
#print((im_msgtype, im_mbz, socket.inet_ntoa(im_src), socket.inet_ntoa(im_dst)))
ip_src = socket.inet_ntoa(im_src)
ip_dst = socket.inet_ntoa(im_dst)
if im_msgtype == Kernel.IGMPMSG_NOCACHE:
print("IGMP NO CACHE")
self.igmpmsg_nocache_handler(ip_src, ip_dst, im_vif)
elif im_msgtype == Kernel.IGMPMSG_WRONGVIF:
print("WRONG VIF HANDLER")
self.igmpmsg_wrongvif_handler(ip_src, ip_dst, im_vif)
#elif im_msgtype == Kernel.IGMPMSG_WHOLEPKT:
# print("IGMP_WHOLEPKT")
# self.igmpmsg_wholepacket_handler(ip_src, ip_dst)
else:
raise Exception
except Exception:
traceback.print_exc()
continue
# receive multicast (S,G) packet and multicast routing table has no (S,G) entry
def igmpmsg_nocache_handler(self, ip_src, ip_dst, iif):
source_group_pair = (ip_src, ip_dst)
self.get_routing_entry(source_group_pair, create_if_not_existent=True).recv_data_msg(iif)
# receive multicast (S,G) packet in a outbound_interface
def igmpmsg_wrongvif_handler(self, ip_src, ip_dst, iif):
source_group_pair = (ip_src, ip_dst)
self.get_routing_entry(source_group_pair, create_if_not_existent=True).recv_data_msg(iif)
''' useless in PIM-DM... useful in PIM-SM
def igmpmsg_wholepacket_handler(self, ip_src, ip_dst):
#kernel_entry = self.routing[(ip_src, ip_dst)]
source_group_pair = (ip_src, ip_dst)
self.get_routing_entry(source_group_pair, create_if_not_existent=True).recv_data_msg()
#kernel_entry.recv_data_msg(iif)
'''
def get_routing_entry(self, source_group: tuple, create_if_not_existent=True):
ip_src = source_group[0]
ip_dst = source_group[1]
with self.rwlock.genRlock():
if ip_src in self.routing and ip_dst in self.routing[ip_src]:
return self.routing[ip_src][ip_dst]
with self.rwlock.genWlock():
if ip_src in self.routing and ip_dst in self.routing[ip_src]:
return self.routing[ip_src][ip_dst]
elif create_if_not_existent:
kernel_entry = KernelEntry(ip_src, ip_dst)
if ip_src not in self.routing:
self.routing[ip_src] = {}
iif = UnicastRouting.check_rpf(ip_src)
self.set_flood_multicast_route(ip_src, ip_dst, iif)
self.routing[ip_src][ip_dst] = kernel_entry
return kernel_entry
else:
return None
# notify KernelEntries about changes at the unicast routing table
def notify_unicast_changes(self, subnet):
with self.rwlock.genWlock():
for source_ip in list(self.routing.keys()):
source_ip_obj = ipaddress.ip_address(source_ip)
if source_ip_obj not in subnet:
continue
for group_ip in list(self.routing[source_ip].keys()):
self.routing[source_ip][group_ip].network_update()
# notify about changes at the interface (IP)
'''
def notify_interface_change(self, interface_name):
with self.interface_lock:
# check if interface was already added
if interface_name not in self.vif_name_to_index_dic:
return
print("trying to change ip")
pim_interface = self.pim_interface.get(interface_name)
if pim_interface:
old_ip = pim_interface.get_ip()
pim_interface.change_interface()
new_ip = pim_interface.get_ip()
if old_ip != new_ip:
self.vif_dic[new_ip] = self.vif_dic.pop(old_ip)
igmp_interface = self.igmp_interface.get(interface_name)
if igmp_interface:
igmp_interface.change_interface()
'''
# When interface changes number of neighbors verify if olist changes and prune/forward respectively
def interface_change_number_of_neighbors(self):
with self.rwlock.genRlock():
for groups_dict in self.routing.values():
for entry in groups_dict.values():
entry.change_at_number_of_neighbors()
# When new neighbor connects try to resend last state refresh msg (if AssertWinner)
def new_or_reset_neighbor(self, vif_index, neighbor_ip):
with self.rwlock.genRlock():
for groups_dict in self.routing.values():
for entry in groups_dict.values():
entry.new_or_reset_neighbor(vif_index, neighbor_ip)
import netifaces
import time
from prettytable import PrettyTable
import sys
import logging, logging.handlers
from TestLogger import RootFilter
from Kernel import Kernel
import UnicastRouting
interfaces = {} # interfaces with multicast routing enabled
igmp_interfaces = {} # igmp interfaces
kernel = None
unicast_routing = None
logger = None
def add_pim_interface(interface_name, state_refresh_capable:bool=False):
kernel.create_pim_interface(interface_name=interface_name, state_refresh_capable=state_refresh_capable)
def add_igmp_interface(interface_name):
kernel.create_igmp_interface(interface_name=interface_name)
'''
def add_interface(interface_name, pim=False, igmp=False):
#if pim is True and interface_name not in interfaces:
# interface = InterfacePim(interface_name)
# interfaces[interface_name] = interface
# interface.create_virtual_interface()
#if igmp is True and interface_name not in igmp_interfaces:
# interface = InterfaceIGMP(interface_name)
# igmp_interfaces[interface_name] = interface
kernel.create_interface(interface_name=interface_name, pim=pim, igmp=igmp)
#if pim:
# interfaces[interface_name] = kernel.pim_interface[interface_name]
#if igmp:
# igmp_interfaces[interface_name] = kernel.igmp_interface[interface_name]
'''
def remove_interface(interface_name, pim=False, igmp=False):
#if pim is True and ((interface_name in interfaces) or interface_name == "*"):
# if interface_name == "*":
# interface_name_list = list(interfaces.keys())
# else:
# interface_name_list = [interface_name]
# for if_name in interface_name_list:
# interface_obj = interfaces.pop(if_name)
# interface_obj.remove()
# #interfaces[if_name].remove()
# #del interfaces[if_name]
# print("removido interface")
# print(interfaces)
#if igmp is True and ((interface_name in igmp_interfaces) or interface_name == "*"):
# if interface_name == "*":
# interface_name_list = list(igmp_interfaces.keys())
# else:
# interface_name_list = [interface_name]
# for if_name in interface_name_list:
# igmp_interfaces[if_name].remove()
# del igmp_interfaces[if_name]
# print("removido interface")
# print(igmp_interfaces)
kernel.remove_interface(interface_name, pim=pim, igmp=igmp)
def list_neighbors():
interfaces_list = interfaces.values()
t = PrettyTable(['Interface', 'Neighbor IP', 'Hello Hold Time', "Generation ID", "Uptime"])
check_time = time.time()
for interface in interfaces_list:
for neighbor in interface.get_neighbors():
uptime = check_time - neighbor.time_of_last_update
uptime = 0 if (uptime < 0) else uptime
t.add_row(
[interface.interface_name, neighbor.ip, neighbor.hello_hold_time, neighbor.generation_id, time.strftime("%H:%M:%S", time.gmtime(uptime))])
print(t)
return str(t)
def list_enabled_interfaces():
global interfaces
t = PrettyTable(['Interface', 'IP', 'PIM/IGMP Enabled', 'IGMP State'])
for interface in netifaces.interfaces():
try:
# TODO: fix same interface with multiple ips
ip = netifaces.ifaddresses(interface)[netifaces.AF_INET][0]['addr']
pim_enabled = interface in interfaces
igmp_enabled = interface in igmp_interfaces
enabled = str(pim_enabled) + "/" + str(igmp_enabled)
if igmp_enabled:
state = igmp_interfaces[interface].interface_state.print_state()
else:
state = "-"
t.add_row([interface, ip, enabled, state])
except Exception:
continue
print(t)
return str(t)
def list_state():
state_text = "IGMP State:\n" + list_igmp_state() + "\n\n\n\n" + "Multicast Routing State:\n" + list_routing_state()
return state_text
def list_igmp_state():
t = PrettyTable(['Interface', 'RouterState', 'Group Adress', 'GroupState'])
for (interface_name, interface_obj) in list(igmp_interfaces.items()):
interface_state = interface_obj.interface_state
state_txt = interface_state.print_state()
print(interface_state.group_state.items())
for (group_addr, group_state) in list(interface_state.group_state.items()):
print(group_addr)
group_state_txt = group_state.print_state()
t.add_row([interface_name, state_txt, group_addr, group_state_txt])
return str(t)
def list_routing_state():
routing_entries = []
for a in list(kernel.routing.values()):
for b in list(a.values()):
routing_entries.append(b)
vif_indexes = kernel.vif_index_to_name_dic.keys()
t = PrettyTable(['SourceIP', 'GroupIP', 'Interface', 'PruneState', 'AssertState', 'LocalMembership', "Is Forwarding?"])
for entry in routing_entries:
ip = entry.source_ip
group = entry.group_ip
upstream_if_index = entry.inbound_interface_index
for index in vif_indexes:
interface_state = entry.interface_state[index]
interface_name = kernel.vif_index_to_name_dic[index]
local_membership = type(interface_state._local_membership_state).__name__
try:
assert_state = type(interface_state._assert_state).__name__
if index != upstream_if_index:
prune_state = type(interface_state._prune_state).__name__
is_forwarding = interface_state.is_forwarding()
else:
prune_state = type(interface_state._graft_prune_state).__name__
is_forwarding = "upstream"
except:
prune_state = "-"
assert_state = "-"
is_forwarding = "-"
t.add_row([ip, group, interface_name, prune_state, assert_state, local_membership, is_forwarding])
return str(t)
def stop():
remove_interface("*", pim=True, igmp=True)
kernel.exit()
unicast_routing.stop()
def test(router_name, server_logger_ip):
global logger
socketHandler = logging.handlers.SocketHandler(server_logger_ip,
logging.handlers.DEFAULT_TCP_LOGGING_PORT)
# don't bother with a formatter, since a socket handler sends the event as
# an unformatted pickle
socketHandler.addFilter(RootFilter(router_name))
logger.addHandler(socketHandler)
def main():
# logging
global logger
logger = logging.getLogger('pim')
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler(sys.stdout))
global kernel
kernel = Kernel()
global unicast_routing
unicast_routing = UnicastRouting.UnicastRouting()
global interfaces
global igmp_interfaces
interfaces = kernel.pim_interface
igmp_interfaces = kernel.igmp_interface
from threading import Timer
import time
from utils import HELLO_HOLD_TIME_NO_TIMEOUT, HELLO_HOLD_TIME_TIMEOUT, TYPE_CHECKING
from threading import Lock, RLock
import Main
import logging
if TYPE_CHECKING:
from InterfacePIM import InterfacePim
class Neighbor:
LOGGER = logging.getLogger('pim.Interface.Neighbor')
def __init__(self, contact_interface: "InterfacePim", ip, generation_id: int, hello_hold_time: int,
state_refresh_capable: bool):
if hello_hold_time == HELLO_HOLD_TIME_TIMEOUT:
raise Exception
logger_info = dict(contact_interface.interface_logger.extra)
logger_info['neighbor_ip'] = ip
self.neighbor_logger = logging.LoggerAdapter(self.LOGGER, logger_info)
self.neighbor_logger.debug('Monitoring new neighbor ' + ip + ' with GenerationID: ' + str(generation_id) +
'; HelloHoldTime: ' + str(hello_hold_time) + '; StateRefreshCapable: ' +
str(state_refresh_capable))
self.contact_interface = contact_interface
self.ip = ip
self.generation_id = generation_id
# todo lan prune delay
# todo override interval
self.state_refresh_capable = state_refresh_capable
self.neighbor_liveness_timer = None
self.hello_hold_time = None
self.set_hello_hold_time(hello_hold_time)
self.time_of_last_update = time.time()
self.neighbor_lock = Lock()
self.tree_interface_nlt_subscribers = []
self.tree_interface_nlt_subscribers_lock = RLock()
def set_hello_hold_time(self, hello_hold_time: int):
self.hello_hold_time = hello_hold_time
if self.neighbor_liveness_timer is not None:
self.neighbor_liveness_timer.cancel()
if hello_hold_time == HELLO_HOLD_TIME_TIMEOUT:
self.remove()
self.neighbor_logger.debug('Detected neighbor removal of ' + self.ip)
elif hello_hold_time != HELLO_HOLD_TIME_NO_TIMEOUT:
self.neighbor_logger.debug('Neighbor Liveness Timer reseted of ' + self.ip)
self.neighbor_liveness_timer = Timer(hello_hold_time, self.remove)
self.neighbor_liveness_timer.start()
else:
self.neighbor_liveness_timer = None
def set_generation_id(self, generation_id):
# neighbor restarted
if self.generation_id != generation_id:
self.neighbor_logger.debug('Detected reset of ' + self.ip + '... new GenerationID: ' + str(generation_id))
self.generation_id = generation_id
self.contact_interface.force_send_hello()
self.reset()
"""
def heartbeat(self):
if (self.hello_hold_time != HELLO_HOLD_TIME_TIMEOUT) and \
(self.hello_hold_time != HELLO_HOLD_TIME_NO_TIMEOUT):
print("HEARTBEAT")
if self.neighbor_liveness_timer is not None:
self.neighbor_liveness_timer.cancel()
self.neighbor_liveness_timer = Timer(self.hello_hold_time, self.remove)
self.neighbor_liveness_timer.start()
self.time_of_last_update = time.time()
"""
def remove(self):
print('HELLO TIMER EXPIRED... remove neighbor')
if self.neighbor_liveness_timer is not None:
self.neighbor_liveness_timer.cancel()
self.neighbor_logger.debug('Neighbor Liveness Timer expired of ' + self.ip)
self.contact_interface.remove_neighbor(self.ip)
# notify interfaces which have this neighbor as AssertWinner
with self.tree_interface_nlt_subscribers_lock:
for tree_if in self.tree_interface_nlt_subscribers:
tree_if.assert_winner_nlt_expires()
def reset(self):
self.contact_interface.new_or_reset_neighbor(self.ip)
def receive_hello(self, generation_id, hello_hold_time, state_refresh_capable):
self.neighbor_logger.debug('Receive Hello message with HelloHoldTime: ' + str(hello_hold_time) +
'; GenerationID: ' + str(generation_id) + '; StateRefreshCapable: ' +
str(state_refresh_capable) + ' from neighbor ' + self.ip)
if hello_hold_time == HELLO_HOLD_TIME_TIMEOUT:
self.set_hello_hold_time(hello_hold_time)
else:
self.time_of_last_update = time.time()
self.set_generation_id(generation_id)
self.set_hello_hold_time(hello_hold_time)
if state_refresh_capable != self.state_refresh_capable:
self.state_refresh_capable = state_refresh_capable
def subscribe_nlt_expiration(self, tree_if):
with self.tree_interface_nlt_subscribers_lock:
if tree_if not in self.tree_interface_nlt_subscribers:
self.tree_interface_nlt_subscribers.append(tree_if)
def unsubscribe_nlt_expiration(self, tree_if):
with self.tree_interface_nlt_subscribers_lock:
if tree_if in self.tree_interface_nlt_subscribers:
self.tree_interface_nlt_subscribers.remove(tree_if)
from .PacketIpHeader import PacketIpHeader
from .PacketPayload import PacketPayload
class Packet(object):
def __init__(self, ip_header: PacketIpHeader = None, payload: PacketPayload = None):
self.ip_header = ip_header
self.payload = payload
def bytes(self) -> bytes:
return self.payload.bytes()
import struct
from utils import checksum
import socket
from .PacketPayload import PacketPayload
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Type | Max Resp Time | Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Group Address |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Resv |S| QRV | QQIC | Number of Sources (N) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Address [1] |
+- -+
| Source Address [2] |
+- . -+
. . .
. . .
+- -+
| Source Address [N] |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketIGMPHeader(PacketPayload):
IGMP_TYPE = 2
IGMP_HDR = "! BB H 4s"
IGMP_HDR_LEN = struct.calcsize(IGMP_HDR)
IGMP3_SRC_ADDR_HDR = "! BB H "
IGMP3_SRC_ADDR_HDR_LEN = struct.calcsize(IGMP3_SRC_ADDR_HDR)
IPv4_HDR = "! 4s"
IPv4_HDR_LEN = struct.calcsize(IPv4_HDR)
Membership_Query = 0x11
Version_2_Membership_Report = 0x16
Leave_Group = 0x17
Version_1_Membership_Report = 0x12
def __init__(self, type: int, max_resp_time: int, group_address: str="0.0.0.0"):
# todo check type
self.type = type
self.max_resp_time = max_resp_time
self.group_address = group_address
def get_igmp_type(self):
return self.type
def bytes(self) -> bytes:
# obter mensagem e criar checksum
msg_without_chcksum = struct.pack(PacketIGMPHeader.IGMP_HDR, self.type, self.max_resp_time, 0,
socket.inet_aton(self.group_address))
igmp_checksum = checksum(msg_without_chcksum)
msg = msg_without_chcksum[0:2] + struct.pack("! H", igmp_checksum) + msg_without_chcksum[4:]
return msg
def __len__(self):
return len(self.bytes())
@staticmethod
def parse_bytes(data: bytes):
#print("parseIGMPHdr: ", data)
igmp_hdr = data[0:PacketIGMPHeader.IGMP_HDR_LEN]
(type, max_resp_time, rcv_checksum, group_address) = struct.unpack(PacketIGMPHeader.IGMP_HDR, igmp_hdr)
#print(type, max_resp_time, rcv_checksum, group_address)
msg_to_checksum = data[0:2] + b'\x00\x00' + data[4:]
#print("checksum calculated: " + str(checksum(msg_to_checksum)))
if checksum(msg_to_checksum) != rcv_checksum:
#print("wrong checksum")
raise Exception("wrong checksum")
igmp_hdr = igmp_hdr[PacketIGMPHeader.IGMP_HDR_LEN:]
group_address = socket.inet_ntoa(group_address)
pkt = PacketIGMPHeader(type, max_resp_time, group_address)
return pkt
\ No newline at end of file
import struct
import socket
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|Version| IHL |Type of Service| Total Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Identification |Flags| Fragment Offset |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Time to Live | Protocol | Header Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Address |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Destination Address |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Options | Padding |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketIpHeader:
IP_HDR = "! BBH HH BBH 4s 4s"
IP_HDR_LEN = struct.calcsize(IP_HDR)
def __init__(self, ver, hdr_len, ttl, proto, ip_src, ip_dst):
self.version = ver
self.hdr_length = hdr_len
self.ttl = ttl
self.proto = proto
self.ip_src = ip_src
self.ip_dst = ip_dst
def __len__(self):
return self.hdr_length
@staticmethod
def parse_bytes(data: bytes):
(verhlen, tos, iplen, ipid, frag, ttl, proto, cksum, src, dst) = \
struct.unpack(PacketIpHeader.IP_HDR, data)
ver = (verhlen & 0xf0) >> 4
hlen = (verhlen & 0x0f) * 4
'''
"VER": ver,
"HLEN": hlen,
"TOS": tos,
"IPLEN": iplen,
"IPID": ipid,
"FRAG": frag,
"TTL": ttl,
"PROTO": proto,
"CKSUM": cksum,
"SRC": socket.inet_ntoa(src),
"DST": socket.inet_ntoa(dst)
'''
src_ip = socket.inet_ntoa(src)
dst_ip = socket.inet_ntoa(dst)
return PacketIpHeader(ver, hlen, ttl, proto, src_ip, dst_ip)
import abc
class PacketPayload(object):
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def bytes(self) -> bytes:
"""Get packet payload in bytes format"""
@abc.abstractmethod
def __len__(self):
"""Get packet payload length"""
@staticmethod
@abc.abstractmethod
def parse_bytes(data: bytes):
"""From bytes create a object payload"""
import struct
import socket
from Packet.PacketPimEncodedGroupAddress import PacketPimEncodedGroupAddress
from Packet.PacketPimEncodedUnicastAddress import PacketPimEncodedUnicastAddress
from tree.globals import ASSERT_CANCEL_METRIC
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|PIM Ver| Type | Reserved | Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Multicast Group Address (Encoded Group Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Address (Encoded Unicast Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|R| Metric Preference |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Metric |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimAssert:
PIM_TYPE = 5
PIM_HDR_ASSERT = "! %ss %ss LL"
PIM_HDR_ASSERT_WITHOUT_ADDRESS = "! LL"
PIM_HDR_ASSERT_v4 = PIM_HDR_ASSERT % (PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN, PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN)
PIM_HDR_ASSERT_v6 = PIM_HDR_ASSERT % (PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN_IPv6, PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN_IPV6)
PIM_HDR_ASSERT_WITHOUT_ADDRESS_LEN = struct.calcsize(PIM_HDR_ASSERT_WITHOUT_ADDRESS)
PIM_HDR_ASSERT_v4_LEN = struct.calcsize(PIM_HDR_ASSERT_v4)
PIM_HDR_ASSERT_v6_LEN = struct.calcsize(PIM_HDR_ASSERT_v6)
def __init__(self, multicast_group_address: str or bytes, source_address: str or bytes, metric_preference: int or float, metric: int or float):
if type(multicast_group_address) is bytes:
multicast_group_address = socket.inet_ntoa(multicast_group_address)
if type(source_address) is bytes:
source_address = socket.inet_ntoa(source_address)
if metric_preference > 0x7FFFFFFF:
metric_preference = 0x7FFFFFFF
if metric > ASSERT_CANCEL_METRIC:
metric = ASSERT_CANCEL_METRIC
self.multicast_group_address = multicast_group_address
self.source_address = source_address
self.metric_preference = metric_preference
self.metric = metric
def bytes(self) -> bytes:
multicast_group_address = PacketPimEncodedGroupAddress(self.multicast_group_address).bytes()
source_address = PacketPimEncodedUnicastAddress(self.source_address).bytes()
msg = multicast_group_address + source_address + struct.pack(PacketPimAssert.PIM_HDR_ASSERT_WITHOUT_ADDRESS,
0x7FFFFFFF & self.metric_preference,
self.metric)
return msg
def __len__(self):
return len(self.bytes())
@staticmethod
def parse_bytes(data: bytes):
multicast_group_addr_obj = PacketPimEncodedGroupAddress.parse_bytes(data)
multicast_group_addr_len = len(multicast_group_addr_obj)
data = data[multicast_group_addr_len:]
source_addr_obj = PacketPimEncodedUnicastAddress.parse_bytes(data)
source_addr_len = len(source_addr_obj)
data = data[source_addr_len:]
(metric_preference, metric) = struct.unpack(PacketPimAssert.PIM_HDR_ASSERT_WITHOUT_ADDRESS, data[:PacketPimAssert.PIM_HDR_ASSERT_WITHOUT_ADDRESS_LEN])
pim_payload = PacketPimAssert(multicast_group_addr_obj.group_address, source_addr_obj.unicast_address, 0x7FFFFFFF & metric_preference, metric)
return pim_payload
import ipaddress
import struct
import socket
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Addr Family | Encoding Type |B| Reserved |Z| Mask Len |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Group Multicast Address
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+...
'''
class PacketPimEncodedGroupAddress:
PIM_ENCODED_GROUP_ADDRESS_HDR = "! BBBB %s"
PIM_ENCODED_GROUP_ADDRESS_HDR_WITHOUT_GROUP_MULTICAST_ADDRESS = "! BBBB"
IPV4_HDR = "4s"
IPV6_HDR = "16s"
# TODO ver melhor versao ip
PIM_ENCODED_GROUP_ADDRESS_HDR_WITHOUT_GROUP_ADDRESS_LEN = struct.calcsize(PIM_ENCODED_GROUP_ADDRESS_HDR_WITHOUT_GROUP_MULTICAST_ADDRESS)
PIM_ENCODED_GROUP_ADDRESS_HDR_LEN = struct.calcsize(PIM_ENCODED_GROUP_ADDRESS_HDR % IPV4_HDR)
PIM_ENCODED_GROUP_ADDRESS_HDR_LEN_IPv6 = struct.calcsize(PIM_ENCODED_GROUP_ADDRESS_HDR % IPV6_HDR)
FAMILY_RESERVED = 0
FAMILY_IPV4 = 1
FAMILY_IPV6 = 2
RESERVED = 0
def __init__(self, group_address, mask_len=None):
if type(group_address) not in (str, bytes):
raise Exception
if type(group_address) is bytes:
group_address = socket.inet_ntoa(group_address)
self.group_address = group_address
self.mask_len = mask_len
def bytes(self) -> bytes:
(string_ip_hdr, hdr_addr_family, socket_family) = PacketPimEncodedGroupAddress.get_ip_info(self.group_address)
mask_len = self.mask_len
if mask_len is None:
mask_len = 8 * struct.calcsize(string_ip_hdr)
ip = socket.inet_pton(socket_family, self.group_address)
msg = struct.pack(PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR % string_ip_hdr, hdr_addr_family, 0,
PacketPimEncodedGroupAddress.RESERVED, mask_len, ip)
return msg
@staticmethod
def get_ip_info(ip):
version = ipaddress.ip_address(ip).version
if version == 4:
return (PacketPimEncodedGroupAddress.IPV4_HDR, PacketPimEncodedGroupAddress.FAMILY_IPV4, socket.AF_INET)
elif version == 6:
return (PacketPimEncodedGroupAddress.IPV6_HDR, PacketPimEncodedGroupAddress.FAMILY_IPV6, socket.AF_INET6)
else:
raise Exception
def __len__(self):
version = ipaddress.ip_address(self.group_address).version
if version == 4:
return self.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN
elif version == 6:
return self.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN_IPv6
else:
raise Exception
@staticmethod
def parse_bytes(data: bytes):
data_without_group_addr = data[0:PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_WITHOUT_GROUP_ADDRESS_LEN]
(addr_family, encoding, _, mask_len) = struct.unpack(PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_WITHOUT_GROUP_MULTICAST_ADDRESS, data_without_group_addr)
data_group_addr = data[PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_WITHOUT_GROUP_ADDRESS_LEN:]
ip = None
if addr_family == PacketPimEncodedGroupAddress.FAMILY_IPV4:
(ip,) = struct.unpack("! " + PacketPimEncodedGroupAddress.IPV4_HDR, data_group_addr[:4])
ip = socket.inet_ntop(socket.AF_INET, ip)
elif addr_family == PacketPimEncodedGroupAddress.FAMILY_IPV6:
(ip,) = struct.unpack("! " + PacketPimEncodedGroupAddress.IPV6_HDR, data_group_addr[:16])
ip = socket.inet_ntop(socket.AF_INET6, ip)
if encoding != 0:
print("unknown encoding")
raise Exception
return PacketPimEncodedGroupAddress(ip, mask_len)
import ipaddress
import struct
import socket
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Addr Family | Encoding Type | Rsrvd |S|W|R| Mask Len |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Address
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+...
'''
class PacketPimEncodedSourceAddress:
PIM_ENCODED_SOURCE_ADDRESS_HDR = "! BBBB %s"
PIM_ENCODED_SOURCE_ADDRESS_HDR_WITHOUT_SOURCE_ADDRESS = "! BBBB"
IPV4_HDR = "4s"
IPV6_HDR = "16s"
# TODO ver melhor versao ip
PIM_ENCODED_SOURCE_ADDRESS_HDR_WITHOUT_SOURCE_ADDRESS_LEN = struct.calcsize(PIM_ENCODED_SOURCE_ADDRESS_HDR_WITHOUT_SOURCE_ADDRESS)
PIM_ENCODED_SOURCE_ADDRESS_HDR_LEN = struct.calcsize(PIM_ENCODED_SOURCE_ADDRESS_HDR % IPV4_HDR)
PIM_ENCODED_SOURCE_ADDRESS_HDR_LEN_IPV6 = struct.calcsize(PIM_ENCODED_SOURCE_ADDRESS_HDR % IPV6_HDR)
FAMILY_RESERVED = 0
FAMILY_IPV4 = 1
FAMILY_IPV6 = 2
RESERVED_AND_SWR_BITS = 0
def __init__(self, source_address, mask_len=None):
if type(source_address) not in (str, bytes):
raise Exception
if type(source_address) is bytes:
source_address = socket.inet_ntoa(source_address)
self.source_address = source_address
self.mask_len = mask_len
def bytes(self) -> bytes:
(string_ip_hdr, hdr_addr_family, socket_family) = PacketPimEncodedSourceAddress.get_ip_info(self.source_address)
mask_len = self.mask_len
if mask_len is None:
mask_len = 8 * struct.calcsize(string_ip_hdr)
ip = socket.inet_pton(socket_family, self.source_address)
msg = struct.pack(PacketPimEncodedSourceAddress.PIM_ENCODED_SOURCE_ADDRESS_HDR % string_ip_hdr, hdr_addr_family, 0,
PacketPimEncodedSourceAddress.RESERVED_AND_SWR_BITS, mask_len, ip)
return msg
@staticmethod
def get_ip_info(ip):
version = ipaddress.ip_address(ip).version
if version == 4:
return (PacketPimEncodedSourceAddress.IPV4_HDR, PacketPimEncodedSourceAddress.FAMILY_IPV4, socket.AF_INET)
elif version == 6:
return (PacketPimEncodedSourceAddress.IPV6_HDR, PacketPimEncodedSourceAddress.FAMILY_IPV6, socket.AF_INET6)
else:
raise Exception
def __len__(self):
version = ipaddress.ip_address(self.source_address).version
if version == 4:
return self.PIM_ENCODED_SOURCE_ADDRESS_HDR_LEN
elif version == 6:
return self.PIM_ENCODED_SOURCE_ADDRESS_HDR_LEN_IPV6
else:
raise Exception
@staticmethod
def parse_bytes(data: bytes):
data_without_source_addr = data[0:PacketPimEncodedSourceAddress.PIM_ENCODED_SOURCE_ADDRESS_HDR_WITHOUT_SOURCE_ADDRESS_LEN]
(addr_family, encoding, _, mask_len) = struct.unpack(PacketPimEncodedSourceAddress.PIM_ENCODED_SOURCE_ADDRESS_HDR_WITHOUT_SOURCE_ADDRESS, data_without_source_addr)
data_source_addr = data[PacketPimEncodedSourceAddress.PIM_ENCODED_SOURCE_ADDRESS_HDR_WITHOUT_SOURCE_ADDRESS_LEN:]
ip = None
if addr_family == PacketPimEncodedSourceAddress.FAMILY_IPV4:
(ip,) = struct.unpack("! " + PacketPimEncodedSourceAddress.IPV4_HDR, data_source_addr[:4])
ip = socket.inet_ntop(socket.AF_INET, ip)
elif addr_family == PacketPimEncodedSourceAddress.FAMILY_IPV6:
(ip,) = struct.unpack("! " + PacketPimEncodedSourceAddress.IPV6_HDR, data_source_addr[:16])
ip = socket.inet_ntop(socket.AF_INET6, ip)
if encoding != 0:
print("unknown encoding")
raise Exception
return PacketPimEncodedSourceAddress(ip, mask_len)
import ipaddress
import struct
import socket
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Addr Family | Encoding Type | Unicast Address
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+...
'''
class PacketPimEncodedUnicastAddress:
PIM_ENCODED_UNICAST_ADDRESS_HDR = "! BB %s"
PIM_ENCODED_UNICAST_ADDRESS_HDR_WITHOUT_UNICAST_ADDRESS = "! BB"
IPV4_HDR = "4s"
IPV6_HDR = "16s"
# TODO ver melhor versao ip
PIM_ENCODED_UNICAST_ADDRESS_HDR_WITHOUT_UNICAST_ADDRESS_LEN = struct.calcsize(PIM_ENCODED_UNICAST_ADDRESS_HDR_WITHOUT_UNICAST_ADDRESS)
PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN = struct.calcsize(PIM_ENCODED_UNICAST_ADDRESS_HDR % IPV4_HDR)
PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN_IPV6 = struct.calcsize(PIM_ENCODED_UNICAST_ADDRESS_HDR % IPV6_HDR)
FAMILY_RESERVED = 0
FAMILY_IPV4 = 1
FAMILY_IPV6 = 2
def __init__(self, unicast_address):
if type(unicast_address) not in (str, bytes):
raise Exception
if type(unicast_address) is bytes:
unicast_address = socket.inet_ntoa(unicast_address)
self.unicast_address = unicast_address
def bytes(self) -> bytes:
(string_ip_hdr, hdr_addr_family, socket_family) = PacketPimEncodedUnicastAddress.get_ip_info(self.unicast_address)
ip = socket.inet_pton(socket_family, self.unicast_address)
msg = struct.pack(PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR % string_ip_hdr, hdr_addr_family, 0, ip)
return msg
@staticmethod
def get_ip_info(ip):
version = ipaddress.ip_address(ip).version
if version == 4:
return (PacketPimEncodedUnicastAddress.IPV4_HDR, PacketPimEncodedUnicastAddress.FAMILY_IPV4, socket.AF_INET)
elif version == 6:
return (PacketPimEncodedUnicastAddress.IPV6_HDR, PacketPimEncodedUnicastAddress.FAMILY_IPV6, socket.AF_INET6)
else:
raise Exception
def __len__(self):
version = ipaddress.ip_address(self.unicast_address).version
if version == 4:
return self.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN
elif version == 6:
return self.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN_IPV6
else:
raise Exception
@staticmethod
def parse_bytes(data: bytes):
data_without_unicast_addr = data[0:PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_WITHOUT_UNICAST_ADDRESS_LEN]
(addr_family, encoding) = struct.unpack(PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_WITHOUT_UNICAST_ADDRESS, data_without_unicast_addr)
data_unicast_addr = data[PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_WITHOUT_UNICAST_ADDRESS_LEN:]
if addr_family == PacketPimEncodedUnicastAddress.FAMILY_IPV4:
(ip,) = struct.unpack("! " + PacketPimEncodedUnicastAddress.IPV4_HDR, data_unicast_addr[:4])
ip = socket.inet_ntop(socket.AF_INET, ip)
elif addr_family == PacketPimEncodedUnicastAddress.FAMILY_IPV6:
(ip,) = struct.unpack("! " + PacketPimEncodedUnicastAddress.IPV6_HDR, data_unicast_addr[:16])
ip = socket.inet_ntop(socket.AF_INET6, ip)
if encoding != 0:
print("unknown encoding")
raise Exception
return PacketPimEncodedUnicastAddress(ip)
from Packet.PacketPimJoinPrune import PacketPimJoinPrune
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Upstream Neighbor Address (Encoded Unicast Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Reserved | Num Groups | Hold Time |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Multicast Group Address 1 (Encoded Group Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Number of Joined Sources | Number of Pruned Sources |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address n (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Pruned Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Pruned Source Address n (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Multicast Group Address m (Encoded Group Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Number of Joined Sources | Number of Pruned Sources |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimGraft(PacketPimJoinPrune):
PIM_TYPE = 6
def __init__(self, upstream_neighbor_address, holdtime=0):
super().__init__(upstream_neighbor_address=upstream_neighbor_address, hold_time=holdtime)
from Packet.PacketPimJoinPrune import PacketPimJoinPrune
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Upstream Neighbor Address (Encoded Unicast Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Reserved | Num Groups | Hold Time |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Multicast Group Address 1 (Encoded Group Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Number of Joined Sources | Number of Pruned Sources |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address n (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Pruned Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Pruned Source Address n (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Multicast Group Address m (Encoded Group Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Number of Joined Sources | Number of Pruned Sources |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimGraftAck(PacketPimJoinPrune):
PIM_TYPE = 7
def __init__(self, upstream_neighbor_address, holdtime=0):
super().__init__(upstream_neighbor_address, hold_time=holdtime)
import struct
from Packet.PacketPimHello import PacketPimHello
from Packet.PacketPimJoinPrune import PacketPimJoinPrune
from Packet.PacketPimAssert import PacketPimAssert
from Packet.PacketPimGraft import PacketPimGraft
from Packet.PacketPimGraftAck import PacketPimGraftAck
from Packet.PacketPimStateRefresh import PacketPimStateRefresh
from utils import checksum
from .PacketPayload import PacketPayload
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|PIM Ver| Type | Reserved | Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimHeader(PacketPayload):
PIM_VERSION = 2
PIM_HDR = "! BB H"
PIM_HDR_LEN = struct.calcsize(PIM_HDR)
PIM_MSG_TYPES = {0: PacketPimHello,
3: PacketPimJoinPrune,
5: PacketPimAssert,
6: PacketPimGraft,
7: PacketPimGraftAck,
9: PacketPimStateRefresh
}
def __init__(self, payload):
self.payload = payload
def get_pim_type(self):
return self.payload.PIM_TYPE
def bytes(self) -> bytes:
# obter mensagem e criar checksum
pim_vrs_type = (PacketPimHeader.PIM_VERSION << 4) + self.get_pim_type()
msg_without_chcksum = struct.pack(PacketPimHeader.PIM_HDR, pim_vrs_type, 0, 0)
msg_without_chcksum += self.payload.bytes()
pim_checksum = checksum(msg_without_chcksum)
msg = msg_without_chcksum[0:2] + struct.pack("! H", pim_checksum) + msg_without_chcksum[4:]
return msg
def __len__(self):
return len(self.bytes())
@staticmethod
def parse_bytes(data: bytes):
print("parsePimHdr: ", data)
pim_hdr = data[0:PacketPimHeader.PIM_HDR_LEN]
(pim_ver_type, reserved, rcv_checksum) = struct.unpack(PacketPimHeader.PIM_HDR, pim_hdr)
print(pim_ver_type, reserved, rcv_checksum)
pim_version = (pim_ver_type & 0xF0) >> 4
pim_type = pim_ver_type & 0x0F
if pim_version != PacketPimHeader.PIM_VERSION:
print("Version of PIM packet received not known (!=2)")
raise Exception
msg_to_checksum = data[0:2] + b'\x00\x00' + data[4:]
if checksum(msg_to_checksum) != rcv_checksum:
print("wrong checksum")
print("checksum calculated: " + str(checksum(msg_to_checksum)))
print("checksum recv: " + str(rcv_checksum))
raise Exception
pim_payload = data[PacketPimHeader.PIM_HDR_LEN:]
pim_payload = PacketPimHeader.PIM_MSG_TYPES[pim_type].parse_bytes(pim_payload)
return PacketPimHeader(pim_payload)
import struct
from abc import ABCMeta, abstractstaticmethod
from .PacketPimHelloOptions import PacketPimHelloOptions, PacketPimHelloStateRefreshCapable, PacketPimHelloGenerationID, PacketPimHelloLANPruneDelay, PacketPimHelloHoldtime
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Option Type | Option Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Option Value |
| ... |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Option Type | Option Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Option Value |
| ... |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimHello:
PIM_TYPE = 0
PIM_HDR_OPTS = "! HH"
PIM_HDR_OPTS_LEN = struct.calcsize(PIM_HDR_OPTS)
PIM_MSG_TYPES_LENGTH = {1: 2,
2: 4,
20: 4,
21: 4,
}
# todo: pensar melhor na implementacao state refresh capable option...
def __init__(self):
self.options = {}
'''
def add_option(self, option_type: int, option_value: int or float):
option_value = int(option_value)
# if option_value requires more bits than the bits available for that field: option value will have all field bits = 1
if option_type in self.PIM_MSG_TYPES_LENGTH and self.PIM_MSG_TYPES_LENGTH[option_type] * 8 < option_value.bit_length():
option_value = (1 << (self.PIM_MSG_TYPES_LENGTH[option_type] * 8)) - 1
self.options[option_type] = option_value
'''
def add_option(self, option: 'PacketPimHelloOptions'):
#if option_type in self.PIM_MSG_TYPES_LENGTH and self.PIM_MSG_TYPES_LENGTH[option_type] * 8 < option_value.bit_length():
# option_value = (1 << (self.PIM_MSG_TYPES_LENGTH[option_type] * 8)) - 1
self.options[option.type] = option
def get_options(self):
return self.options
'''
def bytes(self) -> bytes:
res = b''
for (option_type, option_value) in self.options.items():
option_length = PacketPimHello.PIM_MSG_TYPES_LENGTH[option_type]
type_length_hdr = struct.pack(PacketPimHello.PIM_HDR_OPTS, option_type, option_length)
res += type_length_hdr + struct.pack("! " + str(option_length) + "s", option_value.to_bytes(option_length, byteorder='big'))
return res
'''
def bytes(self) -> bytes:
res = b''
for option in self.options.values():
res += option.bytes()
return res
def __len__(self):
return len(self.bytes())
'''
@staticmethod
def parse_bytes(data: bytes):
pim_payload = PacketPimHello()
while data != b'':
(option_type, option_length) = struct.unpack(PacketPimHello.PIM_HDR_OPTS,
data[:PacketPimHello.PIM_HDR_OPTS_LEN])
print(option_type, option_length)
data = data[PacketPimHello.PIM_HDR_OPTS_LEN:]
print(data)
(option_value,) = struct.unpack("! " + str(option_length) + "s", data[:option_length])
option_value_number = int.from_bytes(option_value, byteorder='big')
print("option value: ", option_value_number)
#options_list.append({"OPTION TYPE": option_type,
# "OPTION LENGTH": option_length,
# "OPTION VALUE": option_value_number
# })
pim_payload.add_option(option_type, option_value_number)
data = data[option_length:]
return pim_payload
'''
@staticmethod
def parse_bytes(data: bytes):
pim_payload = PacketPimHello()
while data != b'':
option = PacketPimHelloOptions.parse_bytes(data)
option_length = len(option)
data = data[option_length:]
pim_payload.add_option(option)
return pim_payload
import struct
from abc import ABCMeta
import math
class PacketPimHelloOptions(metaclass=ABCMeta):
PIM_HDR_OPTS = "! HH"
PIM_HDR_OPTS_LEN = struct.calcsize(PIM_HDR_OPTS)
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Type | Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
def __init__(self, type: int, length: int):
self.type = type
self.length = length
def bytes(self) -> bytes:
return struct.pack(PacketPimHelloOptions.PIM_HDR_OPTS, self.type, self.length)
def __len__(self):
return self.PIM_HDR_OPTS_LEN + self.length
@staticmethod
def parse_bytes(data: bytes, type:int = None, length:int = None):
(type, length) = struct.unpack(PacketPimHelloOptions.PIM_HDR_OPTS,
data[:PacketPimHelloOptions.PIM_HDR_OPTS_LEN])
#print("TYPE:", type)
#print("LENGTH:", length)
data = data[PacketPimHelloOptions.PIM_HDR_OPTS_LEN:]
#return PIM_MSG_TYPES[type](data)
return PIM_MSG_TYPES.get(type, PacketPimHelloUnknown).parse_bytes(data, type, length)
class PacketPimHelloStateRefreshCapable(PacketPimHelloOptions):
PIM_HDR_OPT = "! BBH"
PIM_HDR_OPT_LEN = struct.calcsize(PIM_HDR_OPT)
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Version = 1 | Interval | Reserved |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
VERSION = 1
def __init__(self, interval: int):
super().__init__(type=21, length=4)
self.interval = interval
def bytes(self) -> bytes:
return super().bytes() + struct.pack(self.PIM_HDR_OPT, self.VERSION, self.interval, 0)
@staticmethod
def parse_bytes(data: bytes, type:int = None, length:int = None):
if type is None or length is None:
raise Exception
(version, interval, _) = struct.unpack(PacketPimHelloStateRefreshCapable.PIM_HDR_OPT,
data[:PacketPimHelloStateRefreshCapable.PIM_HDR_OPT_LEN])
return PacketPimHelloStateRefreshCapable(interval)
class PacketPimHelloLANPruneDelay(PacketPimHelloOptions):
PIM_HDR_OPT = "! HH"
PIM_HDR_OPT_LEN = struct.calcsize(PIM_HDR_OPT)
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|T| LAN Prune Delay | Override Interval |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
def __init__(self, lan_prune_delay: float, override_interval: float):
super().__init__(type=2, length=4)
self.lan_prune_delay = 0x7FFF & math.ceil(lan_prune_delay)
self.override_interval = math.ceil(override_interval)
def bytes(self) -> bytes:
return super().bytes() + struct.pack(self.PIM_HDR_OPT, self.lan_prune_delay, self.override_interval)
@staticmethod
def parse_bytes(data: bytes, type:int = None, length:int = None):
if type is None or length is None:
raise Exception
(lan_prune_delay, override_interval) = struct.unpack(PacketPimHelloLANPruneDelay.PIM_HDR_OPT,
data[:PacketPimHelloLANPruneDelay.PIM_HDR_OPT_LEN])
lan_prune_delay = lan_prune_delay & 0x7FFF
return PacketPimHelloLANPruneDelay(lan_prune_delay=lan_prune_delay, override_interval=override_interval)
class PacketPimHelloHoldtime(PacketPimHelloOptions):
PIM_HDR_OPT = "! H"
PIM_HDR_OPT_LEN = struct.calcsize(PIM_HDR_OPT)
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Hold Time |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
def __init__(self, holdtime: int or float):
super().__init__(type=1, length=2)
self.holdtime = int(holdtime)
def bytes(self) -> bytes:
return super().bytes() + struct.pack(self.PIM_HDR_OPT, self.holdtime)
@staticmethod
def parse_bytes(data: bytes, type:int = None, length:int = None):
if type is None or length is None:
raise Exception
(holdtime, ) = struct.unpack(PacketPimHelloHoldtime.PIM_HDR_OPT,
data[:PacketPimHelloHoldtime.PIM_HDR_OPT_LEN])
#print("HOLDTIME:", holdtime)
return PacketPimHelloHoldtime(holdtime=holdtime)
class PacketPimHelloGenerationID(PacketPimHelloOptions):
PIM_HDR_OPT = "! L"
PIM_HDR_OPT_LEN = struct.calcsize(PIM_HDR_OPT)
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Generation ID |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
def __init__(self, generation_id: int):
super().__init__(type=20, length=4)
self.generation_id = generation_id
def bytes(self) -> bytes:
return super().bytes() + struct.pack(self.PIM_HDR_OPT, self.generation_id)
@staticmethod
def parse_bytes(data: bytes, type:int = None, length:int = None):
if type is None or length is None:
raise Exception
(generation_id, ) = struct.unpack(PacketPimHelloGenerationID.PIM_HDR_OPT,
data[:PacketPimHelloGenerationID.PIM_HDR_OPT_LEN])
#print("GenerationID:", generation_id)
return PacketPimHelloGenerationID(generation_id=generation_id)
class PacketPimHelloUnknown(PacketPimHelloOptions):
PIM_HDR_OPT = "! L"
PIM_HDR_OPT_LEN = struct.calcsize(PIM_HDR_OPT)
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Unknown |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
def __init__(self, type, length):
super().__init__(type=type, length=length)
#print("PIM Hello Option Unknown... TYPE=", type, "LENGTH=", length)
def bytes(self) -> bytes:
raise Exception
@staticmethod
def parse_bytes(data: bytes, type:int = None, length:int = None):
if type is None or length is None:
raise Exception
return PacketPimHelloUnknown(type, length)
PIM_MSG_TYPES = {1: PacketPimHelloHoldtime,
2: PacketPimHelloLANPruneDelay,
20: PacketPimHelloGenerationID,
21: PacketPimHelloStateRefreshCapable,
}
import struct
import socket
from Packet.PacketPimEncodedUnicastAddress import PacketPimEncodedUnicastAddress
from Packet.PacketPimJoinPruneMulticastGroup import PacketPimJoinPruneMulticastGroup
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Upstream Neighbor Address (Encoded Unicast Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Reserved | Num Groups | Hold Time |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimJoinPrune:
PIM_TYPE = 3
PIM_HDR_JOIN_PRUNE = "! %ss BBH"
PIM_HDR_JOIN_PRUNE_WITHOUT_ADDRESS = "! BBH"
PIM_HDR_JOIN_PRUNE_v4 = PIM_HDR_JOIN_PRUNE % PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN
PIM_HDR_JOIN_PRUNE_v6 = PIM_HDR_JOIN_PRUNE % PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN_IPV6
PIM_HDR_JOIN_PRUNE_WITHOUT_ADDRESS_LEN = struct.calcsize(PIM_HDR_JOIN_PRUNE_WITHOUT_ADDRESS)
PIM_HDR_JOIN_PRUNE_v4_LEN = struct.calcsize(PIM_HDR_JOIN_PRUNE_v4)
PIM_HDR_JOIN_PRUNE_v6_LEN = struct.calcsize(PIM_HDR_JOIN_PRUNE_v6)
def __init__(self, upstream_neighbor_address, hold_time):
if type(upstream_neighbor_address) not in (str, bytes):
raise Exception
if type(upstream_neighbor_address) is bytes:
upstream_neighbor_address = socket.inet_ntoa(upstream_neighbor_address)
self.groups = []
self.upstream_neighbor_address = upstream_neighbor_address
self.hold_time = hold_time
def add_multicast_group(self, group: PacketPimJoinPruneMulticastGroup):
# TODO verificar se grupo ja esta na msg
self.groups.append(group)
def bytes(self) -> bytes:
upstream_neighbor_address = PacketPimEncodedUnicastAddress(self.upstream_neighbor_address).bytes()
msg = upstream_neighbor_address + struct.pack(PacketPimJoinPrune.PIM_HDR_JOIN_PRUNE_WITHOUT_ADDRESS, 0,
len(self.groups), self.hold_time)
for multicast_group in self.groups:
msg += multicast_group.bytes()
return msg
def __len__(self):
return len(self.bytes())
@classmethod
def parse_bytes(cls, data: bytes):
upstream_neighbor_addr_obj = PacketPimEncodedUnicastAddress.parse_bytes(data)
upstream_neighbor_addr_len = len(upstream_neighbor_addr_obj)
data = data[upstream_neighbor_addr_len:]
(_, num_groups, hold_time) = struct.unpack(PacketPimJoinPrune.PIM_HDR_JOIN_PRUNE_WITHOUT_ADDRESS,
data[:PacketPimJoinPrune.PIM_HDR_JOIN_PRUNE_WITHOUT_ADDRESS_LEN])
data = data[PacketPimJoinPrune.PIM_HDR_JOIN_PRUNE_WITHOUT_ADDRESS_LEN:]
pim_payload = cls(upstream_neighbor_addr_obj.unicast_address, hold_time)
for i in range(0, num_groups):
group = PacketPimJoinPruneMulticastGroup.parse_bytes(data)
group_len = len(group)
pim_payload.add_multicast_group(group)
data = data[group_len:]
return pim_payload
import struct
import socket
from Packet.PacketPimEncodedGroupAddress import PacketPimEncodedGroupAddress
from Packet.PacketPimEncodedSourceAddress import PacketPimEncodedSourceAddress
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Multicast Group Address 1 (Encoded Group Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Number of Joined Sources | Number of Pruned Sources |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Joined Source Address n (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Pruned Source Address 1 (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| . |
| . |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Pruned Source Address n (Encoded Source Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimJoinPruneMulticastGroup:
PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP = "! %ss HH"
PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP_WITHOUT_GROUP_ADDRESS = "! HH"
PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP_v4_LEN_ = struct.calcsize(
PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP % PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN)
PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP_v6_LEN_ = struct.calcsize(
PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP % PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN_IPv6)
PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP_WITHOUT_GROUP_ADDRESS_LEN = struct.calcsize(
PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP_WITHOUT_GROUP_ADDRESS)
PIM_HDR_JOINED_PRUNED_SOURCE = "! %ss"
PIM_HDR_JOINED_PRUNED_SOURCE_v4_LEN = PacketPimEncodedSourceAddress.PIM_ENCODED_SOURCE_ADDRESS_HDR_LEN
PIM_HDR_JOINED_PRUNED_SOURCE_v6_LEN = PacketPimEncodedSourceAddress.PIM_ENCODED_SOURCE_ADDRESS_HDR_LEN_IPV6
def __init__(self, multicast_group: str or bytes, joined_src_addresses: list=[], pruned_src_addresses: list=[]):
if type(multicast_group) not in (str, bytes):
raise Exception
elif type(multicast_group) is bytes:
multicast_group = socket.inet_ntoa(multicast_group)
if type(joined_src_addresses) is not list:
raise Exception
if type(pruned_src_addresses) is not list:
raise Exception
self.multicast_group = multicast_group
self.joined_src_addresses = joined_src_addresses
self.pruned_src_addresses = pruned_src_addresses
def bytes(self) -> bytes:
multicast_group_address = PacketPimEncodedGroupAddress(self.multicast_group).bytes()
msg = multicast_group_address + struct.pack(self.PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP_WITHOUT_GROUP_ADDRESS,
len(self.joined_src_addresses), len(self.pruned_src_addresses))
for joined_src_address in self.joined_src_addresses:
joined_src_address_bytes = PacketPimEncodedSourceAddress(joined_src_address).bytes()
msg += joined_src_address_bytes
for pruned_src_address in self.pruned_src_addresses:
pruned_src_address_bytes = PacketPimEncodedSourceAddress(pruned_src_address).bytes()
msg += pruned_src_address_bytes
return msg
def __len__(self):
return len(self.bytes())
@staticmethod
def parse_bytes(data: bytes):
multicast_group_addr_obj = PacketPimEncodedGroupAddress.parse_bytes(data)
multicast_group_addr_len = len(multicast_group_addr_obj)
data = data[multicast_group_addr_len:]
number_join_prune_data = data[:PacketPimJoinPruneMulticastGroup.PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP_WITHOUT_GROUP_ADDRESS_LEN]
(number_joined_sources, number_pruned_sources) = struct.unpack(PacketPimJoinPruneMulticastGroup.PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP_WITHOUT_GROUP_ADDRESS, number_join_prune_data)
joined = []
pruned = []
data = data[PacketPimJoinPruneMulticastGroup.PIM_HDR_JOIN_PRUNE_MULTICAST_GROUP_WITHOUT_GROUP_ADDRESS_LEN:]
for i in range(0, number_joined_sources):
joined_obj = PacketPimEncodedSourceAddress.parse_bytes(data)
joined_obj_len = len(joined_obj)
data = data[joined_obj_len:]
joined.append(joined_obj.source_address)
for i in range(0, number_pruned_sources):
pruned_obj = PacketPimEncodedSourceAddress.parse_bytes(data)
pruned_obj_len = len(pruned_obj)
data = data[pruned_obj_len:]
pruned.append(pruned_obj.source_address)
return PacketPimJoinPruneMulticastGroup(multicast_group_addr_obj.group_address, joined, pruned)
import struct
import socket
from Packet.PacketPimEncodedUnicastAddress import PacketPimEncodedUnicastAddress
from Packet.PacketPimEncodedGroupAddress import PacketPimEncodedGroupAddress
'''
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|PIM Ver| Type | Reserved | Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Multicast Group Address (Encoded Group Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Address (Encoded Unicast Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Originator Address (Encoded Unicast Format) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|R| Metric Preference |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Metric |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Masklen | TTL |P|N|O|Reserved | Interval |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
'''
class PacketPimStateRefresh:
PIM_TYPE = 9
PIM_HDR_STATE_REFRESH = "! %ss %ss %ss I I BBBB"
PIM_HDR_STATE_REFRESH_WITHOUT_ADDRESSES = "! I I BBBB"
PIM_HDR_STATE_REFRESH_v4 = PIM_HDR_STATE_REFRESH % (PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN, PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN, PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN)
PIM_HDR_STATE_REFRESH_v6 = PIM_HDR_STATE_REFRESH % (PacketPimEncodedGroupAddress.PIM_ENCODED_GROUP_ADDRESS_HDR_LEN_IPv6, PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN_IPV6, PacketPimEncodedUnicastAddress.PIM_ENCODED_UNICAST_ADDRESS_HDR_LEN_IPV6)
PIM_HDR_STATE_REFRESH_WITHOUT_ADDRESSES_LEN = struct.calcsize(PIM_HDR_STATE_REFRESH_WITHOUT_ADDRESSES)
PIM_HDR_STATE_REFRESH_v4_LEN = struct.calcsize(PIM_HDR_STATE_REFRESH_v4)
PIM_HDR_STATE_REFRESH_v6_LEN = struct.calcsize(PIM_HDR_STATE_REFRESH_v6)
def __init__(self, multicast_group_adress: str or bytes, source_address: str or bytes, originator_adress: str or bytes,
metric_preference: int, metric: int, mask_len: int, ttl: int, prune_indicator_flag: bool,
prune_now_flag: bool, assert_override_flag: bool, interval: int):
if type(multicast_group_adress) is bytes:
multicast_group_adress = socket.inet_ntoa(multicast_group_adress)
if type(source_address) is bytes:
source_address = socket.inet_ntoa(source_address)
if type(originator_adress) is bytes:
originator_adress = socket.inet_ntoa(originator_adress)
self.multicast_group_adress = multicast_group_adress
self.source_address = source_address
self.originator_adress = originator_adress
self.metric_preference = metric_preference
self.metric = metric
self.mask_len = mask_len
self.ttl = ttl
self.prune_indicator_flag = prune_indicator_flag
self.prune_now_flag = prune_now_flag
self.assert_override_flag = assert_override_flag
self.interval = interval
def bytes(self) -> bytes:
multicast_group_adress = PacketPimEncodedGroupAddress(self.multicast_group_adress).bytes()
source_address = PacketPimEncodedUnicastAddress(self.source_address).bytes()
originator_adress = PacketPimEncodedUnicastAddress(self.originator_adress).bytes()
prune_and_assert_flags = (self.prune_indicator_flag << 7) | (self.prune_now_flag << 6) | (self.assert_override_flag << 5)
msg = multicast_group_adress + source_address + originator_adress + \
struct.pack(self.PIM_HDR_STATE_REFRESH_WITHOUT_ADDRESSES, 0x7FFFFFFF & self.metric_preference,
self.metric, self.mask_len, self.ttl, prune_and_assert_flags, self.interval)
return msg
def __len__(self):
return len(self.bytes())
@staticmethod
def parse_bytes(data: bytes):
multicast_group_adress_obj = PacketPimEncodedGroupAddress.parse_bytes(data)
multicast_group_adress_len = len(multicast_group_adress_obj)
data = data[multicast_group_adress_len:]
source_address_obj = PacketPimEncodedUnicastAddress.parse_bytes(data)
source_address_len = len(source_address_obj)
data = data[source_address_len:]
originator_address_obj = PacketPimEncodedUnicastAddress.parse_bytes(data)
originator_address_len = len(originator_address_obj)
data = data[originator_address_len:]
(metric_preference, metric, mask_len, ttl, reserved_and_prune_and_assert_flags, interval) = struct.unpack(PacketPimStateRefresh.PIM_HDR_STATE_REFRESH_WITHOUT_ADDRESSES, data[:PacketPimStateRefresh.PIM_HDR_STATE_REFRESH_WITHOUT_ADDRESSES_LEN])
metric_preference = 0x7FFFFFFF & metric_preference
prune_indicator_flag = (0x80 & reserved_and_prune_and_assert_flags) >> 7
prune_now_flag = (0x40 & reserved_and_prune_and_assert_flags) >> 6
assert_override_flag = (0x20 & reserved_and_prune_and_assert_flags) >> 5
pim_payload = PacketPimStateRefresh(multicast_group_adress_obj.group_address, source_address_obj.unicast_address,
originator_address_obj.unicast_address, metric_preference, metric, mask_len,
ttl, prune_indicator_flag, prune_now_flag, assert_override_flag, interval)
return pim_payload
from Packet.Packet import Packet
from Packet.PacketIpHeader import PacketIpHeader
from Packet.PacketIGMPHeader import PacketIGMPHeader
from .PacketPimHeader import PacketPimHeader
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from Interface import Interface
class ReceivedPacket(Packet):
# choose payload protocol class based on ip protocol number
payload_protocol = {2: PacketIGMPHeader, 103: PacketPimHeader}
def __init__(self, raw_packet: bytes, interface: 'Interface'):
self.interface = interface
# Parse ao packet e preencher objeto Packet
packet_ip_hdr = raw_packet[:PacketIpHeader.IP_HDR_LEN]
ip_header = PacketIpHeader.parse_bytes(packet_ip_hdr)
protocol_number = ip_header.proto
packet_without_ip_hdr = raw_packet[ip_header.hdr_length:]
payload = ReceivedPacket.payload_protocol[protocol_number].parse_bytes(packet_without_ip_hdr)
super().__init__(ip_header=ip_header, payload=payload)
# pim # PIM-DM
\ No newline at end of file
We have implemented the specification of PIM-DM ([RFC3973](https://tools.ietf.org/html/rfc3973)).
This repository stores the implementation of this protocol. The implementation is written in Python language and is destined to Linux systems.
# Requirements
- Linux machine
- Python3 (we have written all code to be compatible with at least Python v3.2)
- pip (to install all dependencies)
# Installation
You may need sudo permitions, in order to run this protocol. This is required because we use raw sockets to exchange control messages. For this reason, some sockets to work properly need to have super user permissions.
First clone this repository:
`git clone https://github.com/pedrofran12/pim.git`
Then enter in the cloned repository and install all dependencies:
`pip3 install -r requirements.txt`
And thats it :D
# Run protocol
In order to interact with the protocol you need to allways execute Run.py file. You can interact with the protocol by executing this file and specifying a command and corresponding arguments:
`sudo python3 Run.py -COMMAND ARGUMENTS`
In order to determine which commands are available you can call the help command:
`sudo python3 Run.py -h`
or
`sudo python3 Run.py --help`
In order to start the protocol you first need to explicitly start it. This will start a daemon process, which will be running in the background. The command is the following:
`sudo python3 Run.py -start`
Then you can enable the protocol in specific interfaces. You need to specify which interfaces will have IGMP enabled and which interfaces will have the PIM-DM enabled.
To enable PIM-DM, without State-Refresh, in a given interface, you need to run the following command:
`sudo python3 Run.py -ai INTERFACE_NAME`
To enable PIM-DM, with State-Refresh, in a given interface, you need to run the following command:
`sudo python3 Run.py -aisf INTERFACE_NAME`
To enable IGMP in a given interface, you need to run the following command:
`sudo python3 Run.py -aiigmp INTERFACE_NAME`
If you have previously enabled an interface without State-Refresh and want to enable it, in the same interface, you first need to disable this interface, and the run the command -aisr. The same happens when you want to disable State Refresh in a previously enabled StateRefresh interface.
To remove a previously added interface, you need run the following commands:
To remove a previously added PIM-DM interface:
`sudo python3 Run.py -ri INTERFACE_NAME`
To remove a previously added IGMP interface:
`sudo python3 Run.py -riigmp INTERFACE_NAME`
If you want to stop the protocol process, and stop the daemon process, you need to explicitly run this command:
`sudo python3 Run.py -stop`
## Commands for monitoring the protocol process
We have built some list commands that can be used to check the "internals" of the implementation.
- List neighbors:
Verify neighbors that have established a neighborhood relationship
`sudo python3 Run.py -ln`
- List state:
List all state machines and corresponding state of all trees that are being monitored. Also list IGMP state for each group being monitored.
`sudo python3 Run.py -ls`
- Multicast Routing Table:
List Linux Multicast Routing Table (equivalent to ip mroute -show)
`sudo python3 Run.py -mr`
## Change settings
Files tree/globals.py and igmp/igmp_globals.py store all timer values and some configurations regarding IGMP and the PIM-DM. If you want to tune the implementation, you can change the values of these files. These configurations are used by all interfaces, meaning that there is no tuning per interface.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Read Write Lock
"""
import threading
import time
class RWLockRead(object):
"""
A Read/Write lock giving preference to Reader
"""
def __init__(self):
self.V_ReadCount = 0
self.A_Resource = threading.Lock()
self.A_LockReadCount = threading.Lock()
class _aReader(object):
def __init__(self, p_RWLock):
self.A_RWLock = p_RWLock
self.V_Locked = False
def acquire(self, blocking=1, timeout=-1):
p_TimeOut = None if (blocking and timeout < 0) else (timeout if blocking else 0)
c_DeadLine = None if p_TimeOut is None else (time.time() + p_TimeOut)
if not self.A_RWLock.A_LockReadCount.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
return False
self.A_RWLock.V_ReadCount += 1
if self.A_RWLock.V_ReadCount == 1:
if not self.A_RWLock.A_Resource.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.V_ReadCount -= 1
self.A_RWLock.A_LockReadCount.release()
return False
self.A_RWLock.A_LockReadCount.release()
self.V_Locked = True
return True
def release(self):
if not self.V_Locked: raise RuntimeError("cannot release un-acquired lock")
self.V_Locked = False
self.A_RWLock.A_LockReadCount.acquire()
self.A_RWLock.V_ReadCount -= 1
if self.A_RWLock.V_ReadCount == 0:
self.A_RWLock.A_Resource.release()
self.A_RWLock.A_LockReadCount.release()
def locked(self):
return self.V_Locked
def __enter__(self):
self.acquire()
def __exit__(self, p_Type, p_Value, p_Traceback):
self.release()
class _aWriter(object):
def __init__(self, p_RWLock):
self.A_RWLock = p_RWLock
self.V_Locked = False
def acquire(self, blocking=1, timeout=-1):
self.V_Locked = self.A_RWLock.A_Resource.acquire(blocking, timeout)
return self.V_Locked
def release(self):
if not self.V_Locked: raise RuntimeError("cannot release un-acquired lock")
self.V_Locked = False
self.A_RWLock.A_Resource.release()
def locked(self):
return self.V_Locked
def __enter__(self):
self.acquire()
def __exit__(self, p_Type, p_Value, p_Traceback):
self.release()
def genRlock(self):
"""
Generate a reader lock
"""
return RWLockRead._aReader(self)
def genWlock(self):
"""
Generate a writer lock
"""
return RWLockRead._aWriter(self)
class RWLockWrite(object):
"""
A Read/Write lock giving preference to Writer
"""
def __init__(self):
self.V_ReadCount = 0
self.V_WriteCount = 0
self.A_LockReadCount = threading.Lock()
self.A_LockWriteCount = threading.Lock()
self.A_LockReadEntry = threading.Lock()
self.A_LockReadTry = threading.Lock()
self.A_Resource = threading.Lock()
class _aReader(object):
def __init__(self, p_RWLock):
self.A_RWLock = p_RWLock
self.V_Locked = False
def acquire(self, blocking=1, timeout=-1):
p_TimeOut = None if (blocking and timeout < 0) else (timeout if blocking else 0)
c_DeadLine = None if p_TimeOut is None else (time.time() + p_TimeOut)
if not self.A_RWLock.A_LockReadEntry.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
return False
if not self.A_RWLock.A_LockReadTry.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.A_LockReadEntry.release()
return False
if not self.A_RWLock.A_LockReadCount.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.A_LockReadTry.release()
self.A_RWLock.A_LockReadEntry.release()
return False
self.A_RWLock.V_ReadCount += 1
if (self.A_RWLock.V_ReadCount == 1):
if not self.A_RWLock.A_Resource.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.A_LockReadTry.release()
self.A_RWLock.A_LockReadEntry.release()
self.A_RWLock.V_ReadCount -= 1
self.A_RWLock.A_LockReadCount.release()
return False
self.A_RWLock.A_LockReadCount.release()
self.A_RWLock.A_LockReadTry.release()
self.A_RWLock.A_LockReadEntry.release()
self.V_Locked = True
return True
def release(self):
if not self.V_Locked: raise RuntimeError("cannot release un-acquired lock")
self.V_Locked = False
self.A_RWLock.A_LockReadCount.acquire()
self.A_RWLock.V_ReadCount -= 1
if (self.A_RWLock.V_ReadCount == 0):
self.A_RWLock.A_Resource.release()
self.A_RWLock.A_LockReadCount.release()
def locked(self):
return self.V_Locked
def __enter__(self):
self.acquire()
def __exit__(self, p_Type, p_Value, p_Traceback):
self.release()
class _aWriter(object):
def __init__(self, p_RWLock):
self.A_RWLock = p_RWLock
self.V_Locked = False
def acquire(self, blocking=1, timeout=-1):
p_TimeOut = None if (blocking and timeout < 0) else (timeout if blocking else 0)
c_DeadLine = None if p_TimeOut is None else (time.time() + p_TimeOut)
if not self.A_RWLock.A_LockWriteCount.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
return False
self.A_RWLock.V_WriteCount += 1
if (self.A_RWLock.V_WriteCount == 1):
if not self.A_RWLock.A_LockReadTry.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.V_WriteCount -= 1
self.A_RWLock.A_LockWriteCount.release()
return False
self.A_RWLock.A_LockWriteCount.release()
if not self.A_RWLock.A_Resource.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.A_LockWriteCount.acquire()
self.A_RWLock.V_WriteCount -= 1
if self.A_RWLock.V_WriteCount == 0:
self.A_RWLock.A_LockReadTry.release()
self.A_RWLock.A_LockWriteCount.release()
return False
self.V_Locked = True
return True
def release(self):
if not self.V_Locked: raise RuntimeError("cannot release un-acquired lock")
self.V_Locked = False
self.A_RWLock.A_Resource.release()
self.A_RWLock.A_LockWriteCount.acquire()
self.A_RWLock.V_WriteCount -= 1
if (self.A_RWLock.V_WriteCount == 0):
self.A_RWLock.A_LockReadTry.release()
self.A_RWLock.A_LockWriteCount.release()
def locked(self):
return self.V_Locked
def __enter__(self):
self.acquire()
def __exit__(self, p_Type, p_Value, p_Traceback):
self.release()
def genRlock(self):
"""
Generate a reader lock
"""
return RWLockWrite._aReader(self)
def genWlock(self):
"""
Generate a writer lock
"""
return RWLockWrite._aWriter(self)
class RWLockFair(object):
"""
A Read/Write lock giving fairness to both Reader and Writer
"""
def __init__(self):
self.V_ReadCount = 0
self.A_LockReadCount = threading.Lock()
self.A_LockRead = threading.Lock()
self.A_LockWrite = threading.Lock()
class _aReader(object):
def __init__(self, p_RWLock):
self.A_RWLock = p_RWLock
self.V_Locked = False
def acquire(self, blocking=1, timeout=-1):
p_TimeOut = None if (blocking and timeout < 0) else (timeout if blocking else 0)
c_DeadLine = None if p_TimeOut is None else (time.time() + p_TimeOut)
if not self.A_RWLock.A_LockRead.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
return False
if not self.A_RWLock.A_LockReadCount.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.A_LockRead.release()
return False
self.A_RWLock.V_ReadCount += 1
if self.A_RWLock.V_ReadCount == 1:
if not self.A_RWLock.A_LockWrite.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.V_ReadCount -= 1
self.A_RWLock.A_LockReadCount.release()
self.A_RWLock.A_LockRead.release()
return False
self.A_RWLock.A_LockReadCount.release()
self.A_RWLock.A_LockRead.release()
self.V_Locked = True
return True
def release(self):
if not self.V_Locked: raise RuntimeError("cannot release un-acquired lock")
self.V_Locked = False
self.A_RWLock.A_LockReadCount.acquire()
self.A_RWLock.V_ReadCount -= 1
if self.A_RWLock.V_ReadCount == 0:
self.A_RWLock.A_LockWrite.release()
self.A_RWLock.A_LockReadCount.release()
def locked(self):
return self.V_Locked
def __enter__(self):
self.acquire()
def __exit__(self, p_Type, p_Value, p_Traceback):
self.release()
class _aWriter(object):
def __init__(self, p_RWLock):
self.A_RWLock = p_RWLock
self.V_Locked = False
def acquire(self, blocking=1, timeout=-1):
p_TimeOut = None if (blocking and timeout < 0) else (timeout if blocking else 0)
c_DeadLine = None if p_TimeOut is None else (time.time() + p_TimeOut)
if not self.A_RWLock.A_LockRead.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
return False
if not self.A_RWLock.A_LockWrite.acquire(blocking=1, timeout=-1 if c_DeadLine is None else max(0, c_DeadLine - time.time())):
self.A_RWLock.A_LockRead.release()
return False
self.V_Locked = True
return True
def release(self):
if not self.V_Locked: raise RuntimeError("cannot release un-acquired lock")
self.V_Locked = False
self.A_RWLock.A_LockWrite.release()
self.A_RWLock.A_LockRead.release()
def locked(self):
return self.V_Locked
def __enter__(self):
self.acquire()
def __exit__(self, p_Type, p_Value, p_Traceback):
self.release()
def genRlock(self):
"""
Generate a reader lock
"""
return RWLockFair._aReader(self)
def genWlock(self):
"""
Generate a writer lock
"""
return RWLockFair._aWriter(self)
#!/usr/bin/env python
from Daemon.Daemon import Daemon
import Main
import _pickle as pickle
import socket
import sys
import os
import argparse
import traceback
def client_socket(data_to_send):
# Create a UDS socket
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
# Connect the socket to the port where the server is listening
server_address = './uds_socket'
#print('connecting to %s' % server_address)
try:
sock.connect(server_address)
sock.sendall(pickle.dumps(data_to_send))
data_rcv = sock.recv(1024 * 256)
if data_rcv:
print(pickle.loads(data_rcv))
except socket.error:
pass
finally:
#print('closing socket')
sock.close()
class MyDaemon(Daemon):
def run(self):
Main.main()
server_address = './uds_socket'
# Make sure the socket does not already exist
try:
os.unlink(server_address)
except OSError:
if os.path.exists(server_address):
raise
# Create a UDS socket
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
# Bind the socket to the port
sock.bind(server_address)
# Listen for incoming connections
sock.listen(1)
while True:
try:
connection, client_address = sock.accept()
data = connection.recv(256 * 1024)
print(sys.stderr, 'sending data back to the client')
print(pickle.loads(data))
args = pickle.loads(data)
if 'list_interfaces' in args and args.list_interfaces:
connection.sendall(pickle.dumps(Main.list_enabled_interfaces()))
elif 'list_neighbors' in args and args.list_neighbors:
connection.sendall(pickle.dumps(Main.list_neighbors()))
elif 'list_state' in args and args.list_state:
connection.sendall(pickle.dumps(Main.list_state()))
elif 'add_interface' in args and args.add_interface:
Main.add_pim_interface(args.add_interface[0], False)
connection.shutdown(socket.SHUT_RDWR)
elif 'add_interface_sr' in args and args.add_interface_sr:
Main.add_pim_interface(args.add_interface_sr[0], True)
connection.shutdown(socket.SHUT_RDWR)
elif 'add_interface_igmp' in args and args.add_interface_igmp:
Main.add_igmp_interface(args.add_interface_igmp[0])
connection.shutdown(socket.SHUT_RDWR)
elif 'remove_interface' in args and args.remove_interface:
Main.remove_interface(args.remove_interface[0], pim=True)
connection.shutdown(socket.SHUT_RDWR)
elif 'remove_interface_igmp' in args and args.remove_interface_igmp:
Main.remove_interface(args.remove_interface_igmp[0], igmp=True)
connection.shutdown(socket.SHUT_RDWR)
elif 'stop' in args and args.stop:
Main.stop()
connection.shutdown(socket.SHUT_RDWR)
elif 'test' in args and args.test:
Main.test(args.test[0], args.test[1])
connection.shutdown(socket.SHUT_RDWR)
except Exception:
connection.shutdown(socket.SHUT_RDWR)
traceback.print_exc()
finally:
# Clean up the connection
connection.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='PIM')
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("-start", "--start", action="store_true", default=False, help="Start PIM")
group.add_argument("-stop", "--stop", action="store_true", default=False, help="Stop PIM")
group.add_argument("-restart", "--restart", action="store_true", default=False, help="Restart PIM")
group.add_argument("-li", "--list_interfaces", action="store_true", default=False, help="List All PIM Interfaces")
group.add_argument("-ln", "--list_neighbors", action="store_true", default=False, help="List All PIM Neighbors")
group.add_argument("-ls", "--list_state", action="store_true", default=False, help="List state of IGMP")
group.add_argument("-mr", "--multicast_routes", action="store_true", default=False, help="List Multicast Routing table")
group.add_argument("-ai", "--add_interface", nargs=1, metavar='INTERFACE_NAME', help="Add PIM interface")
group.add_argument("-aisr", "--add_interface_sr", nargs=1, metavar='INTERFACE_NAME', help="Add PIM interface with State Refresh enabled")
group.add_argument("-aiigmp", "--add_interface_igmp", nargs=1, metavar='INTERFACE_NAME', help="Add IGMP interface")
group.add_argument("-ri", "--remove_interface", nargs=1, metavar='INTERFACE_NAME', help="Remove PIM interface")
group.add_argument("-riigmp", "--remove_interface_igmp", nargs=1, metavar='INTERFACE_NAME', help="Remove IGMP interface")
group.add_argument("-v", "--verbose", action="store_true", default=False, help="Verbose (print all debug messages)")
group.add_argument("-t", "--test", nargs=2, metavar=('ROUTER_NAME', 'SERVER_LOG_IP'), help="Tester... send log information to SERVER_LOG_IP. Set the router name to ROUTER_NAME")
args = parser.parse_args()
print(parser.parse_args())
daemon = MyDaemon('/tmp/Daemon-pim.pid')
if args.start:
print("start")
daemon.start()
sys.exit(0)
elif args.stop:
client_socket(args)
daemon.stop()
sys.exit(0)
elif args.restart:
daemon.restart()
sys.exit(0)
elif args.verbose:
os.system("tailf stdout")
sys.exit(0)
elif args.multicast_routes:
os.system("ip mroute show")
sys.exit(0)
elif not daemon.is_running():
print("PIM is not running")
parser.print_usage()
sys.exit(0)
client_socket(args)
import socket
import time
import struct
# ficheiros importantes: /usr/include/linux/mroute.h
MRT_BASE = 200
MRT_INIT = (MRT_BASE) # Activate the kernel mroute code */
MRT_DONE = (MRT_BASE+1) #/* Shutdown the kernel mroute */
MRT_ADD_VIF = (MRT_BASE+2) #/* Add a virtual interface */
MRT_DEL_VIF = (MRT_BASE+3) #/* Delete a virtual interface */
MRT_ADD_MFC = (MRT_BASE+4) #/* Add a multicast forwarding entry */
MRT_DEL_MFC = (MRT_BASE+5) #/* Delete a multicast forwarding entry */
MRT_VERSION = (MRT_BASE+6) #/* Get the kernel multicast version */
MRT_ASSERT = (MRT_BASE+7) #/* Activate PIM assert mode */
MRT_PIM = (MRT_BASE+8) #/* enable PIM code */
MRT_TABLE = (MRT_BASE+9) #/* Specify mroute table ID */
MRT_ADD_MFC_PROXY = (MRT_BASE+10) #/* Add a (*,*|G) mfc entry */
MRT_DEL_MFC_PROXY = (MRT_BASE+11) #/* Del a (*,*|G) mfc entry */
MRT_MAX = (MRT_BASE+11)
IGMPMSG_NOCACHE = 1
IGMPMSG_WRONGVIF = 2
IGMPMSG_WHOLEPKT = 3
s2 = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_IGMP)
#MRT INIT
s2.setsockopt(socket.IPPROTO_IP, MRT_INIT, 1)
#MRT PIM
s2.setsockopt(socket.IPPROTO_IP, MRT_PIM, 1)
#ADD VIRTUAL INTERFACE
#estrutura = struct.pack("HBBI 4s 4s", 1, 0x4, 0, 0, socket.inet_aton("192.168.1.112"), socket.inet_aton("224.1.1.112"))
estrutura = struct.pack("HBBI 4s 4s", 0, 0x0, 1, 0, socket.inet_aton("10.0.0.1"), socket.inet_aton("0.0.0.0"))
print(estrutura)
s2.setsockopt(socket.IPPROTO_IP, MRT_ADD_VIF, estrutura)
estrutura = struct.pack("HBBI 4s 4s", 1, 0x0, 1, 0, socket.inet_aton("192.168.2.2"), socket.inet_aton("0.0.0.0"))
print(estrutura)
s2.setsockopt(socket.IPPROTO_IP, MRT_ADD_VIF, estrutura)
#time.sleep(5)
while True:
print("recv:")
msg = s2.recv(5000)
print(len(msg))
(_, _, im_msgtype, im_mbz, im_vif, _, im_src, im_dst, _) = struct.unpack("II B B B B 4s 4s 8s", msg)
print(im_msgtype)
print(im_mbz)
print(im_vif)
print(socket.inet_ntoa(im_src))
print(socket.inet_ntoa(im_dst))
if im_msgtype == IGMPMSG_NOCACHE:
print("^^ IGMP NO CACHE")
print(struct.unpack("II B B B B 4s 4s 8s", msg))
#s2.setsockopt(socket.IPPROTO_IP, MRT_PIM, 1)
#print(s2.getsockopt(socket.IPPROTO_IP, 208))
#s2.setsockopt(socket.IPPROTO_IP, 208, 0)
#ADD MULTICAST FORWARDING ENTRY
estrutura = struct.pack("4s 4s H BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB IIIi", socket.inet_aton("10.0.0.2"), socket.inet_aton("224.1.1.113"), 0, 0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0, 0, 0, 0)
s2.setsockopt(socket.IPPROTO_IP, MRT_ADD_MFC, estrutura)
time.sleep(30)
#MRT DONE
s2.setsockopt(socket.IPPROTO_IP, MRT_DONE, 1)
s2.close()
exit(1)
import logging
class RootFilter(logging.Filter):
"""
This is a filter which injects contextual information into the log.
Rather than use actual contextual information, we just use random
data in this demo.
"""
def __init__(self, router_name):
super().__init__()
self.router_name = router_name
def filter(self, record):
record.routername = self.router_name
if not hasattr(record, 'tree'):
record.tree = ''
if not hasattr(record, 'vif'):
record.vif = ''
if not hasattr(record, 'interfacename'):
record.interfacename = ''
if not hasattr(record, 'neighbor_ip'):
record.neighbor_ip = ''
return True
import socket
import ipaddress
from pyroute2 import IPDB, IPRoute
import Main
from utils import if_indextoname
def get_route(ip_dst: str):
return UnicastRouting.get_route(ip_dst)
def get_metric(ip_dst: str):
return UnicastRouting.get_metric(ip_dst)
def check_rpf(ip_dst):
return UnicastRouting.check_rpf(ip_dst)
def get_unicast_info(ip_dst):
return UnicastRouting.get_unicast_info(ip_dst)
class UnicastRouting(object):
ipdb = None
def __init__(self):
UnicastRouting.ipr = IPRoute()
UnicastRouting.ipdb = IPDB()
self._ipdb = UnicastRouting.ipdb
self._ipdb.register_callback(UnicastRouting.unicast_changes, mode="post")
# get metrics (routing preference and cost) to IP ip_dst
@staticmethod
def get_metric(ip_dst: str):
(metric_administrative_distance, metric_cost, _, _, mask) = UnicastRouting.get_unicast_info(ip_dst)
return (metric_administrative_distance, metric_cost, mask)
# get output interface IP, used to send data to IP ip_dst
# (root interface IP to ip_dst)
@staticmethod
def check_rpf(ip_dst):
# vif index of rpf interface
return UnicastRouting.get_unicast_info(ip_dst)[3]
@staticmethod
def get_route(ip_dst: str):
ip_bytes = socket.inet_aton(ip_dst)
ip_int = int.from_bytes(ip_bytes, byteorder='big')
info = None
ipdb = UnicastRouting.ipdb # type:IPDB
for mask_len in range(32, 0, -1):
ip_bytes = (ip_int & (0xFFFFFFFF << (32 - mask_len))).to_bytes(4, "big")
ip_dst = socket.inet_ntoa(ip_bytes) + "/" + str(mask_len)
print(ip_dst)
if ip_dst in ipdb.routes:
print(info)
if ipdb.routes[ip_dst]['ipdb_scope'] != 'gc':
info = ipdb.routes[ip_dst]
break
else:
continue
if not info:
print("0.0.0.0/0")
if "default" in ipdb.routes:
info = ipdb.routes["default"]
print(info)
return info
@staticmethod
def get_unicast_info(ip_dst):
metric_administrative_distance = 0xFFFFFFFF
metric_cost = 0xFFFFFFFF
rpf_node = ip_dst
oif = None
mask = 0
unicast_route = UnicastRouting.get_route(ip_dst)
if unicast_route is not None:
oif = unicast_route.get("oif")
next_hop = unicast_route["gateway"]
multipaths = unicast_route["multipath"]
# prefsrc = unicast_route.get("prefsrc")
# rpf_node = ip_dst if (next_hop is None and prefsrc is not None) else next_hop
rpf_node = next_hop if next_hop is not None else ip_dst
highest_ip = ipaddress.ip_address("0.0.0.0")
for m in multipaths:
if m["gateway"] is None:
oif = m.get('oif')
rpf_node = ip_dst
break
elif ipaddress.ip_address(m["gateway"]) > highest_ip:
highest_ip = ipaddress.ip_address(m["gateway"])
oif = m.get('oif')
rpf_node = m["gateway"]
metric_administrative_distance = unicast_route["proto"]
metric_cost = unicast_route["priority"]
metric_cost = metric_cost if metric_cost is not None else 0
mask = unicast_route["dst_len"]
interface_name = None if oif is None else if_indextoname(int(oif))
rpf_if = Main.kernel.vif_name_to_index_dic.get(interface_name)
return (metric_administrative_distance, metric_cost, rpf_node, rpf_if, mask)
@staticmethod
def unicast_changes(ipdb, msg, action):
print("unicast change?")
print(action)
if action == "RTM_NEWROUTE" or action == "RTM_DELROUTE":
print(ipdb.routes)
mask_len = msg["dst_len"]
network_address = None
attrs = msg["attrs"]
print(attrs)
for (key, value) in attrs:
print((key, value))
if key == "RTA_DST":
network_address = value
break
if network_address is None:
network_address = "0.0.0.0"
print(network_address)
print(mask_len)
print(network_address + "/" + str(mask_len))
subnet = ipaddress.ip_network(network_address + "/" + str(mask_len))
print(str(subnet))
Main.kernel.notify_unicast_changes(subnet)
'''
elif action == "RTM_NEWADDR" or action == "RTM_DELADDR":
print(action)
print(msg)
interface_name = None
attrs = msg["attrs"]
for (key, value) in attrs:
print((key, value))
if key == "IFA_LABEL":
interface_name = value
break
UnicastRouting.lock.release()
try:
Main.kernel.notify_interface_changes(interface_name)
except:
import traceback
traceback.print_exc()
pass
bnet = ipaddress.ip_network("0.0.0.0/0")
Main.kernel.notify_unicast_changes(subnet)
elif action == "RTM_NEWLINK" or action == "RTM_DELLINK":
attrs = msg["attrs"]
if_name = None
operation = None
for (key, value) in attrs:
print((key, value))
if key == "IFLA_IFNAME":
if_name = value
elif key == "IFLA_OPERSTATE":
operation = value
if if_name is not None and operation is not None:
break
if if_name is not None:
print(if_name + ": " + operation)
UnicastRouting.lock.release()
if operation == 'DOWN':
Main.kernel.remove_interface(if_name, igmp=True, pim=True)
subnet = ipaddress.ip_network("0.0.0.0/0")
Main.kernel.notify_unicast_changes(subnet)
'''
def stop(self):
if UnicastRouting.ipdb:
UnicastRouting.ipdb = None
if self._ipdb:
self._ipdb.release()
import logging
from threading import Lock
from threading import Timer
from utils import GroupMembershipInterval, LastMemberQueryInterval, TYPE_CHECKING
from .wrapper import NoMembersPresent
if TYPE_CHECKING:
from .RouterState import RouterState
class GroupState(object):
LOGGER = logging.getLogger('pim.igmp.RouterState.GroupState')
def __init__(self, router_state: 'RouterState', group_ip: str):
#logger
extra_dict_logger = router_state.router_state_logger.extra.copy()
extra_dict_logger['tree'] = '(*,' + group_ip + ')'
self.group_state_logger = logging.LoggerAdapter(GroupState.LOGGER, extra_dict_logger)
#timers and state
self.router_state = router_state
self.group_ip = group_ip
self.state = NoMembersPresent
self.timer = None
self.v1_host_timer = None
self.retransmit_timer = None
# lock
self.lock = Lock()
# KernelEntry's instances to notify change of igmp state
self.multicast_interface_state = []
self.multicast_interface_state_lock = Lock()
def print_state(self):
return self.state.print_state()
###########################################
# Set state
###########################################
def set_state(self, state):
self.state = state
self.group_state_logger.debug("change membership state to: " + state.print_state())
###########################################
# Set timers
###########################################
def set_timer(self, alternative: bool=False, max_response_time: int=None):
self.clear_timer()
if not alternative:
time = GroupMembershipInterval
else:
time = self.router_state.interface_state.get_group_membership_time(max_response_time)
timer = Timer(time, self.group_membership_timeout)
timer.start()
self.timer = timer
def clear_timer(self):
if self.timer is not None:
self.timer.cancel()
def set_v1_host_timer(self):
self.clear_v1_host_timer()
v1_host_timer = Timer(GroupMembershipInterval, self.group_membership_v1_timeout)
v1_host_timer.start()
self.v1_host_timer = v1_host_timer
def clear_v1_host_timer(self):
if self.v1_host_timer is not None:
self.v1_host_timer.cancel()
def set_retransmit_timer(self):
self.clear_retransmit_timer()
retransmit_timer = Timer(LastMemberQueryInterval, self.retransmit_timeout)
retransmit_timer.start()
self.retransmit_timer = retransmit_timer
def clear_retransmit_timer(self):
if self.retransmit_timer is not None:
self.retransmit_timer.cancel()
###########################################
# Get group state from specific interface state
###########################################
def get_interface_group_state(self):
return self.state.get_state(self.router_state)
###########################################
# Timer timeout
###########################################
def group_membership_timeout(self):
with self.lock:
self.get_interface_group_state().group_membership_timeout(self)
def group_membership_v1_timeout(self):
with self.lock:
self.get_interface_group_state().group_membership_v1_timeout(self)
def retransmit_timeout(self):
with self.lock:
self.get_interface_group_state().retransmit_timeout(self)
###########################################
# Receive Packets
###########################################
def receive_v1_membership_report(self):
with self.lock:
self.get_interface_group_state().receive_v1_membership_report(self)
def receive_v2_membership_report(self):
with self.lock:
self.get_interface_group_state().receive_v2_membership_report(self)
def receive_leave_group(self):
with self.lock:
self.get_interface_group_state().receive_leave_group(self)
def receive_group_specific_query(self, max_response_time: int):
with self.lock:
self.get_interface_group_state().receive_group_specific_query(self, max_response_time)
###########################################
# Notify Routing
###########################################
def notify_routing_add(self):
with self.multicast_interface_state_lock:
print("notify+", self.multicast_interface_state)
for interface_state in self.multicast_interface_state:
interface_state.notify_igmp(has_members=True)
def notify_routing_remove(self):
with self.multicast_interface_state_lock:
print("notify-", self.multicast_interface_state)
for interface_state in self.multicast_interface_state:
interface_state.notify_igmp(has_members=False)
def add_multicast_routing_entry(self, kernel_entry):
with self.multicast_interface_state_lock:
self.multicast_interface_state.append(kernel_entry)
return self.has_members()
def remove_multicast_routing_entry(self, kernel_entry):
with self.multicast_interface_state_lock:
self.multicast_interface_state.remove(kernel_entry)
def has_members(self):
return self.state is not NoMembersPresent
def remove(self):
with self.multicast_interface_state_lock:
self.clear_retransmit_timer()
self.clear_timer()
self.clear_v1_host_timer()
for interface_state in self.multicast_interface_state:
interface_state.notify_igmp(has_members=False)
del self.multicast_interface_state[:]
from threading import Timer
import logging
from Packet.PacketIGMPHeader import PacketIGMPHeader
from Packet.ReceivedPacket import ReceivedPacket
from utils import Membership_Query, QueryResponseInterval, QueryInterval, OtherQuerierPresentInterval, TYPE_CHECKING
from RWLock.RWLock import RWLockWrite
from .querier.Querier import Querier
from .nonquerier.NonQuerier import NonQuerier
from .GroupState import GroupState
if TYPE_CHECKING:
from InterfaceIGMP import InterfaceIGMP
class RouterState(object):
ROUTER_STATE_LOGGER = logging.getLogger('pim.igmp.RouterState')
def __init__(self, interface: 'InterfaceIGMP'):
#logger
logger_extra = dict()
logger_extra['vif'] = interface.vif_index
logger_extra['interfacename'] = interface.interface_name
self.router_state_logger = logging.LoggerAdapter(RouterState.ROUTER_STATE_LOGGER, logger_extra)
# interface of the router connected to the network
self.interface = interface
# state of the router (Querier/NonQuerier)
self.interface_state = Querier
# state of each group
# Key: GroupIPAddress, Value: GroupState object
self.group_state = {}
self.group_state_lock = RWLockWrite()
# send general query
packet = PacketIGMPHeader(type=Membership_Query, max_resp_time=QueryResponseInterval*10)
self.interface.send(packet.bytes())
# set initial general query timer
timer = Timer(QueryInterval, self.general_query_timeout)
timer.start()
self.general_query_timer = timer
# present timer
self.other_querier_present_timer = None
# Send packet via interface
def send(self, data: bytes, address: str):
self.interface.send(data, address)
############################################
# interface_state methods
############################################
def print_state(self):
return self.interface_state.state_name()
def set_general_query_timer(self):
self.clear_general_query_timer()
general_query_timer = Timer(QueryInterval, self.general_query_timeout)
general_query_timer.start()
self.general_query_timer = general_query_timer
def clear_general_query_timer(self):
if self.general_query_timer is not None:
self.general_query_timer.cancel()
def set_other_querier_present_timer(self):
self.clear_other_querier_present_timer()
other_querier_present_timer = Timer(OtherQuerierPresentInterval, self.other_querier_present_timeout)
other_querier_present_timer.start()
self.other_querier_present_timer = other_querier_present_timer
def clear_other_querier_present_timer(self):
if self.other_querier_present_timer is not None:
self.other_querier_present_timer.cancel()
def general_query_timeout(self):
self.interface_state.general_query_timeout(self)
def other_querier_present_timeout(self):
self.interface_state.other_querier_present_timeout(self)
def change_interface_state(self, querier: bool):
if querier:
self.interface_state = Querier
self.router_state_logger.debug('change querier state to -> Querier')
else:
self.interface_state = NonQuerier
self.router_state_logger.debug('change querier state to -> NonQuerier')
############################################
# group state methods
############################################
def get_group_state(self, group_ip):
with self.group_state_lock.genRlock():
if group_ip in self.group_state:
return self.group_state[group_ip]
with self.group_state_lock.genWlock():
if group_ip in self.group_state:
group_state = self.group_state[group_ip]
else:
group_state = GroupState(self, group_ip)
self.group_state[group_ip] = group_state
return group_state
def receive_v1_membership_report(self, packet: ReceivedPacket):
igmp_group = packet.payload.group_address
#if igmp_group not in self.group_state:
# self.group_state[igmp_group] = GroupState(self, igmp_group)
#self.group_state[igmp_group].receive_v1_membership_report()
self.get_group_state(igmp_group).receive_v1_membership_report()
def receive_v2_membership_report(self, packet: ReceivedPacket):
igmp_group = packet.payload.group_address
#if igmp_group not in self.group_state:
# self.group_state[igmp_group] = GroupState(self, igmp_group)
#self.group_state[igmp_group].receive_v2_membership_report()
self.get_group_state(igmp_group).receive_v2_membership_report()
def receive_leave_group(self, packet: ReceivedPacket):
igmp_group = packet.payload.group_address
#if igmp_group in self.group_state:
# self.group_state[igmp_group].receive_leave_group()
self.get_group_state(igmp_group).receive_leave_group()
def receive_query(self, packet: ReceivedPacket):
self.interface_state.receive_query(self, packet)
igmp_group = packet.payload.group_address
# process group specific query
if igmp_group != "0.0.0.0" and igmp_group in self.group_state:
#if igmp_group != "0.0.0.0":
max_response_time = packet.payload.max_resp_time
#self.group_state[igmp_group].receive_group_specific_query(max_response_time)
self.get_group_state(igmp_group).receive_group_specific_query(max_response_time)
def remove(self):
for group in self.group_state.values():
group.remove()
\ No newline at end of file
from utils import TYPE_CHECKING
from ..wrapper import NoMembersPresent
from ..wrapper import MembersPresent
if TYPE_CHECKING:
from ..GroupState import GroupState
def group_membership_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier CheckingMembership: group_membership_timeout')
group_state.set_state(NoMembersPresent)
# NOTIFY ROUTING - !!!!
group_state.notify_routing_remove()
def group_membership_v1_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier CheckingMembership: group_membership_v1_timeout')
# do nothing
return
def retransmit_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier CheckingMembership: retransmit_timeout')
# do nothing
return
def receive_v1_membership_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier CheckingMembership: receive_v1_membership_report')
receive_v2_membership_report(group_state)
def receive_v2_membership_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier CheckingMembership: receive_v2_membership_report')
group_state.set_timer()
group_state.set_state(MembersPresent)
def receive_leave_group(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier CheckingMembership: receive_leave_group')
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
group_state.group_state_logger.debug('NonQuerier CheckingMembership: receive_group_specific_query')
# do nothing
return
from utils import TYPE_CHECKING
from ..wrapper import NoMembersPresent
from ..wrapper import CheckingMembership
if TYPE_CHECKING:
from ..GroupState import GroupState
def group_membership_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier MembersPresent: group_membership_timeout')
group_state.set_state(NoMembersPresent)
# NOTIFY ROUTING - !!!!
group_state.notify_routing_remove()
def group_membership_v1_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier MembersPresent: group_membership_v1_timeout')
# do nothing
return
def retransmit_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier MembersPresent: retransmit_timeout')
# do nothing
return
def receive_v1_membership_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier MembersPresent: receive_v1_membership_report')
receive_v2_membership_report(group_state)
def receive_v2_membership_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier MembersPresent: receive_v2_membership_report')
group_state.set_timer()
def receive_leave_group(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier MembersPresent: receive_leave_group')
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
group_state.group_state_logger.debug('NonQuerier MembersPresent: receive_group_specific_query')
group_state.set_timer(alternative=True, max_response_time=max_response_time)
group_state.set_state(CheckingMembership)
from utils import TYPE_CHECKING
from ..wrapper import MembersPresent
if TYPE_CHECKING:
from ..GroupState import GroupState
def group_membership_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier NoMembersPresent: group_membership_timeout')
# do nothing
return
def group_membership_v1_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier NoMembersPresent: group_membership_v1_timeout')
# do nothing
return
def retransmit_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier NoMembersPresent: retransmit_timeout')
# do nothing
return
def receive_v1_membership_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier NoMembersPresent: receive_v1_membership_report')
receive_v2_membership_report(group_state)
def receive_v2_membership_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier NoMembersPresent: receive_v2_membership_report')
group_state.set_timer()
group_state.set_state(MembersPresent)
# NOTIFY ROUTING + !!!!
group_state.notify_routing_add()
def receive_leave_group(group_state: 'GroupState'):
group_state.group_state_logger.debug('NonQuerier NoMembersPresent: receive_leave_group')
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
group_state.group_state_logger.debug('NonQuerier NoMembersPresent: receive_group_specific_query')
# do nothing
return
from ipaddress import IPv4Address
from utils import Membership_Query, QueryResponseInterval, LastMemberQueryCount, TYPE_CHECKING
from Packet.PacketIGMPHeader import PacketIGMPHeader
from Packet.ReceivedPacket import ReceivedPacket
from . import NoMembersPresent, MembersPresent, CheckingMembership
if TYPE_CHECKING:
from ..RouterState import RouterState
class NonQuerier:
@staticmethod
def general_query_timeout(router_state: 'RouterState'):
router_state.router_state_logger.debug('NonQuerier state: general_query_timeout')
# do nothing
return
@staticmethod
def other_querier_present_timeout(router_state: 'RouterState'):
router_state.router_state_logger.debug('NonQuerier state: other_querier_present_timeout')
#change state to Querier
router_state.change_interface_state(querier=True)
# send general query
packet = PacketIGMPHeader(type=Membership_Query, max_resp_time=QueryResponseInterval*10)
router_state.interface.send(packet.bytes())
# set general query timer
router_state.set_general_query_timer()
@staticmethod
def receive_query(router_state: 'RouterState', packet: ReceivedPacket):
router_state.router_state_logger.debug('NonQuerier state: receive_query')
source_ip = packet.ip_header.ip_src
# if source ip of membership query not lower than the ip of the received interface => ignore
if IPv4Address(source_ip) >= IPv4Address(router_state.interface.get_ip()):
return
# reset other present querier timer
router_state.set_other_querier_present_timer()
# TODO ver se existe uma melhor maneira de fazer isto
@staticmethod
def state_name():
return "Non Querier"
@staticmethod
def get_group_membership_time(max_response_time: int):
return (max_response_time/10.0) * LastMemberQueryCount
# State
@staticmethod
def get_checking_membership_state():
return CheckingMembership
@staticmethod
def get_members_present_state():
return MembersPresent
@staticmethod
def get_no_members_present_state():
return NoMembersPresent
@staticmethod
def get_version_1_members_present_state():
return NonQuerier.get_members_present_state()
from Packet.PacketIGMPHeader import PacketIGMPHeader
from utils import Membership_Query, LastMemberQueryInterval, TYPE_CHECKING
from ..wrapper import NoMembersPresent, MembersPresent, Version1MembersPresent
if TYPE_CHECKING:
from ..GroupState import GroupState
def group_membership_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier CheckingMembership: group_membership_timeout')
group_state.clear_retransmit_timer()
group_state.set_state(NoMembersPresent)
# NOTIFY ROUTING - !!!!
group_state.notify_routing_remove()
def group_membership_v1_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier CheckingMembership: group_membership_v1_timeout')
# do nothing
return
def retransmit_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier CheckingMembership: retransmit_timeout')
group_addr = group_state.group_ip
packet = PacketIGMPHeader(type=Membership_Query, max_resp_time=LastMemberQueryInterval*10, group_address=group_addr)
group_state.router_state.send(data=packet.bytes(), address=group_addr)
group_state.set_retransmit_timer()
def receive_v1_membership_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier CheckingMembership: receive_v1_membership_report')
group_state.set_timer()
group_state.set_v1_host_timer()
group_state.set_state(Version1MembersPresent)
def receive_v2_membership_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier CheckingMembership: receive_v2_membership_report')
group_state.set_timer()
group_state.set_state(MembersPresent)
def receive_leave_group(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier CheckingMembership: receive_leave_group')
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
group_state.group_state_logger.debug('Querier CheckingMembership: receive_group_specific_query')
# do nothing
return
from Packet.PacketIGMPHeader import PacketIGMPHeader
from utils import Membership_Query, LastMemberQueryInterval, TYPE_CHECKING
from ..wrapper import Version1MembersPresent, CheckingMembership, NoMembersPresent
if TYPE_CHECKING:
from ..GroupState import GroupState
def group_membership_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier MembersPresent: group_membership_timeout')
group_state.set_state(NoMembersPresent)
# NOTIFY ROUTING - !!!!
group_state.notify_routing_remove()
def group_membership_v1_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier MembersPresent: group_membership_v1_timeout')
# do nothing
return
def retransmit_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier MembersPresent: retransmit_timeout')
# do nothing
return
def receive_v1_membership_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier MembersPresent: receive_v1_membership_report')
group_state.set_timer()
group_state.set_v1_host_timer()
group_state.set_state(Version1MembersPresent)
def receive_v2_membership_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier MembersPresent: receive_v2_membership_report')
group_state.set_timer()
def receive_leave_group(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier MembersPresent: receive_leave_group')
group_ip = group_state.group_ip
group_state.set_timer(alternative=True)
group_state.set_retransmit_timer()
packet = PacketIGMPHeader(type=Membership_Query, max_resp_time=LastMemberQueryInterval*10, group_address=group_ip)
group_state.router_state.send(data=packet.bytes(), address=group_ip)
group_state.set_state(CheckingMembership)
def receive_group_specific_query(group_state: 'GroupState', max_response_time):
group_state.group_state_logger.debug('Querier MembersPresent: receive_group_specific_query')
# do nothing
return
from utils import TYPE_CHECKING
from ..wrapper import MembersPresent
from ..wrapper import Version1MembersPresent
if TYPE_CHECKING:
from ..GroupState import GroupState
def group_membership_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier NoMembersPresent: group_membership_timeout')
# do nothing
return
def group_membership_v1_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier NoMembersPresent: group_membership_v1_timeout')
# do nothing
return
def retransmit_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier NoMembersPresent: retransmit_timeout')
# do nothing
return
def receive_v1_membership_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier NoMembersPresent: receive_v1_membership_report')
group_state.set_timer()
group_state.set_v1_host_timer()
group_state.set_state(Version1MembersPresent)
# NOTIFY ROUTING + !!!!
group_state.notify_routing_add()
def receive_v2_membership_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier NoMembersPresent: receive_v2_membership_report')
group_state.set_timer()
group_state.set_state(MembersPresent)
# NOTIFY ROUTING + !!!!
group_state.notify_routing_add()
def receive_leave_group(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier NoMembersPresent: receive_leave_group')
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
group_state.group_state_logger.debug('Querier NoMembersPresent: receive_group_specific_query')
# do nothing
return
from ipaddress import IPv4Address
from utils import TYPE_CHECKING
from utils import Membership_Query, QueryResponseInterval, LastMemberQueryCount, LastMemberQueryInterval
from Packet.PacketIGMPHeader import PacketIGMPHeader
from Packet.ReceivedPacket import ReceivedPacket
from . import CheckingMembership, MembersPresent, Version1MembersPresent, NoMembersPresent
if TYPE_CHECKING:
from ..RouterState import RouterState
class Querier:
@staticmethod
def general_query_timeout(router_state: 'RouterState'):
router_state.router_state_logger.debug('Querier state: general_query_timeout')
# send general query
packet = PacketIGMPHeader(type=Membership_Query, max_resp_time=QueryResponseInterval*10)
router_state.interface.send(packet.bytes())
# set general query timer
router_state.set_general_query_timer()
@staticmethod
def other_querier_present_timeout(router_state: 'RouterState'):
router_state.router_state_logger.debug('Querier state: other_querier_present_timeout')
# do nothing
return
@staticmethod
def receive_query(router_state: 'RouterState', packet: ReceivedPacket):
router_state.router_state_logger.debug('Querier state: receive_query')
source_ip = packet.ip_header.ip_src
# if source ip of membership query not lower than the ip of the received interface => ignore
if IPv4Address(source_ip) >= IPv4Address(router_state.interface.get_ip()):
return
# if source ip of membership query lower than the ip of the received interface => change state
# change state of interface
# Querier -> Non Querier
router_state.change_interface_state(querier=False)
# set other present querier timer
router_state.clear_general_query_timer()
router_state.set_other_querier_present_timer()
# TODO ver se existe uma melhor maneira de fazer isto
@staticmethod
def state_name():
return "Querier"
@staticmethod
def get_group_membership_time(max_response_time: int):
return LastMemberQueryInterval * LastMemberQueryCount
# State
@staticmethod
def get_checking_membership_state():
return CheckingMembership
@staticmethod
def get_members_present_state():
return MembersPresent
@staticmethod
def get_no_members_present_state():
return NoMembersPresent
@staticmethod
def get_version_1_members_present_state():
return Version1MembersPresent
from utils import TYPE_CHECKING
from ..wrapper import NoMembersPresent
from ..wrapper import MembersPresent
if TYPE_CHECKING:
from ..GroupState import GroupState
def group_membership_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier Version1MembersPresent: group_membership_timeout')
group_state.set_state(NoMembersPresent)
# NOTIFY ROUTING - !!!!
group_state.notify_routing_remove()
def group_membership_v1_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier Version1MembersPresent: group_membership_v1_timeout')
group_state.set_state(MembersPresent)
def retransmit_timeout(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier Version1MembersPresent: retransmit_timeout')
# do nothing
return
def receive_v1_membership_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier Version1MembersPresent: receive_v1_membership_report')
group_state.set_timer()
group_state.set_v1_host_timer()
def receive_v2_membership_report(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier Version1MembersPresent: receive_v2_membership_report')
group_state.set_timer()
def receive_leave_group(group_state: 'GroupState'):
group_state.group_state_logger.debug('Querier Version1MembersPresent: receive_leave_group')
# do nothing
return
def receive_group_specific_query(group_state: 'GroupState', max_response_time: int):
group_state.group_state_logger.debug('Querier Version1MembersPresent: receive_group_specific_query')
# do nothing
return
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..RouterState import RouterState
def get_state(router_state: 'RouterState'):
return router_state.interface_state.get_checking_membership_state()
def print_state():
return "CheckingMembership"
'''
def group_membership_timeout(group_state):
get_state(group_state).group_membership_timeout(group_state)
def group_membership_v1_timeout(group_state):
get_state(group_state).group_membership_v1_timeout(group_state)
def retransmit_timeout(group_state):
get_state(group_state).retransmit_timeout(group_state)
def receive_v1_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v1_membership_report(group_state, packet)
def receive_v2_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v2_membership_report(group_state, packet)
def receive_leave_group(group_state, packet: ReceivedPacket):
get_state(group_state).receive_leave_group(group_state, packet)
def receive_group_specific_query(group_state, packet: ReceivedPacket):
get_state(group_state).receive_group_specific_query(group_state, packet)
'''
\ No newline at end of file
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..RouterState import RouterState
def get_state(router_state: 'RouterState'):
return router_state.interface_state.get_members_present_state()
def print_state():
return "MembersPresent"
'''
def group_membership_timeout(group_state):
get_state(group_state).group_membership_timeout(group_state)
def group_membership_v1_timeout(group_state):
get_state(group_state).group_membership_v1_timeout(group_state)
def retransmit_timeout(group_state):
get_state(group_state).retransmit_timeout(group_state)
def receive_v1_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v1_membership_report(group_state, packet)
def receive_v2_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v2_membership_report(group_state, packet)
def receive_leave_group(group_state, packet: ReceivedPacket):
get_state(group_state).receive_leave_group(group_state, packet)
def receive_group_specific_query(group_state, packet: ReceivedPacket):
get_state(group_state).receive_group_specific_query(group_state, packet)
'''
\ No newline at end of file
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..RouterState import RouterState
def get_state(router_state: 'RouterState'):
return router_state.interface_state.get_no_members_present_state()
def print_state():
return "NoMembersPresent"
'''
def group_membership_timeout(group_state):
get_state(group_state).group_membership_timeout(group_state)
def group_membership_v1_timeout(group_state):
get_state(group_state).group_membership_v1_timeout(group_state)
def retransmit_timeout(group_state):
get_state(group_state).retransmit_timeout(group_state)
def receive_v1_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v1_membership_report(group_state, packet)
def receive_v2_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v2_membership_report(group_state, packet)
def receive_leave_group(group_state, packet: ReceivedPacket):
get_state(group_state).receive_leave_group(group_state, packet)
def receive_group_specific_query(group_state, packet: ReceivedPacket):
get_state(group_state).receive_group_specific_query(group_state, packet)
'''
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from ..RouterState import RouterState
def get_state(router_state: 'RouterState'):
return router_state.interface_state.get_version_1_members_present_state()
def print_state():
return "Version1MembersPresent"
'''
def group_membership_timeout(group_state):
get_state(group_state).group_membership_timeout(group_state)
def group_membership_v1_timeout(group_state):
get_state(group_state).group_membership_v1_timeout(group_state)
def retransmit_timeout(group_state):
get_state(group_state).retransmit_timeout(group_state)
def receive_v1_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v1_membership_report(group_state, packet)
def receive_v2_membership_report(group_state, packet: ReceivedPacket):
get_state(group_state).receive_v2_membership_report(group_state, packet)
def receive_leave_group(group_state, packet: ReceivedPacket):
get_state(group_state).receive_leave_group(group_state, packet)
def receive_group_specific_query(group_state, packet: ReceivedPacket):
get_state(group_state).receive_group_specific_query(group_state, packet)
'''
import subprocess
import struct
import socket
from ctypes import create_string_buffer, addressof
SO_ATTACH_FILTER = 26
ETH_P_IP = 0x0800 # Internet Protocol packet
SO_RCVBUFFORCE = 33
def get_s_g_bpf_filter_code(source, group, interface_name):
#cmd = "tcpdump -ddd \"(udp or icmp) and host %s and dst %s\"" % (source, group)
cmd = "tcpdump -ddd \"(ip proto not 2) and host %s and dst %s\"" % (source, group)
result = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE)
bpf_filter = b''
tmp = result.stdout.read().splitlines()
num = int(tmp[0])
for line in tmp[1:]:
print(line)
bpf_filter += struct.pack("HBBI", *tuple(map(int, line.split(b' '))))
print(num)
# defined in linux/filter.h.
b = create_string_buffer(bpf_filter)
mem_addr_of_filters = addressof(b)
fprog = struct.pack('HL', num, mem_addr_of_filters)
# Create listening socket with filters
s = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, ETH_P_IP)
s.setsockopt(socket.SOL_SOCKET, SO_ATTACH_FILTER, fprog)
# todo pequeno ajuste (tamanho de buffer pequeno para o caso de trafego em rajadas):
#s.setsockopt(socket.SOL_SOCKET, SO_RCVBUFFORCE, 1)
s.bind((interface_name, ETH_P_IP))
return s
from tree.tree_if_upstream import TreeInterfaceUpstream
from tree.tree_if_downstream import TreeInterfaceDownstream
from .tree_interface import TreeInterface
from threading import Timer, Lock, RLock
from tree.metric import AssertMetric
import UnicastRouting
from time import time
import Main
import logging
class KernelEntry:
TREE_TIMEOUT = 180
KERNEL_LOGGER = logging.getLogger('pim.KernelEntry')
def __init__(self, source_ip: str, group_ip: str):
self.kernel_entry_logger = logging.LoggerAdapter(KernelEntry.KERNEL_LOGGER, {'tree': '(' + source_ip + ',' + group_ip + ')'})
self.kernel_entry_logger.debug('Create KernelEntry')
self.source_ip = source_ip
self.group_ip = group_ip
# OBTAIN UNICAST ROUTING INFORMATION###################################################
(metric_administrative_distance, metric_cost, rpf_node, root_if, mask) = \
UnicastRouting.get_unicast_info(source_ip)
if root_if is None:
raise Exception
self.rpf_node = rpf_node
# (S,G) starts IG state
self._was_olist_null = False
# Locks
self._multicast_change = Lock()
self._lock_test2 = RLock()
self.CHANGE_STATE_LOCK = RLock()
# decide inbound interface based on rpf check
self.inbound_interface_index = root_if
self.interface_state = {} # type: Dict[int, TreeInterface]
with self.CHANGE_STATE_LOCK:
for i in Main.kernel.vif_index_to_name_dic.keys():
try:
if i == self.inbound_interface_index:
self.interface_state[i] = TreeInterfaceUpstream(self, i)
else:
self.interface_state[i] = TreeInterfaceDownstream(self, i)
except:
import traceback
print(traceback.print_exc())
continue
self.change()
self.evaluate_olist_change()
self.timestamp_of_last_state_refresh_message_received = 0
print('Tree created')
def get_inbound_interface_index(self):
return self.inbound_interface_index
def get_outbound_interfaces_indexes(self):
outbound_indexes = [0]*Main.kernel.MAXVIFS
for (index, state) in self.interface_state.items():
outbound_indexes[index] = state.is_forwarding()
return outbound_indexes
################################################
# Receive (S,G) data packets or control packets
################################################
def recv_data_msg(self, index):
print("recv data")
self.interface_state[index].recv_data_msg()
def recv_assert_msg(self, index, packet):
print("recv assert")
pkt_assert = packet.payload.payload
metric = pkt_assert.metric
metric_preference = pkt_assert.metric_preference
assert_sender_ip = packet.ip_header.ip_src
received_metric = AssertMetric(metric_preference=metric_preference, route_metric=metric, ip_address=assert_sender_ip)
self.interface_state[index].recv_assert_msg(received_metric)
def recv_prune_msg(self, index, packet):
print("recv prune msg")
holdtime = packet.payload.payload.hold_time
upstream_neighbor_address = packet.payload.payload.upstream_neighbor_address
self.interface_state[index].recv_prune_msg(upstream_neighbor_address=upstream_neighbor_address, holdtime=holdtime)
def recv_join_msg(self, index, packet):
print("recv join msg")
upstream_neighbor_address = packet.payload.payload.upstream_neighbor_address
self.interface_state[index].recv_join_msg(upstream_neighbor_address)
def recv_graft_msg(self, index, packet):
print("recv graft msg")
upstream_neighbor_address = packet.payload.payload.upstream_neighbor_address
source_ip = packet.ip_header.ip_src
self.interface_state[index].recv_graft_msg(upstream_neighbor_address, source_ip)
def recv_graft_ack_msg(self, index, packet):
print("recv graft ack msg")
source_ip = packet.ip_header.ip_src
self.interface_state[index].recv_graft_ack_msg(source_ip)
def recv_state_refresh_msg(self, index, packet):
print("recv state refresh msg")
source_of_state_refresh = packet.ip_header.ip_src
metric_preference = packet.payload.payload.metric_preference
metric = packet.payload.payload.metric
ttl = packet.payload.payload.ttl
prune_indicator_flag = packet.payload.payload.prune_indicator_flag #P
interval = packet.payload.payload.interval
received_metric = AssertMetric(metric_preference=metric_preference, route_metric=metric, ip_address=source_of_state_refresh, state_refresh_interval=interval)
self.interface_state[index].recv_state_refresh_msg(received_metric, prune_indicator_flag)
iif = packet.interface.vif_index
if iif != self.inbound_interface_index:
return
if self.interface_state[iif].get_neighbor_RPF() != source_of_state_refresh:
return
# refresh limit
timestamp = time()
if (timestamp - self.timestamp_of_last_state_refresh_message_received) < interval:
return
self.timestamp_of_last_state_refresh_message_received = timestamp
if ttl == 0:
return
self.forward_state_refresh_msg(packet.payload.payload)
################################################
# Send state refresh msg
################################################
def forward_state_refresh_msg(self, state_refresh_packet):
for interface in self.interface_state.values():
interface.send_state_refresh(state_refresh_packet)
###############################################################
# Unicast Changes to RPF
###############################################################
def network_update(self):
# TODO TALVEZ OUTRO LOCK PARA BLOQUEAR ENTRADA DE PACOTES
with self.CHANGE_STATE_LOCK:
(metric_administrative_distance, metric_cost, rpf_node, new_inbound_interface_index, _) = \
UnicastRouting.get_unicast_info(self.source_ip)
if new_inbound_interface_index is None:
self.delete()
return
if new_inbound_interface_index != self.inbound_interface_index:
self.rpf_node = rpf_node
# get old interfaces
old_upstream_interface = self.interface_state.get(self.inbound_interface_index, None)
old_downstream_interface = self.interface_state.get(new_inbound_interface_index, None)
# change type of interfaces
if self.inbound_interface_index is not None:
new_downstream_interface = TreeInterfaceDownstream(self, self.inbound_interface_index)
self.interface_state[self.inbound_interface_index] = new_downstream_interface
new_upstream_interface = None
if new_inbound_interface_index is not None:
new_upstream_interface = TreeInterfaceUpstream(self, new_inbound_interface_index)
self.interface_state[new_inbound_interface_index] = new_upstream_interface
self.inbound_interface_index = new_inbound_interface_index
# remove old interfaces
if old_upstream_interface is not None:
old_upstream_interface.delete(change_type_interface=True)
if old_downstream_interface is not None:
old_downstream_interface.delete(change_type_interface=True)
# atualizar tabela de encaminhamento multicast
#self._was_olist_null = False
self.change()
self.evaluate_olist_change()
if new_upstream_interface is not None:
new_upstream_interface.change_on_unicast_routing(interface_change=True)
elif self.rpf_node != rpf_node:
self.rpf_node = rpf_node
self.interface_state[self.inbound_interface_index].change_on_unicast_routing()
# check if add/removal of neighbors from interface afects olist and forward/prune state of interface
def change_at_number_of_neighbors(self):
with self.CHANGE_STATE_LOCK:
self.change()
self.evaluate_olist_change()
def new_or_reset_neighbor(self, if_index, neighbor_ip):
# todo maybe lock de interfaces
self.interface_state[if_index].new_or_reset_neighbor(neighbor_ip)
def is_olist_null(self):
for interface in self.interface_state.values():
if interface.is_forwarding():
return False
return True
def evaluate_olist_change(self):
with self._lock_test2:
is_olist_null = self.is_olist_null()
if self._was_olist_null != is_olist_null:
if is_olist_null:
self.interface_state[self.inbound_interface_index].olist_is_null()
else:
self.interface_state[self.inbound_interface_index].olist_is_not_null()
self._was_olist_null = is_olist_null
def get_source(self):
return self.source_ip
def get_group(self):
return self.group_ip
def change(self):
with self._multicast_change:
Main.kernel.set_multicast_route(self)
def delete(self):
with self._multicast_change:
for state in self.interface_state.values():
state.delete()
Main.kernel.remove_multicast_route(self)
######################################
# Interface change
#######################################
def new_interface(self, index):
with self.CHANGE_STATE_LOCK:
self.interface_state[index] = TreeInterfaceDownstream(self, index)
self.change()
self.evaluate_olist_change()
def remove_interface(self, index):
with self.CHANGE_STATE_LOCK:
#check if removed interface is root interface
if self.inbound_interface_index == index:
self.delete()
elif index in self.interface_state:
self.interface_state.pop(index).delete()
self.change()
self.evaluate_olist_change()
from abc import ABCMeta, abstractmethod
import tree.globals as pim_globals
from .metric import AssertMetric
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from .tree_if_downstream import TreeInterfaceDownstream
class AssertStateABC(metaclass=ABCMeta):
@staticmethod
@abstractmethod
def receivedDataFromDownstreamIf(interface: "TreeInterfaceDownstream"):
"""
An (S,G) Data packet received on downstream interface
@type interface: TreeInterface
"""
raise NotImplementedError()
@staticmethod
@abstractmethod
def receivedInferiorMetricFromWinner(interface: "TreeInterfaceDownstream"):
"""
Receive Inferior (Assert OR State Refresh) from Assert Winner
@type interface: TreeInterface
"""
raise NotImplementedError()
@staticmethod
@abstractmethod
def receivedInferiorMetricFromNonWinner_couldAssertIsTrue(interface: "TreeInterfaceDownstream"):
"""
Receive Inferior (Assert OR State Refresh) from non-Assert Winner
AND CouldAssert==TRUE
@type interface: TreeInterface
"""
raise NotImplementedError()
@staticmethod
@abstractmethod
def receivedPreferedMetric(interface: "TreeInterfaceDownstream", better_metric, is_metric_equal):
"""
Receive Preferred Assert OR State Refresh
@type interface: TreeInterface
@type better_metric: AssertMetric
"""
raise NotImplementedError()
@staticmethod
@abstractmethod
def sendStateRefresh(interface: "TreeInterfaceDownstream", time):
"""
Send State Refresh
@type interface: TreeInterface
@type time: int
"""
raise NotImplementedError()
@staticmethod
@abstractmethod
def assertTimerExpires(interface: "TreeInterfaceDownstream"):
"""
AT(S,G) Expires
@type interface: TreeInterface
"""
raise NotImplementedError()
@staticmethod
@abstractmethod
def couldAssertIsNowFalse(interface: "TreeInterfaceDownstream"):
"""
CouldAssert -> FALSE
@type interface: TreeInterface
"""
raise NotImplementedError()
@staticmethod
@abstractmethod
def couldAssertIsNowTrue(interface: "TreeInterfaceDownstream"):
"""
CouldAssert -> TRUE
@type interface: TreeInterface
"""
raise NotImplementedError()
@staticmethod
@abstractmethod
def winnerLivelinessTimerExpires(interface: "TreeInterfaceDownstream"):
"""
Winner’s NLT(N,I) Expires
@type interface: TreeInterface
"""
raise NotImplementedError()
@staticmethod
@abstractmethod
def receivedPruneOrJoinOrGraft(interface: "TreeInterfaceDownstream"):
"""
Receive Prune(S,G), Join(S,G) or Graft(S,G)
@type interface: TreeInterface
"""
raise NotImplementedError()
def _sendAssert_setAT(interface: "TreeInterfaceDownstream"):
interface.set_assert_timer(pim_globals.ASSERT_TIME)
interface.send_assert()
# Override
def __str__(self) -> str:
return "AssertSM:" + self.__class__.__name__
class NoInfoState(AssertStateABC):
'''
NoInfoState (NI)
This router has no (S,G) Assert state on interface I.
'''
@staticmethod
def receivedDataFromDownstreamIf(interface: "TreeInterfaceDownstream"):
"""
@type interface: TreeInterface
"""
interface.assert_logger.debug('receivedDataFromDownstreamIf, NI -> W')
interface.set_assert_winner_metric(interface.my_assert_metric())
interface.set_assert_state(AssertState.Winner)
NoInfoState._sendAssert_setAT(interface)
@staticmethod
def receivedInferiorMetricFromWinner(interface: "TreeInterfaceDownstream"):
assert False, "this should never ocurr"
@staticmethod
def receivedInferiorMetricFromNonWinner_couldAssertIsTrue(interface: "TreeInterfaceDownstream"):
interface.assert_logger.debug('receivedInferiorMetricFromNonWinner_couldAssertIsTrue, NI -> W')
interface.set_assert_winner_metric(interface.my_assert_metric())
interface.set_assert_state(AssertState.Winner)
NoInfoState._sendAssert_setAT(interface)
@staticmethod
def receivedPreferedMetric(interface: "TreeInterfaceDownstream", better_metric, is_metric_equal):
'''
@type interface: TreeInterface
'''
if is_metric_equal:
return
interface.assert_logger.debug('receivedPreferedMetric, NI -> L')
state_refresh_interval = better_metric.state_refresh_interval
if state_refresh_interval is None:
# event caused by Assert Msg
assert_timer_value = pim_globals.ASSERT_TIME
else:
# event caused by StateRefreshMsg
assert_timer_value = state_refresh_interval*3
interface.set_assert_timer(assert_timer_value)
interface.set_assert_winner_metric(better_metric)
interface.set_assert_state(AssertState.Loser)
# MUST also multicast a Prune(S,G) to the Assert winner
if interface.could_assert():
interface.send_prune(holdtime=assert_timer_value)
@staticmethod
def sendStateRefresh(interface: "TreeInterfaceDownstream", time):
pass
@staticmethod
def assertTimerExpires(interface: "TreeInterfaceDownstream"):
assert False, "this should never ocurr"
@staticmethod
def couldAssertIsNowFalse(interface: "TreeInterfaceDownstream"):
interface.assert_logger.debug('couldAssertIsNowFalse, NI -> NI')
@staticmethod
def couldAssertIsNowTrue(interface: "TreeInterfaceDownstream"):
interface.assert_logger.debug('couldAssertIsNowTrue, NI -> NI')
@staticmethod
def winnerLivelinessTimerExpires(interface: "TreeInterfaceDownstream"):
assert False, "this should never ocurr"
@staticmethod
def receivedPruneOrJoinOrGraft(interface: "TreeInterfaceDownstream"):
interface.assert_logger.debug('receivedPruneOrJoinOrGraft, NI -> NI')
def __str__(self) -> str:
return "NoInfo"
class WinnerState(AssertStateABC):
'''
I am Assert Winner (W)
This router has won an (S,G) Assert on interface I. It is now
responsible for forwarding traffic from S destined for G via
interface I.
'''
@staticmethod
def receivedDataFromDownstreamIf(interface: "TreeInterfaceDownstream"):
"""
@type interface: TreeInterface
"""
interface.assert_logger.debug('receivedDataFromDownstreamIf, W -> W')
WinnerState._sendAssert_setAT(interface)
@staticmethod
def receivedInferiorMetricFromWinner(interface: "TreeInterfaceDownstream"):
assert False, "this should never ocurr"
@staticmethod
def receivedInferiorMetricFromNonWinner_couldAssertIsTrue(interface: "TreeInterfaceDownstream"):
interface.assert_logger.debug('receivedInferiorMetricFromNonWinner_couldAssertIsTrue, W -> W')
WinnerState._sendAssert_setAT(interface)
@staticmethod
def receivedPreferedMetric(interface: "TreeInterfaceDownstream", better_metric, is_metric_equal):
'''
@type better_metric: AssertMetric
'''
if is_metric_equal:
return
interface.assert_logger.debug('receivedPreferedMetric, W -> L')
state_refresh_interval = better_metric.state_refresh_interval
if state_refresh_interval is None:
# event caused by AssertMsg
assert_timer_value = pim_globals.ASSERT_TIME
else:
# event caused by State Refresh Msg
assert_timer_value = state_refresh_interval*3
interface.set_assert_timer(assert_timer_value)
interface.set_assert_winner_metric(better_metric)
interface.set_assert_state(AssertState.Loser)
interface.send_prune(holdtime=assert_timer_value)
@staticmethod
def sendStateRefresh(interface: "TreeInterfaceDownstream", state_refresh_interval):
interface.set_assert_timer(state_refresh_interval*3)
@staticmethod
def assertTimerExpires(interface: "TreeInterfaceDownstream"):
interface.assert_logger.debug('assertTimerExpires, W -> NI')
interface.set_assert_winner_metric(AssertMetric.infinite_assert_metric())
interface.set_assert_state(AssertState.NoInfo)
@staticmethod
def couldAssertIsNowFalse(interface: "TreeInterfaceDownstream"):
interface.assert_logger.debug('couldAssertIsNowFalse, W -> NI')
interface.send_assert_cancel()
interface.clear_assert_timer()
interface.set_assert_winner_metric(AssertMetric.infinite_assert_metric())
interface.set_assert_state(AssertState.NoInfo)
@staticmethod
def couldAssertIsNowTrue(interface: "TreeInterfaceDownstream"):
assert False, "this should never ocurr"
@staticmethod
def winnerLivelinessTimerExpires(interface: "TreeInterfaceDownstream"):
assert False, "this should never ocurr"
@staticmethod
def receivedPruneOrJoinOrGraft(interface: "TreeInterfaceDownstream"):
pass
def __str__(self) -> str:
return "Winner"
class LoserState(AssertStateABC):
'''
I am Assert Loser (L)
This router has lost an (S,G) Assert on interface I. It must not
forward packets from S destined for G onto interface I.
'''
@staticmethod
def receivedDataFromDownstreamIf(interface: "TreeInterfaceDownstream"):
"""
@type interface: TreeInterface
"""
interface.assert_logger.debug('receivedDataFromDownstreamIf, L -> L')
@staticmethod
def receivedInferiorMetricFromWinner(interface: "TreeInterfaceDownstream"):
interface.assert_logger.debug('receivedInferiorMetricFromWinner, L -> NI')
LoserState._to_NoInfo(interface)
@staticmethod
def receivedInferiorMetricFromNonWinner_couldAssertIsTrue(interface: "TreeInterfaceDownstream"):
interface.assert_logger.debug('receivedInferiorMetricFromNonWinner_couldAssertIsTrue, L -> L')
@staticmethod
def receivedPreferedMetric(interface: "TreeInterfaceDownstream", better_metric, is_metric_equal):
'''
@type better_metric: AssertMetric
'''
interface.assert_logger.debug('receivedPreferedMetric, L -> L')
state_refresh_interval = better_metric.state_refresh_interval
if state_refresh_interval is None:
assert_timer_value = pim_globals.ASSERT_TIME
else:
assert_timer_value = state_refresh_interval*3
interface.set_assert_timer(assert_timer_value)
interface.set_assert_winner_metric(better_metric)
interface.set_assert_state(AssertState.Loser)
if not is_metric_equal and interface.could_assert():
interface.send_prune(holdtime=assert_timer_value)
@staticmethod
def sendStateRefresh(interface: "TreeInterfaceDownstream", time):
assert False, "this should never ocurr"
@staticmethod
def assertTimerExpires(interface: "TreeInterfaceDownstream"):
interface.assert_logger.debug('assertTimerExpires, L -> NI')
LoserState._to_NoInfo(interface)
@staticmethod
def couldAssertIsNowFalse(interface: "TreeInterfaceDownstream"):
interface.assert_logger.debug('couldAssertIsNowFalse, L -> NI')
LoserState._to_NoInfo(interface)
@staticmethod
def couldAssertIsNowTrue(interface: "TreeInterfaceDownstream"):
interface.assert_logger.debug('couldAssertIsNowTrue, L -> NI')
LoserState._to_NoInfo(interface)
@staticmethod
def winnerLivelinessTimerExpires(interface: "TreeInterfaceDownstream"):
interface.assert_logger.debug('winnerLivelinessTimerExpires, L -> NI')
LoserState._to_NoInfo(interface)
@staticmethod
def receivedPruneOrJoinOrGraft(interface: "TreeInterfaceDownstream"):
interface.assert_logger.debug('receivedPruneOrJoinOrGraft, L -> L')
interface.send_assert()
@staticmethod
def _to_NoInfo(interface: "TreeInterfaceDownstream"):
interface.clear_assert_timer()
interface.set_assert_winner_metric(AssertMetric.infinite_assert_metric())
interface.set_assert_state(AssertState.NoInfo)
def __str__(self) -> str:
return "Loser"
class AssertState():
NoInfo = NoInfoState()
Winner = WinnerState()
Loser = LoserState()
from abc import ABCMeta, abstractmethod
from tree import globals as pim_globals
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from .tree_if_downstream import TreeInterfaceDownstream
class DownstreamStateABS(metaclass=ABCMeta):
@staticmethod
@abstractmethod
def receivedPrune(interface: "TreeInterfaceDownstream", holdtime):
"""
Receive Prune(S,G)
@type interface: Downstream
"""
raise NotImplementedError()
@staticmethod
@abstractmethod
def receivedJoin(interface: "TreeInterfaceDownstream"):
"""
Receive Join(S,G)
@type interface: Downstream
"""
raise NotImplementedError()
@staticmethod
@abstractmethod
def receivedGraft(interface: "TreeInterfaceDownstream", source_ip):
"""
Receive Graft(S,G)
@type interface: Downstream
"""
raise NotImplementedError()
@staticmethod
@abstractmethod
def PPTexpires(interface: "TreeInterfaceDownstream"):
"""
PPT(S,G) Expires
@type interface: Downstream
"""
raise NotImplementedError()
@staticmethod
@abstractmethod
def PTexpires(interface: "TreeInterfaceDownstream"):
"""
PT(S,G) Expires
@type interface: Downstream
"""
raise NotImplementedError()
@staticmethod
@abstractmethod
def is_now_RPF_Interface(interface: "TreeInterfaceDownstream"):
"""
RPF_Interface(S) becomes I
@type interface: Downstream
"""
raise NotImplementedError()
@staticmethod
@abstractmethod
def send_state_refresh(interface: "TreeInterfaceDownstream"):
"""
Send State Refresh(S,G) out I
@type interface: Downstream
"""
raise NotImplementedError()
def __str__(self):
return "Downstream." + self.__class__.__name__
class NoInfo(DownstreamStateABS):
'''
NoInfo(NI)
The interface has no (S,G) Prune state, and neither the Prune
timer (PT(S,G,I)) nor the PrunePending timer ((PPT(S,G,I)) is
running.
'''
@staticmethod
def receivedPrune(interface: "TreeInterfaceDownstream", holdtime):
"""
Receive Prune(S,G)
@type interface: TreeInterfaceDownstreamDownstream
"""
interface.join_prune_logger.debug("receivedPrune, NI -> PP")
interface.set_prune_state(DownstreamState.PrunePending)
time = 0
if len(interface.get_interface().neighbors) > 1:
time = pim_globals.JP_OVERRIDE_INTERVAL
interface.set_prune_pending_timer(time)
@staticmethod
def receivedJoin(interface: "TreeInterfaceDownstream"):
"""
Receive Join(S,G)
@type interface: TreeInterfaceDownstreamDownstream
"""
# Do nothing
interface.join_prune_logger.debug("receivedJoin, NI -> NI")
@staticmethod
def receivedGraft(interface: "TreeInterfaceDownstream", source_ip):
"""
Receive Graft(S,G)
@type interface: TreeInterfaceDownstreamDownstream
"""
interface.join_prune_logger.debug('receivedGraft, NI -> NI')
interface.send_graft_ack(source_ip)
@staticmethod
def PPTexpires(interface: "TreeInterfaceDownstream"):
"""
PPT(S,G) Expires
@type interface: TreeInterfaceDownstreamDownstream
"""
#assert False, "PPTexpires in state NI"
return
@staticmethod
def PTexpires(interface: "TreeInterfaceDownstream"):
"""
PT(S,G) Expires
@type interface: TreeInterfaceDownstreamDownstream
"""
#assert False, "PTexpires in state NI"
return
@staticmethod
def is_now_RPF_Interface(interface: "TreeInterfaceDownstream"):
"""
RPF_Interface(S) becomes I
@type interface: TreeInterfaceDownstreamDownstream
"""
# Do nothing
return
@staticmethod
def send_state_refresh(interface: "TreeInterfaceDownstream"):
"""
Send State Refresh(S,G) out I
@type interface: TreeInterfaceDownstreamDownstream
"""
# Do nothing
return
def __str__(self):
return "NoInfo"
class PrunePending(DownstreamStateABS):
'''
PrunePending(PP)
The router has received a Prune(S,G) on this interface from a
downstream neighbor and is waiting to see whether the prune will
be overridden by another downstream router. For forwarding
purposes, the PrunePending state functions exactly like the
NoInfo state.
'''
@staticmethod
def receivedPrune(interface: "TreeInterfaceDownstream", holdtime):
"""
Receive Prune(S,G)
@type interface: TreeInterfaceDownstreamDownstream
"""
interface.join_prune_logger.debug('receivedPrune, PP -> PP')
@staticmethod
def receivedJoin(interface: "TreeInterfaceDownstream"):
"""
Receive Join(S,G)
@type interface: TreeInterfaceDownstreamDownstream
"""
interface.join_prune_logger.debug('receivedJoin, PP -> NI')
interface.clear_prune_pending_timer()
interface.set_prune_state(DownstreamState.NoInfo)
@staticmethod
def receivedGraft(interface: "TreeInterfaceDownstream", source_ip):
"""
Receive Graft(S,G)
@type interface: TreeInterfaceDownstreamDownstream
"""
interface.join_prune_logger.debug('receivedGraft, PP -> NI')
interface.clear_prune_pending_timer()
interface.set_prune_state(DownstreamState.NoInfo)
interface.send_graft_ack(source_ip)
@staticmethod
def PPTexpires(interface: "TreeInterfaceDownstream"):
"""
PPT(S,G) Expires
@type interface: TreeInterfaceDownstreamDownstream
"""
interface.join_prune_logger.debug('PPTexpires, PP -> P')
interface.set_prune_state(DownstreamState.Pruned)
interface.set_prune_timer(interface.get_received_prune_holdtime() - pim_globals.JP_OVERRIDE_INTERVAL)
if len(interface.get_interface().neighbors) > 1:
interface.send_pruneecho()
@staticmethod
def PTexpires(interface: "TreeInterfaceDownstream"):
"""
PT(S,G) Expires
@type interface: TreeInterfaceDownstreamDownstream
"""
#assert False, "PTexpires in state PP"
return
@staticmethod
def is_now_RPF_Interface(interface: "TreeInterfaceDownstream"):
"""
RPF_Interface(S) becomes I
@type interface: TreeInterfaceDownstreamDownstream
"""
interface.join_prune_logger.debug('is_now_RPF_Interface, PP -> NI')
interface.clear_prune_pending_timer()
interface.set_prune_state(DownstreamState.NoInfo)
@staticmethod
def send_state_refresh(interface: "TreeInterfaceDownstream"):
"""
Send State Refresh(S,G) out I
@type interface: TreeInterfaceDownstreamDownstream
"""
return
def __str__(self):
return "PrunePending"
class Pruned(DownstreamStateABS):
'''
Pruned(P)
The router has received a Prune(S,G) on this interface from a
downstream neighbor, and the Prune was not overridden. Data from
S addressed to group G is no longer being forwarded on this
interface.
'''
@staticmethod
def receivedPrune(interface: "TreeInterfaceDownstream", holdtime):
"""
Receive Prune(S,G)
@type interface: TreeInterfaceDownstreamDownstream
"""
interface.join_prune_logger.debug('receivedPrune, P -> P')
if holdtime > interface.remaining_prune_timer():
interface.set_prune_timer(holdtime)
@staticmethod
def receivedJoin(interface: "TreeInterfaceDownstream"):
"""
Receive Join(S,G)
@type interface: TreeInterfaceDownstreamDownstream
"""
interface.join_prune_logger.debug('receivedPrune, P -> NI')
interface.clear_prune_timer()
interface.set_prune_state(DownstreamState.NoInfo)
@staticmethod
def receivedGraft(interface: "TreeInterfaceDownstream", source_ip):
"""
Receive Graft(S,G)
@type interface: TreeInterfaceDownstreamDownstream
"""
interface.join_prune_logger.debug('receivedGraft, P -> NI')
interface.clear_prune_timer()
interface.set_prune_state(DownstreamState.NoInfo)
interface.send_graft_ack(source_ip)
@staticmethod
def PPTexpires(interface: "TreeInterfaceDownstream"):
"""
PPT(S,G) Expires
@type interface: TreeInterfaceDownstreamDownstream
"""
#assert False, "PPTexpires in state P"
return
@staticmethod
def PTexpires(interface: "TreeInterfaceDownstream"):
"""
PT(S,G) Expires
@type interface: TreeInterfaceDownstreamDownstream
"""
interface.join_prune_logger.debug('PTexpires, P -> NI')
interface.set_prune_state(DownstreamState.NoInfo)
@staticmethod
def is_now_RPF_Interface(interface: "TreeInterfaceDownstream"):
"""
RPF_Interface(S) becomes I
@type interface: TreeInterfaceDownstreamDownstream
"""
interface.join_prune_logger('is_now_RPF_Interface, P -> NI')
interface.clear_prune_timer()
interface.set_prune_state(DownstreamState.NoInfo)
@staticmethod
def send_state_refresh(interface: "TreeInterfaceDownstream"):
"""
Send State Refresh(S,G) out I
@type interface: TreeInterfaceDownstreamDownstream
"""
interface.join_prune_logger.debug('send_state_refresh, P -> P')
if interface.get_interface().is_state_refresh_capable():
interface.set_prune_timer(interface.get_received_prune_holdtime())
def __str__(self):
return "Pruned"
class DownstreamState():
NoInfo = NoInfo()
Pruned = Pruned()
PrunePending = PrunePending()
'''
Created on Feb 23, 2015
This module is intended to have all constants and global values for pim_dm
@author: alex
'''
ASSERT_TIME = 180
GRAFT_RETRY_PERIOD = 3
JP_OVERRIDE_INTERVAL = 3.0
OVERRIDE_INTERVAL = 2.5
PROPAGATION_DELAY = 0.5
REFRESH_INTERVAL = 60 # State Refresh Interval
SOURCE_LIFETIME = 210
T_LIMIT = 210
ASSERT_CANCEL_METRIC = 0xFFFFFFFF
\ No newline at end of file
from abc import ABCMeta, abstractmethod
class LocalMembershipStateABC(metaclass=ABCMeta):
@staticmethod
@abstractmethod
def has_members():
raise NotImplementedError
class NoInfo(LocalMembershipStateABC):
@staticmethod
def has_members():
return False
class Include(LocalMembershipStateABC):
@staticmethod
def has_members():
return True
class LocalMembership():
NoInfo = NoInfo()
Include = Include()
\ No newline at end of file
import ipaddress
class AssertMetric(object):
def __init__(self, metric_preference: int = 0x7FFFFFFF, route_metric: int = 0xFFFFFFFF, ip_address: str = "0.0.0.0", state_refresh_interval:int = None):
if type(ip_address) is str:
ip_address = ipaddress.ip_address(ip_address)
self._metric_preference = metric_preference
self._route_metric = route_metric
self._ip_address = ip_address
self._state_refresh_interval = state_refresh_interval
def is_better_than(self, other):
if self.metric_preference != other.metric_preference:
return self.metric_preference < other.metric_preference
elif self.route_metric != other.route_metric:
return self.route_metric < other.route_metric
else:
return self.ip_address > other.ip_address
def is_worse(self, other):
return not self.is_better_than(other)
def equal_metric(self, other):
return self.metric_preference == other.metric_preference and self.metric_preference == other.metric_preference \
and self.ip_address == other.ip_address
@staticmethod
def infinite_assert_metric():
'''
@type metric: AssertMetric
'''
return AssertMetric()
@staticmethod
def spt_assert_metric(tree_if):
'''
@type metric: AssertMetric
@type tree_if: TreeInterface
'''
(source_ip, _) = tree_if.get_tree_id()
import UnicastRouting
(metric_preference, metric_cost, _) = UnicastRouting.get_metric(source_ip)
return AssertMetric(metric_preference, metric_cost, tree_if.get_ip())
def i_am_assert_winner(self, tree_if):
return self.get_ip() == tree_if.get_ip()
@property
def metric_preference(self):
return self._metric_preference
@metric_preference.setter
def metric_preference(self, value):
self._metric_preference = value
@property
def route_metric(self):
return self._route_metric
@route_metric.setter
def route_metric(self, value):
self._route_metric = value
@property
def ip_address(self):
return self._ip_address
@ip_address.setter
def ip_address(self, value):
if type(value) is str:
value = ipaddress.ip_address(value)
self._ip_address = value
@property
def state_refresh_interval(self):
return self._state_refresh_interval
@state_refresh_interval.setter
def state_refresh_interval(self, value):
self._state_refresh_interval = value
def get_ip(self):
return str(self._ip_address)
from abc import ABCMeta, abstractstaticmethod
from tree import globals as pim_globals
class OriginatorStateABC(metaclass=ABCMeta):
@abstractstaticmethod
def recvDataMsgFromSource(tree):
pass
@abstractstaticmethod
def SRTexpires(tree):
pass
@abstractstaticmethod
def SATexpires(tree):
pass
@abstractstaticmethod
def SourceNotConnected(tree):
pass
class Originator(OriginatorStateABC):
@staticmethod
def recvDataMsgFromSource(tree):
tree.set_source_active_timer()
@staticmethod
def SRTexpires(tree):
'''
@type tree: Tree
'''
tree.originator_logger.debug('SRT expired, O -> O')
tree.set_state_refresh_timer()
tree.create_state_refresh_msg()
@staticmethod
def SATexpires(tree):
tree.originator_logger.debug('SAT expired, O -> NO')
tree.clear_state_refresh_timer()
tree.set_originator_state(OriginatorState.NotOriginator)
@staticmethod
def SourceNotConnected(tree):
tree.originator_logger.debug('Source no longer directly connected, O -> NO')
tree.clear_state_refresh_timer()
tree.clear_source_active_timer()
tree.set_originator_state(OriginatorState.NotOriginator)
def __str__(self):
return 'Originator'
class NotOriginator(OriginatorStateABC):
@staticmethod
def recvDataMsgFromSource(tree):
'''
@type interface: Tree
'''
tree.originator_logger.debug('new DataMsg from Source, NO -> O')
tree.set_originator_state(OriginatorState.Originator)
tree.set_state_refresh_timer()
tree.set_source_active_timer()
@staticmethod
def SRTexpires(tree):
assert False, "SRTexpires in NO"
@staticmethod
def SATexpires(tree):
assert False, "SATexpires in NO"
@staticmethod
def SourceNotConnected(tree):
return
def __str__(self):
return 'NotOriginator'
class OriginatorState():
NotOriginator = NotOriginator()
Originator = Originator()
'''
Created on Jul 16, 2015
@author: alex
'''
from threading import Timer
from CustomTimer.RemainingTimer import RemainingTimer
from .assert_ import AssertState, AssertStateABC
from .downstream_prune import DownstreamState, DownstreamStateABS
from .tree_interface import TreeInterface
from Packet.PacketPimStateRefresh import PacketPimStateRefresh
from Packet.Packet import Packet
from Packet.PacketPimHeader import PacketPimHeader
import traceback
import logging
import Main
class TreeInterfaceDownstream(TreeInterface):
LOGGER = logging.getLogger('pim.KernelEntry.DownstreamInterface')
def __init__(self, kernel_entry, interface_id):
extra_dict_logger = kernel_entry.kernel_entry_logger.extra.copy()
extra_dict_logger['vif'] = interface_id
extra_dict_logger['interfacename'] = Main.kernel.vif_index_to_name_dic[interface_id]
logger = logging.LoggerAdapter(TreeInterfaceDownstream.LOGGER, extra_dict_logger)
TreeInterface.__init__(self, kernel_entry, interface_id, logger)
self.logger.debug('Created DownstreamInterface')
self.join_prune_logger.debug('Downstream state transitions to ' + str(self._prune_state))
# Last state refresh message sent (resend in case of new neighbors)
self._last_state_refresh_message = None
##########################################
# Set state
##########################################
def set_prune_state(self, new_state: DownstreamStateABS):
with self.get_state_lock():
if new_state != self._prune_state:
self._prune_state = new_state
self.join_prune_logger.debug('Downstream state transitions to ' + str(new_state))
self.change_tree()
self.evaluate_ingroup()
##########################################
# Check timers
##########################################
def is_prune_pending_timer_running(self):
return self._prune_pending_timer is not None and self._prune_pending_timer.is_alive()
def is_prune_timer_running(self):
return self._prune_timer is not None and self._prune_timer.is_alive()
def remaining_prune_timer(self):
return 0 if not self._prune_timer else self._prune_timer.time_remaining()
##########################################
# Set timers
##########################################
def set_prune_pending_timer(self, time):
self.clear_prune_pending_timer()
self._prune_pending_timer = Timer(time, self.prune_pending_timeout)
self._prune_pending_timer.start()
def clear_prune_pending_timer(self):
if self._prune_pending_timer is not None:
self._prune_pending_timer.cancel()
def set_prune_timer(self, time):
self.clear_prune_timer()
#self._prune_timer = Timer(time, self.prune_timeout)
self._prune_timer = RemainingTimer(time, self.prune_timeout)
self._prune_timer.start()
def clear_prune_timer(self):
if self._prune_timer is not None:
self._prune_timer.cancel()
###########################################
# Timer timeout
###########################################
def prune_pending_timeout(self):
self._prune_state.PPTexpires(self)
def prune_timeout(self):
self._prune_state.PTexpires(self)
###########################################
# Recv packets
###########################################
def recv_data_msg(self):
self._assert_state.receivedDataFromDownstreamIf(self)
# Override
def recv_prune_msg(self, upstream_neighbor_address, holdtime):
super().recv_prune_msg(upstream_neighbor_address, holdtime)
if upstream_neighbor_address == self.get_ip():
self.set_receceived_prune_holdtime(holdtime)
self._prune_state.receivedPrune(self, holdtime)
# Override
def recv_join_msg(self, upstream_neighbor_address):
super().recv_join_msg(upstream_neighbor_address)
if upstream_neighbor_address == self.get_ip():
self._prune_state.receivedJoin(self)
# Override
def recv_graft_msg(self, upstream_neighbor_address, source_ip):
print("GRAFT!!!")
super().recv_graft_msg(upstream_neighbor_address, source_ip)
if upstream_neighbor_address == self.get_ip():
self._prune_state.receivedGraft(self, source_ip)
######################################
# Send messages
######################################
def send_state_refresh(self, state_refresh_msg_received):
if state_refresh_msg_received is None:
return
self._last_state_refresh_message = state_refresh_msg_received
if self.lost_assert() or not self.get_interface().is_state_refresh_enabled():
return
interval = state_refresh_msg_received.interval
self._assert_state.sendStateRefresh(self, interval)
self._prune_state.send_state_refresh(self)
prune_indicator_bit = 0
if self.is_pruned():
prune_indicator_bit = 1
import UnicastRouting
(metric_preference, metric, mask) = UnicastRouting.get_metric(state_refresh_msg_received.source_address)
assert_override_flag = 0
if self._assert_state == AssertState.NoInfo:
assert_override_flag = 1
try:
ph = PacketPimStateRefresh(multicast_group_adress=state_refresh_msg_received.multicast_group_adress,
source_address=state_refresh_msg_received.source_address,
originator_adress=state_refresh_msg_received.originator_adress,
metric_preference=metric_preference, metric=metric, mask_len=mask,
ttl=state_refresh_msg_received.ttl - 1,
prune_indicator_flag=prune_indicator_bit,
prune_now_flag=state_refresh_msg_received.prune_now_flag,
assert_override_flag=assert_override_flag,
interval=interval)
pckt = Packet(payload=PacketPimHeader(ph))
self.get_interface().send(pckt.bytes())
except:
traceback.print_exc()
return
##########################################################
# Override
def is_forwarding(self):
return ((self.has_neighbors() and not self.is_pruned()) or self.igmp_has_members()) and not self.lost_assert()
def is_pruned(self):
return self._prune_state == DownstreamState.Pruned
#def lost_assert(self):
# return not AssertMetric.i_am_assert_winner(self) and \
# self._assert_winner_metric.is_better_than(AssertMetric.spt_assert_metric(self))
# Override
# When new neighbor connects, send last state refresh msg
def new_or_reset_neighbor(self, neighbor_ip):
self.send_state_refresh(self._last_state_refresh_message)
# Override
def delete(self, change_type_interface=False):
super().delete(change_type_interface)
self.clear_assert_timer()
self.clear_prune_timer()
self.clear_prune_pending_timer()
def is_downstream(self):
return True
'''
Created on Jul 16, 2015
@author: alex
'''
from .tree_interface import TreeInterface
from .upstream_prune import UpstreamState
from threading import Timer
from CustomTimer.RemainingTimer import RemainingTimer
from .globals import *
import random
from .metric import AssertMetric
from .originator import OriginatorState, OriginatorStateABC
from Packet.PacketPimStateRefresh import PacketPimStateRefresh
import traceback
from . import DataPacketsSocket
import threading
import logging
import Main
class TreeInterfaceUpstream(TreeInterface):
LOGGER = logging.getLogger('pim.KernelEntry.UpstreamInterface')
def __init__(self, kernel_entry, interface_id):
extra_dict_logger = kernel_entry.kernel_entry_logger.extra.copy()
extra_dict_logger['vif'] = interface_id
extra_dict_logger['interfacename'] = Main.kernel.vif_index_to_name_dic[interface_id]
logger = logging.LoggerAdapter(TreeInterfaceUpstream.LOGGER, extra_dict_logger)
TreeInterface.__init__(self, kernel_entry, interface_id, logger)
# Graft/Prune State:
self._graft_prune_state = UpstreamState.Forward
self._graft_retry_timer = None
self._override_timer = None
self._prune_limit_timer = None
self._last_rpf = self.get_neighbor_RPF()
self.join_prune_logger.debug('Upstream state transitions to ' + str(self._graft_prune_state))
# Originator state
self._originator_state = OriginatorState.NotOriginator
self._state_refresh_timer = None
self._source_active_timer = None
self._prune_now_counter = 0
self.originator_logger = logging.LoggerAdapter(TreeInterfaceUpstream.LOGGER.getChild('Originator'), extra_dict_logger)
self.originator_logger.debug('StateRefresh state transitions to ' + str(self._originator_state))
if self.is_S_directly_conn():
self._graft_prune_state.sourceIsNowDirectConnect(self)
if self.get_interface().is_state_refresh_enabled():
self._originator_state.recvDataMsgFromSource(self)
# TODO TESTE SOCKET RECV DATA PCKTS
self.socket_is_enabled = True
(s,g) = self.get_tree_id()
interface_name = self.get_interface().interface_name
self.socket_pkt = DataPacketsSocket.get_s_g_bpf_filter_code(s, g, interface_name)
# run receive method in background
receive_thread = threading.Thread(target=self.socket_recv)
receive_thread.daemon = True
receive_thread.start()
self.logger.debug('Created UpstreamInterface')
def socket_recv(self):
while self.socket_is_enabled:
try:
self.socket_pkt.recvfrom(0)
print("PACOTE DADOS RECEBIDO")
self.recv_data_msg()
except:
traceback.print_exc()
continue
##########################################
# Set state
##########################################
def set_state(self, new_state):
with self.get_state_lock():
if new_state != self._graft_prune_state:
self._graft_prune_state = new_state
self.join_prune_logger.debug('Upstream state transitions to ' + str(new_state))
self.change_tree()
self.evaluate_ingroup()
def set_originator_state(self, new_state: OriginatorStateABC):
if new_state != self._originator_state:
self._originator_state = new_state
self.originator_logger.debug('StateRefresh state transitions to ' + str(new_state))
##########################################
# Check timers
##########################################
def is_graft_retry_timer_running(self):
return self._graft_retry_timer is not None and self._graft_retry_timer.is_alive()
def is_override_timer_running(self):
return self._override_timer is not None and self._override_timer.is_alive()
def is_prune_limit_timer_running(self):
return self._prune_limit_timer is not None and self._prune_limit_timer.is_alive()
def remaining_prune_limit_timer(self):
return 0 if not self._prune_limit_timer else self._prune_limit_timer.time_remaining()
##########################################
# Set timers
##########################################
def set_graft_retry_timer(self, time=GRAFT_RETRY_PERIOD):
self.clear_graft_retry_timer()
self._graft_retry_timer = Timer(time, self.graft_retry_timeout)
self._graft_retry_timer.start()
def clear_graft_retry_timer(self):
if self._graft_retry_timer is not None:
self._graft_retry_timer.cancel()
def set_override_timer(self):
self.clear_override_timer()
self._override_timer = Timer(self.t_override, self.override_timeout)
self._override_timer.start()
def clear_override_timer(self):
if self._override_timer is not None:
self._override_timer.cancel()
def set_prune_limit_timer(self, time=T_LIMIT):
self.clear_prune_limit_timer()
self._prune_limit_timer = RemainingTimer(time, self.prune_limit_timeout)
self._prune_limit_timer.start()
def clear_prune_limit_timer(self):
if self._prune_limit_timer is not None:
self._prune_limit_timer.cancel()
# State Refresh timers
def set_state_refresh_timer(self):
self.clear_state_refresh_timer()
self._state_refresh_timer = Timer(REFRESH_INTERVAL, self.state_refresh_timeout)
self._state_refresh_timer.start()
def clear_state_refresh_timer(self):
if self._state_refresh_timer is not None:
self._state_refresh_timer.cancel()
def set_source_active_timer(self):
self.clear_source_active_timer()
self._source_active_timer = Timer(SOURCE_LIFETIME, self.source_active_timeout)
self._source_active_timer.start()
def clear_source_active_timer(self):
if self._source_active_timer is not None:
self._source_active_timer.cancel()
###########################################
# Timer timeout
###########################################
def graft_retry_timeout(self):
self._graft_prune_state.GRTexpires(self)
def override_timeout(self):
self._graft_prune_state.OTexpires(self)
def prune_limit_timeout(self):
return
# State Refresh timers
def state_refresh_timeout(self):
self._originator_state.SRTexpires(self)
def source_active_timeout(self):
self._originator_state.SATexpires(self)
###########################################
# Recv packets
###########################################
def recv_data_msg(self):
if not self.is_prune_limit_timer_running() and not self.is_S_directly_conn() and self.is_olist_null():
self._graft_prune_state.dataArrivesRPFinterface_OListNull_PLTstoped(self)
elif self.is_S_directly_conn() and self.get_interface().is_state_refresh_enabled():
self._originator_state.recvDataMsgFromSource(self)
def recv_join_msg(self, upstream_neighbor_address):
super().recv_join_msg(upstream_neighbor_address)
if upstream_neighbor_address == self.get_neighbor_RPF():
self._graft_prune_state.seeJoinToRPFnbr(self)
def recv_prune_msg(self, upstream_neighbor_address, holdtime):
super().recv_prune_msg(upstream_neighbor_address, holdtime)
self.set_receceived_prune_holdtime(holdtime)
self._graft_prune_state.seePrune(self)
def recv_graft_ack_msg(self, source_ip_of_graft_ack):
print("GRAFT ACK!!!")
if source_ip_of_graft_ack == self.get_neighbor_RPF():
self._graft_prune_state.recvGraftAckFromRPFnbr(self)
def recv_state_refresh_msg(self, received_metric: AssertMetric, prune_indicator: int):
super().recv_state_refresh_msg(received_metric, prune_indicator)
if self.get_neighbor_RPF() != received_metric.get_ip():
return
if prune_indicator == 1:
self._graft_prune_state.stateRefreshArrivesRPFnbr_pruneIs1(self)
elif prune_indicator == 0 and not self.is_prune_limit_timer_running():
self._graft_prune_state.stateRefreshArrivesRPFnbr_pruneIs0_PLTstoped(self)
####################################
def create_state_refresh_msg(self):
self._prune_now_counter+=1
(source_ip, group_ip) = self.get_tree_id()
ph = PacketPimStateRefresh(multicast_group_adress=group_ip,
source_address=source_ip,
originator_adress=self.get_ip(),
metric_preference=0, metric=0, mask_len=0,
ttl=256,
prune_indicator_flag=0,
prune_now_flag=self._prune_now_counter//3,
assert_override_flag=0,
interval=60)
self._prune_now_counter %= 3
self._kernel_entry.forward_state_refresh_msg(ph)
###########################################
# Change olist
###########################################
def olist_is_null(self):
self._graft_prune_state.olistIsNowNull(self)
def olist_is_not_null(self):
self._graft_prune_state.olistIsNowNotNull(self)
###########################################
# Changes to RPF'(s)
###########################################
# caused by assert transition:
def set_assert_state(self, new_state):
super().set_assert_state(new_state)
self.change_rpf(self.is_olist_null())
# caused by unicast routing table:
def change_on_unicast_routing(self, interface_change=False):
self.change_rpf(self.is_olist_null(), interface_change)
'''
if self.is_S_directly_conn():
self._graft_prune_state.sourceIsNowDirectConnect(self)
else:
self._originator_state.SourceNotConnected(self)
'''
def change_rpf(self, olist_is_null, interface_change=False):
current_rpf = self.get_neighbor_RPF()
if interface_change or self._last_rpf != current_rpf:
self._last_rpf = current_rpf
if olist_is_null:
self._graft_prune_state.RPFnbrChanges_olistIsNull(self)
else:
self._graft_prune_state.RPFnbrChanges_olistIsNotNull(self)
####################################################################
#Override
def is_forwarding(self):
return False
# If new/reset neighbor is RPF neighbor => clear prune limit timer
def new_or_reset_neighbor(self, neighbor_ip):
if neighbor_ip == self.get_neighbor_RPF():
self.clear_prune_limit_timer()
#Override
def delete(self, change_type_interface=False):
self.socket_is_enabled = False
self.socket_pkt.close()
super().delete(change_type_interface)
self.clear_graft_retry_timer()
self.clear_assert_timer()
self.clear_prune_limit_timer()
self.clear_override_timer()
self.clear_state_refresh_timer()
self.clear_source_active_timer()
# Clear Graft/Prune State:
self._graft_prune_state = None
# Clear Originator state
self._originator_state = None
def is_downstream(self):
return False
def is_originator(self):
return self._originator_state == OriginatorState.Originator
#-------------------------------------------------------------------------
# Properties
#-------------------------------------------------------------------------
@property
def t_override(self):
oi = self.get_interface()._override_interval
return random.uniform(0, oi)
'''
Created on Jul 16, 2015
@author: alex
'''
from abc import ABCMeta, abstractmethod
import Main
from threading import RLock
import traceback
from .downstream_prune import DownstreamState
from .assert_ import AssertState, AssertStateABC
from Packet.PacketPimGraft import PacketPimGraft
from Packet.PacketPimGraftAck import PacketPimGraftAck
from Packet.PacketPimJoinPruneMulticastGroup import PacketPimJoinPruneMulticastGroup
from Packet.PacketPimHeader import PacketPimHeader
from Packet.Packet import Packet
from Packet.PacketPimJoinPrune import PacketPimJoinPrune
from Packet.PacketPimAssert import PacketPimAssert
from Packet.PacketPimStateRefresh import PacketPimStateRefresh
from .metric import AssertMetric
from threading import Timer
from .local_membership import LocalMembership
from .globals import *
import logging
class TreeInterface(metaclass=ABCMeta):
def __init__(self, kernel_entry, interface_id, logger: logging.LoggerAdapter):
self._kernel_entry = kernel_entry
self._interface_id = interface_id
self.logger = logger
self.assert_logger = logging.LoggerAdapter(logger.logger.getChild('Assert'), logger.extra)
self.join_prune_logger = logging.LoggerAdapter(logger.logger.getChild('JoinPrune'), logger.extra)
# Local Membership State
try:
interface_name = Main.kernel.vif_index_to_name_dic[interface_id]
igmp_interface = Main.igmp_interfaces[interface_name] # type: InterfaceIGMP
group_state = igmp_interface.interface_state.get_group_state(kernel_entry.group_ip)
#self._igmp_has_members = group_state.add_multicast_routing_entry(self)
igmp_has_members = group_state.add_multicast_routing_entry(self)
self._local_membership_state = LocalMembership.Include if igmp_has_members else LocalMembership.NoInfo
except:
self._local_membership_state = LocalMembership.NoInfo
# Prune State
self._prune_state = DownstreamState.NoInfo
self._prune_pending_timer = None
self._prune_timer = None
# Assert Winner State
self._assert_state = AssertState.NoInfo
self._assert_winner_metric = AssertMetric()
self._assert_timer = None
self.assert_logger.debug("Assert state transitions to NoInfo")
# Received prune hold time
self._received_prune_holdtime = None
self._igmp_lock = RLock()
############################################
# Set ASSERT State
############################################
def set_assert_state(self, new_state: AssertStateABC):
with self.get_state_lock():
if new_state != self._assert_state:
self._assert_state = new_state
self.assert_logger.debug('Assert state transitions to ' + str(new_state))
self.change_tree()
self.evaluate_ingroup()
def set_assert_winner_metric(self, new_assert_metric: AssertMetric):
with self.get_state_lock():
try:
old_neighbor = self.get_interface().get_neighbor(self._assert_winner_metric.get_ip())
new_neighbor = self.get_interface().get_neighbor(new_assert_metric.get_ip())
if old_neighbor is not None:
old_neighbor.unsubscribe_nlt_expiration(self)
if new_neighbor is not None:
new_neighbor.subscribe_nlt_expiration(self)
except:
traceback.print_exc()
finally:
self._assert_winner_metric = new_assert_metric
############################################
# ASSERT Timer
############################################
def set_assert_timer(self, time):
self.clear_assert_timer()
self._assert_timer = Timer(time, self.assert_timeout)
self._assert_timer.start()
def clear_assert_timer(self):
if self._assert_timer is not None:
self._assert_timer.cancel()
def assert_timeout(self):
self._assert_state.assertTimerExpires(self)
###########################################
# Recv packets
###########################################
def recv_data_msg(self):
pass
def recv_assert_msg(self, received_metric: AssertMetric):
if self._assert_winner_metric.is_better_than(received_metric) and \
self._assert_winner_metric.ip_address == received_metric.ip_address:
# received inferior assert from Assert Winner
self._assert_state.receivedInferiorMetricFromWinner(self)
elif self.my_assert_metric().is_better_than(received_metric) and self.could_assert():
# received inferior assert from non assert winner and could_assert
self._assert_state.receivedInferiorMetricFromNonWinner_couldAssertIsTrue(self)
elif received_metric.is_better_than(self._assert_winner_metric) or \
received_metric.equal_metric(self._assert_winner_metric):
#received preferred assert
equal_metric = received_metric.equal_metric(self._assert_winner_metric)
self._assert_state.receivedPreferedMetric(self, received_metric, equal_metric)
def recv_prune_msg(self, upstream_neighbor_address, holdtime):
if upstream_neighbor_address == self.get_ip():
self._assert_state.receivedPruneOrJoinOrGraft(self)
def recv_join_msg(self, upstream_neighbor_address):
if upstream_neighbor_address == self.get_ip():
self._assert_state.receivedPruneOrJoinOrGraft(self)
def recv_graft_msg(self, upstream_neighbor_address, source_ip):
if upstream_neighbor_address == self.get_ip():
self._assert_state.receivedPruneOrJoinOrGraft(self)
def recv_graft_ack_msg(self, source_ip_of_graft_ack):
return
def recv_state_refresh_msg(self, received_metric: AssertMetric, prune_indicator):
self.recv_assert_msg(received_metric)
######################################
# Send messages
######################################
def send_graft(self):
print("send graft")
try:
(source, group) = self.get_tree_id()
ip_dst = self.get_neighbor_RPF()
ph = PacketPimGraft(ip_dst)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, joined_src_addresses=[source]))
pckt = Packet(payload=PacketPimHeader(ph))
self.get_interface().send(pckt.bytes(), ip_dst)
except:
traceback.print_exc()
return
def send_graft_ack(self, ip_sender):
print("send graft ack")
try:
(source, group) = self.get_tree_id()
ph = PacketPimGraftAck(ip_sender)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, joined_src_addresses=[source]))
pckt = Packet(payload=PacketPimHeader(ph))
self.get_interface().send(pckt.bytes(), ip_sender)
except:
traceback.print_exc()
return
def send_prune(self, holdtime=None):
if holdtime is None:
holdtime = T_LIMIT
print("send prune")
try:
(source, group) = self.get_tree_id()
ph = PacketPimJoinPrune(self.get_neighbor_RPF(), holdtime)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, pruned_src_addresses=[source]))
pckt = Packet(payload=PacketPimHeader(ph))
self.get_interface().send(pckt.bytes())
print('sent prune msg')
except:
traceback.print_exc()
return
def send_pruneecho(self):
holdtime = T_LIMIT
try:
(source, group) = self.get_tree_id()
ph = PacketPimJoinPrune(self.get_ip(), holdtime)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, pruned_src_addresses=[source]))
pckt = Packet(payload=PacketPimHeader(ph))
self.get_interface().send(pckt.bytes())
print("send prune echo")
except:
traceback.print_exc()
return
def send_join(self):
print("send join")
try:
(source, group) = self.get_tree_id()
ph = PacketPimJoinPrune(self.get_neighbor_RPF(), 210)
ph.add_multicast_group(PacketPimJoinPruneMulticastGroup(group, joined_src_addresses=[source]))
pckt = Packet(payload=PacketPimHeader(ph))
self.get_interface().send(pckt.bytes())
except:
traceback.print_exc()
return
def send_assert(self):
print("send assert")
try:
(source, group) = self.get_tree_id()
assert_metric = self.my_assert_metric()
ph = PacketPimAssert(multicast_group_address=group, source_address=source, metric_preference=assert_metric.metric_preference, metric=assert_metric.route_metric)
pckt = Packet(payload=PacketPimHeader(ph))
self.get_interface().send(pckt.bytes())
except:
traceback.print_exc()
return
def send_assert_cancel(self):
print("send assert cancel")
try:
(source, group) = self.get_tree_id()
ph = PacketPimAssert(multicast_group_address=group, source_address=source, metric_preference=float("Inf"), metric=float("Inf"))
pckt = Packet(payload=PacketPimHeader(ph))
self.get_interface().send(pckt.bytes())
except:
traceback.print_exc()
return
def send_state_refresh(self, state_refresh_msg_received: PacketPimStateRefresh):
pass
#############################################################
@abstractmethod
def is_forwarding(self):
pass
def assert_winner_nlt_expires(self):
self._assert_state.winnerLivelinessTimerExpires(self)
@abstractmethod
def new_or_reset_neighbor(self, neighbor_ip):
raise NotImplementedError()
@abstractmethod
def delete(self, change_type_interface=False):
if change_type_interface:
if self.could_assert():
self._assert_state.couldAssertIsNowFalse(self)
else:
self._assert_state.couldAssertIsNowTrue(self)
(s, g) = self.get_tree_id()
# unsubscribe igmp information
try:
interface_name = Main.kernel.vif_index_to_name_dic[self._interface_id]
igmp_interface = Main.igmp_interfaces[interface_name] # type: InterfaceIGMP
group_state = igmp_interface.interface_state.get_group_state(g)
group_state.remove_multicast_routing_entry(self)
except:
pass
# Prune State
self._prune_state = None
# Assert State
self._assert_state = None
self.set_assert_winner_metric(AssertMetric.infinite_assert_metric()) # unsubscribe from current AssertWinner NeighborLivenessTimer
self._assert_winner_metric = None
self.clear_assert_timer()
print('Tree Interface deleted')
def is_olist_null(self):
return self._kernel_entry.is_olist_null()
def evaluate_ingroup(self):
self._kernel_entry.evaluate_olist_change()
#############################################################
# Local Membership (IGMP)
############################################################
def notify_igmp(self, has_members: bool):
with self.get_state_lock():
with self._igmp_lock:
if has_members != self._local_membership_state.has_members():
self._local_membership_state = LocalMembership.Include if has_members else LocalMembership.NoInfo
self.change_tree()
self.evaluate_ingroup()
def igmp_has_members(self):
with self._igmp_lock:
return self._local_membership_state.has_members()
def get_interface(self):
kernel = Main.kernel
interface_name = kernel.vif_index_to_name_dic[self._interface_id]
interface = Main.interfaces[interface_name]
return interface
def get_ip(self):
ip = self.get_interface().get_ip()
return ip
def has_neighbors(self):
try:
return len(self.get_interface().neighbors) > 0
except:
return False
def get_tree_id(self):
return (self._kernel_entry.source_ip, self._kernel_entry.group_ip)
def change_tree(self):
self._kernel_entry.change()
def get_state_lock(self):
return self._kernel_entry.CHANGE_STATE_LOCK
@abstractmethod
def is_downstream(self):
raise NotImplementedError()
# obtain ip of RPF'(S)
def get_neighbor_RPF(self):
'''
RPF'(S)
'''
if self.i_am_assert_loser():
return self._assert_winner_metric.get_ip()
else:
return self._kernel_entry.rpf_node
def is_S_directly_conn(self):
return self._kernel_entry.rpf_node == self._kernel_entry.source_ip
def set_receceived_prune_holdtime(self, holdtime):
self._received_prune_holdtime = holdtime
def get_received_prune_holdtime(self):
return self._received_prune_holdtime
###################################################
# ASSERT
###################################################
def lost_assert(self):
if not self.is_downstream():
return False
else:
return not self._assert_winner_metric.i_am_assert_winner(self) and \
self._assert_winner_metric.is_better_than(AssertMetric.spt_assert_metric(self))
def i_am_assert_loser(self):
return self._assert_state == AssertState.Loser
def could_assert(self):
return self.is_downstream()
def my_assert_metric(self):
'''
The assert metric of this interface for usage in assert state machine
@rtype: AssertMetric
'''
if self.could_assert():
return AssertMetric.spt_assert_metric(self)
else:
return AssertMetric.infinite_assert_metric()
from abc import ABCMeta, abstractmethod
from utils import TYPE_CHECKING
if TYPE_CHECKING:
from .tree_if_upstream import TreeInterfaceUpstream
class UpstreamStateABC(metaclass=ABCMeta):
@staticmethod
@abstractmethod
def dataArrivesRPFinterface_OListNull_PLTstoped(interface: "TreeInterfaceUpstream"):
"""
Data arrives on RPF_Interface(S) AND
olist(S, G) == NULL AND
PLT(S, G) not running
@type interface: Upstream
"""
raise NotImplementedError()
@staticmethod
@abstractmethod
def stateRefreshArrivesRPFnbr_pruneIs1(interface: "TreeInterfaceUpstream"):
"""
State Refresh(S,G) received from RPF‘(S) AND
Prune Indicator == 1
@type interface: Upstream
"""
raise NotImplementedError()
@staticmethod
@abstractmethod
def stateRefreshArrivesRPFnbr_pruneIs0_PLTstoped(interface: "TreeInterfaceUpstream"):
"""
State Refresh(S,G) received from RPF‘(S) AND
Prune Indicator == 0 AND
PLT(S, G) not running
@type interface: Upstream
"""
raise NotImplementedError()
@staticmethod
@abstractmethod
def seeJoinToRPFnbr(interface: "TreeInterfaceUpstream"):
"""
See Join(S,G) to RPF’(S)
@type interface: Upstream
"""
raise NotImplementedError()
@staticmethod
@abstractmethod
def seePrune(interface: "TreeInterfaceUpstream"):
"""
See Prune(S,G)
@type interface: Upstream
"""
raise NotImplementedError()
@staticmethod
@abstractmethod
def OTexpires(interface: "TreeInterfaceUpstream"):
"""
OT(S,G) Expires
@type interface: Upstream
"""
raise NotImplementedError()
@staticmethod
@abstractmethod
def olistIsNowNull(interface: "TreeInterfaceUpstream"):
"""
olist(S,G)->NULL
@type interface: Upstream
"""
raise NotImplementedError()
@staticmethod
@abstractmethod
def olistIsNowNotNull(interface: "TreeInterfaceUpstream"):
"""
olist(S,G)->non-NULL
@type interface: Upstream
"""
raise NotImplementedError()
@staticmethod
@abstractmethod
def RPFnbrChanges_olistIsNotNull(interface: "TreeInterfaceUpstream"):
"""
RPF’(S) Changes AND
olist(S,G) != NULL AND
S not directly connected
@type interface: Upstream
"""
raise NotImplementedError()
@staticmethod
@abstractmethod
def RPFnbrChanges_olistIsNull(interface: "TreeInterfaceUpstream"):
"""
RPF’(S) Changes AND
olist(S,G) == NULL
@type interface: Upstream
"""
raise NotImplementedError()
@staticmethod
@abstractmethod
def sourceIsNowDirectConnect(interface: "TreeInterfaceUpstream"):
"""
S becomes directly connected
@type interface: Upstream
"""
raise NotImplementedError()
@staticmethod
@abstractmethod
def GRTexpires(interface: "TreeInterfaceUpstream"):
"""
GRT(S,G) Expires
@type interface: Upstream
"""
raise NotImplementedError()
@staticmethod
@abstractmethod
def recvGraftAckFromRPFnbr(interface: "TreeInterfaceUpstream"):
"""
Receive GraftAck(S,G) from RPF’(S)
@type interface: Upstream
"""
raise NotImplementedError()
class Forward(UpstreamStateABC):
"""
Forwarding (F)
This is the starting state of the Upsteam(S,G) state machine.
The state machine is in this state if it just started or if
oiflist(S,G) != NULL.
"""
@staticmethod
def dataArrivesRPFinterface_OListNull_PLTstoped(interface: "TreeInterfaceUpstream"):
"""
Data arrives on RPF_Interface(S) AND
olist(S, G) == NULL AND
PLT(S, G) not running
@type interface: TreeInterfaceUpstream
"""
if not interface.is_S_directly_conn():
interface.join_prune_logger.debug("dataArrivesRPFinterface_OListNull_PLTstoped, F -> P")
interface.set_state(UpstreamState.Pruned)
interface.send_prune()
interface.set_prune_limit_timer()
@staticmethod
def stateRefreshArrivesRPFnbr_pruneIs1(interface: "TreeInterfaceUpstream"):
"""
State Refresh(S,G) received from RPF‘(S) AND
Prune Indicator == 1
@type interface: TreeInterfaceUpstream
"""
# if OT is not running the router must set OT to t_override seconds
interface.join_prune_logger.debug('stateRefreshArrivesRPFnbr_pruneIs1, F -> F')
if not interface.is_override_timer_running():
interface.set_override_timer()
@staticmethod
def stateRefreshArrivesRPFnbr_pruneIs0_PLTstoped(interface: "TreeInterfaceUpstream"):
"""
State Refresh(S,G) received from RPF‘(S) AND
Prune Indicator == 0 AND
PLT(S, G) not running
@type interface: TreeInterfaceUpstream
"""
interface.join_prune_logger.debug('stateRefreshArrivesRPFnbr_pruneIs0_PLTstoped, F -> F')
@staticmethod
def seeJoinToRPFnbr(interface: "TreeInterfaceUpstream"):
"""
See Join(S,G) to RPF’(S)
@type interface: TreeInterfaceUpstream
"""
interface.join_prune_logger.debug('seeJoinToRPFnbr, F -> F')
interface.clear_override_timer()
@staticmethod
def seePrune(interface: "TreeInterfaceUpstream"):
"""
See Prune(S,G)
@type interface: TreeInterfaceUpstream
"""
interface.join_prune_logger.debug('seePrune, F -> F')
if not interface.is_S_directly_conn() and not interface.is_override_timer_running():
interface.set_override_timer()
@staticmethod
def OTexpires(interface: "TreeInterfaceUpstream"):
"""
OT(S,G) Expires
@type interface: TreeInterfaceUpstream
"""
interface.join_prune_logger.debug('OTexpires, F -> F')
if not interface.is_S_directly_conn():
interface.send_join()
@staticmethod
def olistIsNowNull(interface: "TreeInterfaceUpstream"):
"""
olist(S,G)->NULL
@type interface: TreeInterfaceUpstream
"""
if not interface.is_S_directly_conn():
interface.join_prune_logger.debug("olistIsNowNull, F -> P")
interface.set_state(UpstreamState.Pruned)
interface.send_prune()
interface.set_prune_limit_timer()
@staticmethod
def olistIsNowNotNull(interface: "TreeInterfaceUpstream"):
"""
olist(S,G)->non-NULL
@type interface: TreeInterfaceUpstream
"""
#assert False, "olistIsNowNotNull (in state F)"
return
@staticmethod
def RPFnbrChanges_olistIsNotNull(interface: "TreeInterfaceUpstream"):
"""
RPF’(S) Changes AND
olist(S,G) != NULL AND
S not directly connected
@type interface: TreeInterfaceUpstream
"""
if not interface.is_S_directly_conn():
interface.join_prune_logger.debug('RPFnbrChanges_olistIsNotNull, F -> AP')
interface.set_state(UpstreamState.AckPending)
interface.send_graft()
interface.set_graft_retry_timer()
@staticmethod
def RPFnbrChanges_olistIsNull(interface: "TreeInterfaceUpstream"):
"""
RPF’(S) Changes AND
olist(S,G) == NULL
@type interface: TreeInterfaceUpstream
"""
interface.join_prune_logger.debug('RPFnbrChanges_olistIsNull, F -> P')
interface.set_state(UpstreamState.Pruned)
@staticmethod
def sourceIsNowDirectConnect(interface: "TreeInterfaceUpstream"):
"""
S becomes directly connected
@type interface: TreeInterfaceUpstream
"""
interface.join_prune_logger.debug("sourceIsNowDirectConnect, F -> F")
@staticmethod
def GRTexpires(interface: "TreeInterfaceUpstream"):
"""
GRT(S,G) Expires
@type interface: TreeInterfaceUpstream
"""
#assert False, "GRTexpires (in state F)"
return
@staticmethod
def recvGraftAckFromRPFnbr(interface: "TreeInterfaceUpstream"):
"""
Receive GraftAck(S,G) from RPF’(S)
@type interface: TreeInterfaceUpstream
"""
interface.join_prune_logger.debug("recvGraftAckFromRPFnbr, F -> F")
def __str__(self):
return "Forwarding"
class Pruned(UpstreamStateABC):
'''
Pruned (P)
The set, olist(S,G), is empty.
The router will not forward data from S addressed to group G.
'''
@staticmethod
def dataArrivesRPFinterface_OListNull_PLTstoped(interface: "TreeInterfaceUpstream"):
"""
Data arrives on RPF_Interface(S) AND
olist(S, G) == NULL AND
PLT(S, G) not running
@type interface: TreeInterfaceUpstream
"""
if not interface.is_S_directly_conn():
interface.join_prune_logger.debug("dataArrivesRPFinterface_OListNull_PLTstoped, P -> P")
interface.set_prune_limit_timer()
interface.send_prune()
@staticmethod
def stateRefreshArrivesRPFnbr_pruneIs1(interface: "TreeInterfaceUpstream"):
"""
State Refresh(S,G) received from RPF‘(S) AND
Prune Indicator == 1
@type interface: TreeInterfaceUpstream
"""
interface.join_prune_logger.debug('stateRefreshArrivesRPFnbr_pruneIs1, P -> P')
interface.set_prune_limit_timer()
@staticmethod
def stateRefreshArrivesRPFnbr_pruneIs0_PLTstoped(interface: "TreeInterfaceUpstream"):
"""
State Refresh(S,G) received from RPF‘(S) AND
Prune Indicator == 0 AND
PLT(S, G) not running
@type interface: TreeInterfaceUpstream
"""
interface.join_prune_logger.debug('stateRefreshArrivesRPFnbr_pruneIs0_PLTstoped, P -> P')
interface.send_prune()
interface.set_prune_limit_timer()
@staticmethod
def seeJoinToRPFnbr(interface: "TreeInterfaceUpstream"):
"""
See Join(S,G) to RPF’(S)
@type interface: TreeInterfaceUpstream
"""
# Do nothing
interface.join_prune_logger.debug('seeJoinToRPFnbr, P -> P')
@staticmethod
def seePrune(interface: "TreeInterfaceUpstream"):
"""
See Prune(S,G)
@type interface: TreeInterfaceUpstream
"""
interface.join_prune_logger.debug('seePrune, P -> P')
if interface.get_received_prune_holdtime() > interface.remaining_prune_limit_timer():
interface.set_prune_limit_timer(time=interface.get_received_prune_holdtime())
@staticmethod
def OTexpires(interface: "TreeInterfaceUpstream"):
"""
OT(S,G) Expires
@type interface: TreeInterfaceUpstream
"""
#assert False, "OTexpires in state Pruned"
return
@staticmethod
def olistIsNowNull(interface: "TreeInterfaceUpstream"):
"""
olist(S,G)->NULL
@type interface: TreeInterfaceUpstream
"""
#assert False, "olistIsNowNull in state Pruned"
return
@staticmethod
def olistIsNowNotNull(interface: "TreeInterfaceUpstream"):
"""
olist(S,G)->non-NULL
@type interface: TreeInterfaceUpstream
"""
if not interface.is_S_directly_conn():
interface.join_prune_logger.debug('olistIsNowNotNull, P -> AP')
interface.clear_prune_limit_timer()
interface.set_state(UpstreamState.AckPending)
interface.send_graft()
interface.set_graft_retry_timer()
@staticmethod
def RPFnbrChanges_olistIsNotNull(interface: "TreeInterfaceUpstream"):
"""
RPF’(S) Changes AND
olist(S,G) != NULL AND
S not directly connected
@type interface: TreeInterfaceUpstream
"""
if not interface.is_S_directly_conn():
interface.join_prune_logger.debug('RPFnbrChanges_olistIsNotNull, P -> AP')
interface.clear_prune_limit_timer()
interface.set_state(UpstreamState.AckPending)
interface.send_graft()
interface.set_graft_retry_timer()
@staticmethod
def RPFnbrChanges_olistIsNull(interface: "TreeInterfaceUpstream"):
"""
RPF’(S) Changes AND
olist(S,G) == NULL
@type interface: TreeInterfaceUpstream
"""
if not interface.is_S_directly_conn():
interface.join_prune_logger.debug('RPFnbrChanges_olistIsNull, P -> P')
interface.clear_prune_limit_timer()
@staticmethod
def sourceIsNowDirectConnect(interface: "TreeInterfaceUpstream"):
"""
S becomes directly connected
@type interface: TreeInterfaceUpstream
"""
interface.join_prune_logger.debug('sourceIsNowDirectConnect, P -> P')
@staticmethod
def GRTexpires(interface: "TreeInterfaceUpstream"):
"""
GRT(S,G) Expires
@type interface: TreeInterfaceUpstream
"""
#assert False, "GRTexpires in state Pruned"
return
@staticmethod
def recvGraftAckFromRPFnbr(interface: "TreeInterfaceUpstream"):
"""
Receive GraftAck(S,G) from RPF’(S)
@type interface: TreeInterfaceUpstream
"""
interface.join_prune_logger.debug('recvGraftAckFromRPFnbr, P -> P')
def __str__(self):
return "Pruned"
class AckPending(UpstreamStateABC):
"""
AckPending (AP)
The router was in the Pruned(P) state, but a transition has
occurred in the Downstream(S,G) state machine for one of this
(S,G) entry’s outgoing interfaces, indicating that traffic from S
addressed to G should again be forwarded. A Graft message has
been sent to RPF’(S), but a Graft Ack message has not yet been
received.
"""
@staticmethod
def dataArrivesRPFinterface_OListNull_PLTstoped(interface: "TreeInterfaceUpstream"):
"""
Data arrives on RPF_Interface(S) AND
olist(S, G) == NULL AND
PLT(S, G) not running
@type interface: TreeInterfaceUpstream
"""
#assert False, "dataArrivesRPFinterface_OListNull_PLTstoped in state AP"
return
@staticmethod
def stateRefreshArrivesRPFnbr_pruneIs1(interface: "TreeInterfaceUpstream"):
"""
State Refresh(S,G) received from RPF‘(S) AND
Prune Indicator == 1
@type interface: TreeInterfaceUpstream
"""
interface.join_prune_logger.debug('stateRefreshArrivesRPFnbr_pruneIs1, AP -> AP')
if not interface.is_override_timer_running():
interface.set_override_timer()
@staticmethod
def stateRefreshArrivesRPFnbr_pruneIs0_PLTstoped(interface: "TreeInterfaceUpstream"):
"""
State Refresh(S,G) received from RPF‘(S) AND
Prune Indicator == 0 AND
PLT(S, G) not running
@type interface: TreeInterfaceUpstream
"""
interface.join_prune_logger.debug('stateRefreshArrivesRPFnbr_pruneIs0_PLTstoped, AP -> F')
interface.clear_graft_retry_timer()
interface.set_state(UpstreamState.Forward)
@staticmethod
def seeJoinToRPFnbr(interface: "TreeInterfaceUpstream"):
"""
See Join(S,G) to RPF’(S)
@type interface: TreeInterfaceUpstream
"""
interface.join_prune_logger.debug('seeJoinToRPFnbr, AP -> AP')
interface.clear_override_timer()
@staticmethod
def seePrune(interface: "TreeInterfaceUpstream"):
"""
See Prune(S,G)
@type interface: TreeInterfaceUpstream
"""
interface.join_prune_logger.debug('seePrune, AP -> AP')
if not interface.is_override_timer_running():
interface.set_override_timer()
@staticmethod
def OTexpires(interface: "TreeInterfaceUpstream"):
"""
OT(S,G) Expires
@type interface: TreeInterfaceUpstream
"""
interface.join_prune_logger.debug('OTexpires, AP -> AP')
interface.send_join()
@staticmethod
def olistIsNowNull(interface: "TreeInterfaceUpstream"):
"""
olist(S,G)->NULL
@type interface: TreeInterfaceUpstream
"""
interface.join_prune_logger.debug('olistIsNowNull, AP -> P')
interface.set_state(UpstreamState.Pruned)
interface.send_prune()
interface.clear_graft_retry_timer()
interface.set_prune_limit_timer()
@staticmethod
def olistIsNowNotNull(interface: "TreeInterfaceUpstream"):
"""
olist(S,G)->non-NULL
@type interface: TreeInterfaceUpstream
"""
#assert False, "olistIsNowNotNull in state AP"
return
@staticmethod
def RPFnbrChanges_olistIsNotNull(interface: "TreeInterfaceUpstream"):
"""
RPF’(S) Changes AND
olist(S,G) != NULL AND
S not directly connected
@type interface: TreeInterfaceUpstream
"""
if not interface.is_S_directly_conn():
interface.join_prune_logger.debug('RPFnbrChanges_olistIsNotNull, AP -> AP')
interface.send_graft()
interface.set_graft_retry_timer()
@staticmethod
def RPFnbrChanges_olistIsNull(interface: "TreeInterfaceUpstream"):
"""
RPF’(S) Changes AND
olist(S,G) == NULL
@type interface: TreeInterfaceUpstream
"""
if not interface.is_S_directly_conn():
interface.join_prune_logger.debug('RPFnbrChanges_olistIsNull, AP -> P')
interface.clear_graft_retry_timer()
interface.set_state(UpstreamState.Pruned)
@staticmethod
def sourceIsNowDirectConnect(interface: "TreeInterfaceUpstream"):
"""
S becomes directly connected
@type interface: TreeInterfaceUpstream
"""
interface.join_prune_logger.debug('sourceIsNowDirectConnect, AP -> F')
interface.set_state(UpstreamState.Forward)
interface.clear_graft_retry_timer()
@staticmethod
def GRTexpires(interface: "TreeInterfaceUpstream"):
"""
GRT(S,G) Expires
@type interface: TreeInterfaceUpstream
"""
interface.join_prune_logger.debug('GRTexpires, AP -> AP')
interface.set_graft_retry_timer()
interface.send_graft()
@staticmethod
def recvGraftAckFromRPFnbr(interface: "TreeInterfaceUpstream"):
"""
Receive GraftAck(S,G) from RPF’(S)
@type interface: TreeInterfaceUpstream
"""
interface.join_prune_logger.debug('recvGraftAckFromRPFnbr, AP -> F')
interface.clear_graft_retry_timer()
interface.set_state(UpstreamState.Forward)
def __str__(self):
return "AckPending"
class UpstreamState():
Forward = Forward()
Pruned = Pruned()
AckPending = AckPending()
import array
'''
import struct
if struct.pack("H",1) == "\x00\x01": # big endian
def checksum(pkt):
if len(pkt) % 2 == 1:
pkt += "\0"
s = sum(array.array("H", pkt))
s = (s >> 16) + (s & 0xffff)
s += s >> 16
s = ~s
return s & 0xffff
else:
def checksum(pkt):
if len(pkt) % 2 == 1:
pkt += "\0"
s = sum(array.array("H", pkt))
s = (s >> 16) + (s & 0xffff)
s += s >> 16
s = ~s
return (((s>>8)&0xff)|s<<8) & 0xffff
'''
HELLO_HOLD_TIME_NO_TIMEOUT = 0xFFFF
HELLO_HOLD_TIME = 160
HELLO_HOLD_TIME_TIMEOUT = 0
def checksum(pkt: bytes) -> bytes:
if len(pkt) % 2 == 1:
pkt += "\0"
s = sum(array.array("H", pkt))
s = (s >> 16) + (s & 0xffff)
s += s >> 16
s = ~s
return (((s >> 8) & 0xff) | s << 8) & 0xffff
import ctypes
import ctypes.util
libc = ctypes.CDLL(ctypes.util.find_library('c'))
def if_nametoindex(name):
if not isinstance(name, str):
raise TypeError('name must be a string.')
ret = libc.if_nametoindex(name)
if not ret:
raise RuntimeError("Invalid Name")
return ret
def if_indextoname(index):
if not isinstance(index, int):
raise TypeError('index must be an int.')
libc.if_indextoname.argtypes = [ctypes.c_uint32, ctypes.c_char_p]
libc.if_indextoname.restype = ctypes.c_char_p
ifname = ctypes.create_string_buffer(32)
ifname = libc.if_indextoname(index, ifname)
if not ifname:
raise RuntimeError ("Inavlid Index")
return ifname.decode("utf-8")
# obtain TYPE_CHECKING (for type hinting)
try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False
# IGMP timers (in seconds)
RobustnessVariable = 2
QueryInterval = 125
QueryResponseInterval = 10
MaxResponseTime_QueryResponseInterval = QueryResponseInterval*10
GroupMembershipInterval = RobustnessVariable * QueryInterval + QueryResponseInterval
OtherQuerierPresentInterval = RobustnessVariable * QueryInterval + QueryResponseInterval/2
StartupQueryInterval = QueryInterval / 4
StartupQueryCount = RobustnessVariable
LastMemberQueryInterval = 1
MaxResponseTime_LastMemberQueryInterval = LastMemberQueryInterval*10
LastMemberQueryCount = RobustnessVariable
UnsolicitedReportInterval = 10
Version1RouterPresentTimeout = 400
# IGMP msg type
Membership_Query = 0x11
Version_1_Membership_Report = 0x12
Version_2_Membership_Report = 0x16
Leave_Group = 0x17
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment