2
0
mirror of https://github.com/xcat2/confluent.git synced 2026-04-30 20:37:47 +00:00

Fix async handling of passed file descriptors

This commit is contained in:
Jarrod Johnson
2026-04-30 09:25:09 -04:00
parent bfc27595dc
commit 7f604e3e35
3 changed files with 6 additions and 26 deletions

View File

@@ -225,7 +225,7 @@ class Command(object):
raise Exception('Not supported with connected confluent server')
if not self.unixdomain:
raise Exception('Can only add a file to a unix domain connection')
tlvdata.send(self.connection, {'filename': name, 'mode': mode}, handle)
await tlvdata.send(self.connection, {'filename': name, 'mode': mode}, handle)
async def authenticate(self, username, password):
await tlvdata.send(self.connection,

View File

@@ -249,7 +249,7 @@ async def send(handle, data, filehandle=None):
else:
tl |= (2 << 24)
await cloop.sock_sendall(handle, struct.pack("!I", tl))
await send_fds(handle, b'', [filehandle])
await send_fds(handle, sdata, [filehandle])
async def _grabhdl(handle, size):
@@ -292,29 +292,8 @@ async def recv(handle):
if datatype == tlv.Types.filehandle:
if isinstance(handle, tuple):
raise Exception('Filehandle not supported over TLS socket')
filehandles = array.array('i')
rawbuffer = bytearray(2048)
pkttype = ctypes.c_ubyte * 2048
data = pkttype.from_buffer(rawbuffer)
cmsgsize = CMSG_SPACE(ctypes.sizeof(ctypes.c_int)).value
cmsgarr = bytearray(cmsgsize)
cmtype = ctypes.c_ubyte * cmsgsize
cmsg = cmtype.from_buffer(cmsgarr)
cmsg.cmsg_level = socket.SOL_SOCKET
cmsg.cmsg_type = SCM_RIGHTS
cmsg.cmsg_len = CMSG_LEN(ctypes.sizeof(ctypes.c_int))
iov = iovec()
iov.iov_base = ctypes.addressof(data)
iov.iov_len = 2048
msg = msghdr()
msg.msg_iov = ctypes.pointer(iov)
msg.msg_iovlen = 1
msg.msg_control = ctypes.addressof(cmsg)
msg.msg_controllen = ctypes.sizeof(cmsg)
i = await recv_fds(handle, 2048, 4)
data = i[0]
filehandles = i[1]
data = json.loads(bytes(data))
msg, filehandles = await recv_fds(handle, dlen, 4)
data = json.loads(bytes(msg))
return ClientFile(data['filename'], data['mode'], filehandles[0])
else:
data = await _grabhdl(handle, dlen)

View File

@@ -535,10 +535,11 @@ async def oslist():
async def osimport(imagefile, checkonly=False, custname=None):
c = client.Command()
imagefile = os.path.abspath(imagefile)
await c.ensure_connected()
if c.unixdomain:
ofile = open(imagefile, 'rb')
try:
c.add_file(imagefile, ofile.fileno(), 'rb')
await c.add_file(imagefile, ofile.fileno(), 'rb')
except Exception:
pass
importing = False