diff --git a/confluent_server/confluent/discovery/protocols/pxe.py b/confluent_server/confluent/discovery/protocols/pxe.py index a1c69b27..49e78d74 100644 --- a/confluent_server/confluent/discovery/protocols/pxe.py +++ b/confluent_server/confluent/discovery/protocols/pxe.py @@ -30,6 +30,7 @@ import ctypes.util import eventlet import eventlet.green.socket as socket import eventlet.green.select as select +import netifaces import struct libc = ctypes.CDLL(ctypes.util.find_library('c')) @@ -60,7 +61,7 @@ class sockaddr_ll(ctypes.Structure): ('sll_hatype', ctypes.c_ushort), ('sll_pkttype', ctypes.c_ubyte), ('sll_halen', ctypes.c_ubyte), - ('sll_addr', ctypes.c_ubyte * 8)] + ('sll_addr', ctypes.c_ubyte * 20)] class iovec(ctypes.Structure): # from uio.h _fields_ = [('iov_base', ctypes.c_void_p), @@ -106,6 +107,22 @@ recvmsg.restype = ctypes.c_size_t pkttype = ctypes.c_char * 2048 +_idxtoname = libc.if_indextoname +_idxtoname.argtypes = [ctypes.c_uint, ctypes.c_char_p] + +def idxtoname(idx): + name = (ctypes.c_char * 16)() + _idxtoname(idx, name) + return name.value.strip() + +_idxtobcast = {} +def get_bcastaddr(idx): + if idx not in _idxtobcast: + bc = netifaces.ifaddresses(idxtoname(idx))[17][0]['broadcast'] + bc = bytearray([int(x, 16) for x in bc.split(':')]) + _idxtobcast[idx] = bc + return _idxtobcast[idx] + IP_PKTINFO = 8 @@ -238,7 +255,8 @@ def proxydhcp(): node = uuidmap[disco['uuid']] if not node: continue - myipn = myipbypeer.get(rqv[28:44].tobytes(), None) + hwlen = rq[2] + myipn = myipbypeer.get(rqv[28:28+hwlen].tobytes(), None) if not myipn: continue if opts.get(77, None) == b'iPXE': @@ -435,7 +453,6 @@ def check_reply(node, info, packet, sock, cfg, reqview): repview[0:1] = b'\x02' repview[1:10] = reqview[1:10] # duplicate txid, hwlen, and others repview[10:11] = b'\x80' # always set broadcast - hwaddr = reqview[28:44].tobytes() repview[28:44] = reqview[28:44] # copy chaddr field if httpboot: proto = 'https' if insecuremode == 'never' else 'http' @@ -480,7 +497,14 @@ def check_reply(node, info, packet, sock, cfg, reqview): else: repview[replen - 1:replen + 10] = b'\x3c\x09PXEClient' replen += 11 - myipbypeer[repview[28:44].tobytes()] = myipn + hwlen = bytearray(reqview[2:3].tobytes())[0] + fulladdr = repview[28:28+hwlen].tobytes() + myipbypeer[fulladdr] = myipn + if hwlen == 8: # omnipath may present a mangled proxydhcp request later + shortaddr = bytearray(6) + shortaddr[0] = 2 + shortaddr[1:] = fulladdr[3:] + myipbypeer[bytes(shortaddr)] = myipn if netmask: repview[replen - 1:replen + 1] = b'\x01\x04' repview[replen + 1:replen + 5] = netmask @@ -501,26 +525,23 @@ def check_reply(node, info, packet, sock, cfg, reqview): datasum = ~datasum & 0xffff repview[26:28] = struct.pack('!H', datasum) if clipn: - staticassigns[hwaddr] = (clipn, repview[:replen + 28].tobytes()) - elif hwaddr in staticassigns: - del staticassigns[hwaddr] + staticassigns[fulladdr] = (clipn, repview[:replen + 28].tobytes()) + elif fulladdr in staticassigns: + del staticassigns[fulladdr] send_raw_packet(repview, replen + 28, reqview, info) def send_raw_packet(repview, replen, reqview, info): + ifidx = info['netinfo']['ifidx'] tsock = socket.socket(socket.AF_PACKET, socket.SOCK_DGRAM, socket.htons(0x800)) targ = sockaddr_ll() - bcastaddr = bytearray(8) - hwlen = reqview[2] - if not isinstance(hwlen, int): - # python 2 is different than python 3... - hwlen = bytearray(hwlen)[0] - bcastaddr[:hwlen] = b'\xff' * hwlen - targ.sll_addr = (ctypes.c_ubyte * 8).from_buffer(bcastaddr) + bcastaddr = get_bcastaddr(ifidx) + hwlen = len(bcastaddr) + targ.sll_addr = (ctypes.c_ubyte * 20).from_buffer(bcastaddr) targ.sll_family = socket.AF_PACKET targ.sll_halen = hwlen targ.sll_protocol = socket.htons(0x800) - targ.sll_ifindex = info['netinfo']['ifidx'] + targ.sll_ifindex = ifidx try: pkt = ctypes.byref((ctypes.c_char * (replen)).from_buffer(repview)) except TypeError: @@ -531,7 +552,8 @@ def send_raw_packet(repview, replen, reqview, info): ctypes.sizeof(targ)) def ack_request(pkt, rq, info): - hwaddr = rq[28:44].tobytes() + hwlen = bytearray(rq[2:3].tobytes())[0] + hwaddr = rq[28:28+hwlen].tobytes() myipn = myipbypeer.get(hwaddr, None) if not myipn or pkt.get(54, None) != myipn: return @@ -557,7 +579,7 @@ def consider_discover(info, packet, sock, cfg, reqview): check_reply(macmap[info['hwaddr']], info, packet, sock, cfg, reqview) elif info.get('uuid', None) in uuidmap: check_reply(uuidmap[info['uuid']], info, packet, sock, cfg, reqview) - elif packet[53] == b'\x03': + elif packet.get(53, None) == b'\x03': ack_request(packet, reqview, info)