diff --git a/confluent_osdeploy/common/initramfs/opt/confluent/bin/apiclient b/confluent_osdeploy/common/initramfs/opt/confluent/bin/apiclient index d468b4d9..dc2d8e4b 100644 --- a/confluent_osdeploy/common/initramfs/opt/confluent/bin/apiclient +++ b/confluent_osdeploy/common/initramfs/opt/confluent/bin/apiclient @@ -3,6 +3,7 @@ try: import http.client as client except ImportError: import httplib as client +import base64 import ctypes import ctypes.util import glob @@ -15,6 +16,12 @@ import sys import struct import time import re +import hmac +import hashlib +try: + import json +except ImportError: + json = None class InvalidApiKey(Exception): pass @@ -72,7 +79,7 @@ def get_my_addresses(): return addrs -def scan_confluents(): +def scan_confluents(confuuid=None): srvs = {} s6 = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) s6.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1) @@ -84,12 +91,13 @@ def scan_confluents(): s4.bind(('0.0.0.0', 1900)) doneidxs = set([]) msg = 'M-SEARCH * HTTP/1.1\r\nST: urn:xcat.org:service:confluent:' - with open('/etc/confluent/confluent.deploycfg') as dcfg: - for line in dcfg.read().split('\n'): - if line.startswith('confluent_uuid:'): - confluentuuid = line.split(': ')[1] - msg += '/confluentuuid=' + confluentuuid - break + if not confuuid: + with open('/etc/confluent/confluent.deploycfg') as dcfg: + for line in dcfg.read().split('\n'): + if line.startswith('confluent_uuid:'): + confluentuuid = line.split(': ')[1] + msg += '/confluentuuid=' + confluentuuid + break try: with open('/sys/devices/virtual/dmi/id/product_uuid') as uuidin: msg += '/uuid=' + uuidin.read().strip() @@ -126,6 +134,7 @@ def scan_confluents(): srvlist = [] if r: r = r[0] + nodename = None while r: for s in r: (rsp, peer) = s.recvfrom(9000) @@ -133,6 +142,7 @@ def scan_confluents(): current = None for line in rsp: if line.startswith(b'NODENAME: '): + nodename = line.replace(b'NODENAME: ', b'').strip().decode('utf8') current = {} elif line.startswith(b'DEFAULTNET: 1'): current['isdefault'] = True @@ -148,16 +158,32 @@ def scan_confluents(): r = select.select((s4, s6), (), (), 2) if r: r = r[0] + if not os.path.exists('/etc/confluent/confluent.info'): + with open('/etc/confluent/confluent.info', 'w+') as cinfo: + if nodename: + cinfo.write('NODENAME: {0}\n'.format(nodename)) + for srv in srvlist: + cinfo.write('MANAGER: {0}\n'.format(srv)) return srvlist, srvs -def get_net_apikey(nodename, mgr): +def get_net_apikey(nodename, mgr, hmackey=None, confuuid=None): alpha = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789./' newpass = ''.join([alpha[x >> 2] for x in bytearray(os.urandom(32))]) salt = '$5$' + ''.join([alpha[x >> 2] for x in bytearray(os.urandom(8))]) newpass = newpass.encode('utf8') salt = salt.encode('utf8') crypted = c_crypt(newpass, salt) + if hmackey: + hmacvalue = hmac.new(hmackey.encode('utf8'), crypted, hashlib.sha256).digest() + hmacvalue = base64.b64encode(hmacvalue).decode('utf8') + client = HTTPSClient(host=mgr, phmac=hmacvalue, nodename=nodename, confuuid=confuuid) + try: + status, rsp = client.grab_url_with_status('/confluent-api/self/registerapikey', data=crypted, returnrsp=True) + if status == 200: + return newpass.decode('utf8') + except Exception: + pass for addrinfo in socket.getaddrinfo(mgr, 13001, 0, socket.SOCK_STREAM): try: clisock = socket.socket(addrinfo[0], addrinfo[1]) @@ -195,7 +221,7 @@ def get_net_apikey(nodename, mgr): return '' -def get_apikey(nodename, hosts, errout=None): +def get_apikey(nodename, hosts, errout=None, hmackey=None, confuuid=None): apikey = "" if os.path.exists('/etc/confluent/confluent.apikey'): apikey = open('/etc/confluent/confluent.apikey').read().strip() @@ -204,16 +230,16 @@ def get_apikey(nodename, hosts, errout=None): while not apikey: for host in hosts: try: - apikey = get_net_apikey(nodename, host) + apikey = get_net_apikey(nodename, host, hmackey=hmackey, confuuid=confuuid) except OSError: apikey = None if apikey: break else: - srvlist, _ = scan_confluents() + srvlist, _ = scan_confluents(confuuid=confuuid) for host in srvlist: try: - apikey = get_net_apikey(nodename, host) + apikey = get_net_apikey(nodename, host, hmackey=hmackey, confuuid=confuuid) except OSError: apikey = None if apikey: @@ -231,35 +257,43 @@ def get_apikey(nodename, hosts, errout=None): return apikey class HTTPSClient(client.HTTPConnection, object): - def __init__(self, usejson=False, port=443, host=None, errout=None, phmac=None, checkonly=False): + def __init__(self, usejson=False, port=443, host=None, errout=None, phmac=None, checkonly=False, hmackey=None, nodename=None, confuuid=None): self.ignorehosts = set([]) self.phmac = phmac + self.hmackey = hmackey + self.confuuid = confuuid self.errout = None + self.stdheaders = {} + if nodename: + self.stdheaders['CONFLUENT_NODENAME'] = nodename if errout: self.errout = open(errout, 'w') self.errout.flush() - self.stdheaders = {} mgtiface = None if usejson: self.stdheaders['ACCEPT'] = 'application/json' if host: self.hosts = [host] - with open('/etc/confluent/confluent.info') as cinfo: - info = cinfo.read().split('\n') - for line in info: - if line.startswith('NODENAME:'): - node = line.split(' ')[1] - self.stdheaders['CONFLUENT_NODENAME'] = node + if not nodename: + with open('/etc/confluent/confluent.info') as cinfo: + info = cinfo.read().split('\n') + for line in info: + if line.startswith('NODENAME:'): + nodename = line.split(' ')[1] + self.stdheaders['CONFLUENT_NODENAME'] = nodename else: self.hosts = [] - info = open('/etc/confluent/confluent.info').read().split('\n') + try: + info = open('/etc/confluent/confluent.info').read().split('\n') + except Exception: + info = [] havedefault = '0' plainhost = '' for line in info: host = '' if line.startswith('NODENAME:'): - node = line.split(' ')[1] - self.stdheaders['CONFLUENT_NODENAME'] = node + nodename = line.split(' ')[1] + self.stdheaders['CONFLUENT_NODENAME'] = nodename if line.startswith('MANAGER:') and not host: host = line.split(' ')[1] self.hosts.append(host) @@ -294,15 +328,14 @@ class HTTPSClient(client.HTTPConnection, object): if plainhost and not self.hosts: self.hosts.append(plainhost) if self.phmac: - with open(phmac, 'r') as hmacin: - self.stdheaders['CONFLUENT_CRYPTHMAC'] = hmacin.read() + self.stdheaders['CONFLUENT_CRYPTHMAC'] = self.phmac elif not checkonly: - self.stdheaders['CONFLUENT_APIKEY'] = get_apikey(node, self.hosts, errout=self.errout) + self.stdheaders['CONFLUENT_APIKEY'] = get_apikey(nodename, self.hosts, errout=self.errout, hmackey=hmackey, confuuid=self.confuuid) if mgtiface: self.stdheaders['CONFLUENT_MGTIFACE'] = mgtiface self.port = port self.host = None - self.node = node + self.node = nodename host = self.check_connections() client.HTTPConnection.__init__(self, host, port) self.connect() @@ -342,7 +375,7 @@ class HTTPSClient(client.HTTPConnection, object): continue break if not foundsrv: - srvlist, srvs = scan_confluents() + srvlist, srvs = scan_confluents(self.confuuid) hosts = [] for srv in srvlist: if srvs[srv].get('isdefault', False): @@ -416,7 +449,7 @@ class HTTPSClient(client.HTTPConnection, object): with open('/etc/confluent/confluent.apikey', 'w+') as akfile: akfile.write('') self.stdheaders['CONFLUENT_APIKEY'] = get_apikey( - self.node, [self.host], errout=self.errout) + self.node, [self.host], errout=self.errout, hmackey=self.hmackey, confuuid=self.confuuid) if rsp.status == 503: # confluent is down, but the server running confluent is otherwise up authed = False self.ignorehosts.add(self.host) @@ -545,8 +578,24 @@ if __name__ == '__main__': phmac = sys.argv.index('-p') sys.argv.pop(phmac) phmac = sys.argv.pop(phmac) + with open(phmac, 'r') as hmacin: + phmac = hmacin.read() except ValueError: phmac = None + try: + identfile = sys.argv.index('-i') + sys.argv.pop(identfile) + identfile = sys.argv.pop(identfile) + with open(identfile) as idin: + data = idin.read() + identinfo = json.loads(data) + nodename = identinfo.get('nodename', None) + hmackey = identinfo.get('apitoken', None) + confuuid = identinfo.get('confluent_uuid', None) + except ValueError: + hmackey = None + nodename = None + confuuid = None try: checkonly = False idxit = sys.argv.index('-c') @@ -558,7 +607,7 @@ if __name__ == '__main__': data = open(sys.argv[-1]).read() if outbin: with open(outbin, 'ab+') as outf: - reader = HTTPSClient(usejson=usejson, errout=errout).grab_url( + reader = HTTPSClient(usejson=usejson, errout=errout, hmackey=hmackey, nodename=nodename, confuuid=confuuid).grab_url( sys.argv[1], data, returnrsp=True) chunk = reader.read(16384) while chunk: @@ -566,7 +615,7 @@ if __name__ == '__main__': chunk = reader.read(16384) sys.exit(0) - mclient = HTTPSClient(usejson, errout=errout, phmac=phmac, checkonly=checkonly) + mclient = HTTPSClient(usejson, errout=errout, phmac=phmac, checkonly=checkonly, hmackey=hmackey, nodename=nodename, confuuid=confuuid) if waitfor: status = 201 while status != waitfor: