#!/usr/bin/python

import glob
import json
import os
import socket
import sys
import time
import shlex
import subprocess
try:
    from importlib.machinery import SourceFileLoader
    def load_source(mod, path):
        return SourceFileLoader(mod, path).load_module()
except ImportError:
    from imp import load_source

try:
    apiclient = load_source('apiclient', '/opt/confluent/bin/apiclient')
except IOError:
    apiclient = load_source('apiclient', '/etc/confluent/apiclient')

def add_lla(iface, mac):
    pieces = mac.split(':')
    initbyte = int(pieces[0], 16) ^ 2
    lla = 'fe80::{0:x}{1}:{2}ff:fe{3}:{4}{5}/64'.format(initbyte, pieces[1], pieces[2], pieces[3], pieces[4], pieces[5])
    try:
        subprocess.check_call(['ip', 'addr', 'add', 'dev', iface, lla, 'scope', 'link'])
    except Exception:
        return None
    return lla

#cli = apiclient.HTTPSClient(json=True)
#c = cli.grab_url_with_status('/confluent-api/self/netcfg')
def add_missing_llas():
    #NetworkManager goes out of its way to suppress ipv6 lla, so will just add some
    added = {}
    linkinfo = subprocess.check_output(['ip', '-br', 'l']).decode('utf8')
    ifaces = {}
    for line in linkinfo.split('\n'):
        line = line.strip().split()
        if not line or 'LOOPBACK' in line[-1] or 'NO-CARRIER' in line[-1]:
            continue
        if 'UP' not in line[-1]:
            subprocess.call(['ip', 'link', 'set', line[0], 'up'])
        ifaces[line[0]] = line[2]
    ips = {}
    ipinfo = subprocess.check_output(['ip', '-br', '-6', 'a']).decode('utf8')
    for line in ipinfo.split('\n'):
        line = line.strip().split(None, 2)
        if not line:
            continue
        ips[line[0]] = line[2]
    for iface in ifaces:
        for addr in ips.get(iface, '').split():
            if addr.startswith('fe80::'):
                break
        else:
            newlla = add_lla(iface, ifaces[iface])
            if newlla:
                added[iface] = newlla
    return added

def rm_tmp_llas(tmpllas):
    for iface in tmpllas:
        subprocess.check_call(['ip', 'addr', 'del', 'dev', iface, tmpllas[iface]])

def await_tentative():
    maxwait = 10
    while b'tentative' in subprocess.check_output(['ip', 'a']):
        if maxwait == 0:
            break
        maxwait -= 1
        time.sleep(1)

def map_idx_to_name():
    map = {}
    devtype = {}
    prevdev = None
    for line in subprocess.check_output(['ip', 'l']).decode('utf8').splitlines():
        if line.startswith(' ') and 'link/' in line:
            typ = line.split()[0].split('/')[1]
            devtype[prevdev] = typ if type != 'ether' else 'ethernet'
        if line.startswith(' '):
            continue
        idx, iface, rst = line.split(':', 2)
        prevdev = iface.strip()
        rst = rst.split()
        try:
            midx = rst.index('master')
            continue
        except ValueError:
            pass
        idx = int(idx)
        iface = iface.strip()
        map[idx] = iface
    return map, devtype


def get_interface_name(iname, settings):
    explicitname = settings.get('interface_names', None)
    if explicitname:
        return explicitname
    if settings.get('current_nic', False):
        return iname
    return None

class WickedManager(object):
    def __init__(self):
        self.teamidx = 0
        self.read_connections()

    def read_connections(self):
        self.cfgbydev = {}
        for ifcfg in glob.glob('/etc/sysconfig/network/ifcfg-*'):
            devname = ifcfg.replace('/etc/sysconfig/network/ifcfg-', '')
            if devname == 'lo':
                continue
            currcfg = {}
            self.cfgbydev[devname] = currcfg
            for cfg in open(ifcfg).read().splitlines():
                cfg = cfg.split('#', 1)[0]
                kv = ' '.join(shlex.split(cfg)).split('=', 1)
                if len(kv) != 2:
                    continue
                k, v = kv
                k = k.strip()
                v = v.strip()
                currcfg[k] = v

    def apply_configuration(self, cfg):
        stgs = cfg['settings']
        ipcfg = 'STARTMODE=auto\n'
        routecfg = ''
        bootproto4 = stgs.get('ipv4_method', 'none')
        bootproto6 = stgs.get('ipv6_method', 'none')
        if bootproto4 == 'dhcp' and bootproto6 == 'dhcp':
            ipcfg += 'BOOTPROTO=dhcp\n'
        elif bootproto4 == 'dhcp':
            ipcfg += 'BOOTPROTO=dhcp4\n'
        elif bootproto6 == 'dhcp':
            ipcfg += 'BOOTPROTO=dhcp6\n'
        else:
            ipcfg += 'BOOTPROTO=static\n'
        if stgs.get('ipv4_address', None):
            ipcfg += 'IPADDR=' + stgs['ipv4_address'] + '\n'
        v4gw = stgs.get('ipv4_gateway', None)
        if stgs.get('ipv6_address', None):
            ipcfg += 'IPADDR_V6=' + stgs['ipv6_address'] + '\n'
        v6gw = stgs.get('ipv6_gateway', None)
        cname = None
        if len(cfg['interfaces']) > 1:  # creating new team
            if not stgs.get('team_mode', None):
                sys.stderr.write("Warning, multiple interfaces ({0}) without a team_mode, skipping setup\n".format(','.join(cfg['interfaces'])))
                return
            if not stgs.get('connection_name', None):
                stgs['connection_name'] = 'team{0}'.format(self.teamidx)
                self.teamidx += 1
            cname = stgs['connection_name']
            with open('/etc/sysconfig/network/ifcfg-{0}'.format(cname), 'w') as teamout:
                teamout.write(ipcfg)
                teamout.write('TEAM_RUNNER={0}\n'.format(stgs['team_mode']))
                idx = 1
                for iface in cfg['interfaces']:
                    subprocess.call(['wicked', 'ifdown', iface])
                    try:
                        os.remove('/etc/sysconfig/network/ifcfg-{0}'.format(iface))
                        os.remove('/etc/sysconfig/network/ifroute-{0}'.format(iface))
                    except OSError:
                        pass
                    teamout.write('TEAM_PORT_DEVICE_{0}={1}\n'.format(idx, iface))
                    idx += 1
        else:
            cname = list(cfg['interfaces'])[0]
            priorcfg = self.cfgbydev.get(cname, {})
            for cf in priorcfg:
                if cf.startswith('TEAM_'):
                    ipcfg += '{0}={1}\n'.format(cf, priorcfg[cf])
            with open('/etc/sysconfig/network/ifcfg-{0}'.format(cname), 'w') as iout:
                iout.write(ipcfg)
        if v4gw:
            routecfg += 'default {0} - {1}\n'.format(v4gw, cname)
        if v6gw:
            routecfg += 'default {0} - {1}\n'.format(v6gw, cname)
        if routecfg:
            with open('/etc/sysconfig/network/ifroute-{0}'.format(cname), 'w') as routeout:
                routeout.write(routecfg)
        subprocess.call(['wicked', 'ifup', cname])


class NetworkManager(object):
    def __init__(self, devtypes):
        self.connections = {}
        self.uuidbyname = {}
        self.uuidbydev = {}
        self.connectiondetail = {}
        self.read_connections()
        self.teamidx = 0
        self.devtypes = devtypes

    def read_connections(self):
        self.connections = {}
        self.uuidbyname = {}
        self.uuidbydev = {}
        self.connectiondetail = {}
        ci = subprocess.check_output(['nmcli', '-t', 'c']).decode('utf8')
        for inf in ci.splitlines():
            n, u, t, dev = inf.split(':')
            if n == 'NAME':
                continue
            if dev == '--':
                dev = None
            self.uuidbyname[n] = u
            if dev:
                self.uuidbydev[dev] = u
            self.connections[u] = {'name': n, 'uuid': u, 'type': t, 'dev': dev}
            deats = {}
            for deat in subprocess.check_output(['nmcli', 'c', 's', u]).decode('utf8').splitlines():
                k, v = deat.split(':', 1)
                v = v.strip()
                if v == '--':
                    continue
                if '(default)' in v:
                    continue
                deats[k] = v
            self.connectiondetail[u] = deats


    def add_team_member(self, team, member):
        bondcfg = {}
        if member in self.uuidbydev:
            myuuid = self.uuidbydev[member]
            deats = self.connectiondetail[myuuid]
            currteam = deats.get('connection.master', None)
            if currteam == team:
                return
            for stg in ('ipv4.dhcp-hostname', 'ipv4.dns', 'ipv6.dns', 'ipv6.dhcp-hostname'):
                if deats.get(stg, None):
                    bondcfg[stg] = deats[stg]
        if member in self.uuidbyname:
            subprocess.check_call(['nmcli', 'c', 'del', self.uuidbyname[member]])
        subprocess.check_call(['nmcli', 'c', 'add', 'type', 'team-slave', 'master', team, 'con-name', member, 'connection.interface-name', member])
        if bondcfg:
            args = []
            for parm in bondcfg:
                args.append(parm)
                args.append(bondcfg[parm])
            subprocess.check_call(['nmcli', 'c', 'm', team] + args)

    def apply_configuration(self, cfg):
        cmdargs = {}
        stgs = cfg['settings']
        cmdargs['ipv6.method'] = stgs.get('ipv6_method', 'link-local')
        if stgs.get('ipv6_address', None):
            cmdargs['ipv6.addresses'] = stgs['ipv6_address']
        cmdargs['ipv4.method'] = stgs.get('ipv4_method', 'disabled')
        if stgs.get('ipv4_address', None):
            cmdargs['ipv4.addresses'] = stgs['ipv4_address']
        if stgs.get('ipv4_gateway', None):
            cmdargs['ipv4.gateway'] = stgs['ipv4_gateway']
        if stgs.get('ipv6_gateway', None):
            cmdargs['ipv6.gateway'] = stgs['ipv6_gateway']
        if len(cfg['interfaces']) > 1:  # team time.. should be..
            if not cfg['settings'].get('team_mode', None):
                sys.stderr.write("Warning, multiple interfaces ({0}) without a team_mode, skipping setup\n".format(','.join(cfg['interfaces'])))
                return
            if not cfg['settings'].get('connection_name', None):
                cfg['settings']['connection_name'] = 'team{0}'.format(self.teamidx)
                self.teamidx += 1
            cname = cfg['settings']['connection_name']
            cargs = []
            for arg in cmdargs:
                cargs.append(arg)
                cargs.append(cmdargs[arg])
            subprocess.check_call(['nmcli', 'c', 'add', 'type', 'team', 'con-name', cname, 'connection.interface-name', cname, 'team.runner', stgs['team_mode']] + cargs)
            for iface in cfg['interfaces']:
                self.add_team_member(cname, iface)
            subprocess.check_call(['nmcli', 'c', 'u', cname])
        else:
            cname = stgs.get('connection_name', None)
            iname = list(cfg['interfaces'])[0]
            if not cname:
                cname = iname
            u = self.uuidbyname.get(cname, None)
            cargs = []
            for arg in cmdargs:
                cargs.append(arg)
                cargs.append(cmdargs[arg])
            if u:
                cmdargs['connection.interface-name'] = iname
                subprocess.check_call(['nmcli', 'c', 'm', u] + cargs)
                subprocess.check_call(['nmcli', 'c', 'u', u])
            else:
                subprocess.check_call(['nmcli', 'c', 'add', 'type', self.devtypes[iname], 'con-name', cname, 'connection.interface-name', iname] + cargs)



if __name__ == '__main__':
    havefirewall = subprocess.call(['systemctl', 'status', 'firewalld'])
    havefirewall = havefirewall == 0
    if havefirewall:
        subprocess.check_call(['systemctl', 'stop', 'firewalld'])
    tmpllas = add_missing_llas()
    await_tentative()
    idxmap, devtypes = map_idx_to_name()
    netname_to_interfaces = {}
    myaddrs = apiclient.get_my_addresses()
    srvs, _ = apiclient.scan_confluents()
    doneidxs = set([])
    for srv in srvs:
        s = socket.create_connection((srv, 443))
        myname = s.getsockname()
        s.close()
        if len(myname) == 4:
            curridx = myname[-1]
        else:
            myname = myname[0]
            myname = socket.inet_pton(socket.AF_INET, myname)
            for addr in myaddrs:
                if myname == addr[1].tobytes():
                    curridx = addr[-1]
        if curridx in doneidxs:
            continue
        status, nc = apiclient.HTTPSClient(usejson=True, host=srv).grab_url_with_status('/confluent-api/self/netcfg')
        nc = json.loads(nc)
        iname = get_interface_name(idxmap[curridx], nc.get('default', {}))
        if iname:
            if 'default' in netname_to_interfaces:
                netname_to_interfaces['default']['interfaces'].add(iname)
            else:
                netname_to_interfaces['default'] = {'interfaces': set([iname]), 'settings': nc['default']}
        for netname in nc.get('extranets', {}):
            uname = '_' + netname
            iname = get_interface_name(idxmap[curridx], nc['extranets'][netname])
            if iname:
                if uname in netname_to_interfaces:
                    netname_to_interfaces[uname]['interfaces'].add(iname)
                else:
                    netname_to_interfaces[uname] = {'interfaces': set([iname]), 'settings': nc['extranets'][netname]}
        doneidxs.add(curridx)
    rm_tmp_llas(tmpllas)
    if os.path.exists('/usr/bin/nmcli'):
        nm = NetworkManager(devtypes)
    elif os.path.exists('/usr/sbin/wicked'):
        nm = WickedManager()
    for netn in netname_to_interfaces:
        nm.apply_configuration(netname_to_interfaces[netn])
    if havefirewall:
        subprocess.check_call(['systemctl', 'start', 'firewalld'])
    await_tentative()

