#!/usr/bin/python
try:
    import http.client as client
except ImportError:
    import httplib as client
import ctypes
import ctypes.util
import glob
import os
import select
import socket
import subprocess
import ssl
import sys
import struct
import time

class InvalidApiKey(Exception):
    pass

def msg_align(len):
    return (len + 3) & ~3

try:
    _protoparm = ssl.PROTOCOL_TLS_CLIENT
except AttributeError:    
    _protoparm = ssl.PROTOCOL_SSLv23
cryptname = ctypes.util.find_library('crypt')
if not cryptname:
    if os.path.exists('/lib64/libcrypt.so.2') or os.path.exists('/usr/lib64/libcrypt.so.2'):
        cryptname = 'libcrypt.so.2'
    else:
        cryptname = 'libcrypt.so.1'
c_libcrypt = ctypes.CDLL(cryptname)
c_crypt = c_libcrypt.crypt
c_crypt.argtypes = (ctypes.c_char_p, ctypes.c_char_p)
c_crypt.restype = ctypes.c_char_p


def get_my_addresses():
    nlhdrsz = struct.calcsize('IHHII')
    ifaddrsz = struct.calcsize('BBBBI')
    # RTM_GETADDR = 22
    # nlmsghdr struct: u32 len, u16 type, u16 flags, u32 seq, u32 pid
    nlhdr = struct.pack('IHHII', nlhdrsz + ifaddrsz, 22, 0x301, 0, 0)
    # ifaddrmsg struct: u8 family, u8 prefixlen, u8 flags, u8 scope, u32 index
    ifaddrmsg = struct.pack('BBBBI', 0, 0, 0, 0, 0)
    s = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, socket.NETLINK_ROUTE)
    s.bind((0, 0))
    s.sendall(nlhdr + ifaddrmsg)
    addrs = []
    while True:
        pdata = s.recv(65536)
        v = memoryview(pdata)
        if struct.unpack('H', v[4:6])[0] == 3:  # netlink done message
            break
        while len(v):
            length, typ = struct.unpack('IH', v[:6])
            if typ == 20:
                fam, plen, _, scope, ridx = struct.unpack('BBBBI', v[nlhdrsz:nlhdrsz+ifaddrsz])
                if scope in (253, 0):
                    rta = v[nlhdrsz+ifaddrsz:length]
                    while len(rta):
                        rtalen, rtatyp = struct.unpack('HH', rta[:4])
                        if rtalen < 4:
                            break
                        if rtatyp == 1:
                            addrs.append((fam, rta[4:rtalen], plen, ridx))
                        rta = rta[msg_align(rtalen):]
            v = v[msg_align(length):]
    return addrs


def scan_confluents():
    srvs = {}
    s6 = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
    s6.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
    s6.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
    s6.bind(('::', 1900))
    s4 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    s4.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
    s4.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
    s4.bind(('0.0.0.0', 1900))
    doneidxs = set([])
    msg = 'M-SEARCH * HTTP/1.1\r\nST: urn:xcat.org:service:confluent:'
    with open('/etc/confluent/confluent.deploycfg') as dcfg:
        for line in dcfg.read().split('\n'):
            if line.startswith('confluent_uuid:'):
                confluentuuid = line.split(': ')[1]
                msg += '/confluentuuid=' + confluentuuid
                break
    with open('/sys/devices/virtual/dmi/id/product_uuid') as uuidin:
        msg += '/uuid=' + uuidin.read().strip()
    for addrf in glob.glob('/sys/class/net/*/address'):
        with open(addrf) as addrin:
            hwaddr = addrin.read().strip()
            msg += '/mac=' + hwaddr
    msg = msg.encode('utf8')
    for addr in get_my_addresses():
        if addr[0] == socket.AF_INET6:
            if addr[-1] in doneidxs:
                continue
            try:
                s6.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, addr[-1])
            except TypeError:
                s6.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, addr[-1].tobytes())
            try:
                s6.sendto(msg, ('ff02::c', 1900))
            except OSError:
                pass
            doneidxs.add(addr[-1])
        elif addr[0] == socket.AF_INET:
            try:
                s4.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF, addr[1])
            except TypeError:
                s4.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF, addr[1].tobytes())
            try:
                s4.sendto(msg, ('239.255.255.250', 1900))
            except OSError:
                pass
    r = select.select((s4, s6), (), (), 4)
    srvlist = []
    if r:
        r = r[0]
    while r:
        for s in r:
            (rsp, peer) = s.recvfrom(9000)
            rsp = rsp.split(b'\r\n')
            current = None
            for line in rsp:
                if line.startswith(b'NODENAME: '):
                    current = {}
                elif line.startswith(b'DEFAULTNET: 1'):
                    current['isdefault'] = True
                elif line.startswith(b'MGTIFACE: '):
                    current['mgtiface'] = line.replace(b'MGTIFACE: ', b'').strip().decode('utf8')
            if len(peer) > 2:
                current['myidx'] = peer[-1]
            currip = peer[0]
            if currip.startswith('fe80::') and '%' not in currip:
                currip = '{0}%{1}'.format(currip, peer[-1])
            srvs[currip] = current
            srvlist.append(currip)
        r = select.select((s4, s6), (), (), 2)
        if r:
            r = r[0]
    return srvlist, srvs


def get_net_apikey(nodename, mgr):
    alpha = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789./'
    newpass = ''.join([alpha[x >> 2] for x in bytearray(os.urandom(32))])
    salt = '$5$' + ''.join([alpha[x >> 2] for x in bytearray(os.urandom(8))])
    newpass = newpass.encode('utf8')
    salt = salt.encode('utf8')
    crypted = c_crypt(newpass, salt)
    for addrinfo in socket.getaddrinfo(mgr, 13001, 0, socket.SOCK_STREAM):
        try:
            clisock = socket.socket(addrinfo[0], addrinfo[1])
            clisock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            if addrinfo[0] == socket.AF_INET:
                cliaddr = ('0.0.0.0', 302)
            else:
                cliaddr = ('::', 302)
            clisock.bind(cliaddr)
            clisock.settimeout(1)
            clisock.connect(addrinfo[-1])
            rsp = clisock.recv(8)
            if rsp != b'\xc2\xd1-\xa8\x80\xd8j\xba':
                raise Exception('Unrecognized credential banner')
            hellostr = bytearray([1, len(nodename)]) + bytearray(nodename.encode('utf8')) + bytearray(b'\x00\x00')
            clisock.send(hellostr)
            rsp = bytearray(clisock.recv(2))
            if not rsp:
                continue
            if rsp[0] == 128:
                continue
            if rsp[0] == 2:
                echotoken = clisock.recv(rsp[1])
                clisock.recv(2)  # drain \x00\x00
                clisock.send(bytes(bytearray([3, rsp[1]])))
                clisock.send(echotoken)
                clisock.send(bytes(bytearray([4, len(crypted)])))
                clisock.send(crypted)
                clisock.send(b'\x00\x00')
            rsp = bytearray(clisock.recv(2))
            if rsp[0] == 5:
                return newpass.decode('utf8')
        finally:
            clisock.close()
    return ''


def get_apikey(nodename, hosts, errout=None):
    apikey = ""
    if os.path.exists('/etc/confluent/confluent.apikey'):
        apikey = open('/etc/confluent/confluent.apikey').read().strip()
        if apikey:
            return apikey
    while not apikey:
        for host in hosts:
            try:
                apikey = get_net_apikey(nodename, host)
            except OSError:
                apikey = None
            if apikey:
                break
        else:
            srvlist, _ = scan_confluents()
            for host in srvlist:
                try:
                    apikey = get_net_apikey(nodename, host)
                except OSError:
                    apikey = None
                if apikey:
                    break
        if not apikey:
            errmsg = "Failed getting API token, check deployment.apiarmed attribute on {}\n".format(nodename)
            sys.stderr.write(errmsg)
            if errout:
                errout.write(errmsg)
            time.sleep(10)
    with open('/etc/confluent/confluent.apikey', 'w+') as apiout:
        apiout.write(apikey)
    apikey = apikey.strip()
    os.chmod('/etc/confluent/confluent.apikey', 0o600)
    return apikey

class HTTPSClient(client.HTTPConnection, object):
    def __init__(self, usejson=False, port=443, host=None, errout=None, phmac=None, checkonly=False):
        self.phmac = phmac
        self.errout = None
        if errout:
            self.errout = open(errout, 'w')
            self.errout.flush()
        self.stdheaders = {}
        mgtiface = None
        if usejson:
            self.stdheaders['ACCEPT'] = 'application/json'
        if host:
            self.hosts = [host]
            with open('/etc/confluent/confluent.info') as cinfo:
                info = cinfo.read().split('\n')
                for line in info:
                    if line.startswith('NODENAME:'):
                        node = line.split(' ')[1]
                        self.stdheaders['CONFLUENT_NODENAME'] = node
        else:
            self.hosts = []
            info = open('/etc/confluent/confluent.info').read().split('\n')
            havedefault = '0'
            plainhost = ''
            for line in info:
                host = ''
                if line.startswith('NODENAME:'):
                    node = line.split(' ')[1]
                    self.stdheaders['CONFLUENT_NODENAME'] = node
                if line.startswith('MANAGER:') and not host:
                    host = line.split(' ')[1]
                    self.hosts.append(host)
                    if not plainhost:
                        plainhost = host
                if line.startswith('EXTMGRINFO:'):
                    extinfo = line.split(' ')[1]
                    extinfo = extinfo.split('|')
                    if havedefault == '0' and extinfo[2] == '1':
                        host, mgtiface, havedefault = extinfo[:3]
                        if '%' in host:
                            ifidx = host.split('%', 1)[1]
                            with open('/tmp/confluent.ifidx', 'w+') as ifout:
                                ifout.write(ifidx)    
                    if not mgtiface:
                        host, mgtiface, havedefault = extinfo[:3]
                    if host:
                        if havedefault == '0':
                            if '%' in host:
                                ifidx = host.split('%', 1)[1]
                                with open('/tmp/confluent.ifidx', 'w+') as ifout:
                                    ifout.write(ifidx)
                        self.hosts.append(host)
            try:
                info = open('/etc/confluent/confluent.deploycfg').read().split('\n')
            except Exception:
                info = None
            if info:
                for line in info:
                    if line.startswith('deploy_server: ') or line.startswith('deploy_server_v6: '):
                        self.hosts.append(line.split(': ', 1)[1])
            if plainhost and not self.hosts:
                self.hosts.append(plainhost)
        if self.phmac:
            with open(phmac, 'r') as hmacin:
                self.stdheaders['CONFLUENT_CRYPTHMAC'] = hmacin.read()
        elif not checkonly:
            self.stdheaders['CONFLUENT_APIKEY'] = get_apikey(node, self.hosts, errout=self.errout)
        if mgtiface:
            self.stdheaders['CONFLUENT_MGTIFACE'] = mgtiface
        self.port = port
        self.host = None
        self.node = node
        host = self.check_connections()
        client.HTTPConnection.__init__(self, host, port)
        self.connect()


    def set_header(self, key, val):
        self.stdheaders[key] = val

    def check_connections(self):
        foundsrv = None
        hosts = self.hosts
        ctx = ssl.SSLContext(_protoparm)
        ctx.load_verify_locations('/etc/confluent/ca.pem')
        ctx.verify_mode = ssl.CERT_REQUIRED
        ctx.check_hostname = True
        for timeo in (0.1, 5):
            for host in hosts:
                try:
                    addrinf = socket.getaddrinfo(host, self.port)[0]
                    psock = socket.socket(addrinf[0])
                    psock.settimeout(timeo)
                    psock.connect(addrinf[4])
                    chost = host.split('%', 1)[0]
                    ctx.wrap_socket(psock, server_hostname=chost)
                    foundsrv = host
                    psock.close()
                    break
                except (OSError, socket.error):
                    continue
                except ssl.SSLError:
                    continue
                except ssl.CertificateError:
                    continue
            else:
                continue
            break
        if not foundsrv:
            srvlist, srvs = scan_confluents()
            hosts = []
            for srv in srvlist:
                if srvs[srv].get('isdefault', False):
                    hosts = [srv] + hosts
                else:
                    hosts = hosts + [srv]
            for host in hosts:
                try:
                    addrinf = socket.getaddrinfo(host, self.port)[0]
                    psock = socket.socket(addrinf[0])
                    psock.settimeout(timeo)
                    psock.connect(addrinf[4])
                    foundsrv = host
                    psock.close()
                    break
                except OSError:
                    continue
            else:
                raise Exception('Unable to reach any hosts')
        return foundsrv

    def connect(self, tries=3):
        addrinf = socket.getaddrinfo(self.host, self.port)[0]
        psock = socket.socket(addrinf[0])
        psock.settimeout(15)
        psock.connect(addrinf[4])
        ctx = ssl.SSLContext(_protoparm)
        ctx.load_verify_locations('/etc/confluent/ca.pem')
        host = self.host.split('%', 1)[0]
        if '[' not in host and ':' in host:
            self.stdheaders['Host'] = '[{0}]'.format(host)
        else:
            self.stdheaders['Host'] = '{0}'.format(host)
        ctx.verify_mode = ssl.CERT_REQUIRED
        ctx.check_hostname = True
        try:
            self.sock = ctx.wrap_socket(psock, server_hostname=host)
        except ssl.SSLError:
            errmsg = 'Error validating certificate on deployer (try `osdeploy initialize -t` on the deployment server {0})\n'.format(host)
            sys.stderr.write(errmsg)
            if self.errout:
                self.errout.write(errmsg)
                self.errout.flush()
            sys.exit(1)
        except socket.timeout:
            if not tries:
                raise
            return self.connect(tries=tries-1)

    def grab_url(self, url, data=None, returnrsp=False):
        return self.grab_url_with_status(url, data, returnrsp)[1]

    def grab_url_with_status(self, url, data=None, returnrsp=False):
        if data:
            method = 'POST'
        else:
            method = 'GET'
        authed = False
        while not authed:
            authed = True
            self.request(method, url, data, headers=self.stdheaders)
            rsp = self.getresponse()
            if rsp.status >= 200 and rsp.status < 300:
                if returnrsp:
                    return rsp.status, rsp
                else:
                    return rsp.status, rsp.read()
            if rsp.status == 401:
                authed = False
                rsp.read()
                with open('/etc/confluent/confluent.apikey', 'w+') as akfile:
                    akfile.write('')
                self.stdheaders['CONFLUENT_APIKEY'] = get_apikey(
                    self.node, [self.host], errout=self.errout)
        raise Exception(rsp.read())

if __name__ == '__main__':
    data = None
    usejson = False
    if '-j' in sys.argv:
        usejson = True
    if len(sys.argv) == 1:
        HTTPSClient()
        sys.exit(0)
    try:
        outbin = sys.argv.index('-o')
        sys.argv.pop(outbin)
        outbin = sys.argv.pop(outbin)
    except ValueError:
        outbin = None
    try:
        waitfor = sys.argv.index('-w')
        sys.argv.pop(waitfor)
        waitfor = int(sys.argv.pop(waitfor))
    except ValueError:
        waitfor = None
    try:
        data = sys.argv.index('-d')
        sys.argv.pop(data)
        data = sys.argv.pop(data)
    except ValueError:
        data = None
    try:
        errout = sys.argv.index('-e')
        sys.argv.pop(errout)
        errout = sys.argv.pop(errout)
    except ValueError:
        errout = None
    try:
        phmac = sys.argv.index('-p')
        sys.argv.pop(phmac)
        phmac = sys.argv.pop(phmac)
    except ValueError:
        phmac = None
    try:
        checkonly = False
        idxit = sys.argv.index('-c')
        sys.argv.pop(idxit)
        checkonly = True
    except Exception:
        pass
    if len(sys.argv) > 2 and os.path.exists(sys.argv[-1]):
        data = open(sys.argv[-1]).read()
    if outbin:
        with open(outbin, 'ab+') as outf:
            reader = HTTPSClient(usejson=usejson, errout=errout).grab_url(
                sys.argv[1], data, returnrsp=True)
            chunk = reader.read(16384)
            while chunk:
                outf.write(chunk)
                chunk = reader.read(16384)
        sys.exit(0)
    client = HTTPSClient(usejson, errout=errout, phmac=phmac, checkonly=checkonly)
    if waitfor:
        status = 201
        while status != waitfor:
            status, rsp = client.grab_url_with_status(sys.argv[1], data)
            sys.stdout.write(rsp.decode())
    elif checkonly:
        sys.stdout.write(client.check_connections())
    else:
        sys.stdout.write(client.grab_url(sys.argv[1], data).decode())
