diff --git a/confluent_client/bin/confetty b/confluent_client/bin/confetty index fc6dfb3e..44010f8d 100755 --- a/confluent_client/bin/confetty +++ b/confluent_client/bin/confetty @@ -41,7 +41,7 @@ # esc-( would interfere with normal esc use too much # ~ I will not use for now... -import asyncio +import math import getpass import optparse import os @@ -51,7 +51,6 @@ import signal import socket import struct import sys -import concurrent.futures import time try: import fcntl @@ -235,23 +234,14 @@ session = None def completer(text, state): try: - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(run_completer, text, state) - return future.result() - except Exception: + return rcompleter(text, state) + except: pass - import traceback - traceback.print_exc() + #import traceback + #traceback.print_exc() -def run_completer(text, state): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - return loop.run_until_complete( - rcompleter(text, state)) - - -async def rcompleter(text, state): +def rcompleter(text, state): global candidates global valid_commands cline = readline.get_line_buffer() @@ -281,7 +271,7 @@ async def rcompleter(text, state): if candidates is None: candidates = [] targpath = fullpath_target(lastarg) - async for res in session.read(targpath): + for res in session.read(targpath): if 'item' in res: # a link relation if type(res['item']) == dict: candidates.append(res['item']["href"]) @@ -370,7 +360,7 @@ def print_result(res): print(output.encode('utf-8')) -async def do_command(command, server): +def do_command(command, server): global exitcode global target global currconsole @@ -409,7 +399,7 @@ async def do_command(command, server): target = otarget else: foundchild = False - async for res in session.read(parentpath, server): + for res in session.read(parentpath, server): try: if res['item']['href'] == childname: foundchild = True @@ -444,7 +434,7 @@ async def do_command(command, server): pass else: targpath = target - async for res in session.read(targpath): + for res in session.read(targpath): if 'item' in res: # a link relation if type(res['item']) == dict: print(res['item']["href"]) @@ -484,9 +474,9 @@ async def do_command(command, server): startconsole(nodename) return elif argv[0] == 'set': - await setvalues(argv[1:]) + setvalues(argv[1:]) elif argv[0] == 'create': - await createresource(argv[1:]) + createresource(argv[1:]) elif argv[0] in ('rm', 'delete', 'remove'): delresource(argv[1]) elif argv[0] in ('unset', 'clear'): @@ -501,7 +491,7 @@ def shutdown(): tlvdata.send(session.connection, {'operation': 'shutdown', 'path': '/'}) -async def createresource(args): +def createresource(args): resname = args[0] attribs = args[1:] keydata = parameterize_attribs(attribs) @@ -514,12 +504,12 @@ async def createresource(args): collection, _, resname = targpath.rpartition('/') if 'name' not in keydata: keydata['name'] = resname - await makecall(session.create, (collection, keydata)) + makecall(session.create, (collection, keydata)) -async def makecall(callout, args): +def makecall(callout, args): global exitcode - async for response in callout(*args): + for response in callout(*args): if 'deleted' in response: print("Deleted: " + response['deleted']) if 'created' in response: @@ -550,12 +540,12 @@ def clearvalues(resource, attribs): sys.stderr.write('Error: ' + res['error'] + '\n') -async def delresource(resname): +def delresource(resname): resname = fullpath_target(resname) - await makecall(session.delete, (resname,)) + makecall(session.delete, (resname,)) -async def setvalues(attribs): +def setvalues(attribs): global exitcode if '=' in attribs[0]: # going straight to attribute resource = attribs[0][:attribs[0].index("=")] @@ -569,7 +559,7 @@ async def setvalues(attribs): if not keydata: return targpath = fullpath_target(resource) - async for res in session.update(targpath, keydata): + for res in session.update(targpath, keydata): if 'error' in res: if 'errorcode' in res: exitcode = res['errorcode'] @@ -864,7 +854,7 @@ opts, shellargs = parser.parse_args() username = None passphrase = None -async def server_connect(): +def server_connect(): global session, username, passphrase if opts.controlpath: termhandler.TermHandler(opts.controlpath) @@ -874,7 +864,7 @@ async def server_connect(): session = client.Command(os.environ['CONFLUENT_HOST']) else: # unix domain session = client.Command() - await session.ensure_connected() + # Next stop, reading and writing from whichever of stdin and server goes first. #see pyghmi code for solconnect.py if not session.authenticated and username is not None: @@ -900,14 +890,11 @@ if sys.stdout.isatty(): import readline -async def main(): +def main(): global inconsole - global consoleonly - global doexit - global doexit try: - await server_connect() - except (EOFError, KeyboardInterrupt): + server_connect() + except (EOFError, KeyboardInterrupt) as _: raise BailOut(0) except socket.gaierror: sys.stderr.write('Could not connect to confluent\n') @@ -929,13 +916,14 @@ async def main(): doexit = False inconsole = False + pendingcommand = "" session_node = get_session_node(shellargs) if session_node is not None: consoleonly = True - await do_command("start /nodes/%s/console/session" % session_node, netserver) + do_command("start /nodes/%s/console/session" % session_node, netserver) doexit = True elif shellargs: - await do_command(shellargs, netserver) + do_command(shellargs, netserver) quitconfetty(fullexit=True, fixterm=False) powerstate = None @@ -966,7 +954,7 @@ async def main(): else: currcommand = prompt() try: - await do_command(currcommand, netserver) + do_command(currcommand, netserver) except socket.error: try: server_connect() @@ -1041,10 +1029,10 @@ if __name__ == '__main__': if opts.mintime: deadline = os.times()[4] + float(opts.mintime) try: - asyncio.get_event_loop().run_until_complete(main()) + main() except BailOut as e: errcode = e.errorcode - except Exception: + except Exception as e: import traceback excinfo = traceback.print_exc() try: diff --git a/confluent_client/bin/nodepower b/confluent_client/bin/nodepower index 968b83c6..32b9f906 100755 --- a/confluent_client/bin/nodepower +++ b/confluent_client/bin/nodepower @@ -15,7 +15,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import optparse import os import signal @@ -86,8 +85,4 @@ if options.previous: def outhandler(node, res): for k in res[node]: client.cprint('{0}: {1}: {2}'.format(node, k.replace('inlet_', ''), res[node][k])) -async def main(): - sys.exit(await session.simple_noderange_command(noderange, '/power/{0}'.format(powurl), setstate, promptover=options.maxnodes, key='state', outhandler=outhandler)) - -if __name__ == '__main__': - asyncio.get_event_loop().run_until_complete(main()) +sys.exit(session.simple_noderange_command(noderange, '/power/{0}'.format(powurl), setstate, promptover=options.maxnodes, key='state', outhandler=outhandler)) diff --git a/confluent_client/confluent/asynclient.py b/confluent_client/confluent/asynclient.py new file mode 100644 index 00000000..7e641eee --- /dev/null +++ b/confluent_client/confluent/asynclient.py @@ -0,0 +1,827 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2014 IBM Corporation +# Copyright 2015-2019 Lenovo +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import ctypes +import ctypes.util +import dbm +import csv +import errno +import fnmatch +import hashlib +import os +import shlex +import socket +import ssl +import sys +import confluent.asynctlvdata as tlvdata +import confluent.sortutil as sortutil +libssl = ctypes.CDLL(ctypes.util.find_library('ssl')) +libssl.SSL_CTX_set_cert_verify_callback.argtypes = [ + ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p] + +SO_PASSCRED = 16 + +_attraliases = { + 'bmc': 'hardwaremanagement.manager', + 'bmcuser': 'secret.hardwaremanagementuser', + 'switchuser': 'secret.hardwaremanagementuser', + 'bmcpass': 'secret.hardwaremanagementpassword', + 'switchpass': 'secret.hardwaremanagementpassword', +} + +try: + getinput = raw_input +except NameError: + getinput = input + + +class PyObject_HEAD(ctypes.Structure): + _fields_ = [ + ("ob_refcnt", ctypes.c_ssize_t), + ("ob_type", ctypes.c_void_p), + ] + + +# see main/Modules/_ssl.c, only caring about the SSL_CTX pointer +class PySSLContext(ctypes.Structure): + _fields_ = [ + ("ob_base", PyObject_HEAD), + ("ctx", ctypes.c_void_p), + ] + + +@ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p, ctypes.c_void_p) +def verify_stub(store, misc): + return 1 + + +class NestedDict(dict): + def __missing__(self, key): + value = self[key] = type(self)() + return value + + +def stringify(instr): + # Normalize unicode and bytes to 'str', correcting for + # current python version + if isinstance(instr, bytes) and not isinstance(instr, str): + return instr.decode('utf-8') + elif not isinstance(instr, bytes) and not isinstance(instr, str): + return instr.encode('utf-8') + return instr + + +class Tabulator(object): + def __init__(self, headers): + self.headers = headers + self.rows = [] + + def add_row(self, row): + self.rows.append(row) + + def get_table(self, order=None): + i = 0 + fmtstr = '' + separator = [] + for head in self.headers: + if order and order == head: + order = i + neededlen = len(head) + for row in self.rows: + if len(row[i]) > neededlen: + neededlen = len(row[i]) + separator.append('-' * (neededlen + 1)) + fmtstr += '{{{0}:>{1}}}|'.format(i, neededlen + 1) + i = i + 1 + fmtstr = fmtstr[:-1] + yield fmtstr.format(*self.headers) + yield fmtstr.format(*separator) + if order is not None: + for row in sorted( + self.rows, + key=lambda x: sortutil.naturalize_string(x[order])): + yield fmtstr.format(*row) + else: + for row in self.rows: + yield fmtstr.format(*row) + + def write_csv(self, output, order=None): + output = csv.writer(output) + output.writerow(self.headers) + i = 0 + for head in self.headers: + if order and order == head: + order = i + i = i + 1 + if order is not None: + for row in sorted( + self.rows, + key=lambda x: sortutil.naturalize_string(x[order])): + output.writerow(row) + else: + for row in self.rows: + output.writerow(row) + + +def printerror(res, node=None): + exitcode = 0 + if 'errorcode' in res: + exitcode = res['errorcode'] + for node in res.get('databynode', {}): + exitcode = res['databynode'][node].get('errorcode', exitcode) + if 'error' in res['databynode'][node]: + sys.stderr.write( + '{0}: {1}\n'.format(node, res['databynode'][node]['error'])) + if exitcode == 0: + exitcode = 1 + if 'error' in res: + if node: + sys.stderr.write('{0}: {1}\n'.format(node, res['error'])) + else: + sys.stderr.write('{0}\n'.format(res['error'])) + if 'errorcode' not in res: + exitcode = 1 + return exitcode + + +def cprint(txt): + try: + print(txt) + except UnicodeEncodeError: + print(txt.encode('utf8')) + sys.stdout.flush() + +def _parseserver(string): + if ']:' in string: + server, port = string[1:].split(']:') + elif string[0] == '[': + server = string[1:-1] + port = '13001' + elif ':' in string: + server, port = string.split(':') + else: + server = string + port = '13001' + return server, port + + +class Command(object): + def __init__(self, server=None): + self._prevdict = None + self._prevkeyname = None + self.connection = None + self._currnoderange = None + self.unixdomain = False + if server is None: + if 'CONFLUENT_HOST' in os.environ: + self.serverloc = os.environ['CONFLUENT_HOST'] + else: + self.serverloc = '/var/run/confluent/api.sock' + else: + self.serverloc = server + self.connected = False + + async def ensure_connected(self): + if self.connected: + return True + if os.path.isabs(self.serverloc) and os.path.exists(self.serverloc): + self._connect_unix() + self.unixdomain = True + elif self.serverloc == '/var/run/confluent/api.sock': + raise Exception('Confluent service is not available') + else: + await self._connect_tls() + self.protversion = int((await tlvdata.recv(self.connection)).split( + b'--')[1].strip()[1:]) + authdata = await tlvdata.recv(self.connection) + if authdata['authpassed'] == 1: + self.authenticated = True + else: + self.authenticated = False + if not self.authenticated and 'CONFLUENT_USER' in os.environ: + username = os.environ['CONFLUENT_USER'] + passphrase = os.environ['CONFLUENT_PASSPHRASE'] + await self.authenticate(username, passphrase) + self.connected = True + + async def add_file(self, name, handle, mode): + await self.ensure_connected() + if self.protversion < 3: + raise Exception('Not supported with connected confluent server') + if not self.unixdomain: + raise Exception('Can only add a file to a unix domain connection') + tlvdata.send(self.connection, {'filename': name, 'mode': mode}, handle) + + async def authenticate(self, username, password): + await tlvdata.send(self.connection, + {'username': username, 'password': password}) + authdata = await tlvdata.recv(self.connection) + if authdata['authpassed'] == 1: + self.authenticated = True + + def add_precede_key(self, keyname): + self._prevkeyname = keyname + + def add_precede_dict(self, dict): + self._prevdict = dict + + def handle_results(self, ikey, rc, res, errnodes=None, outhandler=None): + if 'error' in res: + if errnodes is not None: + errnodes.add(self._currnoderange) + sys.stderr.write('Error: {0}\n'.format(res['error'])) + if 'errorcode' in res: + return res['errorcode'] + else: + return 1 + if 'databynode' not in res: + return 0 + res = res['databynode'] + for node in res: + if 'error' in res[node]: + if errnodes is not None: + errnodes.add(node) + sys.stderr.write('{0}: Error: {1}\n'.format( + node, res[node]['error'])) + if 'errorcode' in res[node]: + rc |= res[node]['errorcode'] + else: + rc |= 1 + elif ikey in res[node]: + if 'value' in res[node][ikey]: + val = res[node][ikey]['value'] + elif 'isset' in res[node][ikey]: + val = '********' if res[node][ikey] else '' + else: + val = repr(res[node][ikey]) + if self._prevkeyname and self._prevkeyname in res[node]: + cprint('{0}: {2}->{1}'.format( + node, val, res[node][self._prevkeyname]['value'])) + elif self._prevdict and node in self._prevdict: + cprint('{0}: {2}->{1}'.format( + node, val, self._prevdict[node])) + else: + cprint('{0}: {1}'.format(node, val)) + elif outhandler: + outhandler(node, res) + return rc + + async def simple_noderange_command(self, noderange, resource, input=None, + key=None, errnodes=None, promptover=None, outhandler=None, **kwargs): + try: + self._currnoderange = noderange + rc = 0 + if resource[0] == '/': + resource = resource[1:] + # The implicit key is the resource basename + if key is None: + ikey = resource.rpartition('/')[-1] + else: + ikey = key + if input is None: + async for res in self.read('/noderange/{0}/{1}'.format( + noderange, resource)): + rc = self.handle_results(ikey, rc, res, errnodes, outhandler) + else: + await self.stop_if_noderange_over(noderange, promptover) + kwargs[ikey] = input + async for res in self.update('/noderange/{0}/{1}'.format( + noderange, resource), kwargs): + rc = self.handle_results(ikey, rc, res, errnodes, outhandler) + self._currnoderange = None + return rc + except KeyboardInterrupt: + cprint('') + return 0 + + async def stop_if_noderange_over(self, noderange, maxnodes): + if maxnodes is None: + return + nsize = await self.get_noderange_size(noderange) + if nsize > maxnodes: + if nsize == 1: + nodename = [x async for x in self.read( + '/noderange/{0}/nodes/'.format(noderange))][0].get('item', {}).get('href', None) + nodename = nodename[:-1] + p = getinput('Command is about to affect node {0}, continue (y/n)? '.format(nodename)) + else: + p = getinput('Command is about to affect {0} nodes, continue (y/n)? '.format(nsize)) + if p.lower() != 'y': + sys.stderr.write('Aborting at user request\n') + sys.exit(1) + raise Exception("Aborting at user request") + + + async def get_noderange_size(self, noderange): + numnodes = 0 + async for node in self.read('/noderange/{0}/nodes/'.format(noderange)): + if node.get('item', {}).get('href', None): + numnodes += 1 + else: + raise Exception("Error trying to size noderange {0}".format(noderange)) + return numnodes + + async def simple_nodegroups_command(self, noderange, resource, input=None, key=None, **kwargs): + try: + rc = 0 + if resource[0] == '/': + resource = resource[1:] + # The implicit key is the resource basename + if key is None: + ikey = resource.rpartition('/')[-1] + else: + ikey = key + if input is None: + for res in await self.read('/nodegroups/{0}/{1}'.format( + noderange, resource)): + rc = self.handle_results(ikey, rc, res) + else: + kwargs[ikey] = input + for res in await self.update('/nodegroups/{0}/{1}'.format( + noderange, resource), kwargs): + rc = self.handle_results(ikey, rc, res) + return rc + except KeyboardInterrupt: + cprint('') + return 0 + + async def read(self, path, parameters=None): + await self.ensure_connected() + if not self.authenticated: + raise Exception('Unauthenticated') + async for rsp in send_request( + 'retrieve', path, self.connection, parameters): + yield rsp + + async def update(self, path, parameters=None): + await self.ensure_connected() + if not self.authenticated: + raise Exception('Unauthenticated') + async for rsp in send_request( + 'update', path, self.connection, parameters): + yield rsp + + async def create(self, path, parameters=None): + await self.ensure_connected() + if not self.authenticated: + raise Exception('Unauthenticated') + async for rsp in send_request( + 'create', path, self.connection, parameters): + yield rsp + + async def delete(self, path, parameters=None): + await self.ensure_connected() + if not self.authenticated: + raise Exception('Unauthenticated') + async for rsp in send_request( + 'delete', path, self.connection, parameters): + yield rsp + + def _connect_unix(self): + self.connection = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.connection.setsockopt(socket.SOL_SOCKET, SO_PASSCRED, 1) + self.connection.connect(self.serverloc) + + async def _connect_tls(self): + server, port = _parseserver(self.serverloc) + for res in socket.getaddrinfo(server, port, socket.AF_UNSPEC, + socket.SOCK_STREAM): + af, socktype, proto, canonname, sa = res + try: + self.connection = socket.socket(af, socktype, proto) + self.connection.setsockopt( + socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + except: + self.connection = None + continue + try: + self.connection.settimeout(5) + self.connection.connect(sa) + self.connection.settimeout(0) + except: + raise + self.connection.close() + self.connection = None + continue + break + if self.connection is None: + raise Exception("Failed to connect to %s" % self.serverloc) + #TODO(jbjohnso): server certificate validation + clientcfgdir = os.path.join(os.path.expanduser("~"), ".confluent") + try: + os.makedirs(clientcfgdir) + except OSError as exc: + if not (exc.errno == errno.EEXIST and os.path.isdir(clientcfgdir)): + raise + cacert = os.path.join(clientcfgdir, "ca.pem") + certreqs = ssl.CERT_REQUIRED + knownhosts = False + if not os.path.exists(cacert): + cacert = None + certreqs = ssl.CERT_NONE + knownhosts = True + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) + ssl_ctx = PySSLContext.from_address(id(ctx)).ctx + libssl.SSL_CTX_set_cert_verify_callback(ssl_ctx, verify_stub, 0) + sreader = asyncio.StreamReader() + sreaderprot = asyncio.StreamReaderProtocol(sreader) + cloop = asyncio.get_event_loop() + tport, _ = await cloop.create_connection( + lambda: sreaderprot, sock=self.connection, ssl=ctx, server_hostname='x') + swriter = asyncio.StreamWriter(tport, sreaderprot, sreader, cloop) + self.connection = (sreader, swriter) + #self.connection = ssl.wrap_socket(self.connection, ca_certs=cacert, + # cert_reqs=certreqs) + if knownhosts: + certdata = tport.get_extra_info('ssl_object').getpeercert(binary_form=True) + # certdata = self.connection.getpeercert(binary_form=True) + fingerprint = 'sha512$' + hashlib.sha512(certdata).hexdigest() + fingerprint = fingerprint.encode('utf-8') + hostid = '@'.join((port, server)) + khf = dbm.open(os.path.join(clientcfgdir, "knownhosts"), 'c', 384) + if hostid in khf: + if fingerprint == khf[hostid]: + return + else: + replace = getinput( + "MISMATCHED CERTIFICATE DATA, ACCEPT NEW? (y/n):") + if replace not in ('y', 'Y'): + raise Exception("BAD CERTIFICATE") + cprint('Adding new key for %s:%s' % (server, port)) + khf[hostid] = fingerprint + + +async def send_request(operation, path, server, parameters=None): + """This function iterates over all the responses + received from the server. + + :param operation: The operation to request, retrieve, update, delete, + create, start, stop + :param path: The URI path to the resource to operate on + :param server: The socket to send data over + :param parameters: Parameters if any to send along with the request + """ + payload = {'operation': operation, 'path': path} + if parameters is not None: + payload['parameters'] = parameters + await tlvdata.send(server, payload) + result = await tlvdata.recv(server) + while '_requestdone' not in result: + try: + yield result + except GeneratorExit: + while '_requestdone' not in result: + result = await tlvdata.recv(server) + raise + result = await tlvdata.recv(server) + + +def attrrequested(attr, attrlist, seenattributes, node=None): + for candidate in attrlist: + truename = candidate + if candidate.startswith('hm'): + candidate = candidate.replace('hm', 'hardwaremanagement', 1) + if candidate in _attraliases: + candidate = _attraliases[candidate] + if fnmatch.fnmatch(attr.lower(), candidate.lower()): + if node is None: + seenattributes.add(truename) + else: + seenattributes[node][truename] = True + return True + elif attr.lower().startswith(candidate.lower() + '.'): + if node is None: + seenattributes.add(truename) + else: + seenattributes[node][truename] = 1 + return True + return False + + +async def printattributes(session, requestargs, showtype, nodetype, noderange, options): + path = '/{0}/{1}/attributes/{2}'.format(nodetype, noderange, showtype) + return await print_attrib_path(path, session, requestargs, options) + +def _sort_attrib(k): + if isinstance(k[1], dict) and k[1].get('sortid', None) is not None: + return k[1]['sortid'] + return k[0] + +async def print_attrib_path(path, session, requestargs, options, rename=None, attrprefix=None): + exitcode = 0 + seenattributes = NestedDict() + allnodes = set([]) + async for res in session.read(path): + if 'error' in res: + sys.stderr.write(res['error'] + '\n') + exitcode = 1 + continue + for node in sorted(res['databynode']): + allnodes.add(node) + for attr, val in sorted(res['databynode'][node].items(), key=_sort_attrib): + if attr == 'error': + sys.stderr.write('{0}: Error: {1}\n'.format(node, val)) + continue + if attr == 'errorcode': + exitcode |= val + continue + seenattributes[node][attr] = True + if rename: + printattr = rename.get(attr, attr) + else: + printattr = attr + if attrprefix: + printattr = attrprefix + printattr + currattr = res['databynode'][node][attr] + if show_attr(attr, requestargs, seenattributes, options, node): + if 'value' in currattr: + if currattr['value'] is not None: + val = currattr['value'] + if isinstance(val, list): + val = ','.join(val) + attrout = '{0}: {1}: {2}'.format( + node, printattr, val).strip() + else: + attrout = '{0}: {1}:'.format(node, printattr) + elif 'isset' in currattr: + if currattr['isset']: + attrout = '{0}: {1}: ********'.format(node, + printattr) + else: + attrout = '{0}: {1}:'.format(node, printattr) + elif isinstance(currattr, dict) and 'broken' in currattr: + attrout = '{0}: {1}: *ERROR* BROKEN EXPRESSION: ' \ + '{2}'.format(node, printattr, + currattr['broken']) + elif isinstance(currattr, list) or isinstance(currattr, tuple): + attrout = '{0}: {1}: {2}'.format(node, attr, ','.join(map(str, currattr))) + elif isinstance(currattr, dict): + dictout = [] + for k, v in currattr.items: + dictout.append("{0}={1}".format(k, v)) + attrout = '{0}: {1}: {2}'.format(node, printattr, ','.join(map(str, dictout))) + else: + cprint("CODE ERROR" + repr(attr)) + try: + blame = options.blame + except AttributeError: + blame = False + if blame or (isinstance(currattr, dict) and 'broken' in currattr): + blamedata = [] + if 'inheritedfrom' in currattr: + blamedata.append('inherited from group {0}'.format( + currattr['inheritedfrom'] + )) + if 'expression' in currattr: + blamedata.append( + 'derived from expression "{0}"'.format( + currattr['expression'])) + if blamedata: + attrout += ' (' + ', '.join(blamedata) + ')' + try: + comparedefault = options.comparedefault + except AttributeError: + comparedefault = False + if comparedefault: + try: + exclude = options.exclude + except AttributeError: + exclude = False + if ((requestargs and not exclude) or + (currattr.get('default', None) is not None and + currattr.get('value', None) is not None and + currattr['value'] != currattr['default'])): + cval = ','.join(currattr['value']) if isinstance( + currattr['value'], list) else currattr['value'] + dval = ','.join(currattr['default']) if isinstance( + currattr['default'], list) else currattr['default'] + cprint('{0}: {1}: {2} (Default: {3})'.format( + node, printattr, cval, dval)) + else: + + try: + details = options.detail + except AttributeError: + details = False + if details: + if currattr.get('help', None): + attrout += u' (Help: {0})'.format( + currattr['help']) + if currattr.get('possible', None): + try: + attrout += u' (Choices: {0})'.format( + ','.join(currattr['possible'])) + except TypeError: + pass + cprint(attrout) + somematched = set([]) + printmissing = set([]) + badnodes = NestedDict() + if not exitcode: + if requestargs: + for attr in requestargs: + for node in allnodes: + if attr in seenattributes[node]: + somematched.add(attr) + else: + badnodes[node][attr] = True + exitcode = 1 + for node in sortutil.natural_sort(badnodes): + for attr in badnodes[node]: + if attr in somematched: + sys.stderr.write( + 'Error: {0} matches no valid value for {1}\n'.format( + attr, node)) + else: + printmissing.add(attr) + for missing in printmissing: + sys.stderr.write('Error: {0} not a valid attribute\n'.format(missing)) + return exitcode + + +def show_attr(attr, requestargs, seenattributes, options, node): + try: + reverse = options.exclude + except AttributeError: + reverse = False + if requestargs is None or requestargs == []: + return True + processattr = attrrequested(attr, requestargs, seenattributes, node) + if reverse: + processattr = not processattr + return processattr + + +def printgroupattributes(session, requestargs, showtype, nodetype, noderange, options): + exitcode = 0 + seenattributes = set([]) + for res in session.read('/{0}/{1}/attributes/{2}'.format(nodetype, noderange, showtype)): + if 'error' in res: + sys.stderr.write(res['error'] + '\n') + exitcode = 1 + continue + for attr in res: + seenattributes.add(attr) + currattr = res[attr] + if (requestargs is None or requestargs == [] or attrrequested(attr, requestargs, seenattributes)): + if 'value' in currattr: + if currattr['value'] is not None: + attrout = '{0}: {1}: {2}'.format( + noderange, attr, currattr['value']) + else: + attrout = '{0}: {1}:'.format(noderange, attr) + elif 'isset' in currattr: + if currattr['isset']: + attrout = '{0}: {1}: ********'.format(noderange, attr) + else: + attrout = '{0}: {1}:'.format(noderange, attr) + elif isinstance(currattr, dict) and 'broken' in currattr: + attrout = '{0}: {1}: *ERROR* BROKEN EXPRESSION: ' \ + '{2}'.format(noderange, attr, + currattr['broken']) + elif 'expression' in currattr: + attrout = '{0}: {1}: (will derive from expression {2})'.format(noderange, attr, currattr['expression']) + elif isinstance(currattr, list) or isinstance(currattr, tuple): + attrout = '{0}: {1}: {2}'.format(noderange, attr, ','.join(map(str, currattr))) + elif isinstance(currattr, dict): + dictout = [] + for k, v in currattr.items: + dictout.append("{0}={1}".format(k, v)) + attrout = '{0}: {1}: {2}'.format(noderange, attr, ','.join(map(str, dictout))) + else: + cprint("CODE ERROR" + repr(attr)) + cprint(attrout) + if not exitcode: + if requestargs: + for attr in requestargs: + if attr not in seenattributes: + sys.stderr.write('Error: {0} not a valid attribute\n'.format(attr)) + exitcode = 1 + return exitcode + +async def updateattrib(session, updateargs, nodetype, noderange, options, dictassign=None): + # update attribute + exitcode = 0 + if options.clear: + targpath = '/{0}/{1}/attributes/all'.format(nodetype, noderange) + keydata = {} + for attrib in updateargs[1:]: + keydata[attrib] = None + async for res in session.update(targpath, keydata): + for node in res.get('databynode', {}): + for warnmsg in res['databynode'][node].get('_warnings', []): + sys.stderr.write('Warning: ' + warnmsg + '\n') + if 'error' in res: + if 'errorcode' in res: + exitcode = res['errorcode'] + sys.stderr.write('Error: ' + res['error'] + '\n') + sys.exit(exitcode) + elif hasattr(options, 'environment') and options.environment: + for key in updateargs[1:]: + key = key.replace('.', '_') + value = os.environ.get( + key, os.environ[key.upper()]) + # Let's do one pass to make sure that there's not a usage problem + for key in updateargs[1:]: + key = key.replace('.', '_') + value = os.environ.get( + key, os.environ[key.upper()]) + if (nodetype == "nodegroups"): + exitcode = await session.simple_nodegroups_command(noderange, + 'attributes/all', + value, key) + else: + exitcode = await session.simple_noderange_command(noderange, + 'attributes/all', + value, key) + sys.exit(exitcode) + elif dictassign: + for key in dictassign: + if nodetype == 'nodegroups': + exitcode = await session.simple_nodegroups_command( + noderange, 'attributes/all', dictassign[key], key) + else: + exitcode = await session.simple_noderange_command( + noderange, 'attributes/all', dictassign[key], key) + else: + if "=" in updateargs[1]: + update_ready = True + for arg in updateargs[1:]: + if not '=' in arg: + update_ready = False + exitcode = 1 + if not update_ready: + sys.stderr.write('Error: {0} Can not set and read at the same time!\n'.format(str(updateargs[1:]))) + sys.exit(exitcode) + try: + for val in updateargs[1:]: + val = val.split('=', 1) + if val[0][-1] in (',', '-', '^'): + key = val[0][:-1] + if val[0][-1] == ',': + value = {'prepend': val[1]} + elif val[0][-1] in ('-', '^'): + value = {'remove': val[1]} + else: + key = val[0] + value = val[1] + if (nodetype == "nodegroups"): + exitcode = await session.simple_nodegroups_command(noderange, 'attributes/all', + value, key) + else: + exitcode = await session.simple_noderange_command(noderange, 'attributes/all', + value, key) + except Exception: + sys.stderr.write('Error: {0} not a valid expression\n'.format(str(updateargs[1:]))) + exitcode = 1 + sys.exit(exitcode) + return exitcode + + +# So we try to prevent bad things from happening when globbing +# We tried to head this off at the shell, but the various solutions would end +# up breaking the shell in various ways (breaking pipe capability if using +# DEBUG, breaking globbing if in pipe, etc) +# Then we tried to parse the original commandline instead, however shlex isn't +# going to parse full bourne language (e.g. knowing that '|' and '>' and +# a world of other things would not be in our command line +# so finally, just make sure the noderange appears verbatim in the command line +# if we glob to something, then bash will change noderange and this should +# detect it and save the user from tragedy +def check_globbing(noderange): + if not os.path.exists(noderange): + return True + rawargs = os.environ.get('CURRENT_CMDLINE', None) + if rawargs: + rawargs = shlex.split(rawargs) + for arg in rawargs: + if arg.startswith('$'): + arg = arg[1:] + if arg.endswith(';'): + arg = arg[:-1] + arg = os.environ.get(arg, '$' + arg) + if arg.startswith(noderange): + break + else: + sys.stderr.write( + 'Shell glob conflict detected, specified target "{0}" ' + 'not in command line, but is a file. You can use "set -f" in ' + 'bash or change directories such that there is no filename ' + 'that would conflict.' + '\n'.format(noderange)) + sys.exit(1) diff --git a/confluent_client/confluent/asynctlvdata.py b/confluent_client/confluent/asynctlvdata.py new file mode 100644 index 00000000..629687d6 --- /dev/null +++ b/confluent_client/confluent/asynctlvdata.py @@ -0,0 +1,318 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2014 IBM Corporation +# Copyright 2015 Lenovo +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import array +import asyncio +import ctypes +import ctypes.util +import confluent.tlv as tlv +import socket +from datetime import datetime +import json +import os +import struct + +try: + unicode +except NameError: + unicode = str + +try: + range = xrange +except NameError: + pass + + +class iovec(ctypes.Structure): # from uio.h + _fields_ = [('iov_base', ctypes.c_void_p), + ('iov_len', ctypes.c_size_t)] + + +iovec_ptr = ctypes.POINTER(iovec) + + +class cmsghdr(ctypes.Structure): # also from bits/socket.h + _fields_ = [('cmsg_len', ctypes.c_size_t), + ('cmsg_level', ctypes.c_int), + ('cmsg_type', ctypes.c_int)] + + @classmethod + def init_data(cls, cmsg_len, cmsg_level, cmsg_type, cmsg_data): + Data = ctypes.c_ubyte * ctypes.sizeof(cmsg_data) + + class _flexhdr(ctypes.Structure): + _fields_ = cls._fields_ + [('cmsg_data', Data)] + + datab = Data(*bytearray(cmsg_data)) + return _flexhdr(cmsg_len=cmsg_len, cmsg_level=cmsg_level, + cmsg_type=cmsg_type, cmsg_data=datab) + + +def CMSG_LEN(length): + sizeof_cmshdr = ctypes.sizeof(cmsghdr) + return ctypes.c_size_t(CMSG_ALIGN(sizeof_cmshdr).value + length) + + +SCM_RIGHTS = 1 + + +class msghdr(ctypes.Structure): # from bits/socket.h + _fields_ = [('msg_name', ctypes.c_void_p), + ('msg_namelen', ctypes.c_uint), + ('msg_iov', ctypes.POINTER(iovec)), + ('msg_iovlen', ctypes.c_size_t), + ('msg_control', ctypes.c_void_p), + ('msg_controllen', ctypes.c_size_t), + ('msg_flags', ctypes.c_int)] + + +def CMSG_ALIGN(length): # bits/socket.h + ret = (length + ctypes.sizeof(ctypes.c_size_t) - 1 + & ~(ctypes.sizeof(ctypes.c_size_t) - 1)) + return ctypes.c_size_t(ret) + + +def CMSG_SPACE(length): # bits/socket.h + ret = CMSG_ALIGN(length).value + CMSG_ALIGN(ctypes.sizeof(cmsghdr)).value + return ctypes.c_size_t(ret) + + +class ClientFile(object): + def __init__(self, name, mode, fd): + self.fileobject = os.fdopen(fd, mode) + self.filename = name + + + + +def _sendmsg(loop, fut, sock, msg, fds, rfd): + if rfd is not None: + loop.remove_reader(rfd) + if fut.cancelled(): + return + try: + retdata = sock.sendmsg( + [msg], + [(socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds))]) + except (BlockingIOError, InterruptedError): + fd = sock.fileno() + loop.add_reader(fd, _sendmsg, loop, fut, sock, fd) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(retdata) + + +def send_fds(sock, msg, fds): + cloop = asyncio.get_event_loop() + fut = cloop.create_future() + _sendmsg(cloop, fut, sock, msg, fds, None) + return fut + + +def _recvmsg(loop, fut, sock, msglen, maxfds, rfd): + if rfd is not None: + loop.remove_reader(rfd) + fds = array.array("i") # Array of ints + try: + msg, ancdata, flags, addr = sock.recvmsg( + msglen, socket.CMSG_LEN(maxfds * fds.itemsize)) + except (BlockingIOError, InterruptedError): + fd = sock.fileno() + loop.add_reader(fd, _recvmsg, loop, fut, sock, fd) + except Exception as exc: + fut.set_exception(exc) + else: + for cmsg_level, cmsg_type, cmsg_data in ancdata: + if (cmsg_level == socket.SOL_SOCKET + and cmsg_type == socket.SCM_RIGHTS): + # Append data, ignoring any truncated integers at the end. + fds.frombytes( + cmsg_data[ + :len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) + fut.set_result(msglen, list(fds)) + + +def recv_fds(sock, msglen, maxfds): + cloop = asyncio.get_event_loop() + fut = cloop.create_future() + _recvmsg(cloop, fut, sock, msglen, maxfds, None) + return fut + + +def decodestr(value): + ret = None + try: + ret = value.decode('utf-8') + except UnicodeDecodeError: + try: + ret = value.decode('cp437') + except UnicodeDecodeError: + ret = value + except AttributeError: + return value + return ret + + +def unicode_dictvalues(dictdata): + for key in dictdata: + if isinstance(dictdata[key], bytes): + dictdata[key] = decodestr(dictdata[key]) + elif isinstance(dictdata[key], datetime): + dictdata[key] = dictdata[key].strftime('%Y-%m-%dT%H:%M:%S') + elif isinstance(dictdata[key], list): + _unicode_list(dictdata[key]) + elif isinstance(dictdata[key], dict): + unicode_dictvalues(dictdata[key]) + + +def _unicode_list(currlist): + for i in range(len(currlist)): + if isinstance(currlist[i], str): + currlist[i] = decodestr(currlist[i]) + elif isinstance(currlist[i], dict): + unicode_dictvalues(currlist[i]) + elif isinstance(currlist[i], list): + _unicode_list(currlist[i]) + + +async def sendall(handle, data): + if isinstance(handle, tuple): + handle[1].write(data) + return await handle[1].drain() + else: + cloop = asyncio.get_event_loop() + return await cloop.sock_sendall(handle, data) + + +async def send(handle, data, filehandle=None): + cloop = asyncio.get_event_loop() + if isinstance(data, unicode): + try: + data = data.encode('utf-8') + except AttributeError: + pass + if isinstance(data, bytes) or isinstance(data, unicode): + # plain text, e.g. console data + tl = len(data) + if tl == 0: + # if you don't have anything to say, don't say anything at all + return + if tl < 16777216: + # type for string is '0', so we don't need + # to xor anything in + await sendall(handle, struct.pack("!I", tl)) + else: + raise Exception("String data length exceeds protocol") + await sendall(handle, data) + elif isinstance(data, dict): # JSON currently only goes to 4 bytes + # Some structured message, like what would be seen in http responses + unicode_dictvalues(data) # make everything unicode, assuming UTF-8 + sdata = json.dumps(data, ensure_ascii=False, separators=(',', ':')) + sdata = sdata.encode('utf-8') + tl = len(sdata) + if tl > 16777215: + raise Exception("JSON data exceeds protocol limits") + # xor in the type (0b1 << 24) + if filehandle is None: + tl |= 16777216 + await sendall(handle, struct.pack("!I", tl)) + await sendall(handle, sdata) + elif isinstance(handle, tuple): + raise Exception("Cannot send filehandle over network socket") + else: + tl |= (2 << 24) + await cloop.sock_sendall(handle, struct.pack("!I", tl)) + await send_fds(handle, b'', [filehandle]) + + +async def _grabhdl(handle, size): + if isinstance(handle, tuple): + return await handle[0].read(size) + else: + cloop = asyncio.get_event_loop() + return await cloop.sock_recv(handle, size) + + +async def recvall(handle, size): + rd = await _grabhdl(handle, size) + while len(rd) < size: + nd = await _grabhdl(handle, size - len(rd)) + if not nd: + raise Exception("Error reading data") + rd += nd + return rd + + +async def recv(handle): + tl = await _grabhdl(handle, 4) + if not tl: + return None + while len(tl) < 4: + ndata = await _grabhdl(handle, 4 - len(tl)) + if not ndata: + raise Exception("Error reading data") + tl += ndata + if len(tl) == 0: + return None + tl = struct.unpack("!I", tl)[0] + if tl & 0b10000000000000000000000000000000: + raise Exception("Protocol Violation, reserved bit set") + # 4 byte tlv + dlen = tl & 16777215 # grab lower 24 bits + datatype = (tl & 2130706432) >> 24 # grab 7 bits from near beginning + if dlen == 0: + return None + if datatype == tlv.Types.filehandle: + if isinstance(handle, tuple): + raise Exception('Filehandle not supported over TLS socket') + filehandles = array.array('i') + rawbuffer = bytearray(2048) + pkttype = ctypes.c_ubyte * 2048 + data = pkttype.from_buffer(rawbuffer) + cmsgsize = CMSG_SPACE(ctypes.sizeof(ctypes.c_int)).value + cmsgarr = bytearray(cmsgsize) + cmtype = ctypes.c_ubyte * cmsgsize + cmsg = cmtype.from_buffer(cmsgarr) + cmsg.cmsg_level = socket.SOL_SOCKET + cmsg.cmsg_type = SCM_RIGHTS + cmsg.cmsg_len = CMSG_LEN(ctypes.sizeof(ctypes.c_int)) + iov = iovec() + iov.iov_base = ctypes.addressof(data) + iov.iov_len = 2048 + msg = msghdr() + msg.msg_iov = ctypes.pointer(iov) + msg.msg_iovlen = 1 + msg.msg_control = ctypes.addressof(cmsg) + msg.msg_controllen = ctypes.sizeof(cmsg) + i = await recv_fds(handle, 2048, 4) + print(repr(i)) + data = i[0] + filehandles = i[1] + data = json.loads(bytes(data)) + return ClientFile(data['filename'], data['mode'], filehandles[0]) + else: + data = await _grabhdl(handle, dlen) + while len(data) < dlen: + ndata = await _grabhdl(handle, dlen - len(data)) + if not ndata: + raise Exception("Error reading data") + data += ndata + if datatype == tlv.Types.text: + return data + elif datatype == tlv.Types.json: + return json.loads(data) diff --git a/confluent_client/confluent/client.py b/confluent_client/confluent/client.py index 117f1d83..a9957b96 100644 --- a/confluent_client/confluent/client.py +++ b/confluent_client/confluent/client.py @@ -15,10 +15,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio -import ctypes -import ctypes.util -import dbm +try: + import anydbm as dbm +except ImportError: + import dbm import csv import errno import fnmatch @@ -30,9 +30,6 @@ import ssl import sys import confluent.tlvdata as tlvdata import confluent.sortutil as sortutil -libssl = ctypes.CDLL(ctypes.util.find_library('ssl')) -libssl.SSL_CTX_set_cert_verify_callback.argtypes = [ - ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p] SO_PASSCRED = 16 @@ -50,26 +47,6 @@ except NameError: getinput = input -class PyObject_HEAD(ctypes.Structure): - _fields_ = [ - ("ob_refcnt", ctypes.c_ssize_t), - ("ob_type", ctypes.c_void_p), - ] - - -# see main/Modules/_ssl.c, only caring about the SSL_CTX pointer -class PySSLContext(ctypes.Structure): - _fields_ = [ - ("ob_base", PyObject_HEAD), - ("ctx", ctypes.c_void_p), - ] - - -@ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p, ctypes.c_void_p) -def verify_stub(store, misc): - return 1 - - class NestedDict(dict): def __missing__(self, key): value = self[key] = type(self)() @@ -85,7 +62,6 @@ def stringify(instr): return instr.encode('utf-8') return instr - class Tabulator(object): def __init__(self, headers): self.headers = headers @@ -145,8 +121,7 @@ def printerror(res, node=None): for node in res.get('databynode', {}): exitcode = res['databynode'][node].get('errorcode', exitcode) if 'error' in res['databynode'][node]: - sys.stderr.write( - '{0}: {1}\n'.format(node, res['databynode'][node]['error'])) + sys.stderr.write('{0}: {1}\n'.format(node, res['databynode'][node]['error'])) if exitcode == 0: exitcode = 1 if 'error' in res: @@ -194,21 +169,16 @@ class Command(object): self.serverloc = '/var/run/confluent/api.sock' else: self.serverloc = server - self.connected = False - - async def ensure_connected(self): - if self.connected: - return True if os.path.isabs(self.serverloc) and os.path.exists(self.serverloc): self._connect_unix() self.unixdomain = True elif self.serverloc == '/var/run/confluent/api.sock': raise Exception('Confluent service is not available') else: - await self._connect_tls() - self.protversion = int((await tlvdata.recv(self.connection)).split( + self._connect_tls() + self.protversion = int(tlvdata.recv(self.connection).split( b'--')[1].strip()[1:]) - authdata = await tlvdata.recv(self.connection) + authdata = tlvdata.recv(self.connection) if authdata['authpassed'] == 1: self.authenticated = True else: @@ -216,21 +186,19 @@ class Command(object): if not self.authenticated and 'CONFLUENT_USER' in os.environ: username = os.environ['CONFLUENT_USER'] passphrase = os.environ['CONFLUENT_PASSPHRASE'] - await self.authenticate(username, passphrase) - self.connected = True + self.authenticate(username, passphrase) - async def add_file(self, name, handle, mode): - await self.ensure_connected() + def add_file(self, name, handle, mode): if self.protversion < 3: raise Exception('Not supported with connected confluent server') if not self.unixdomain: raise Exception('Can only add a file to a unix domain connection') tlvdata.send(self.connection, {'filename': name, 'mode': mode}, handle) - async def authenticate(self, username, password): - await tlvdata.send(self.connection, - {'username': username, 'password': password}) - authdata = await tlvdata.recv(self.connection) + def authenticate(self, username, password): + tlvdata.send(self.connection, + {'username': username, 'password': password}) + authdata = tlvdata.recv(self.connection) if authdata['authpassed'] == 1: self.authenticated = True @@ -281,7 +249,7 @@ class Command(object): outhandler(node, res) return rc - async def simple_noderange_command(self, noderange, resource, input=None, + def simple_noderange_command(self, noderange, resource, input=None, key=None, errnodes=None, promptover=None, outhandler=None, **kwargs): try: self._currnoderange = noderange @@ -294,13 +262,13 @@ class Command(object): else: ikey = key if input is None: - async for res in self.read('/noderange/{0}/{1}'.format( + for res in self.read('/noderange/{0}/{1}'.format( noderange, resource)): rc = self.handle_results(ikey, rc, res, errnodes, outhandler) else: - await self.stop_if_noderange_over(noderange, promptover) + self.stop_if_noderange_over(noderange, promptover) kwargs[ikey] = input - async for res in self.update('/noderange/{0}/{1}'.format( + for res in self.update('/noderange/{0}/{1}'.format( noderange, resource), kwargs): rc = self.handle_results(ikey, rc, res, errnodes, outhandler) self._currnoderange = None @@ -309,14 +277,14 @@ class Command(object): cprint('') return 0 - async def stop_if_noderange_over(self, noderange, maxnodes): + def stop_if_noderange_over(self, noderange, maxnodes): if maxnodes is None: return - nsize = await self.get_noderange_size(noderange) + nsize = self.get_noderange_size(noderange) if nsize > maxnodes: if nsize == 1: - nodename = [x async for x in self.read( - '/noderange/{0}/nodes/'.format(noderange))][0].get('item', {}).get('href', None) + nodename = list(self.read( + '/noderange/{0}/nodes/'.format(noderange)))[0].get('item', {}).get('href', None) nodename = nodename[:-1] p = getinput('Command is about to affect node {0}, continue (y/n)? '.format(nodename)) else: @@ -327,16 +295,16 @@ class Command(object): raise Exception("Aborting at user request") - async def get_noderange_size(self, noderange): + def get_noderange_size(self, noderange): numnodes = 0 - async for node in self.read('/noderange/{0}/nodes/'.format(noderange)): + for node in self.read('/noderange/{0}/nodes/'.format(noderange)): if node.get('item', {}).get('href', None): numnodes += 1 else: raise Exception("Error trying to size noderange {0}".format(noderange)) return numnodes - async def simple_nodegroups_command(self, noderange, resource, input=None, key=None, **kwargs): + def simple_nodegroups_command(self, noderange, resource, input=None, key=None, **kwargs): try: rc = 0 if resource[0] == '/': @@ -347,12 +315,12 @@ class Command(object): else: ikey = key if input is None: - for res in await self.read('/nodegroups/{0}/{1}'.format( + for res in self.read('/nodegroups/{0}/{1}'.format( noderange, resource)): rc = self.handle_results(ikey, rc, res) else: kwargs[ikey] = input - for res in await self.update('/nodegroups/{0}/{1}'.format( + for res in self.update('/nodegroups/{0}/{1}'.format( noderange, resource), kwargs): rc = self.handle_results(ikey, rc, res) return rc @@ -360,44 +328,32 @@ class Command(object): cprint('') return 0 - async def read(self, path, parameters=None): - await self.ensure_connected() + def read(self, path, parameters=None): if not self.authenticated: raise Exception('Unauthenticated') - async for rsp in send_request( - 'retrieve', path, self.connection, parameters): - yield rsp + return send_request('retrieve', path, self.connection, parameters) - async def update(self, path, parameters=None): - await self.ensure_connected() + def update(self, path, parameters=None): if not self.authenticated: raise Exception('Unauthenticated') - async for rsp in send_request( - 'update', path, self.connection, parameters): - yield rsp + return send_request('update', path, self.connection, parameters) - async def create(self, path, parameters=None): - await self.ensure_connected() + def create(self, path, parameters=None): if not self.authenticated: raise Exception('Unauthenticated') - async for rsp in send_request( - 'create', path, self.connection, parameters): - yield rsp + return send_request('create', path, self.connection, parameters) - async def delete(self, path, parameters=None): - await self.ensure_connected() + def delete(self, path, parameters=None): if not self.authenticated: raise Exception('Unauthenticated') - async for rsp in send_request( - 'delete', path, self.connection, parameters): - yield rsp + return send_request('delete', path, self.connection, parameters) def _connect_unix(self): self.connection = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.connection.setsockopt(socket.SOL_SOCKET, SO_PASSCRED, 1) self.connection.connect(self.serverloc) - async def _connect_tls(self): + def _connect_tls(self): server, port = _parseserver(self.serverloc) for res in socket.getaddrinfo(server, port, socket.AF_UNSPEC, socket.SOCK_STREAM): @@ -412,7 +368,7 @@ class Command(object): try: self.connection.settimeout(5) self.connection.connect(sa) - self.connection.settimeout(0) + self.connection.settimeout(None) except: raise self.connection.close() @@ -435,21 +391,10 @@ class Command(object): cacert = None certreqs = ssl.CERT_NONE knownhosts = True - ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) - ssl_ctx = PySSLContext.from_address(id(ctx)).ctx - libssl.SSL_CTX_set_cert_verify_callback(ssl_ctx, verify_stub, 0) - sreader = asyncio.StreamReader() - sreaderprot = asyncio.StreamReaderProtocol(sreader) - cloop = asyncio.get_event_loop() - tport, _ = await cloop.create_connection( - lambda: sreaderprot, sock=self.connection, ssl=ctx, server_hostname='x') - swriter = asyncio.StreamWriter(tport, sreaderprot, sreader, cloop) - self.connection = (sreader, swriter) - #self.connection = ssl.wrap_socket(self.connection, ca_certs=cacert, - # cert_reqs=certreqs) + self.connection = ssl.wrap_socket(self.connection, ca_certs=cacert, + cert_reqs=certreqs) if knownhosts: - certdata = tport.get_extra_info('ssl_object').getpeercert(binary_form=True) - # certdata = self.connection.getpeercert(binary_form=True) + certdata = self.connection.getpeercert(binary_form=True) fingerprint = 'sha512$' + hashlib.sha512(certdata).hexdigest() fingerprint = fingerprint.encode('utf-8') hostid = '@'.join((port, server)) @@ -466,7 +411,7 @@ class Command(object): khf[hostid] = fingerprint -async def send_request(operation, path, server, parameters=None): +def send_request(operation, path, server, parameters=None): """This function iterates over all the responses received from the server. @@ -479,16 +424,16 @@ async def send_request(operation, path, server, parameters=None): payload = {'operation': operation, 'path': path} if parameters is not None: payload['parameters'] = parameters - await tlvdata.send(server, payload) - result = await tlvdata.recv(server) + tlvdata.send(server, payload) + result = tlvdata.recv(server) while '_requestdone' not in result: try: yield result except GeneratorExit: while '_requestdone' not in result: - result = await tlvdata.recv(server) + result = tlvdata.recv(server) raise - result = await tlvdata.recv(server) + result = tlvdata.recv(server) def attrrequested(attr, attrlist, seenattributes, node=None): @@ -513,20 +458,20 @@ def attrrequested(attr, attrlist, seenattributes, node=None): return False -async def printattributes(session, requestargs, showtype, nodetype, noderange, options): +def printattributes(session, requestargs, showtype, nodetype, noderange, options): path = '/{0}/{1}/attributes/{2}'.format(nodetype, noderange, showtype) - return await print_attrib_path(path, session, requestargs, options) + return print_attrib_path(path, session, requestargs, options) def _sort_attrib(k): if isinstance(k[1], dict) and k[1].get('sortid', None) is not None: return k[1]['sortid'] return k[0] -async def print_attrib_path(path, session, requestargs, options, rename=None, attrprefix=None): +def print_attrib_path(path, session, requestargs, options, rename=None, attrprefix=None): exitcode = 0 seenattributes = NestedDict() allnodes = set([]) - async for res in session.read(path): + for res in session.read(path): if 'error' in res: sys.stderr.write(res['error'] + '\n') exitcode = 1 @@ -714,7 +659,7 @@ def printgroupattributes(session, requestargs, showtype, nodetype, noderange, op exitcode = 1 return exitcode -async def updateattrib(session, updateargs, nodetype, noderange, options, dictassign=None): +def updateattrib(session, updateargs, nodetype, noderange, options, dictassign=None): # update attribute exitcode = 0 if options.clear: @@ -722,7 +667,7 @@ async def updateattrib(session, updateargs, nodetype, noderange, options, dictas keydata = {} for attrib in updateargs[1:]: keydata[attrib] = None - async for res in session.update(targpath, keydata): + for res in session.update(targpath, keydata): for node in res.get('databynode', {}): for warnmsg in res['databynode'][node].get('_warnings', []): sys.stderr.write('Warning: ' + warnmsg + '\n') @@ -742,21 +687,21 @@ async def updateattrib(session, updateargs, nodetype, noderange, options, dictas value = os.environ.get( key, os.environ[key.upper()]) if (nodetype == "nodegroups"): - exitcode = await session.simple_nodegroups_command(noderange, + exitcode = session.simple_nodegroups_command(noderange, 'attributes/all', value, key) else: - exitcode = await session.simple_noderange_command(noderange, + exitcode = session.simple_noderange_command(noderange, 'attributes/all', value, key) sys.exit(exitcode) elif dictassign: for key in dictassign: if nodetype == 'nodegroups': - exitcode = await session.simple_nodegroups_command( + exitcode = session.simple_nodegroups_command( noderange, 'attributes/all', dictassign[key], key) else: - exitcode = await session.simple_noderange_command( + exitcode = session.simple_noderange_command( noderange, 'attributes/all', dictassign[key], key) else: if "=" in updateargs[1]: @@ -781,12 +726,12 @@ async def updateattrib(session, updateargs, nodetype, noderange, options, dictas key = val[0] value = val[1] if (nodetype == "nodegroups"): - exitcode = await session.simple_nodegroups_command(noderange, 'attributes/all', + exitcode = session.simple_nodegroups_command(noderange, 'attributes/all', value, key) else: - exitcode = await session.simple_noderange_command(noderange, 'attributes/all', + exitcode = session.simple_noderange_command(noderange, 'attributes/all', value, key) - except Exception: + except: sys.stderr.write('Error: {0} not a valid expression\n'.format(str(updateargs[1:]))) exitcode = 1 sys.exit(exitcode) diff --git a/confluent_client/confluent/tlvdata.py b/confluent_client/confluent/tlvdata.py index 629687d6..7fcb663b 100644 --- a/confluent_client/confluent/tlvdata.py +++ b/confluent_client/confluent/tlvdata.py @@ -16,11 +16,15 @@ # limitations under the License. import array -import asyncio import ctypes import ctypes.util import confluent.tlv as tlv -import socket +try: + import eventlet.green.socket as socket + import eventlet.green.select as select +except ImportError: + import socket + import select from datetime import datetime import json import os @@ -36,7 +40,6 @@ try: except NameError: pass - class iovec(ctypes.Structure): # from uio.h _fields_ = [('iov_base', ctypes.c_void_p), ('iov_len', ctypes.c_size_t)] @@ -53,7 +56,6 @@ class cmsghdr(ctypes.Structure): # also from bits/socket.h @classmethod def init_data(cls, cmsg_len, cmsg_level, cmsg_type, cmsg_data): Data = ctypes.c_ubyte * ctypes.sizeof(cmsg_data) - class _flexhdr(ctypes.Structure): _fields_ = cls._fields_ + [('cmsg_data', Data)] @@ -96,63 +98,13 @@ class ClientFile(object): self.fileobject = os.fdopen(fd, mode) self.filename = name - - - -def _sendmsg(loop, fut, sock, msg, fds, rfd): - if rfd is not None: - loop.remove_reader(rfd) - if fut.cancelled(): - return - try: - retdata = sock.sendmsg( - [msg], - [(socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds))]) - except (BlockingIOError, InterruptedError): - fd = sock.fileno() - loop.add_reader(fd, _sendmsg, loop, fut, sock, fd) - except Exception as exc: - fut.set_exception(exc) - else: - fut.set_result(retdata) - - -def send_fds(sock, msg, fds): - cloop = asyncio.get_event_loop() - fut = cloop.create_future() - _sendmsg(cloop, fut, sock, msg, fds, None) - return fut - - -def _recvmsg(loop, fut, sock, msglen, maxfds, rfd): - if rfd is not None: - loop.remove_reader(rfd) - fds = array.array("i") # Array of ints - try: - msg, ancdata, flags, addr = sock.recvmsg( - msglen, socket.CMSG_LEN(maxfds * fds.itemsize)) - except (BlockingIOError, InterruptedError): - fd = sock.fileno() - loop.add_reader(fd, _recvmsg, loop, fut, sock, fd) - except Exception as exc: - fut.set_exception(exc) - else: - for cmsg_level, cmsg_type, cmsg_data in ancdata: - if (cmsg_level == socket.SOL_SOCKET - and cmsg_type == socket.SCM_RIGHTS): - # Append data, ignoring any truncated integers at the end. - fds.frombytes( - cmsg_data[ - :len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) - fut.set_result(msglen, list(fds)) - - -def recv_fds(sock, msglen, maxfds): - cloop = asyncio.get_event_loop() - fut = cloop.create_future() - _recvmsg(cloop, fut, sock, msglen, maxfds, None) - return fut - +libc = ctypes.CDLL(ctypes.util.find_library('c')) +recvmsg = libc.recvmsg +recvmsg.argtypes = [ctypes.c_int, ctypes.POINTER(msghdr), ctypes.c_int] +recvmsg.restype = ctypes.c_int +sendmsg = libc.sendmsg +sendmsg.argtypes = [ctypes.c_int, ctypes.POINTER(msghdr), ctypes.c_int] +sendmsg.restype = ctypes.c_size_t def decodestr(value): ret = None @@ -167,7 +119,6 @@ def decodestr(value): return value return ret - def unicode_dictvalues(dictdata): for key in dictdata: if isinstance(dictdata[key], bytes): @@ -190,17 +141,7 @@ def _unicode_list(currlist): _unicode_list(currlist[i]) -async def sendall(handle, data): - if isinstance(handle, tuple): - handle[1].write(data) - return await handle[1].drain() - else: - cloop = asyncio.get_event_loop() - return await cloop.sock_sendall(handle, data) - - -async def send(handle, data, filehandle=None): - cloop = asyncio.get_event_loop() +def send(handle, data, filehandle=None): if isinstance(data, unicode): try: data = data.encode('utf-8') @@ -215,10 +156,10 @@ async def send(handle, data, filehandle=None): if tl < 16777216: # type for string is '0', so we don't need # to xor anything in - await sendall(handle, struct.pack("!I", tl)) + handle.sendall(struct.pack("!I", tl)) else: raise Exception("String data length exceeds protocol") - await sendall(handle, data) + handle.sendall(data) elif isinstance(data, dict): # JSON currently only goes to 4 bytes # Some structured message, like what would be seen in http responses unicode_dictvalues(data) # make everything unicode, assuming UTF-8 @@ -230,40 +171,41 @@ async def send(handle, data, filehandle=None): # xor in the type (0b1 << 24) if filehandle is None: tl |= 16777216 - await sendall(handle, struct.pack("!I", tl)) - await sendall(handle, sdata) - elif isinstance(handle, tuple): - raise Exception("Cannot send filehandle over network socket") + handle.sendall(struct.pack("!I", tl)) + handle.sendall(sdata) else: tl |= (2 << 24) - await cloop.sock_sendall(handle, struct.pack("!I", tl)) - await send_fds(handle, b'', [filehandle]) + handle.sendall(struct.pack("!I", tl)) + cdtype = ctypes.c_ubyte * len(sdata) + cdata = cdtype.from_buffer(bytearray(sdata)) + ciov = iovec(iov_base=ctypes.addressof(cdata), + iov_len=ctypes.c_size_t(ctypes.sizeof(cdata))) + fd = ctypes.c_int(filehandle) + cmh = cmsghdr.init_data( + cmsg_len=CMSG_LEN( + ctypes.sizeof(fd)), cmsg_level=socket.SOL_SOCKET, + cmsg_type=SCM_RIGHTS, cmsg_data=fd) + mh = msghdr(msg_name=None, msg_len=0, msg_iov=iovec_ptr(ciov), + msg_iovlen=1, msg_control=ctypes.addressof(cmh), + msg_controllen=ctypes.c_size_t(ctypes.sizeof(cmh))) + sendmsg(handle.fileno(), mh, 0) -async def _grabhdl(handle, size): - if isinstance(handle, tuple): - return await handle[0].read(size) - else: - cloop = asyncio.get_event_loop() - return await cloop.sock_recv(handle, size) - - -async def recvall(handle, size): - rd = await _grabhdl(handle, size) +def recvall(handle, size): + rd = handle.recv(size) while len(rd) < size: - nd = await _grabhdl(handle, size - len(rd)) + nd = handle.recv(size - len(rd)) if not nd: raise Exception("Error reading data") rd += nd return rd - -async def recv(handle): - tl = await _grabhdl(handle, 4) +def recv(handle): + tl = handle.recv(4) if not tl: return None while len(tl) < 4: - ndata = await _grabhdl(handle, 4 - len(tl)) + ndata = handle.recv(4 - len(tl)) if not ndata: raise Exception("Error reading data") tl += ndata @@ -278,8 +220,6 @@ async def recv(handle): if dlen == 0: return None if datatype == tlv.Types.filehandle: - if isinstance(handle, tuple): - raise Exception('Filehandle not supported over TLS socket') filehandles = array.array('i') rawbuffer = bytearray(2048) pkttype = ctypes.c_ubyte * 2048 @@ -299,16 +239,23 @@ async def recv(handle): msg.msg_iovlen = 1 msg.msg_control = ctypes.addressof(cmsg) msg.msg_controllen = ctypes.sizeof(cmsg) - i = await recv_fds(handle, 2048, 4) - print(repr(i)) - data = i[0] - filehandles = i[1] + select.select([handle], [], []) + i = recvmsg(handle.fileno(), ctypes.pointer(msg), 0) + cdata = cmsgarr[CMSG_LEN(0).value:] + data = rawbuffer[:i] + if cmsg.cmsg_level == socket.SOL_SOCKET and cmsg.cmsg_type == SCM_RIGHTS: + try: + filehandles.fromstring(bytes( + cdata[:len(cdata) - len(cdata) % filehandles.itemsize])) + except AttributeError: + filehandles.frombytes(bytes( + cdata[:len(cdata) - len(cdata) % filehandles.itemsize])) data = json.loads(bytes(data)) return ClientFile(data['filename'], data['mode'], filehandles[0]) else: - data = await _grabhdl(handle, dlen) + data = handle.recv(dlen) while len(data) < dlen: - ndata = await _grabhdl(handle, dlen - len(data)) + ndata = handle.recv(dlen - len(data)) if not ndata: raise Exception("Error reading data") data += ndata diff --git a/confluent_server/confluent/collective/manager.py b/confluent_server/confluent/collective/manager.py index 752636ff..edc9db88 100644 --- a/confluent_server/confluent/collective/manager.py +++ b/confluent_server/confluent/collective/manager.py @@ -21,7 +21,7 @@ import confluent.config.configmanager as cfm import confluent.exceptions as exc import confluent.log as log import confluent.noderange as noderange -import confluent.tlvdata as tlvdata +import confluent.asynctlvdata as tlvdata import confluent.util as util import socket import ssl diff --git a/confluent_server/confluent/core.py b/confluent_server/confluent/core.py index f41f580b..be67b438 100644 --- a/confluent_server/confluent/core.py +++ b/confluent_server/confluent/core.py @@ -37,7 +37,7 @@ import asyncio import confluent import confluent.alerts as alerts import confluent.log as log -import confluent.tlvdata as tlvdata +import confluent.asynctlvdata as tlvdata import confluent.config.attributes as attrscheme import confluent.config.configmanager as cfm import confluent.collective.manager as collective diff --git a/confluent_server/confluent/httpapi.py b/confluent_server/confluent/httpapi.py index 5701ebf9..b62db319 100644 --- a/confluent_server/confluent/httpapi.py +++ b/confluent_server/confluent/httpapi.py @@ -30,7 +30,7 @@ from aiohttp import web, web_urldispatcher, connector, ClientSession, WSMsgType import confluent.auth as auth import confluent.config.attributes as attribs import confluent.config.configmanager as configmanager -import confluent.consoleserver as consoleserver +#import confluent.consoleserver as consoleserver import confluent.discovery.core as disco import confluent.forwarder as forwarder import confluent.exceptions as exc @@ -40,7 +40,7 @@ import confluent.core as pluginapi import confluent.asynchttp import confluent.selfservice as selfservice import confluent.shellserver as shellserver -import confluent.tlvdata +import confluent.asynctlvdata as tlvdata import confluent.util as util import copy import json @@ -52,7 +52,6 @@ try: import urlparse except ModuleNotFoundError: import urllib.parse as urlparse -tlvdata = confluent.tlvdata _cleaner = None diff --git a/confluent_server/confluent/sockapi.py b/confluent_server/confluent/sockapi.py index 47bfbdd3..bfb257a2 100644 --- a/confluent_server/confluent/sockapi.py +++ b/confluent_server/confluent/sockapi.py @@ -38,7 +38,7 @@ import ssl import confluent.auth as auth import confluent.credserver as credserver import confluent.config.conf as conf -import confluent.tlvdata as tlvdata +import confluent.asynctlvdata as tlvdata #import confluent.consoleserver as consoleserver import confluent.config.configmanager as configmanager import confluent.exceptions as exc