Commit 4c12b098 authored by Pedro Oliveira's avatar Pedro Oliveira

fix Hello&Neighbor (timers and deadlocks) & check olist of all trees when...

fix Hello&Neighbor (timers and deadlocks) & check olist of all trees when interface changes number of neighbors & packet reception handled inside of interfaces' classes (instead of methods in separate files) & fix send of state refresh through non-root interfaces (if assert loser dont check prune and assert state)
parent 2d3d8f7e
from Packet.ReceivedPacket import ReceivedPacket
from Packet.PacketPimAssert import PacketPimAssert
import Main
import traceback
class Assert:
TYPE = 5
def __init__(self):
Main.add_protocol(Assert.TYPE, self)
# receive handler
def receive_handle(self, packet: ReceivedPacket):
interface = packet.interface
interface_name = interface.interface_name
ip = packet.ip_header.ip_src
print("ip = ", ip)
pkt_assert = packet.payload.payload # type: PacketPimAssert
metric = pkt_assert.metric
metric_preference = pkt_assert.metric_preference
source = pkt_assert.source_address
group = pkt_assert.multicast_group_address
source_group = (source, group)
interface_name = packet.interface.interface_name
interface_index = Main.kernel.vif_name_to_index_dic[interface_name]
try:
#Main.kernel.routing[source_group].recv_assert_msg(interface_index, packet)
Main.kernel.get_routing_entry(source_group).recv_assert_msg(interface_index, packet)
except:
traceback.print_exc()
from Packet.Packet import Packet
from Packet.ReceivedPacket import ReceivedPacket
import Main
import traceback
class Graft:
TYPE = 6
def __init__(self):
Main.add_protocol(Graft.TYPE, self)
# receive handler
def receive_handle(self, packet: ReceivedPacket):
print("GRAFT!!")
interface = packet.interface
ip = packet.ip_header.ip_src
print("ip = ", ip)
pkt_join_prune = packet.payload.payload # type: PacketPimJoinPrune
# if im not upstream neighbor ignore message
if pkt_join_prune.upstream_neighbor_address != interface.ip_interface:
#return
pass
interface_name = interface.interface_name
interface_index = Main.kernel.vif_name_to_index_dic[interface_name]
join_prune_groups = pkt_join_prune.groups
for group in join_prune_groups:
multicast_group = group.multicast_group
joined_src_addresses = group.joined_src_addresses
for source_address in joined_src_addresses:
source_group = (source_address, multicast_group)
try:
Main.kernel.get_routing_entry(source_group).recv_graft_msg(interface_index, packet)
except:
try:
#import time
#time.sleep(2)
Main.kernel.get_routing_entry(source_group).recv_graft_msg(interface_index, packet)
except:
pass
# todo o que fazer quando n existe arvore para (s,g) ???
traceback.print_exc()
print("ATENCAO!!!!")
print(Main.kernel.routing)
continue
from Packet.ReceivedPacket import ReceivedPacket
import Main
import traceback
class GraftAck:
TYPE = 7
def __init__(self):
Main.add_protocol(GraftAck.TYPE, self)
# receive handler
def receive_handle(self, packet: ReceivedPacket):
print("GRAFT ACK!!")
interface = packet.interface
ip = packet.ip_header.ip_src
print("ip = ", ip)
pkt_join_prune = packet.payload.payload # type: PacketPimJoinPrune
# if im not upstream neighbor ignore message
if pkt_join_prune.upstream_neighbor_address != interface.ip_interface:
#return
pass
interface_name = interface.interface_name
interface_index = Main.kernel.vif_name_to_index_dic[interface_name]
join_prune_groups = pkt_join_prune.groups
for group in join_prune_groups:
multicast_group = group.multicast_group
joined_src_addresses = group.joined_src_addresses
for source_address in joined_src_addresses:
source_group = (source_address, multicast_group)
try:
Main.kernel.get_routing_entry(source_group).recv_graft_ack_msg(interface_index, packet)
except:
try:
#import time
#time.sleep(2)
Main.kernel.get_routing_entry(source_group).recv_graft_ack_msg(interface_index, packet)
except:
pass
# todo o que fazer quando n existe arvore para (s,g) ???
traceback.print_exc()
print("ATENCAO!!!!")
print(Main.kernel.routing)
continue
from Packet.ReceivedPacket import ReceivedPacket
import Main
from Neighbor import Neighbor
class Hello:
TYPE = 0
TRIGGERED_HELLO_DELAY = 16 # TODO: configure via external file??
def __init__(self):
Main.add_protocol(Hello.TYPE, self)
# receive handler
def receive_handle(self, packet: ReceivedPacket):
interface = packet.interface
ip = packet.ip_header.ip_src
print("ip = ", ip)
options = packet.payload.payload.get_options()
if (1 in options) and (20 in options):
#hello_hold_time = options[1]
hello_hold_time = options[1].holdtime
#generation_id = options[20]
generation_id = options[20].generation_id
else:
raise Exception
with interface.neighbors_lock.genWlock():
if ip in interface.neighbors:
neighbor = interface.neighbors[ip]
else:
interface.neighbors[ip] = Neighbor(interface, ip, generation_id, hello_hold_time)
return
neighbor.receive_hello(generation_id, hello_hold_time)
"""
with neighbor.neighbor_lock:
# Already know Neighbor
print("neighbor conhecido")
neighbor.heartbeat()
if neighbor.hello_hold_time != hello_hold_time:
print("keep alive period diferente")
neighbor.set_hello_hold_time(hello_hold_time)
if neighbor.generation_id != generation_id:
print("neighbor reiniciado")
neighbor.set_generation_id(generation_id)
with interface.neighbors_lock.genWlock():
#if interface.get_neighbor(ip) is None:
if ip in interface.neighbors:
# Unknown Neighbor
if (1 in options) and (20 in options):
try:
#Main.add_neighbor(packet.interface, ip, options[20], options[1])
print("non neighbor and options inside")
except Exception:
# Received Neighbor with Timeout
print("non neighbor and options inside but neighbor timedout")
pass
return
print("non neighbor and required options not inside")
else:
# Already know Neighbor
print("neighbor conhecido")
neighbor = Main.get_neighbor(ip)
neighbor.heartbeat()
if 1 in options and neighbor.hello_hold_time != options[1]:
print("keep alive period diferente")
neighbor.set_hello_hold_time(options[1])
if 20 in options and neighbor.generation_id != options[20]:
print("neighbor reiniciado")
neighbor.remove()
Main.add_neighbor(packet.interface, ip, options[20], options[1])
"""
\ No newline at end of file
from Packet.ReceivedPacket import ReceivedPacket
from utils import *
from ipaddress import IPv4Address
class IGMP:
# receive handler
@staticmethod
def receive_handle(packet: ReceivedPacket):
interface = packet.interface
ip_src = packet.ip_header.ip_src
ip_dst = packet.ip_header.ip_dst
#print("ip = ", ip_src)
igmp_hdr = packet.payload
igmp_type = igmp_hdr.type
igmp_group = igmp_hdr.group_address
# source ip can't be 0.0.0.0 or multicast
if ip_src == "0.0.0.0" or IPv4Address(ip_src).is_multicast:
return
if igmp_type == Version_1_Membership_Report and ip_dst == igmp_group and IPv4Address(igmp_group).is_multicast:
interface.interface_state.receive_v1_membership_report(packet)
elif igmp_type == Version_2_Membership_Report and ip_dst == igmp_group and IPv4Address(igmp_group).is_multicast:
interface.interface_state.receive_v2_membership_report(packet)
elif igmp_type == Leave_Group and ip_dst == "224.0.0.2" and IPv4Address(igmp_group).is_multicast:
interface.interface_state.receive_leave_group(packet)
elif igmp_type == Membership_Query and (ip_dst == igmp_group or (ip_dst == "224.0.0.1" and igmp_group == "0.0.0.0")):
interface.interface_state.receive_query(packet)
else:
raise Exception("Exception igmp packet: type={}; ip_dst={}; packet_group_report={}".format(igmp_type, ip_dst, igmp_group))
import socket import socket
from abc import ABCMeta, abstractmethod
import threading import threading
import random import random
import netifaces import netifaces
...@@ -8,110 +9,55 @@ import traceback ...@@ -8,110 +9,55 @@ import traceback
from RWLock.RWLock import RWLockWrite from RWLock.RWLock import RWLockWrite
class Interface(object): class Interface(metaclass=ABCMeta):
MCAST_GRP = '224.0.0.13' MCAST_GRP = '224.0.0.13'
# substituir ip por interface ou algo parecido def __init__(self, interface_name, recv_socket, send_socket, vif_index):
def __init__(self, interface_name: str):
self.interface_name = interface_name self.interface_name = interface_name
ip_interface = netifaces.ifaddresses(interface_name)[netifaces.AF_INET][0]['addr']
self.ip_mask_interface = netifaces.ifaddresses(interface_name)[netifaces.AF_INET][0]['netmask']
self.ip_interface = ip_interface
s = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_PIM) # virtual interface index for the multicast routing table
self.vif_index = vif_index
# allow other sockets to bind this port too # set receive socket and send socket
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self._send_socket = send_socket
self._recv_socket = recv_socket
# explicitly join the multicast group on the interface specified
#s.setsockopt(socket.SOL_IP, socket.IP_ADD_MEMBERSHIP, socket.inet_aton(Interface.MCAST_GRP) + socket.inet_aton(ip_interface))
s.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP,
socket.inet_aton(Interface.MCAST_GRP) + socket.inet_aton(ip_interface))
s.setsockopt(socket.SOL_SOCKET, 25, str(interface_name + '\0').encode('utf-8'))
# set socket output interface
s.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_IF, socket.inet_aton(ip_interface))
# set socket TTL to 1
s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 1)
s.setsockopt(socket.IPPROTO_IP, socket.IP_TTL, 1)
# don't receive outgoing packets
s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, 0)
self.socket = s
self.interface_enabled = True self.interface_enabled = True
# generation id
#self.generation_id = random.getrandbits(32)
# todo neighbors
#self.neighbors = {}
#self.neighbors_lock = RWLockWrite()
# run receive method in background # run receive method in background
#receive_thread = threading.Thread(target=self.receive) receive_thread = threading.Thread(target=self.receive)
#receive_thread.daemon = True receive_thread.daemon = True
#receive_thread.start() receive_thread.start()
def receive(self): def receive(self):
try:
(raw_packet, (ip, _)) = self.socket.recvfrom(256 * 1024)
if raw_packet:
packet = ReceivedPacket(raw_packet, self)
else:
packet = None
return packet
except Exception:
traceback.print_exc()
return None
"""
while self.interface_enabled: while self.interface_enabled:
try: try:
(raw_packet, (ip, _)) = self.socket.recvfrom(256 * 1024) (raw_bytes, _) = self._recv_socket.recvfrom(256 * 1024)
if raw_packet: if raw_bytes:
packet = ReceivedPacket(raw_packet, self) self._receive(raw_bytes)
Main.protocols[packet.payload.get_pim_type()].receive_handle(packet) # TODO: perceber se existe melhor maneira de fazer isto
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
continue continue
"""
@abstractmethod
def _receive(self, raw_bytes):
raise NotImplementedError
def send(self, data: bytes, group_ip: str): def send(self, data: bytes, group_ip: str):
if self.interface_enabled and data: if self.interface_enabled and data:
self.socket.sendto(data, (group_ip, 0)) self._send_socket.sendto(data, (group_ip, 0))
def remove(self): def remove(self):
self.interface_enabled = False self.interface_enabled = False
try: try:
self.socket.shutdown(socket.SHUT_RDWR) self._recv_socket.shutdown(socket.SHUT_RDWR)
except Exception: except Exception:
pass pass
self.socket.close() self._recv_socket.close()
self._send_socket.close()
def is_enabled(self): def is_enabled(self):
return self.interface_enabled return self.interface_enabled
@abstractmethod
def get_ip(self): def get_ip(self):
return self.ip_interface raise NotImplementedError
\ No newline at end of file
"""
def add_neighbor(self, ip, random_number, hello_hold_time):
with self.neighbors_lock.genWlock():
if ip not in self.neighbors:
print("ADD NEIGHBOR")
from Neighbor import Neighbor
n = Neighbor(self, ip, random_number, hello_hold_time)
self.neighbors[ip] = n
Main.protocols[0].force_send(self)
def get_neighbors(self):
with self.neighbors_lock.genRlock():
return self.neighbors.values()
def get_neighbor(self, ip):
with self.neighbors_lock.genRlock():
return self.neighbors[ip]
"""
\ No newline at end of file
import socket import socket
import struct import struct
import threading
import netifaces import netifaces
from Packet.ReceivedPacket import ReceivedPacket from Packet.ReceivedPacket import ReceivedPacket
import Main from Interface import Interface
import traceback
from ctypes import create_string_buffer, addressof from ctypes import create_string_buffer, addressof
from ipaddress import IPv4Address
from utils import Version_1_Membership_Report, Version_2_Membership_Report, Leave_Group, Membership_Query
if not hasattr(socket, 'SO_BINDTODEVICE'): if not hasattr(socket, 'SO_BINDTODEVICE'):
socket.SO_BINDTODEVICE = 25 socket.SO_BINDTODEVICE = 25
class InterfaceIGMP(object): class InterfaceIGMP(Interface):
ETH_P_IP = 0x0800 # Internet Protocol packet ETH_P_IP = 0x0800 # Internet Protocol packet
SO_ATTACH_FILTER = 26
FILTER_IGMP = [ FILTER_IGMP = [
struct.pack('HBBI', 0x28, 0, 0, 0x0000000c), struct.pack('HBBI', 0x28, 0, 0, 0x0000000c),
...@@ -22,10 +24,6 @@ class InterfaceIGMP(object): ...@@ -22,10 +24,6 @@ class InterfaceIGMP(object):
struct.pack('HBBI', 0x6, 0, 0, 0x00000000), struct.pack('HBBI', 0x6, 0, 0, 0x00000000),
] ]
SO_ATTACH_FILTER = 26
PACKET_MR_ALLMULTI = 2
def __init__(self, interface_name: str, vif_index:int): def __init__(self, interface_name: str, vif_index:int):
# RECEIVE SOCKET # RECEIVE SOCKET
rcv_s = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, socket.htons(InterfaceIGMP.ETH_P_IP)) rcv_s = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, socket.htons(InterfaceIGMP.ETH_P_IP))
...@@ -40,7 +38,6 @@ class InterfaceIGMP(object): ...@@ -40,7 +38,6 @@ class InterfaceIGMP(object):
# bind to interface # bind to interface
rcv_s.bind((interface_name, 0x0800)) rcv_s.bind((interface_name, 0x0800))
self.recv_socket = rcv_s
# SEND SOCKET # SEND SOCKET
snd_s = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_IGMP) snd_s = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_IGMP)
...@@ -48,20 +45,12 @@ class InterfaceIGMP(object): ...@@ -48,20 +45,12 @@ class InterfaceIGMP(object):
# bind to interface # bind to interface
snd_s.setsockopt(socket.SOL_SOCKET, socket.SO_BINDTODEVICE, str(interface_name + "\0").encode('utf-8')) snd_s.setsockopt(socket.SOL_SOCKET, socket.SO_BINDTODEVICE, str(interface_name + "\0").encode('utf-8'))
self.send_socket = snd_s super().__init__(interface_name=interface_name, recv_socket=rcv_s, send_socket=snd_s, vif_index=vif_index)
self.interface_enabled = True
self.interface_name = interface_name
from igmp.RouterState import RouterState from igmp.RouterState import RouterState
self.interface_state = RouterState(self) self.interface_state = RouterState(self)
# virtual interface index for the multicast routing table
self.vif_index = vif_index
# run receive method in background
receive_thread = threading.Thread(target=self.receive)
receive_thread.daemon = True
receive_thread.start()
def get_ip(self): def get_ip(self):
return netifaces.ifaddresses(self.interface_name)[netifaces.AF_INET][0]['addr'] return netifaces.ifaddresses(self.interface_name)[netifaces.AF_INET][0]['addr']
...@@ -70,24 +59,48 @@ class InterfaceIGMP(object): ...@@ -70,24 +59,48 @@ class InterfaceIGMP(object):
def ip_interface(self): def ip_interface(self):
return self.get_ip() return self.get_ip()
def send(self, data: bytes, address: str="224.0.0.1"): def send(self, data: bytes, address: str="224.0.0.1"):
if self.interface_enabled: super().send(data, address)
self.send_socket.sendto(data, (address, 0))
def _receive(self, raw_bytes):
def receive(self): if raw_bytes:
while self.interface_enabled: raw_bytes = raw_bytes[14:]
try: packet = ReceivedPacket(raw_bytes, self)
(raw_packet, _) = self.recv_socket.recvfrom(256 * 1024) ip_src = packet.ip_header.ip_src
if raw_packet: if not (ip_src == "0.0.0.0" or IPv4Address(ip_src).is_multicast):
raw_packet = raw_packet[14:] self.PKT_FUNCTIONS[packet.payload.get_igmp_type()](self, packet)
packet = ReceivedPacket(raw_packet, self)
Main.igmp.receive_handle(packet)
except Exception: ###########################################
traceback.print_exc() # Recv packets
continue ###########################################
def receive_version_1_membership_report(self, packet):
def remove(self): ip_dst = packet.ip_header.ip_dst
self.interface_enabled = False igmp_group = packet.payload.group_address
self.recv_socket.close() if ip_dst == igmp_group and IPv4Address(igmp_group).is_multicast:
self.send_socket.close() self.interface_state.receive_v1_membership_report(packet)
def receive_version_2_membership_report(self, packet):
ip_dst = packet.ip_header.ip_dst
igmp_group = packet.payload.group_address
if ip_dst == igmp_group and IPv4Address(igmp_group).is_multicast:
self.interface_state.receive_v2_membership_report(packet)
def receive_leave_group(self, packet):
ip_dst = packet.ip_header.ip_dst
igmp_group = packet.payload.group_address
if ip_dst == "224.0.0.2" and IPv4Address(igmp_group).is_multicast:
self.interface_state.receive_leave_group(packet)
def receive_membership_query(self, packet):
ip_dst = packet.ip_header.ip_dst
igmp_group = packet.payload.group_address
if ip_dst == igmp_group or (ip_dst == "224.0.0.1" and igmp_group == "0.0.0.0"):
self.interface_state.receive_query(packet)
PKT_FUNCTIONS = {
Version_1_Membership_Report: receive_version_1_membership_report,
Version_2_Membership_Report: receive_version_2_membership_report,
Leave_Group: receive_leave_group,
Membership_Query: receive_membership_query,
}
This diff is collapsed.
from Packet.Packet import Packet
from Packet.ReceivedPacket import ReceivedPacket
from Packet.PacketPimJoinPrune import PacketPimJoinPrune
from Packet.PacketPimJoinPruneMulticastGroup import PacketPimJoinPruneMulticastGroup
from Interface import Interface
import Main
import traceback
class JoinPrune:
TYPE = 3
def __init__(self):
Main.add_protocol(JoinPrune.TYPE, self)
# receive handler
def receive_handle(self, packet: ReceivedPacket):
interface = packet.interface
ip = packet.ip_header.ip_src
print("ip = ", ip)
pkt_join_prune = packet.payload.payload # type: PacketPimJoinPrune
# if im not upstream neighbor ignore message
if pkt_join_prune.upstream_neighbor_address != interface.ip_interface:
#return
pass
interface_name = interface.interface_name
interface_index = Main.kernel.vif_name_to_index_dic[interface_name]
# todo holdtime
holdtime = pkt_join_prune.hold_time
join_prune_groups = pkt_join_prune.groups
for group in join_prune_groups:
multicast_group = group.multicast_group
joined_src_addresses = group.joined_src_addresses
pruned_src_addresses = group.pruned_src_addresses
for source_address in joined_src_addresses:
source_group = (source_address, multicast_group)
try:
#Main.kernel.routing[source_group].recv_join_msg(interface_index, packet)
Main.kernel.get_routing_entry(source_group).recv_join_msg(interface_index, packet)
except:
try:
#import time
#time.sleep(2)
Main.kernel.get_routing_entry(source_group).recv_join_msg(interface_index, packet)
except:
pass
# todo o que fazer quando n existe arvore para (s,g) ???
traceback.print_exc()
print("ATENCAO!!!!")
print(Main.kernel.routing)
continue
for source_address in pruned_src_addresses:
source_group = (source_address, multicast_group)
try:
#Main.kernel.routing[source_group].recv_prune_msg(interface_index, packet)
Main.kernel.get_routing_entry(source_group).recv_prune_msg(interface_index, packet)
except:
try:
#import time
#time.sleep(2)
Main.kernel.get_routing_entry(source_group).recv_prune_msg(interface_index, packet)
except:
pass
# todo o que fazer quando n existe arvore para (s,g) ???
traceback.print_exc()
print("ATENCAO!!!!")
print(Main.kernel.routing)
continue
...@@ -470,3 +470,8 @@ class Kernel: ...@@ -470,3 +470,8 @@ class Kernel:
pass pass
# When interface changes number of neighbors verify if olist changes and prune/forward respectively
def interface_change_number_of_neighbors(self):
with self.rwlock.genWlock():
for entry in self.routing.values():
entry.change_at_number_of_neighbors()
...@@ -10,10 +10,9 @@ import UnicastRouting ...@@ -10,10 +10,9 @@ import UnicastRouting
interfaces = {} # interfaces with multicast routing enabled interfaces = {} # interfaces with multicast routing enabled
igmp_interfaces = {} # igmp interfaces igmp_interfaces = {} # igmp interfaces
protocols = {}
kernel = None kernel = None
igmp = None igmp = None
unicast_routing = None
def add_pim_interface(interface_name, state_refresh_capable:bool=False): def add_pim_interface(interface_name, state_refresh_capable:bool=False):
kernel.create_pim_interface(interface_name=interface_name, state_refresh_capable=state_refresh_capable) kernel.create_pim_interface(interface_name=interface_name, state_refresh_capable=state_refresh_capable)
...@@ -64,10 +63,6 @@ def remove_interface(interface_name, pim=False, igmp=False): ...@@ -64,10 +63,6 @@ def remove_interface(interface_name, pim=False, igmp=False):
# print(igmp_interfaces) # print(igmp_interfaces)
kernel.remove_interface(interface_name, pim=pim, igmp=igmp) kernel.remove_interface(interface_name, pim=pim, igmp=igmp)
def add_protocol(protocol_number, protocol_obj):
global protocols
protocols[protocol_number] = protocol_obj
def list_neighbors(): def list_neighbors():
interfaces_list = interfaces.values() interfaces_list = interfaces.values()
t = PrettyTable(['Interface', 'Neighbor IP', 'Hello Hold Time', "Generation ID", "Uptime"]) t = PrettyTable(['Interface', 'Neighbor IP', 'Hello Hold Time', "Generation ID", "Uptime"])
...@@ -157,32 +152,16 @@ def list_routing_state(): ...@@ -157,32 +152,16 @@ def list_routing_state():
def stop(): def stop():
remove_interface("*", pim=True, igmp=True) remove_interface("*", pim=True, igmp=True)
kernel.exit() kernel.exit()
UnicastRouting.stop() unicast_routing.stop()
def main(): def main():
from Hello import Hello
from IGMP import IGMP
from Assert import Assert
from JoinPrune import JoinPrune
from GraftAck import GraftAck
from Graft import Graft
from StateRefresh import StateRefresh
Hello()
Assert()
JoinPrune()
Graft()
GraftAck()
StateRefresh()
global kernel global kernel
kernel = Kernel() kernel = Kernel()
global igmp global unicast_routing
igmp = IGMP() unicast_routing = UnicastRouting.UnicastRouting()
global u
u = UnicastRouting.UnicastRouting()
global interfaces global interfaces
global igmp_interfaces global igmp_interfaces
......
from threading import Timer from threading import Timer
import time import time
from utils import HELLO_HOLD_TIME_NO_TIMEOUT, HELLO_HOLD_TIME_TIMEOUT, TYPE_CHECKING from utils import HELLO_HOLD_TIME_NO_TIMEOUT, HELLO_HOLD_TIME_TIMEOUT, TYPE_CHECKING
from threading import Lock from threading import Lock, RLock
from RWLock.RWLock import RWLockWrite
import Main import Main
if TYPE_CHECKING: if TYPE_CHECKING:
from InterfacePIM import InterfacePim from InterfacePIM import InterfacePim
class Neighbor: class Neighbor:
def __init__(self, contact_interface: "InterfacePim", ip, generation_id: int, hello_hold_time: int): def __init__(self, contact_interface: "InterfacePim", ip, generation_id: int, hello_hold_time: int, state_refresh_capable:bool):
if hello_hold_time == HELLO_HOLD_TIME_TIMEOUT: if hello_hold_time == HELLO_HOLD_TIME_TIMEOUT:
raise Exception raise Exception
self.contact_interface = contact_interface self.contact_interface = contact_interface
...@@ -17,7 +16,7 @@ class Neighbor: ...@@ -17,7 +16,7 @@ class Neighbor:
self.generation_id = generation_id self.generation_id = generation_id
# todo lan prune delay # todo lan prune delay
# todo override interval # todo override interval
# todo state refresh capable self.state_refresh_capable = state_refresh_capable
self.neighbor_liveness_timer = None self.neighbor_liveness_timer = None
self.hello_hold_time = None self.hello_hold_time = None
...@@ -26,13 +25,9 @@ class Neighbor: ...@@ -26,13 +25,9 @@ class Neighbor:
self.neighbor_lock = Lock() self.neighbor_lock = Lock()
self.tree_interface_nlt_subscribers = [] self.tree_interface_nlt_subscribers = []
self.tree_interface_nlt_subscribers_lock = RWLockWrite() self.tree_interface_nlt_subscribers_lock = RLock()
# send hello to new neighbor
#self.contact_interface.send_hello()
# todo RANDOM DELAY??? => DO NOTHING... EVENTUALLY THE HELLO MESSAGE WILL BE SENT
def set_hello_hold_time(self, hello_hold_time: int): def set_hello_hold_time(self, hello_hold_time: int):
self.hello_hold_time = hello_hold_time self.hello_hold_time = hello_hold_time
if self.neighbor_liveness_timer is not None: if self.neighbor_liveness_timer is not None:
...@@ -69,14 +64,11 @@ class Neighbor: ...@@ -69,14 +64,11 @@ class Neighbor:
print('HELLO TIMER EXPIRED... remove neighbor') print('HELLO TIMER EXPIRED... remove neighbor')
if self.neighbor_liveness_timer is not None: if self.neighbor_liveness_timer is not None:
self.neighbor_liveness_timer.cancel() self.neighbor_liveness_timer.cancel()
#Main.remove_neighbor(self.ip)
interface_name = self.contact_interface.interface_name
neighbor_ip = self.ip
del self.contact_interface.neighbors[self.ip] self.contact_interface.remove_neighbor(self.ip)
# notify interfaces which have this neighbor as AssertWinner # notify interfaces which have this neighbor as AssertWinner
with self.tree_interface_nlt_subscribers_lock.genRlock(): with self.tree_interface_nlt_subscribers_lock:
for tree_if in self.tree_interface_nlt_subscribers: for tree_if in self.tree_interface_nlt_subscribers:
tree_if.assert_winner_nlt_expires() tree_if.assert_winner_nlt_expires()
...@@ -85,22 +77,23 @@ class Neighbor: ...@@ -85,22 +77,23 @@ class Neighbor:
return return
def receive_hello(self, generation_id, hello_hold_time): def receive_hello(self, generation_id, hello_hold_time, state_refresh_capable):
if hello_hold_time == HELLO_HOLD_TIME_TIMEOUT: if hello_hold_time == HELLO_HOLD_TIME_TIMEOUT:
self.set_hello_hold_time(hello_hold_time) self.set_hello_hold_time(hello_hold_time)
else: else:
self.time_of_last_update = time.time() self.time_of_last_update = time.time()
self.set_generation_id(generation_id) self.set_generation_id(generation_id)
self.set_hello_hold_time(hello_hold_time) self.set_hello_hold_time(hello_hold_time)
if state_refresh_capable != self.state_refresh_capable:
self.state_refresh_capable = state_refresh_capable
def subscribe_nlt_expiration(self, tree_if): def subscribe_nlt_expiration(self, tree_if):
with self.tree_interface_nlt_subscribers_lock.genWlock(): with self.tree_interface_nlt_subscribers_lock:
if tree_if not in self.tree_interface_nlt_subscribers: if tree_if not in self.tree_interface_nlt_subscribers:
self.tree_interface_nlt_subscribers.append(tree_if) self.tree_interface_nlt_subscribers.append(tree_if)
def unsubscribe_nlt_expiration(self, tree_if): def unsubscribe_nlt_expiration(self, tree_if):
with self.tree_interface_nlt_subscribers_lock.genWlock(): with self.tree_interface_nlt_subscribers_lock:
if tree_if in self.tree_interface_nlt_subscribers: if tree_if in self.tree_interface_nlt_subscribers:
self.tree_interface_nlt_subscribers.remove(tree_if) self.tree_interface_nlt_subscribers.remove(tree_if)
...@@ -47,6 +47,9 @@ class PacketIGMPHeader(PacketPayload): ...@@ -47,6 +47,9 @@ class PacketIGMPHeader(PacketPayload):
self.max_resp_time = max_resp_time self.max_resp_time = max_resp_time
self.group_address = group_address self.group_address = group_address
def get_igmp_type(self):
return self.type
def bytes(self) -> bytes: def bytes(self) -> bytes:
# obter mensagem e criar checksum # obter mensagem e criar checksum
msg_without_chcksum = struct.pack(PacketIGMPHeader.IGMP_HDR, self.type, self.max_resp_time, 0, msg_without_chcksum = struct.pack(PacketIGMPHeader.IGMP_HDR, self.type, self.max_resp_time, 0,
......
...@@ -73,20 +73,4 @@ class PacketPimHeader(PacketPayload): ...@@ -73,20 +73,4 @@ class PacketPimHeader(PacketPayload):
pim_payload = data[PacketPimHeader.PIM_HDR_LEN:] pim_payload = data[PacketPimHeader.PIM_HDR_LEN:]
pim_payload = PacketPimHeader.PIM_MSG_TYPES[pim_type].parse_bytes(pim_payload) pim_payload = PacketPimHeader.PIM_MSG_TYPES[pim_type].parse_bytes(pim_payload)
'''
if pim_type == 0: # hello
pim_payload = PacketPimHello.parse_bytes(pim_payload)
elif pim_type == 3: # join/prune
pim_payload = PacketPimJoinPrune.parse_bytes(pim_payload)
print("hold_time = ", pim_payload.hold_time)
print("upstream_neighbor = ", pim_payload.upstream_neighbor_address)
for i in pim_payload.groups:
print(i.multicast_group)
print(i.joined_src_addresses)
print(i.pruned_src_addresses)
elif pim_type == 5: # assert
pim_payload = PacketPimAssert.parse_bytes(pim_payload)
else:
raise Exception
'''
return PacketPimHeader(pim_payload) return PacketPimHeader(pim_payload)
import random
from threading import Timer
from Packet.Packet import Packet
from Packet.ReceivedPacket import ReceivedPacket
from Packet.PacketPimHello import PacketPimHello
from Packet.PacketPimHeader import PacketPimHeader
from Packet.PacketPimStateRefresh import PacketPimStateRefresh
from Interface import Interface
import Main
class StateRefresh:
TYPE = 9
def __init__(self):
Main.add_protocol(StateRefresh.TYPE, self)
# receive handler
def receive_handle(self, packet: ReceivedPacket):
#check if interface supports state refresh
if not packet.interface._state_refresh_capable:
return
ip = packet.ip_header.ip_src
print("ip = ", ip)
pkt_state_refresh = packet.payload.payload # type: PacketPimStateRefresh
# TODO
interface_index = packet.interface.vif_index
source = pkt_state_refresh.source_address
group = pkt_state_refresh.multicast_group_adress
source_group = (source, group)
try:
Main.kernel.get_routing_entry(source_group).recv_state_refresh_msg(interface_index, packet)
except:
try:
# import time
# time.sleep(2)
Main.kernel.get_routing_entry(source_group).recv_state_refresh_msg(interface_index, packet)
except:
pass
...@@ -195,9 +195,11 @@ class KernelEntry: ...@@ -195,9 +195,11 @@ class KernelEntry:
self.interface_state[self.inbound_interface_index].change_rpf(self._was_olist_null) self.interface_state[self.inbound_interface_index].change_rpf(self._was_olist_null)
def nbr_event(self, link, node, event): # check if add/removal of neighbors from interface afects olist and forward/prune state of interface
# todo pode ser interessante verificar se a adicao/remocao de vizinhos se altera o olist def change_at_number_of_neighbors(self):
return with self.CHANGE_STATE_LOCK:
self.change()
self.evaluate_olist_change()
def is_olist_null(self): def is_olist_null(self):
for interface in self.interface_state.values(): for interface in self.interface_state.values():
......
...@@ -120,12 +120,13 @@ class TreeInterfaceDownstream(TreeInterface): ...@@ -120,12 +120,13 @@ class TreeInterfaceDownstream(TreeInterface):
return return
interval = state_refresh_msg_received.interval interval = state_refresh_msg_received.interval
self._assert_state.sendStateRefresh(self, interval)
self._prune_state.send_state_refresh(self)
if self.lost_assert(): if self.lost_assert():
return return
self._assert_state.sendStateRefresh(self, interval)
self._prune_state.send_state_refresh(self)
prune_indicator_bit = 0 prune_indicator_bit = 0
if self.is_pruned(): if self.is_pruned():
prune_indicator_bit = 1 prune_indicator_bit = 1
...@@ -164,7 +165,7 @@ class TreeInterfaceDownstream(TreeInterface): ...@@ -164,7 +165,7 @@ class TreeInterfaceDownstream(TreeInterface):
# Override # Override
def is_forwarding(self): def is_forwarding(self):
return ((len(self.get_interface().neighbors) >= 1 and not self.is_pruned()) or self.igmp_has_members()) and not self.lost_assert() return ((self.has_neighbors() and not self.is_pruned()) or self.igmp_has_members()) and not self.lost_assert()
#return self._assert_state == AssertState.Winner and self.is_in_group() #return self._assert_state == AssertState.Winner and self.is_in_group()
def is_pruned(self): def is_pruned(self):
......
...@@ -193,7 +193,6 @@ class TreeInterfaceUpstream(TreeInterface): ...@@ -193,7 +193,6 @@ class TreeInterfaceUpstream(TreeInterface):
#################################### ####################################
def create_state_refresh_msg(self): def create_state_refresh_msg(self):
self._prune_now_counter+=1 self._prune_now_counter+=1
self._prune_now_counter%=3
(source_ip, group_ip) = self.get_tree_id() (source_ip, group_ip) = self.get_tree_id()
ph = PacketPimStateRefresh(multicast_group_adress=group_ip, ph = PacketPimStateRefresh(multicast_group_adress=group_ip,
source_address=source_ip, source_address=source_ip,
...@@ -201,9 +200,11 @@ class TreeInterfaceUpstream(TreeInterface): ...@@ -201,9 +200,11 @@ class TreeInterfaceUpstream(TreeInterface):
metric_preference=0, metric=0, mask_len=0, metric_preference=0, metric=0, mask_len=0,
ttl=256, ttl=256,
prune_indicator_flag=0, prune_indicator_flag=0,
prune_now_flag=(self._prune_now_counter+1)//3, prune_now_flag=self._prune_now_counter//3,
assert_override_flag=0, assert_override_flag=0,
interval=60) interval=60)
self._prune_now_counter %= 3
self._kernel_entry.forward_state_refresh_msg(ph) self._kernel_entry.forward_state_refresh_msg(ph)
########################################### ###########################################
......
...@@ -30,19 +30,8 @@ from .globals import * ...@@ -30,19 +30,8 @@ from .globals import *
class TreeInterface(metaclass=ABCMeta): class TreeInterface(metaclass=ABCMeta):
def __init__(self, kernel_entry, interface_id): def __init__(self, kernel_entry, interface_id):
'''
@type interface: SFMRInterface
@type node: Node
'''
#assert isinstance(interface, SFMRInterface)
self._kernel_entry = kernel_entry self._kernel_entry = kernel_entry
self._interface_id = interface_id self._interface_id = interface_id
#self._interface = interface
#self._node = node
#self._tree_id = tree_id
#self._cost = cost
#self._evaluate_ig = evaluate_ig_cb
# Local Membership State # Local Membership State
try: try:
...@@ -53,7 +42,6 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -53,7 +42,6 @@ class TreeInterface(metaclass=ABCMeta):
igmp_has_members = group_state.add_multicast_routing_entry(self) igmp_has_members = group_state.add_multicast_routing_entry(self)
self._local_membership_state = LocalMembership.Include if igmp_has_members else LocalMembership.NoInfo self._local_membership_state = LocalMembership.Include if igmp_has_members else LocalMembership.NoInfo
except: except:
#traceback.print_exc()
self._local_membership_state = LocalMembership.NoInfo self._local_membership_state = LocalMembership.NoInfo
...@@ -86,24 +74,17 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -86,24 +74,17 @@ class TreeInterface(metaclass=ABCMeta):
self.evaluate_ingroup() self.evaluate_ingroup()
def set_assert_winner_metric(self, new_assert_metric: AssertMetric): def set_assert_winner_metric(self, new_assert_metric: AssertMetric):
import ipaddress
with self.get_state_lock(): with self.get_state_lock():
try: try:
old_neighbor = self.get_interface().get_neighbor(str(self._assert_winner_metric.ip_address)) old_neighbor = self.get_interface().get_neighbor(self._assert_winner_metric.get_ip())
new_neighbor = self.get_interface().get_neighbor(str(new_assert_metric.ip_address)) new_neighbor = self.get_interface().get_neighbor(new_assert_metric.get_ip())
if old_neighbor is not None: if old_neighbor is not None:
old_neighbor.unsubscribe_nlt_expiration(self) old_neighbor.unsubscribe_nlt_expiration(self)
if new_neighbor is not None: if new_neighbor is not None:
new_neighbor.subscribe_nlt_expiration(self) new_neighbor.subscribe_nlt_expiration(self)
''' except:
if new_assert_metric.ip_address == ipaddress.ip_address("0.0.0.0") or new_assert_metric.ip_address is None: traceback.print_exc()
if old_neighbor is not None:
old_neighbor.unsubscribe_nlt_expiration(self)
else:
old_neighbor.unsubscribe_nlt_expiration(self)
new_neighbor.subscribe_nlt_expiration(self)
'''
finally: finally:
self._assert_winner_metric = new_assert_metric self._assert_winner_metric = new_assert_metric
...@@ -340,6 +321,12 @@ class TreeInterface(metaclass=ABCMeta): ...@@ -340,6 +321,12 @@ class TreeInterface(metaclass=ABCMeta):
ip = self.get_interface().get_ip() ip = self.get_interface().get_ip()
return ip return ip
def has_neighbors(self):
try:
return len(self.get_interface().neighbors) > 0
except:
return False
def get_tree_id(self): def get_tree_id(self):
return (self._kernel_entry.source_ip, self._kernel_entry.group_ip) return (self._kernel_entry.source_ip, self._kernel_entry.group_ip)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment