diff --git a/unit_tests/utilities/test_zaza_utilities_openstack.py b/unit_tests/utilities/test_zaza_utilities_openstack.py index 1ced14b..9aa3cde 100644 --- a/unit_tests/utilities/test_zaza_utilities_openstack.py +++ b/unit_tests/utilities/test_zaza_utilities_openstack.py @@ -696,6 +696,31 @@ class TestOpenStackUtils(ut_utils.BaseTestCase): password='reallyhardpassord', username='bob') + def test_ssh_command(self): + paramiko_mock = mock.MagicMock() + self.patch_object(openstack_utils.paramiko, 'SSHClient', + return_value=paramiko_mock) + self.patch_object(openstack_utils.paramiko, 'AutoAddPolicy', + return_value='some_policy') + stdout = io.StringIO("myvm") + + paramiko_mock.exec_command.return_value = ('stdin', stdout, 'stderr') + + def verifier(_stdin, stdout, _stderr): + self.assertEqual('myvm', stdout.readlines()[0].strip()) + + openstack_utils.ssh_command( + 'bob', + '10.0.0.10', + 'myvm', + 'uname -n', + password='reallyhardpassord', + verify=verifier) + paramiko_mock.connect.assert_called_once_with( + '10.0.0.10', + password='reallyhardpassord', + username='bob') + def test_ssh_test_wrong_server(self): paramiko_mock = mock.MagicMock() self.patch_object(openstack_utils.paramiko, 'SSHClient', diff --git a/zaza/openstack/utilities/openstack.py b/zaza/openstack/utilities/openstack.py index 9b2e617..0b41a65 100644 --- a/zaza/openstack/utilities/openstack.py +++ b/zaza/openstack/utilities/openstack.py @@ -1921,6 +1921,48 @@ def ssh_test(username, ip, vm_name, password=None, privkey=None): :type privkey: str :raises: exceptions.SSHFailed """ + def verify(stdin, stdout, stderr): + return_string = stdout.readlines()[0].strip() + + if return_string == vm_name: + logging.info('SSH to %s(%s) succesfull' % (vm_name, ip)) + else: + logging.info('SSH to %s(%s) failed (%s != %s)' % (vm_name, ip, + return_string, + vm_name)) + raise exceptions.SSHFailed() + + ssh_command(username, ip, vm_name, 'uname -n', + password=password, privkey=privkey, verify=verify) + + +def ssh_command(username, + ip, + vm_name, + command, + password=None, + privkey=None, + verify=None): + """SSH to given ip using supplied credentials. + + :param username: Username to connect with + :type username: str + :param ip: IP address to ssh to. + :type ip: str + :param vm_name: Name of VM. + :type vm_name: str + :param command: What command to run on the remote host + :type command: str + :param password: Password to authenticate with. If supplied it is used + rather than privkey. + :type password: str + :param privkey: Private key to authenticate with. If a password is + supplied it is used rather than the private key. + :type privkey: str + :param verify: A callable to verify the command output with + :type verify: callable + :raises: exceptions.SSHFailed + """ logging.info('Attempting to ssh to %s(%s)' % (vm_name, ip)) ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) @@ -1929,16 +1971,15 @@ def ssh_test(username, ip, vm_name, password=None, privkey=None): else: key = paramiko.RSAKey.from_private_key(io.StringIO(privkey)) ssh.connect(ip, username=username, password='', pkey=key) - stdin, stdout, stderr = ssh.exec_command('uname -n') - return_string = stdout.readlines()[0].strip() - ssh.close() - if return_string == vm_name: - logging.info('SSH to %s(%s) succesfull' % (vm_name, ip)) - else: - logging.info('SSH to %s(%s) failed (%s != %s)' % (vm_name, ip, - return_string, - vm_name)) - raise exceptions.SSHFailed() + logging.info("Running {} on {}".format(command, vm_name)) + stdin, stdout, stderr = ssh.exec_command(command) + if verify and callable(verify): + try: + verify(stdin, stdout, stderr) + except Exception as e: + raise(e) + finally: + ssh.close() @tenacity.retry(wait=tenacity.wait_exponential(multiplier=0.01),