2
0
mirror of https://github.com/xcat2/confluent.git synced 2026-05-01 04:47:45 +00:00
Files
confluent/confluent_client/confluent/asynctlvdata.py
2026-04-30 09:25:09 -04:00

309 lines
9.4 KiB
Python

# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2014 IBM Corporation
# Copyright 2015 Lenovo
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import array
import asyncio
import ctypes
import ctypes.util
import confluent.tlv as tlv
import socket
from datetime import datetime
import json
import os
import struct
try:
unicode
except NameError:
unicode = str
try:
range = xrange
except NameError:
pass
class iovec(ctypes.Structure): # from uio.h
_fields_ = [('iov_base', ctypes.c_void_p),
('iov_len', ctypes.c_size_t)]
iovec_ptr = ctypes.POINTER(iovec)
class cmsghdr(ctypes.Structure): # also from bits/socket.h
_fields_ = [('cmsg_len', ctypes.c_size_t),
('cmsg_level', ctypes.c_int),
('cmsg_type', ctypes.c_int)]
@classmethod
def init_data(cls, cmsg_len, cmsg_level, cmsg_type, cmsg_data):
Data = ctypes.c_ubyte * ctypes.sizeof(cmsg_data)
class _flexhdr(ctypes.Structure):
_fields_ = cls._fields_ + [('cmsg_data', Data)]
datab = Data(*bytearray(cmsg_data))
return _flexhdr(cmsg_len=cmsg_len, cmsg_level=cmsg_level,
cmsg_type=cmsg_type, cmsg_data=datab)
def CMSG_LEN(length):
sizeof_cmshdr = ctypes.sizeof(cmsghdr)
return ctypes.c_size_t(CMSG_ALIGN(sizeof_cmshdr).value + length)
SCM_RIGHTS = 1
class msghdr(ctypes.Structure): # from bits/socket.h
_fields_ = [('msg_name', ctypes.c_void_p),
('msg_namelen', ctypes.c_uint),
('msg_iov', ctypes.POINTER(iovec)),
('msg_iovlen', ctypes.c_size_t),
('msg_control', ctypes.c_void_p),
('msg_controllen', ctypes.c_size_t),
('msg_flags', ctypes.c_int)]
def CMSG_ALIGN(length): # bits/socket.h
ret = (length + ctypes.sizeof(ctypes.c_size_t) - 1
& ~(ctypes.sizeof(ctypes.c_size_t) - 1))
return ctypes.c_size_t(ret)
def CMSG_SPACE(length): # bits/socket.h
ret = CMSG_ALIGN(length).value + CMSG_ALIGN(ctypes.sizeof(cmsghdr)).value
return ctypes.c_size_t(ret)
class ClientFile(object):
def __init__(self, name, mode, fd):
self.fileobject = os.fdopen(fd, mode)
self.filename = name
def _sendmsg(loop, fut, sock, msg, fds, rfd):
if rfd is not None:
loop.remove_reader(rfd)
if fut.cancelled():
return
try:
retdata = sock.sendmsg(
[msg],
[(socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds))])
except (BlockingIOError, InterruptedError):
fd = sock.fileno()
loop.add_reader(fd, _sendmsg, loop, fut, sock, fd)
except Exception as exc:
fut.set_exception(exc)
else:
fut.set_result(retdata)
def send_fds(sock, msg, fds):
cloop = asyncio.get_event_loop()
fut = cloop.create_future()
_sendmsg(cloop, fut, sock, msg, fds, None)
return fut
def _recvmsg(loop, fut, sock, msglen, maxfds, rfd):
if rfd is not None:
loop.remove_reader(rfd)
fds = array.array("i") # Array of ints
try:
msg, ancdata, flags, addr = sock.recvmsg(
msglen, socket.CMSG_LEN(maxfds * fds.itemsize))
except (BlockingIOError, InterruptedError):
fd = sock.fileno()
loop.add_reader(fd, _recvmsg, loop, fut, sock, msglen, maxfds, fd)
except Exception as exc:
fut.set_exception(exc)
else:
for cmsg_level, cmsg_type, cmsg_data in ancdata:
if (cmsg_level == socket.SOL_SOCKET
and cmsg_type == socket.SCM_RIGHTS):
# Append data, ignoring any truncated integers at the end.
fds.frombytes(
cmsg_data[
:len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
fut.set_result((msg, list(fds)))
def recv_fds(sock, msglen, maxfds):
cloop = asyncio.get_event_loop()
fut = cloop.create_future()
_recvmsg(cloop, fut, sock, msglen, maxfds, None)
return fut
def decodestr(value):
ret = None
try:
ret = value.decode('utf-8')
except UnicodeDecodeError:
try:
ret = value.decode('cp437')
except UnicodeDecodeError:
ret = value
except AttributeError:
return value
return ret
def unicode_dictvalues(dictdata):
for key in dictdata:
if isinstance(dictdata[key], bytes):
dictdata[key] = decodestr(dictdata[key])
elif isinstance(dictdata[key], datetime):
dictdata[key] = dictdata[key].strftime('%Y-%m-%dT%H:%M:%S')
elif isinstance(dictdata[key], list):
_unicode_list(dictdata[key])
elif isinstance(dictdata[key], dict):
unicode_dictvalues(dictdata[key])
def _unicode_list(currlist):
for i in range(len(currlist)):
if isinstance(currlist[i], str):
currlist[i] = decodestr(currlist[i])
elif isinstance(currlist[i], dict):
unicode_dictvalues(currlist[i])
elif isinstance(currlist[i], list):
_unicode_list(currlist[i])
async def sendall(handle, data):
if isinstance(handle, tuple):
handle[1].write(data)
return await handle[1].drain()
else:
cloop = asyncio.get_event_loop()
return await cloop.sock_sendall(handle, data)
def get_socket(handle):
if isinstance(handle, tuple):
return handle[1].transport.get_extra_info('socket')
else:
return handle
async def close(handle):
if isinstance(handle, tuple):
handle[1].close()
await handle[1].wait_closed()
else:
handle.close()
async def send(handle, data, filehandle=None):
cloop = asyncio.get_event_loop()
if isinstance(data, unicode):
try:
data = data.encode('utf-8')
except AttributeError:
pass
if isinstance(data, bytes) or isinstance(data, unicode):
# plain text, e.g. console data
tl = len(data)
if tl == 0:
# if you don't have anything to say, don't say anything at all
return
if tl < 16777216:
# type for string is '0', so we don't need
# to xor anything in
await sendall(handle, struct.pack("!I", tl))
else:
raise Exception("String data length exceeds protocol")
await sendall(handle, data)
elif isinstance(data, dict): # JSON currently only goes to 4 bytes
# Some structured message, like what would be seen in http responses
unicode_dictvalues(data) # make everything unicode, assuming UTF-8
sdata = json.dumps(data, ensure_ascii=False, separators=(',', ':'))
sdata = sdata.encode('utf-8')
tl = len(sdata)
if tl > 16777215:
raise Exception("JSON data exceeds protocol limits")
# xor in the type (0b1 << 24)
if filehandle is None:
tl |= 16777216
await sendall(handle, struct.pack("!I", tl))
await sendall(handle, sdata)
elif isinstance(handle, tuple):
raise Exception("Cannot send filehandle over network socket")
else:
tl |= (2 << 24)
await cloop.sock_sendall(handle, struct.pack("!I", tl))
await send_fds(handle, sdata, [filehandle])
async def _grabhdl(handle, size):
if isinstance(handle, tuple):
return await handle[0].read(size)
else:
cloop = asyncio.get_event_loop()
return await cloop.sock_recv(handle, size)
async def recvall(handle, size):
rd = await _grabhdl(handle, size)
while len(rd) < size:
nd = await _grabhdl(handle, size - len(rd))
if not nd:
raise Exception("Error reading data")
rd += nd
return rd
async def recv(handle):
tl = await _grabhdl(handle, 4)
if not tl:
return None
while len(tl) < 4:
ndata = await _grabhdl(handle, 4 - len(tl))
if not ndata:
raise Exception("Error reading data")
tl += ndata
if len(tl) == 0:
return None
tl = struct.unpack("!I", tl)[0]
if tl & 0b10000000000000000000000000000000:
raise Exception("Protocol Violation, reserved bit set")
# 4 byte tlv
dlen = tl & 16777215 # grab lower 24 bits
datatype = (tl & 2130706432) >> 24 # grab 7 bits from near beginning
if dlen == 0:
return None
if datatype == tlv.Types.filehandle:
if isinstance(handle, tuple):
raise Exception('Filehandle not supported over TLS socket')
msg, filehandles = await recv_fds(handle, dlen, 4)
data = json.loads(bytes(msg))
return ClientFile(data['filename'], data['mode'], filehandles[0])
else:
data = await _grabhdl(handle, dlen)
while len(data) < dlen:
ndata = await _grabhdl(handle, dlen - len(data))
if not ndata:
raise Exception("Error reading data")
data += ndata
if datatype == tlv.Types.text:
return data
elif datatype == tlv.Types.json:
return json.loads(data)