2
0
mirror of https://github.com/xcat2/confluent.git synced 2026-02-18 13:49:09 +00:00

Further progress toward asyncio

Basic operations can now happen with some async flows.
This commit is contained in:
Jarrod Johnson
2024-03-04 16:18:55 -05:00
parent 25f2698ae6
commit 0a8ec96cdf
10 changed files with 207 additions and 145 deletions

View File

@@ -1,4 +1,4 @@
#!/usr/bin/python2
#!/usr/bin/python3
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2014 IBM Corporation
@@ -41,6 +41,7 @@
# esc-( would interfere with normal esc use too much
# ~ I will not use for now...
import asyncio
import math
import getpass
import optparse
@@ -360,7 +361,7 @@ def print_result(res):
print(output.encode('utf-8'))
def do_command(command, server):
async def do_command(command, server):
global exitcode
global target
global currconsole
@@ -399,7 +400,7 @@ def do_command(command, server):
target = otarget
else:
foundchild = False
for res in session.read(parentpath, server):
async for res in session.read(parentpath, server):
try:
if res['item']['href'] == childname:
foundchild = True
@@ -434,7 +435,7 @@ def do_command(command, server):
pass
else:
targpath = target
for res in session.read(targpath):
async for res in session.read(targpath):
if 'item' in res: # a link relation
if type(res['item']) == dict:
print(res['item']["href"])
@@ -854,7 +855,7 @@ opts, shellargs = parser.parse_args()
username = None
passphrase = None
def server_connect():
async def server_connect():
global session, username, passphrase
if opts.controlpath:
termhandler.TermHandler(opts.controlpath)
@@ -864,7 +865,7 @@ def server_connect():
session = client.Command(os.environ['CONFLUENT_HOST'])
else: # unix domain
session = client.Command()
await session.ensure_connected()
# Next stop, reading and writing from whichever of stdin and server goes first.
#see pyghmi code for solconnect.py
if not session.authenticated and username is not None:
@@ -890,10 +891,10 @@ if sys.stdout.isatty():
import readline
def main():
async def main():
global inconsole
try:
server_connect()
await server_connect()
except (EOFError, KeyboardInterrupt) as _:
raise BailOut(0)
except socket.gaierror:
@@ -920,10 +921,10 @@ def main():
session_node = get_session_node(shellargs)
if session_node is not None:
consoleonly = True
do_command("start /nodes/%s/console/session" % session_node, netserver)
await do_command("start /nodes/%s/console/session" % session_node, netserver)
doexit = True
elif shellargs:
do_command(shellargs, netserver)
await do_command(shellargs, netserver)
quitconfetty(fullexit=True, fixterm=False)
powerstate = None
@@ -954,7 +955,7 @@ def main():
else:
currcommand = prompt()
try:
do_command(currcommand, netserver)
await do_command(currcommand, netserver)
except socket.error:
try:
server_connect()
@@ -1029,7 +1030,7 @@ if __name__ == '__main__':
if opts.mintime:
deadline = os.times()[4] + float(opts.mintime)
try:
main()
asyncio.get_event_loop().run_until_complete(main())
except BailOut as e:
errcode = e.errorcode
except Exception as e:

View File

@@ -1,4 +1,4 @@
#!/usr/bin/python2
#!/usr/bin/python3
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2017 Lenovo
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import optparse
import os
import signal
@@ -32,27 +33,31 @@ if path.startswith('/opt'):
import confluent.client as client
argparser = optparse.OptionParser(
usage='''\n %prog noderange attribute1=value1 attribute2=value,...
\n ''')
(options, args) = argparser.parse_args()
requestargs=None
try:
noderange = args[0]
except IndexError:
argparser.print_help()
sys.exit(1)
client.check_globbing(noderange)
session = client.Command()
exitcode = 0
attribs = {'name': noderange}
for arg in args[1:]:
key, val = arg.split('=', 1)
attribs[key] = val
for r in session.create('/noderange/', attribs):
if 'error' in r:
sys.stderr.write(r['error'] + '\n')
exitcode |= 1
if 'created' in r:
print('{0}: created'.format(r['created']))
sys.exit(exitcode)
async def main():
argparser = optparse.OptionParser(
usage='''\n %prog noderange attribute1=value1 attribute2=value,...
\n ''')
(options, args) = argparser.parse_args()
requestargs = None
try:
noderange = args[0]
except IndexError:
argparser.print_help()
sys.exit(1)
client.check_globbing(noderange)
session = client.Command()
exitcode = 0
attribs = {'name': noderange}
for arg in args[1:]:
key, val = arg.split('=', 1)
attribs[key] = val
async for r in session.create('/noderange/', attribs):
if 'error' in r:
sys.stderr.write(r['error'] + '\n')
exitcode |= 1
if 'created' in r:
print('{0}: created'.format(r['created']))
sys.exit(exitcode)
if __name__ == '__main__':
asyncio.get_event_loop().run_until_complete(main())

View File

@@ -1,4 +1,4 @@
#!/usr/libexec/platform-python
#!/usr/bin/python3
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2015-2017 Lenovo
@@ -21,6 +21,7 @@ import optparse
import os
import signal
import sys
import asyncio
@@ -35,7 +36,7 @@ if path.startswith('/opt'):
import confluent.client as client
def main():
async def main():
argparser = optparse.OptionParser(
usage="Usage: %prog noderange\n"
" or: %prog [options] noderange <nodeattribute>...")
@@ -61,7 +62,7 @@ def main():
if len(args) > 1:
exitcode=client.printattributes(session,requestargs,showtype,nodetype,noderange,options)
else:
for res in session.read(nodelist):
async for res in session.read(nodelist):
if 'error' in res:
sys.stderr.write(res['error'] + '\n')
exitcode = 1
@@ -73,4 +74,4 @@ def main():
sys.exit(exitcode)
if __name__ == '__main__':
main()
asyncio.get_event_loop().run_until_complete(main())

View File

@@ -62,6 +62,7 @@ def stringify(instr):
return instr.encode('utf-8')
return instr
class Tabulator(object):
def __init__(self, headers):
self.headers = headers
@@ -121,7 +122,8 @@ def printerror(res, node=None):
for node in res.get('databynode', {}):
exitcode = res['databynode'][node].get('errorcode', exitcode)
if 'error' in res['databynode'][node]:
sys.stderr.write('{0}: {1}\n'.format(node, res['databynode'][node]['error']))
sys.stderr.write(
'{0}: {1}\n'.format(node, res['databynode'][node]['error']))
if exitcode == 0:
exitcode = 1
if 'error' in res:
@@ -169,6 +171,11 @@ class Command(object):
self.serverloc = '/var/run/confluent/api.sock'
else:
self.serverloc = server
self.connected = False
async def ensure_connected(self):
if self.connected:
return True
if os.path.isabs(self.serverloc) and os.path.exists(self.serverloc):
self._connect_unix()
self.unixdomain = True
@@ -176,9 +183,9 @@ class Command(object):
raise Exception('Confluent service is not available')
else:
self._connect_tls()
self.protversion = int(tlvdata.recv(self.connection).split(
self.protversion = int((await tlvdata.recv(self.connection)).split(
b'--')[1].strip()[1:])
authdata = tlvdata.recv(self.connection)
authdata = await tlvdata.recv(self.connection)
if authdata['authpassed'] == 1:
self.authenticated = True
else:
@@ -186,19 +193,21 @@ class Command(object):
if not self.authenticated and 'CONFLUENT_USER' in os.environ:
username = os.environ['CONFLUENT_USER']
passphrase = os.environ['CONFLUENT_PASSPHRASE']
self.authenticate(username, passphrase)
await self.authenticate(username, passphrase)
self.connected = True
def add_file(self, name, handle, mode):
async def add_file(self, name, handle, mode):
await self.ensure_connected()
if self.protversion < 3:
raise Exception('Not supported with connected confluent server')
if not self.unixdomain:
raise Exception('Can only add a file to a unix domain connection')
tlvdata.send(self.connection, {'filename': name, 'mode': mode}, handle)
def authenticate(self, username, password):
async def authenticate(self, username, password):
tlvdata.send(self.connection,
{'username': username, 'password': password})
authdata = tlvdata.recv(self.connection)
authdata = await tlvdata.recv(self.connection)
if authdata['authpassed'] == 1:
self.authenticated = True
@@ -328,25 +337,32 @@ class Command(object):
cprint('')
return 0
def read(self, path, parameters=None):
async def read(self, path, parameters=None):
await self.ensure_connected()
if not self.authenticated:
raise Exception('Unauthenticated')
return send_request('retrieve', path, self.connection, parameters)
async for rsp in send_request(
'retrieve', path, self.connection, parameters):
yield rsp
def update(self, path, parameters=None):
async def update(self, path, parameters=None):
await self.ensure_connected()
if not self.authenticated:
raise Exception('Unauthenticated')
return send_request('update', path, self.connection, parameters)
return await send_request('update', path, self.connection, parameters)
def create(self, path, parameters=None):
async def create(self, path, parameters=None):
await self.ensure_connected()
if not self.authenticated:
raise Exception('Unauthenticated')
return send_request('create', path, self.connection, parameters)
async for rsp in send_request('create', path, self.connection, parameters):
yield rsp
def delete(self, path, parameters=None):
async def delete(self, path, parameters=None):
await self.ensure_connected()
if not self.authenticated:
raise Exception('Unauthenticated')
return send_request('delete', path, self.connection, parameters)
return await send_request('delete', path, self.connection, parameters)
def _connect_unix(self):
self.connection = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
@@ -411,7 +427,7 @@ class Command(object):
khf[hostid] = fingerprint
def send_request(operation, path, server, parameters=None):
async def send_request(operation, path, server, parameters=None):
"""This function iterates over all the responses
received from the server.
@@ -424,16 +440,16 @@ def send_request(operation, path, server, parameters=None):
payload = {'operation': operation, 'path': path}
if parameters is not None:
payload['parameters'] = parameters
tlvdata.send(server, payload)
result = tlvdata.recv(server)
await tlvdata.send(server, payload)
result = await tlvdata.recv(server)
while '_requestdone' not in result:
try:
yield result
except GeneratorExit:
while '_requestdone' not in result:
result = tlvdata.recv(server)
result = await tlvdata.recv(server)
raise
result = tlvdata.recv(server)
result = await tlvdata.recv(server)
def attrrequested(attr, attrlist, seenattributes, node=None):
@@ -720,7 +736,7 @@ def updateattrib(session, updateargs, nodetype, noderange, options, dictassign=N
else:
exitcode = session.simple_noderange_command(noderange, 'attributes/all',
value, key)
except:
except Exception:
sys.stderr.write('Error: {0} not a valid expression\n'.format(str(updateargs[1:])))
exitcode = 1
sys.exit(exitcode)

View File

@@ -16,15 +16,12 @@
# limitations under the License.
import array
import asyncio
import ctypes
import ctypes.util
import confluent.tlv as tlv
try:
import eventlet.green.socket as socket
import eventlet.green.select as select
except ImportError:
import socket
import select
import socket
import select
from datetime import datetime
import json
import os
@@ -40,6 +37,7 @@ try:
except NameError:
pass
class iovec(ctypes.Structure): # from uio.h
_fields_ = [('iov_base', ctypes.c_void_p),
('iov_len', ctypes.c_size_t)]
@@ -56,6 +54,7 @@ class cmsghdr(ctypes.Structure): # also from bits/socket.h
@classmethod
def init_data(cls, cmsg_len, cmsg_level, cmsg_type, cmsg_data):
Data = ctypes.c_ubyte * ctypes.sizeof(cmsg_data)
class _flexhdr(ctypes.Structure):
_fields_ = cls._fields_ + [('cmsg_data', Data)]
@@ -102,9 +101,59 @@ libc = ctypes.CDLL(ctypes.util.find_library('c'))
recvmsg = libc.recvmsg
recvmsg.argtypes = [ctypes.c_int, ctypes.POINTER(msghdr), ctypes.c_int]
recvmsg.restype = ctypes.c_int
sendmsg = libc.sendmsg
sendmsg.argtypes = [ctypes.c_int, ctypes.POINTER(msghdr), ctypes.c_int]
sendmsg.restype = ctypes.c_size_t
def _sendmsg(loop, fut, sock, msg, fds, rfd):
if rfd is not None:
loop.remove_reader(rfd)
if fut.cancelled():
return
try:
retdata = sock.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds))])
except (BlockingIOError, InterruptedError):
fd = sock.fileno()
loop.add_reader(fd, _sendmsg, loop, fut, sock, fd)
except Exception as exc:
fut.set_exception(exc)
else:
fut.set_result(retdata)
def send_fds(sock, msg, fds):
cloop = asyncio.get_event_loop()
fut = cloop.create_future()
_sendmsg(cloop, fut, sock, msg, fds, None)
return fut
def _recvmsg(loop, fut, sock, msglen, maxfds, rfd):
if rfd is not None:
loop.remove_reader(rfd)
fds = array.array("i") # Array of ints
try:
msg, ancdata, flags, addr = sock.recvmsg(
msglen, socket.CMSG_LEN(maxfds * fds.itemsize))
except (BlockingIOError, InterruptedError):
fd = sock.fileno()
loop.add_reader(fd, _recvmsg, loop, fut, sock, fd)
except Exception as exc:
fut.set_exception(exc)
else:
for cmsg_level, cmsg_type, cmsg_data in ancdata:
if (cmsg_level == socket.SOL_SOCKET
and cmsg_type == socket.SCM_RIGHTS):
# Append data, ignoring any truncated integers at the end.
fds.frombytes(
cmsg_data[
:len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
fut.set_result(msglen, list(fds))
def recv_fds(sock, msglen, maxfds):
cloop = asyncio.get_event_loop()
fut = cloop.create_future()
_recvmsg(cloop, fut, sock, msglen, maxfds, None)
return fut
def decodestr(value):
ret = None
@@ -141,7 +190,8 @@ def _unicode_list(currlist):
_unicode_list(currlist[i])
def send(handle, data, filehandle=None):
async def send(handle, data, filehandle=None):
cloop = asyncio.get_event_loop()
if isinstance(data, unicode):
try:
data = data.encode('utf-8')
@@ -156,10 +206,10 @@ def send(handle, data, filehandle=None):
if tl < 16777216:
# type for string is '0', so we don't need
# to xor anything in
handle.sendall(struct.pack("!I", tl))
await cloop.sock_sendall(handle, struct.pack("!I", tl))
else:
raise Exception("String data length exceeds protocol")
handle.sendall(data)
await cloop.sock_sendall(handle, data)
elif isinstance(data, dict): # JSON currently only goes to 4 bytes
# Some structured message, like what would be seen in http responses
unicode_dictvalues(data) # make everything unicode, assuming UTF-8
@@ -171,41 +221,32 @@ def send(handle, data, filehandle=None):
# xor in the type (0b1 << 24)
if filehandle is None:
tl |= 16777216
handle.sendall(struct.pack("!I", tl))
handle.sendall(sdata)
await cloop.sock_sendall(handle, struct.pack("!I", tl))
await cloop.sock_sendall(handle, sdata)
else:
tl |= (2 << 24)
handle.sendall(struct.pack("!I", tl))
cdtype = ctypes.c_ubyte * len(sdata)
cdata = cdtype.from_buffer(bytearray(sdata))
ciov = iovec(iov_base=ctypes.addressof(cdata),
iov_len=ctypes.c_size_t(ctypes.sizeof(cdata)))
fd = ctypes.c_int(filehandle)
cmh = cmsghdr.init_data(
cmsg_len=CMSG_LEN(
ctypes.sizeof(fd)), cmsg_level=socket.SOL_SOCKET,
cmsg_type=SCM_RIGHTS, cmsg_data=fd)
mh = msghdr(msg_name=None, msg_len=0, msg_iov=iovec_ptr(ciov),
msg_iovlen=1, msg_control=ctypes.addressof(cmh),
msg_controllen=ctypes.c_size_t(ctypes.sizeof(cmh)))
sendmsg(handle.fileno(), mh, 0)
await cloop.sock_sendall(handle, struct.pack("!I", tl))
await send_fds(handle, b'', [filehandle])
def recvall(handle, size):
rd = handle.recv(size)
async def recvall(handle, size):
cloop = asyncio.get_event_loop()
rd = await cloop.sock_recv(handle, size)
while len(rd) < size:
nd = handle.recv(size - len(rd))
nd = await cloop.sock_recv(handle, size - len(rd))
if not nd:
raise Exception("Error reading data")
rd += nd
return rd
def recv(handle):
tl = handle.recv(4)
async def recv(handle):
cloop = asyncio.get_event_loop()
tl = await cloop.sock_recv(handle, 4)
if not tl:
return None
while len(tl) < 4:
ndata = handle.recv(4 - len(tl))
ndata = await cloop.sock_recv(handle, 4 - len(tl))
if not ndata:
raise Exception("Error reading data")
tl += ndata
@@ -239,23 +280,16 @@ def recv(handle):
msg.msg_iovlen = 1
msg.msg_control = ctypes.addressof(cmsg)
msg.msg_controllen = ctypes.sizeof(cmsg)
select.select([handle], [], [])
i = recvmsg(handle.fileno(), ctypes.pointer(msg), 0)
cdata = cmsgarr[CMSG_LEN(0).value:]
data = rawbuffer[:i]
if cmsg.cmsg_level == socket.SOL_SOCKET and cmsg.cmsg_type == SCM_RIGHTS:
try:
filehandles.fromstring(bytes(
cdata[:len(cdata) - len(cdata) % filehandles.itemsize]))
except AttributeError:
filehandles.frombytes(bytes(
cdata[:len(cdata) - len(cdata) % filehandles.itemsize]))
i = await recv_fds(handle, 2048, 4)
print(repr(i))
data = i[0]
filehandles = i[1]
data = json.loads(bytes(data))
return ClientFile(data['filename'], data['mode'], filehandles[0])
else:
data = handle.recv(dlen)
data = await cloop.sock_recv(handle, dlen)
while len(data) < dlen:
ndata = handle.recv(dlen - len(data))
ndata = await asyncio.sock_recv(handle, dlen - len(data))
if not ndata:
raise Exception("Error reading data")
data += ndata

View File

@@ -1267,7 +1267,7 @@ def handle_discovery(pathcomponents, operation, configmanager, inputdata):
if pathcomponents[0] == 'detected':
pass
def handle_path(path, operation, configmanager, inputdata=None, autostrip=True):
async def handle_path(path, operation, configmanager, inputdata=None, autostrip=True):
"""Given a full path request, return an object.
The plugins should generally return some sort of iterator.
@@ -1281,7 +1281,7 @@ def handle_path(path, operation, configmanager, inputdata=None, autostrip=True):
if not pathcomponents: # root collection list
return enumerate_collections(rootcollections)
elif pathcomponents[0] == 'noderange':
return handle_node_request(configmanager, inputdata, operation,
return await handle_node_request(configmanager, inputdata, operation,
pathcomponents, autostrip)
elif pathcomponents[0] == 'deployment':
return handle_deployment(configmanager, inputdata, pathcomponents,
@@ -1295,7 +1295,7 @@ def handle_path(path, operation, configmanager, inputdata=None, autostrip=True):
operation)
elif pathcomponents[0] == 'nodes':
# single node request of some sort
return handle_node_request(configmanager, inputdata,
return await handle_node_request(configmanager, inputdata,
operation, pathcomponents, autostrip)
elif pathcomponents[0] == 'discovery':
return disco.handle_api_request(

View File

@@ -39,6 +39,7 @@ import confluent.httpapi as httpapi
import confluent.log as log
import confluent.collective.manager as collective
import confluent.discovery.protocols.pxe as pxe
from eventlet.asyncio import spawn_for_awaitable
try:
import confluent.sockapi as sockapi
except ImportError:
@@ -313,7 +314,7 @@ def run(args):
sock_bind_host, sock_bind_port = _get_connector_config('socket')
try:
sockservice = sockapi.SockApi(sock_bind_host, sock_bind_port)
sockservice.start()
spawn_for_awaitable(sockservice.start())
except NameError:
pass
webservice = httpapi.HttpApi(http_bind_host, http_bind_port)

View File

@@ -34,7 +34,7 @@ async def reseat_bays(encmgr, bays, configmanager, rspq):
finally:
await rspq.put(None)
def update(nodes, element, configmanager, inputdata):
async def update(nodes, element, configmanager, inputdata):
emebs = configmanager.get_node_attributes(
nodes, (u'enclosure.manager', u'enclosure.bay'))
baysbyencmgr = {}
@@ -57,7 +57,7 @@ def update(nodes, element, configmanager, inputdata):
for encmgr in baysbyencmgr:
currtask = asyncio.create_task(reseat_bays(encmgr, baysbyencmgr[encmgr], configmanager, rspq))
reseattasks.append(currtask)
while not all([task.done() for task in reseattasks]):;
while not all([task.done() for task in reseattasks]):
nrsp = await rspq.get()
if nrsp is not None:
yield nrsp

View File

@@ -187,7 +187,7 @@ class IpmiCommandWrapper(ipmicommand.Command):
(node,), ('secret.hardwaremanagementuser', 'collective.manager',
'secret.hardwaremanagementpassword', 'secret.ipmikg',
'hardwaremanagement.manager'), self._attribschanged)
amait super().create(**kwargs)
await super().create(**kwargs)
self.setup_confluent_keyhandler()
try:
os.makedirs('/var/cache/confluent/ipmi/')

View File

@@ -21,6 +21,7 @@
#
import atexit
import asyncio
import ctypes
import ctypes.util
import errno
@@ -32,7 +33,7 @@ import sys
import traceback
import eventlet.green.select as select
import eventlet.green.socket as socket
import socket
import eventlet.green.ssl as ssl
import eventlet.support.greendns as greendns
import eventlet
@@ -107,15 +108,15 @@ class ClientConsole(object):
self.pendingdata = []
def send_data(connection, data):
async def send_data(connection, data):
try:
tlvdata.send(connection, data)
await tlvdata.send(connection, data)
except IOError as ie:
if ie.errno != errno.EPIPE:
raise
def sessionhdl(connection, authname, skipauth=False, cert=None):
async def sessionhdl(connection, authname, skipauth=False, cert=None):
try:
# For now, trying to test the console stuff, so let's just do n4.
authenticated = False
@@ -132,10 +133,11 @@ def sessionhdl(connection, authname, skipauth=False, cert=None):
# version 0 == original, version 1 == pickle3 allowed, 2 = pickle forbidden, msgpack allowed
# v3 - filehandle allowed
# v4 - schema change and keepalive changes
send_data(connection, "Confluent -- v4 --")
await send_data(connection, "Confluent -- v4 --")
while not authenticated: # prompt for name and passphrase
send_data(connection, {'authpassed': 0})
response = tlvdata.recv(connection)
await send_data(connection, {'authpassed': 0})
response = await tlvdata.recv(connection)
if not response:
return
if 'collective' in response:
@@ -162,8 +164,8 @@ def sessionhdl(connection, authname, skipauth=False, cert=None):
else:
authenticated = True
cfm = authdata[1]
send_data(connection, {'authpassed': 1})
request = tlvdata.recv(connection)
await send_data(connection, {'authpassed': 1})
request = await tlvdata.recv(connection)
if request and isinstance(request, dict) and 'collective' in request:
if skipauth:
if not libssl:
@@ -185,27 +187,27 @@ def sessionhdl(connection, authname, skipauth=False, cert=None):
{'collective': {'error': 'collective management commands may only be used by root'}})
while request is not None:
try:
process_request(
await process_request(
connection, request, cfm, authdata, authname, skipauth)
except exc.ConfluentException as e:
if ((not isinstance(e, exc.LockedCredentials)) and
e.apierrorcode == 500):
tracelog.log(traceback.format_exc(), ltype=log.DataTypes.event,
event=log.Events.stacktrace)
send_data(connection, {'errorcode': e.apierrorcode,
await send_data(connection, {'errorcode': e.apierrorcode,
'error': e.apierrorstr,
'detail': e.get_error_body()})
send_data(connection, {'_requestdone': 1})
await send_data(connection, {'_requestdone': 1})
except SystemExit:
sys.exit(0)
except Exception as e:
tracelog.log(traceback.format_exc(), ltype=log.DataTypes.event,
event=log.Events.stacktrace)
send_data(connection, {'errorcode': 500,
await send_data(connection, {'errorcode': 500,
'error': 'Unexpected error - ' + str(e)})
send_data(connection, {'_requestdone': 1})
await send_data(connection, {'_requestdone': 1})
try:
request = tlvdata.recv(connection)
request = await tlvdata.recv(connection)
except Exception:
request = None
finally:
@@ -216,15 +218,16 @@ def sessionhdl(connection, authname, skipauth=False, cert=None):
except Exception:
pass
def send_response(responses, connection):
async def send_response(responses, connection):
if responses is None:
return
responses = await responses
for rsp in responses:
send_data(connection, rsp.raw())
send_data(connection, {'_requestdone': 1})
await send_data(connection, rsp.raw())
await send_data(connection, {'_requestdone': 1})
def process_request(connection, request, cfm, authdata, authname, skipauth):
async def process_request(connection, request, cfm, authdata, authname, skipauth):
if isinstance(request, tlvdata.ClientFile):
cfm.add_client_file(request)
return
@@ -267,7 +270,7 @@ def process_request(connection, request, cfm, authdata, authname, skipauth):
send_data(connection, {"errorcode": 400,
"error": "Bad Request - " + str(e)})
send_data(connection, {"_requestdone": 1})
send_response(hdlr, connection)
await send_response(hdlr, connection)
return
def start_proxy_term(connection, cert, request):
@@ -387,7 +390,7 @@ def _tlshandler(bind_host, bind_port):
if addr[1] < 1000:
eventlet.spawn_n(cs.handle_client, cnn, addr)
else:
eventlet.spawn_n(_tlsstartup, cnn)
asyncio.create_task(_tlsstartup(cnn))
if ffi:
@@ -395,8 +398,7 @@ if ffi:
def verify_stub(store, misc):
return 1
def _tlsstartup(cnn):
async def _tlsstartup(cnn):
authname = None
cert = None
conf.init_config()
@@ -435,7 +437,7 @@ def _tlsstartup(cnn):
cnn = ctx.wrap_socket(cnn, server_side=True)
except AttributeError:
raise Exception('Unable to find workable SSL support')
sessionhdl(cnn, authname, cert=cert)
await sessionhdl(cnn, authname, cert=cert)
def removesocket():
try:
@@ -443,8 +445,10 @@ def removesocket():
except OSError:
pass
def _unixdomainhandler():
async def _unixdomainhandler():
aloop = asyncio.get_event_loop()
unixsocket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
unixsocket.settimeout(0)
try:
os.remove("/var/run/confluent/api.sock")
except OSError: # if file does not exist, no big deal
@@ -458,7 +462,7 @@ def _unixdomainhandler():
atexit.register(removesocket)
unixsocket.listen(5)
while True:
cnn, addr = unixsocket.accept()
cnn, addr = await aloop.sock_accept(unixsocket)
creds = cnn.getsockopt(socket.SOL_SOCKET, SO_PEERCRED,
struct.calcsize('iII'))
pid, uid, gid = struct.unpack('iII', creds)
@@ -479,7 +483,7 @@ def _unixdomainhandler():
except KeyError:
cnn.close()
return
eventlet.spawn_n(sessionhdl, cnn, authname, skipauth)
asyncio.create_task(sessionhdl(cnn, authname, skipauth))
class SockApi(object):
@@ -489,7 +493,7 @@ class SockApi(object):
self.bind_host = bindhost or '::'
self.bind_port = bindport or 13001
def start(self):
async def start(self):
global auditlog
global tracelog
tracelog = log.Logger('trace')
@@ -500,7 +504,7 @@ class SockApi(object):
else:
eventlet.spawn_n(self.watch_for_cert)
eventlet.spawn_n(self.watch_resolv)
self.unixdomainserver = eventlet.spawn(_unixdomainhandler)
self.unixdomainserver = asyncio.create_task(_unixdomainhandler())
def watch_resolv(self):
while True: