diff --git a/confluent_client/bin/confetty b/confluent_client/bin/confetty index 0ce3cd5e..1078fe8e 100755 --- a/confluent_client/bin/confetty +++ b/confluent_client/bin/confetty @@ -1,4 +1,4 @@ -#!/usr/bin/python2 +#!/usr/bin/python3 # vim: tabstop=4 shiftwidth=4 softtabstop=4 # Copyright 2014 IBM Corporation @@ -41,6 +41,7 @@ # esc-( would interfere with normal esc use too much # ~ I will not use for now... +import asyncio import math import getpass import optparse @@ -360,7 +361,7 @@ def print_result(res): print(output.encode('utf-8')) -def do_command(command, server): +async def do_command(command, server): global exitcode global target global currconsole @@ -399,7 +400,7 @@ def do_command(command, server): target = otarget else: foundchild = False - for res in session.read(parentpath, server): + async for res in session.read(parentpath, server): try: if res['item']['href'] == childname: foundchild = True @@ -434,7 +435,7 @@ def do_command(command, server): pass else: targpath = target - for res in session.read(targpath): + async for res in session.read(targpath): if 'item' in res: # a link relation if type(res['item']) == dict: print(res['item']["href"]) @@ -854,7 +855,7 @@ opts, shellargs = parser.parse_args() username = None passphrase = None -def server_connect(): +async def server_connect(): global session, username, passphrase if opts.controlpath: termhandler.TermHandler(opts.controlpath) @@ -864,7 +865,7 @@ def server_connect(): session = client.Command(os.environ['CONFLUENT_HOST']) else: # unix domain session = client.Command() - + await session.ensure_connected() # Next stop, reading and writing from whichever of stdin and server goes first. #see pyghmi code for solconnect.py if not session.authenticated and username is not None: @@ -890,10 +891,10 @@ if sys.stdout.isatty(): import readline -def main(): +async def main(): global inconsole try: - server_connect() + await server_connect() except (EOFError, KeyboardInterrupt) as _: raise BailOut(0) except socket.gaierror: @@ -920,10 +921,10 @@ def main(): session_node = get_session_node(shellargs) if session_node is not None: consoleonly = True - do_command("start /nodes/%s/console/session" % session_node, netserver) + await do_command("start /nodes/%s/console/session" % session_node, netserver) doexit = True elif shellargs: - do_command(shellargs, netserver) + await do_command(shellargs, netserver) quitconfetty(fullexit=True, fixterm=False) powerstate = None @@ -954,7 +955,7 @@ def main(): else: currcommand = prompt() try: - do_command(currcommand, netserver) + await do_command(currcommand, netserver) except socket.error: try: server_connect() @@ -1029,7 +1030,7 @@ if __name__ == '__main__': if opts.mintime: deadline = os.times()[4] + float(opts.mintime) try: - main() + asyncio.get_event_loop().run_until_complete(main()) except BailOut as e: errcode = e.errorcode except Exception as e: diff --git a/confluent_client/bin/nodedefine b/confluent_client/bin/nodedefine index 6e56722a..ce347bec 100755 --- a/confluent_client/bin/nodedefine +++ b/confluent_client/bin/nodedefine @@ -1,4 +1,4 @@ -#!/usr/bin/python2 +#!/usr/bin/python3 # vim: tabstop=4 shiftwidth=4 softtabstop=4 # Copyright 2017 Lenovo @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import optparse import os import signal @@ -32,27 +33,31 @@ if path.startswith('/opt'): import confluent.client as client -argparser = optparse.OptionParser( - usage='''\n %prog noderange attribute1=value1 attribute2=value,... - \n ''') -(options, args) = argparser.parse_args() -requestargs=None -try: - noderange = args[0] -except IndexError: - argparser.print_help() - sys.exit(1) -client.check_globbing(noderange) -session = client.Command() -exitcode = 0 -attribs = {'name': noderange} -for arg in args[1:]: - key, val = arg.split('=', 1) - attribs[key] = val -for r in session.create('/noderange/', attribs): - if 'error' in r: - sys.stderr.write(r['error'] + '\n') - exitcode |= 1 - if 'created' in r: - print('{0}: created'.format(r['created'])) -sys.exit(exitcode) +async def main(): + argparser = optparse.OptionParser( + usage='''\n %prog noderange attribute1=value1 attribute2=value,... + \n ''') + (options, args) = argparser.parse_args() + requestargs = None + try: + noderange = args[0] + except IndexError: + argparser.print_help() + sys.exit(1) + client.check_globbing(noderange) + session = client.Command() + exitcode = 0 + attribs = {'name': noderange} + for arg in args[1:]: + key, val = arg.split('=', 1) + attribs[key] = val + async for r in session.create('/noderange/', attribs): + if 'error' in r: + sys.stderr.write(r['error'] + '\n') + exitcode |= 1 + if 'created' in r: + print('{0}: created'.format(r['created'])) + sys.exit(exitcode) + +if __name__ == '__main__': + asyncio.get_event_loop().run_until_complete(main()) diff --git a/confluent_client/bin/nodelist b/confluent_client/bin/nodelist index 462ed922..fe40d25c 100755 --- a/confluent_client/bin/nodelist +++ b/confluent_client/bin/nodelist @@ -1,4 +1,4 @@ -#!/usr/libexec/platform-python +#!/usr/bin/python3 # vim: tabstop=4 shiftwidth=4 softtabstop=4 # Copyright 2015-2017 Lenovo @@ -21,6 +21,7 @@ import optparse import os import signal import sys +import asyncio @@ -35,7 +36,7 @@ if path.startswith('/opt'): import confluent.client as client -def main(): +async def main(): argparser = optparse.OptionParser( usage="Usage: %prog noderange\n" " or: %prog [options] noderange ...") @@ -61,7 +62,7 @@ def main(): if len(args) > 1: exitcode=client.printattributes(session,requestargs,showtype,nodetype,noderange,options) else: - for res in session.read(nodelist): + async for res in session.read(nodelist): if 'error' in res: sys.stderr.write(res['error'] + '\n') exitcode = 1 @@ -73,4 +74,4 @@ def main(): sys.exit(exitcode) if __name__ == '__main__': - main() + asyncio.get_event_loop().run_until_complete(main()) diff --git a/confluent_client/confluent/client.py b/confluent_client/confluent/client.py index ad29ff02..633671f4 100644 --- a/confluent_client/confluent/client.py +++ b/confluent_client/confluent/client.py @@ -62,6 +62,7 @@ def stringify(instr): return instr.encode('utf-8') return instr + class Tabulator(object): def __init__(self, headers): self.headers = headers @@ -121,7 +122,8 @@ def printerror(res, node=None): for node in res.get('databynode', {}): exitcode = res['databynode'][node].get('errorcode', exitcode) if 'error' in res['databynode'][node]: - sys.stderr.write('{0}: {1}\n'.format(node, res['databynode'][node]['error'])) + sys.stderr.write( + '{0}: {1}\n'.format(node, res['databynode'][node]['error'])) if exitcode == 0: exitcode = 1 if 'error' in res: @@ -169,6 +171,11 @@ class Command(object): self.serverloc = '/var/run/confluent/api.sock' else: self.serverloc = server + self.connected = False + + async def ensure_connected(self): + if self.connected: + return True if os.path.isabs(self.serverloc) and os.path.exists(self.serverloc): self._connect_unix() self.unixdomain = True @@ -176,9 +183,9 @@ class Command(object): raise Exception('Confluent service is not available') else: self._connect_tls() - self.protversion = int(tlvdata.recv(self.connection).split( + self.protversion = int((await tlvdata.recv(self.connection)).split( b'--')[1].strip()[1:]) - authdata = tlvdata.recv(self.connection) + authdata = await tlvdata.recv(self.connection) if authdata['authpassed'] == 1: self.authenticated = True else: @@ -186,19 +193,21 @@ class Command(object): if not self.authenticated and 'CONFLUENT_USER' in os.environ: username = os.environ['CONFLUENT_USER'] passphrase = os.environ['CONFLUENT_PASSPHRASE'] - self.authenticate(username, passphrase) + await self.authenticate(username, passphrase) + self.connected = True - def add_file(self, name, handle, mode): + async def add_file(self, name, handle, mode): + await self.ensure_connected() if self.protversion < 3: raise Exception('Not supported with connected confluent server') if not self.unixdomain: raise Exception('Can only add a file to a unix domain connection') tlvdata.send(self.connection, {'filename': name, 'mode': mode}, handle) - def authenticate(self, username, password): + async def authenticate(self, username, password): tlvdata.send(self.connection, {'username': username, 'password': password}) - authdata = tlvdata.recv(self.connection) + authdata = await tlvdata.recv(self.connection) if authdata['authpassed'] == 1: self.authenticated = True @@ -328,25 +337,32 @@ class Command(object): cprint('') return 0 - def read(self, path, parameters=None): + async def read(self, path, parameters=None): + await self.ensure_connected() if not self.authenticated: raise Exception('Unauthenticated') - return send_request('retrieve', path, self.connection, parameters) + async for rsp in send_request( + 'retrieve', path, self.connection, parameters): + yield rsp - def update(self, path, parameters=None): + async def update(self, path, parameters=None): + await self.ensure_connected() if not self.authenticated: raise Exception('Unauthenticated') - return send_request('update', path, self.connection, parameters) + return await send_request('update', path, self.connection, parameters) - def create(self, path, parameters=None): + async def create(self, path, parameters=None): + await self.ensure_connected() if not self.authenticated: raise Exception('Unauthenticated') - return send_request('create', path, self.connection, parameters) + async for rsp in send_request('create', path, self.connection, parameters): + yield rsp - def delete(self, path, parameters=None): + async def delete(self, path, parameters=None): + await self.ensure_connected() if not self.authenticated: raise Exception('Unauthenticated') - return send_request('delete', path, self.connection, parameters) + return await send_request('delete', path, self.connection, parameters) def _connect_unix(self): self.connection = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) @@ -411,7 +427,7 @@ class Command(object): khf[hostid] = fingerprint -def send_request(operation, path, server, parameters=None): +async def send_request(operation, path, server, parameters=None): """This function iterates over all the responses received from the server. @@ -424,16 +440,16 @@ def send_request(operation, path, server, parameters=None): payload = {'operation': operation, 'path': path} if parameters is not None: payload['parameters'] = parameters - tlvdata.send(server, payload) - result = tlvdata.recv(server) + await tlvdata.send(server, payload) + result = await tlvdata.recv(server) while '_requestdone' not in result: try: yield result except GeneratorExit: while '_requestdone' not in result: - result = tlvdata.recv(server) + result = await tlvdata.recv(server) raise - result = tlvdata.recv(server) + result = await tlvdata.recv(server) def attrrequested(attr, attrlist, seenattributes, node=None): @@ -720,7 +736,7 @@ def updateattrib(session, updateargs, nodetype, noderange, options, dictassign=N else: exitcode = session.simple_noderange_command(noderange, 'attributes/all', value, key) - except: + except Exception: sys.stderr.write('Error: {0} not a valid expression\n'.format(str(updateargs[1:]))) exitcode = 1 sys.exit(exitcode) diff --git a/confluent_client/confluent/tlvdata.py b/confluent_client/confluent/tlvdata.py index 7fcb663b..f865aa22 100644 --- a/confluent_client/confluent/tlvdata.py +++ b/confluent_client/confluent/tlvdata.py @@ -16,15 +16,12 @@ # limitations under the License. import array +import asyncio import ctypes import ctypes.util import confluent.tlv as tlv -try: - import eventlet.green.socket as socket - import eventlet.green.select as select -except ImportError: - import socket - import select +import socket +import select from datetime import datetime import json import os @@ -40,6 +37,7 @@ try: except NameError: pass + class iovec(ctypes.Structure): # from uio.h _fields_ = [('iov_base', ctypes.c_void_p), ('iov_len', ctypes.c_size_t)] @@ -56,6 +54,7 @@ class cmsghdr(ctypes.Structure): # also from bits/socket.h @classmethod def init_data(cls, cmsg_len, cmsg_level, cmsg_type, cmsg_data): Data = ctypes.c_ubyte * ctypes.sizeof(cmsg_data) + class _flexhdr(ctypes.Structure): _fields_ = cls._fields_ + [('cmsg_data', Data)] @@ -102,9 +101,59 @@ libc = ctypes.CDLL(ctypes.util.find_library('c')) recvmsg = libc.recvmsg recvmsg.argtypes = [ctypes.c_int, ctypes.POINTER(msghdr), ctypes.c_int] recvmsg.restype = ctypes.c_int -sendmsg = libc.sendmsg -sendmsg.argtypes = [ctypes.c_int, ctypes.POINTER(msghdr), ctypes.c_int] -sendmsg.restype = ctypes.c_size_t + + +def _sendmsg(loop, fut, sock, msg, fds, rfd): + if rfd is not None: + loop.remove_reader(rfd) + if fut.cancelled(): + return + try: + retdata = sock.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds))]) + except (BlockingIOError, InterruptedError): + fd = sock.fileno() + loop.add_reader(fd, _sendmsg, loop, fut, sock, fd) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(retdata) + + +def send_fds(sock, msg, fds): + cloop = asyncio.get_event_loop() + fut = cloop.create_future() + _sendmsg(cloop, fut, sock, msg, fds, None) + return fut + + +def _recvmsg(loop, fut, sock, msglen, maxfds, rfd): + if rfd is not None: + loop.remove_reader(rfd) + fds = array.array("i") # Array of ints + try: + msg, ancdata, flags, addr = sock.recvmsg( + msglen, socket.CMSG_LEN(maxfds * fds.itemsize)) + except (BlockingIOError, InterruptedError): + fd = sock.fileno() + loop.add_reader(fd, _recvmsg, loop, fut, sock, fd) + except Exception as exc: + fut.set_exception(exc) + else: + for cmsg_level, cmsg_type, cmsg_data in ancdata: + if (cmsg_level == socket.SOL_SOCKET + and cmsg_type == socket.SCM_RIGHTS): + # Append data, ignoring any truncated integers at the end. + fds.frombytes( + cmsg_data[ + :len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) + fut.set_result(msglen, list(fds)) + +def recv_fds(sock, msglen, maxfds): + cloop = asyncio.get_event_loop() + fut = cloop.create_future() + _recvmsg(cloop, fut, sock, msglen, maxfds, None) + return fut + def decodestr(value): ret = None @@ -141,7 +190,8 @@ def _unicode_list(currlist): _unicode_list(currlist[i]) -def send(handle, data, filehandle=None): +async def send(handle, data, filehandle=None): + cloop = asyncio.get_event_loop() if isinstance(data, unicode): try: data = data.encode('utf-8') @@ -156,10 +206,10 @@ def send(handle, data, filehandle=None): if tl < 16777216: # type for string is '0', so we don't need # to xor anything in - handle.sendall(struct.pack("!I", tl)) + await cloop.sock_sendall(handle, struct.pack("!I", tl)) else: raise Exception("String data length exceeds protocol") - handle.sendall(data) + await cloop.sock_sendall(handle, data) elif isinstance(data, dict): # JSON currently only goes to 4 bytes # Some structured message, like what would be seen in http responses unicode_dictvalues(data) # make everything unicode, assuming UTF-8 @@ -171,41 +221,32 @@ def send(handle, data, filehandle=None): # xor in the type (0b1 << 24) if filehandle is None: tl |= 16777216 - handle.sendall(struct.pack("!I", tl)) - handle.sendall(sdata) + await cloop.sock_sendall(handle, struct.pack("!I", tl)) + await cloop.sock_sendall(handle, sdata) else: tl |= (2 << 24) - handle.sendall(struct.pack("!I", tl)) - cdtype = ctypes.c_ubyte * len(sdata) - cdata = cdtype.from_buffer(bytearray(sdata)) - ciov = iovec(iov_base=ctypes.addressof(cdata), - iov_len=ctypes.c_size_t(ctypes.sizeof(cdata))) - fd = ctypes.c_int(filehandle) - cmh = cmsghdr.init_data( - cmsg_len=CMSG_LEN( - ctypes.sizeof(fd)), cmsg_level=socket.SOL_SOCKET, - cmsg_type=SCM_RIGHTS, cmsg_data=fd) - mh = msghdr(msg_name=None, msg_len=0, msg_iov=iovec_ptr(ciov), - msg_iovlen=1, msg_control=ctypes.addressof(cmh), - msg_controllen=ctypes.c_size_t(ctypes.sizeof(cmh))) - sendmsg(handle.fileno(), mh, 0) + await cloop.sock_sendall(handle, struct.pack("!I", tl)) + await send_fds(handle, b'', [filehandle]) -def recvall(handle, size): - rd = handle.recv(size) +async def recvall(handle, size): + cloop = asyncio.get_event_loop() + rd = await cloop.sock_recv(handle, size) while len(rd) < size: - nd = handle.recv(size - len(rd)) + nd = await cloop.sock_recv(handle, size - len(rd)) if not nd: raise Exception("Error reading data") rd += nd return rd -def recv(handle): - tl = handle.recv(4) + +async def recv(handle): + cloop = asyncio.get_event_loop() + tl = await cloop.sock_recv(handle, 4) if not tl: return None while len(tl) < 4: - ndata = handle.recv(4 - len(tl)) + ndata = await cloop.sock_recv(handle, 4 - len(tl)) if not ndata: raise Exception("Error reading data") tl += ndata @@ -239,23 +280,16 @@ def recv(handle): msg.msg_iovlen = 1 msg.msg_control = ctypes.addressof(cmsg) msg.msg_controllen = ctypes.sizeof(cmsg) - select.select([handle], [], []) - i = recvmsg(handle.fileno(), ctypes.pointer(msg), 0) - cdata = cmsgarr[CMSG_LEN(0).value:] - data = rawbuffer[:i] - if cmsg.cmsg_level == socket.SOL_SOCKET and cmsg.cmsg_type == SCM_RIGHTS: - try: - filehandles.fromstring(bytes( - cdata[:len(cdata) - len(cdata) % filehandles.itemsize])) - except AttributeError: - filehandles.frombytes(bytes( - cdata[:len(cdata) - len(cdata) % filehandles.itemsize])) + i = await recv_fds(handle, 2048, 4) + print(repr(i)) + data = i[0] + filehandles = i[1] data = json.loads(bytes(data)) return ClientFile(data['filename'], data['mode'], filehandles[0]) else: - data = handle.recv(dlen) + data = await cloop.sock_recv(handle, dlen) while len(data) < dlen: - ndata = handle.recv(dlen - len(data)) + ndata = await asyncio.sock_recv(handle, dlen - len(data)) if not ndata: raise Exception("Error reading data") data += ndata diff --git a/confluent_server/confluent/core.py b/confluent_server/confluent/core.py index 09a298e2..21d24b1e 100644 --- a/confluent_server/confluent/core.py +++ b/confluent_server/confluent/core.py @@ -1267,7 +1267,7 @@ def handle_discovery(pathcomponents, operation, configmanager, inputdata): if pathcomponents[0] == 'detected': pass -def handle_path(path, operation, configmanager, inputdata=None, autostrip=True): +async def handle_path(path, operation, configmanager, inputdata=None, autostrip=True): """Given a full path request, return an object. The plugins should generally return some sort of iterator. @@ -1281,7 +1281,7 @@ def handle_path(path, operation, configmanager, inputdata=None, autostrip=True): if not pathcomponents: # root collection list return enumerate_collections(rootcollections) elif pathcomponents[0] == 'noderange': - return handle_node_request(configmanager, inputdata, operation, + return await handle_node_request(configmanager, inputdata, operation, pathcomponents, autostrip) elif pathcomponents[0] == 'deployment': return handle_deployment(configmanager, inputdata, pathcomponents, @@ -1295,7 +1295,7 @@ def handle_path(path, operation, configmanager, inputdata=None, autostrip=True): operation) elif pathcomponents[0] == 'nodes': # single node request of some sort - return handle_node_request(configmanager, inputdata, + return await handle_node_request(configmanager, inputdata, operation, pathcomponents, autostrip) elif pathcomponents[0] == 'discovery': return disco.handle_api_request( diff --git a/confluent_server/confluent/main.py b/confluent_server/confluent/main.py index b49d8f56..4aa84675 100644 --- a/confluent_server/confluent/main.py +++ b/confluent_server/confluent/main.py @@ -39,6 +39,7 @@ import confluent.httpapi as httpapi import confluent.log as log import confluent.collective.manager as collective import confluent.discovery.protocols.pxe as pxe +from eventlet.asyncio import spawn_for_awaitable try: import confluent.sockapi as sockapi except ImportError: @@ -313,7 +314,7 @@ def run(args): sock_bind_host, sock_bind_port = _get_connector_config('socket') try: sockservice = sockapi.SockApi(sock_bind_host, sock_bind_port) - sockservice.start() + spawn_for_awaitable(sockservice.start()) except NameError: pass webservice = httpapi.HttpApi(http_bind_host, http_bind_port) diff --git a/confluent_server/confluent/plugins/hardwaremanagement/enclosure.py b/confluent_server/confluent/plugins/hardwaremanagement/enclosure.py index d910400b..f70ce40f 100644 --- a/confluent_server/confluent/plugins/hardwaremanagement/enclosure.py +++ b/confluent_server/confluent/plugins/hardwaremanagement/enclosure.py @@ -34,7 +34,7 @@ async def reseat_bays(encmgr, bays, configmanager, rspq): finally: await rspq.put(None) -def update(nodes, element, configmanager, inputdata): +async def update(nodes, element, configmanager, inputdata): emebs = configmanager.get_node_attributes( nodes, (u'enclosure.manager', u'enclosure.bay')) baysbyencmgr = {} @@ -57,7 +57,7 @@ def update(nodes, element, configmanager, inputdata): for encmgr in baysbyencmgr: currtask = asyncio.create_task(reseat_bays(encmgr, baysbyencmgr[encmgr], configmanager, rspq)) reseattasks.append(currtask) - while not all([task.done() for task in reseattasks]):; + while not all([task.done() for task in reseattasks]): nrsp = await rspq.get() if nrsp is not None: yield nrsp diff --git a/confluent_server/confluent/plugins/hardwaremanagement/ipmi.py b/confluent_server/confluent/plugins/hardwaremanagement/ipmi.py index 103df7c2..91ab9511 100644 --- a/confluent_server/confluent/plugins/hardwaremanagement/ipmi.py +++ b/confluent_server/confluent/plugins/hardwaremanagement/ipmi.py @@ -187,7 +187,7 @@ class IpmiCommandWrapper(ipmicommand.Command): (node,), ('secret.hardwaremanagementuser', 'collective.manager', 'secret.hardwaremanagementpassword', 'secret.ipmikg', 'hardwaremanagement.manager'), self._attribschanged) - amait super().create(**kwargs) + await super().create(**kwargs) self.setup_confluent_keyhandler() try: os.makedirs('/var/cache/confluent/ipmi/') diff --git a/confluent_server/confluent/sockapi.py b/confluent_server/confluent/sockapi.py index 2d4db15b..0a4ff1b7 100644 --- a/confluent_server/confluent/sockapi.py +++ b/confluent_server/confluent/sockapi.py @@ -21,6 +21,7 @@ # import atexit +import asyncio import ctypes import ctypes.util import errno @@ -32,7 +33,7 @@ import sys import traceback import eventlet.green.select as select -import eventlet.green.socket as socket +import socket import eventlet.green.ssl as ssl import eventlet.support.greendns as greendns import eventlet @@ -107,15 +108,15 @@ class ClientConsole(object): self.pendingdata = [] -def send_data(connection, data): +async def send_data(connection, data): try: - tlvdata.send(connection, data) + await tlvdata.send(connection, data) except IOError as ie: if ie.errno != errno.EPIPE: raise -def sessionhdl(connection, authname, skipauth=False, cert=None): +async def sessionhdl(connection, authname, skipauth=False, cert=None): try: # For now, trying to test the console stuff, so let's just do n4. authenticated = False @@ -132,10 +133,11 @@ def sessionhdl(connection, authname, skipauth=False, cert=None): # version 0 == original, version 1 == pickle3 allowed, 2 = pickle forbidden, msgpack allowed # v3 - filehandle allowed # v4 - schema change and keepalive changes - send_data(connection, "Confluent -- v4 --") + + await send_data(connection, "Confluent -- v4 --") while not authenticated: # prompt for name and passphrase - send_data(connection, {'authpassed': 0}) - response = tlvdata.recv(connection) + await send_data(connection, {'authpassed': 0}) + response = await tlvdata.recv(connection) if not response: return if 'collective' in response: @@ -162,8 +164,8 @@ def sessionhdl(connection, authname, skipauth=False, cert=None): else: authenticated = True cfm = authdata[1] - send_data(connection, {'authpassed': 1}) - request = tlvdata.recv(connection) + await send_data(connection, {'authpassed': 1}) + request = await tlvdata.recv(connection) if request and isinstance(request, dict) and 'collective' in request: if skipauth: if not libssl: @@ -185,27 +187,27 @@ def sessionhdl(connection, authname, skipauth=False, cert=None): {'collective': {'error': 'collective management commands may only be used by root'}}) while request is not None: try: - process_request( + await process_request( connection, request, cfm, authdata, authname, skipauth) except exc.ConfluentException as e: if ((not isinstance(e, exc.LockedCredentials)) and e.apierrorcode == 500): tracelog.log(traceback.format_exc(), ltype=log.DataTypes.event, event=log.Events.stacktrace) - send_data(connection, {'errorcode': e.apierrorcode, + await send_data(connection, {'errorcode': e.apierrorcode, 'error': e.apierrorstr, 'detail': e.get_error_body()}) - send_data(connection, {'_requestdone': 1}) + await send_data(connection, {'_requestdone': 1}) except SystemExit: sys.exit(0) except Exception as e: tracelog.log(traceback.format_exc(), ltype=log.DataTypes.event, event=log.Events.stacktrace) - send_data(connection, {'errorcode': 500, + await send_data(connection, {'errorcode': 500, 'error': 'Unexpected error - ' + str(e)}) - send_data(connection, {'_requestdone': 1}) + await send_data(connection, {'_requestdone': 1}) try: - request = tlvdata.recv(connection) + request = await tlvdata.recv(connection) except Exception: request = None finally: @@ -216,15 +218,16 @@ def sessionhdl(connection, authname, skipauth=False, cert=None): except Exception: pass -def send_response(responses, connection): +async def send_response(responses, connection): if responses is None: return + responses = await responses for rsp in responses: - send_data(connection, rsp.raw()) - send_data(connection, {'_requestdone': 1}) + await send_data(connection, rsp.raw()) + await send_data(connection, {'_requestdone': 1}) -def process_request(connection, request, cfm, authdata, authname, skipauth): +async def process_request(connection, request, cfm, authdata, authname, skipauth): if isinstance(request, tlvdata.ClientFile): cfm.add_client_file(request) return @@ -267,7 +270,7 @@ def process_request(connection, request, cfm, authdata, authname, skipauth): send_data(connection, {"errorcode": 400, "error": "Bad Request - " + str(e)}) send_data(connection, {"_requestdone": 1}) - send_response(hdlr, connection) + await send_response(hdlr, connection) return def start_proxy_term(connection, cert, request): @@ -387,7 +390,7 @@ def _tlshandler(bind_host, bind_port): if addr[1] < 1000: eventlet.spawn_n(cs.handle_client, cnn, addr) else: - eventlet.spawn_n(_tlsstartup, cnn) + asyncio.create_task(_tlsstartup(cnn)) if ffi: @@ -395,8 +398,7 @@ if ffi: def verify_stub(store, misc): return 1 - -def _tlsstartup(cnn): +async def _tlsstartup(cnn): authname = None cert = None conf.init_config() @@ -435,7 +437,7 @@ def _tlsstartup(cnn): cnn = ctx.wrap_socket(cnn, server_side=True) except AttributeError: raise Exception('Unable to find workable SSL support') - sessionhdl(cnn, authname, cert=cert) + await sessionhdl(cnn, authname, cert=cert) def removesocket(): try: @@ -443,8 +445,10 @@ def removesocket(): except OSError: pass -def _unixdomainhandler(): +async def _unixdomainhandler(): + aloop = asyncio.get_event_loop() unixsocket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + unixsocket.settimeout(0) try: os.remove("/var/run/confluent/api.sock") except OSError: # if file does not exist, no big deal @@ -458,7 +462,7 @@ def _unixdomainhandler(): atexit.register(removesocket) unixsocket.listen(5) while True: - cnn, addr = unixsocket.accept() + cnn, addr = await aloop.sock_accept(unixsocket) creds = cnn.getsockopt(socket.SOL_SOCKET, SO_PEERCRED, struct.calcsize('iII')) pid, uid, gid = struct.unpack('iII', creds) @@ -479,7 +483,7 @@ def _unixdomainhandler(): except KeyError: cnn.close() return - eventlet.spawn_n(sessionhdl, cnn, authname, skipauth) + asyncio.create_task(sessionhdl(cnn, authname, skipauth)) class SockApi(object): @@ -489,7 +493,7 @@ class SockApi(object): self.bind_host = bindhost or '::' self.bind_port = bindport or 13001 - def start(self): + async def start(self): global auditlog global tracelog tracelog = log.Logger('trace') @@ -500,7 +504,7 @@ class SockApi(object): else: eventlet.spawn_n(self.watch_for_cert) eventlet.spawn_n(self.watch_resolv) - self.unixdomainserver = eventlet.spawn(_unixdomainhandler) + self.unixdomainserver = asyncio.create_task(_unixdomainhandler()) def watch_resolv(self): while True: