diff --git a/zaza/openstack/charm_tests/vault/tests.py b/zaza/openstack/charm_tests/vault/tests.py index 4719c4b..5c93046 100644 --- a/zaza/openstack/charm_tests/vault/tests.py +++ b/zaza/openstack/charm_tests/vault/tests.py @@ -248,9 +248,11 @@ class VaultTest(BaseVaultTest): if 'pause' not in vault_actions or 'resume' not in vault_actions: raise unittest.SkipTest("The version of charm-vault tested does " "not have pause/resume actions") + # this pauses and resumes the LEAD unit with self.pause_resume(['vault']): logging.info("Testing pause resume") - self.assertTrue(self.clients[0].hvac_client.seal_status['sealed']) + lead_client = vault_utils.extract_lead_unit_client(self.clients) + self.assertTrue(lead_client.hvac_client.seal_status['sealed']) if __name__ == '__main__': diff --git a/zaza/openstack/charm_tests/vault/utils.py b/zaza/openstack/charm_tests/vault/utils.py index b6f4cf5..c05fcf6 100644 --- a/zaza/openstack/charm_tests/vault/utils.py +++ b/zaza/openstack/charm_tests/vault/utils.py @@ -173,6 +173,37 @@ def get_clients(units=None, cacert=None): return clients +def extract_lead_unit_client( + clients=None, application_name='vault', cacert=None): + """Find the lead unit client. + + This returns the lead unit client from a list of clients. If no clients + are passed, then the clients are resolved using the cacert (if needed) and + the application_name. The client is then matched to the lead unit. If + clients are passed, but no leader is found in them, then the function + raises a RuntimeError. + + :param clients: List of CharmVaultClient + :type clients: List[CharmVaultClient] + :param application_name: The application name + :type application_name: str + :param cacert: Path to CA cert used for vaults api cert. + :type cacert: str + :returns: The leader client + :rtype: CharmVaultClient + :raises: RuntimeError if the lead unit cannot be found + """ + if clients is None: + units = zaza.model.get_app_ips('vault') + clients = get_clients(units, cacert) + lead_ip = zaza.model.get_lead_unit_ip(application_name) + for client in clients: + if client.addr == lead_ip: + return client + raise RuntimeError("Leader client not found for application: {}" + .format(application_name)) + + def is_initialized(client): """Check if vault is initialized.