diff --git a/unit_tests/utilitites/test_zaza_utilitites_cert.py b/unit_tests/utilitites/test_zaza_utilitites_cert.py index 876c299..0b9be6f 100644 --- a/unit_tests/utilitites/test_zaza_utilitites_cert.py +++ b/unit_tests/utilitites/test_zaza_utilitites_cert.py @@ -1,3 +1,5 @@ +import mock + import unit_tests.utils as ut_utils import zaza.utilities.cert as cert @@ -100,3 +102,87 @@ class TestUtilitiesCert(ut_utils.BaseTestCase): self.cryptography.x509.BasicConstraints.assert_called_with( ca=True, path_length=None ) + + def sign_csr_mocks(self): + self.patch_object(cert, 'serialization') + self.patch_object(cert, 'cryptography') + self.expect_bend = self.cryptography.hazmat.backends.default_backend() + self.builder_mock = mock.MagicMock() + self.builder_mock.serial_number.return_value = self.builder_mock + self.builder_mock.issuer_name.return_value = self.builder_mock + self.builder_mock.not_valid_before.return_value = self.builder_mock + self.builder_mock.not_valid_after.return_value = self.builder_mock + self.builder_mock.subject_name.return_value = self.builder_mock + self.builder_mock.public_key.return_value = self.builder_mock + self.builder_mock.add_extension.return_value = self.builder_mock + + self.cryptography.x509.CertificateBuilder.return_value = \ + self.builder_mock + self.bcons_mock = mock.MagicMock() + self.cryptography.x509.BasicConstraints.side_effect = self.bcons_mock + + def test_sign_csr(self): + self.sign_csr_mocks() + cert.sign_csr('acsr', 'secretkey', ca_cert='cacert') + self.serialization.load_pem_private_key.assert_called_with( + b'secretkey', + password=None, + backend=self.expect_bend) + self.cryptography.x509.load_pem_x509_csr.assert_called_with( + b'acsr', + self.expect_bend) + self.cryptography.x509.load_pem_x509_certificate.assert_called_with( + b'cacert', + self.expect_bend) + + def test_sign_csr_key_password(self): + self.sign_csr_mocks() + cert.sign_csr('acsr', 'secretkey', ca_cert='cacert', + ca_private_key_password='bob') + self.serialization.load_pem_private_key.assert_called_with( + b'secretkey', + password='bob', + backend=self.expect_bend) + self.cryptography.x509.load_pem_x509_csr.assert_called_with( + b'acsr', + self.expect_bend) + self.cryptography.x509.load_pem_x509_certificate.assert_called_with( + b'cacert', + self.expect_bend) + self.bcons_mock.assert_called_with(ca=False, path_length=None) + self.builder_mock.add_extension.assert_called_once_with( + self.bcons_mock(), + critical=True) + + def test_sign_csr_issuer_name(self): + self.sign_csr_mocks() + cert.sign_csr('acsr', 'secretkey', issuer_name='issuer') + self.serialization.load_pem_private_key.assert_called_with( + b'secretkey', + password=None, + backend=self.expect_bend) + self.cryptography.x509.load_pem_x509_csr.assert_called_with( + b'acsr', + self.expect_bend) + self.bcons_mock.assert_called_with(ca=False, path_length=None) + self.builder_mock.issuer_name.assert_called_once_with('issuer') + self.builder_mock.add_extension.assert_called_once_with( + self.bcons_mock(), + critical=True) + + def test_sign_csr_generate_ca(self): + self.sign_csr_mocks() + cert.sign_csr('acsr', 'secretkey', issuer_name='issuer', + generate_ca=True) + self.serialization.load_pem_private_key.assert_called_with( + b'secretkey', + password=None, + backend=self.expect_bend) + self.cryptography.x509.load_pem_x509_csr.assert_called_with( + b'acsr', + self.expect_bend) + self.bcons_mock.assert_called_with(ca=True, path_length=None) + self.builder_mock.issuer_name.assert_called_once_with('issuer') + self.builder_mock.add_extension.assert_called_once_with( + self.bcons_mock(), + critical=True) diff --git a/zaza/utilities/cert.py b/zaza/utilities/cert.py index 753499b..bb317e0 100644 --- a/zaza/utilities/cert.py +++ b/zaza/utilities/cert.py @@ -16,6 +16,7 @@ import cryptography from cryptography.hazmat.primitives.asymmetric import rsa +import cryptography.hazmat.primitives.hashes as hashes import cryptography.hazmat.primitives.serialization as serialization import datetime @@ -121,3 +122,68 @@ def generate_cert(common_name, certificate.public_bytes( serialization.Encoding.PEM) ) + + +def sign_csr(csr, ca_private_key, ca_cert=None, issuer_name=None, + ca_private_key_password=None, generate_ca=False): + """Sign CSR with the given key. + + :param csr: Certificate to sign + :type csr: str + :param ca_private_key: Private key to be used to sign csr + :type ca_private_key: str + :param ca_cert: Cert to base some options from + :type ca_cert: str + :param issuer_name: Issuer name, must match provided_private_key issuer + :type issuer_name: Optional[str] + :param ca_private_key_password: Password to decrypt ca_private_key + :type ca_private_key_password: Optional[str] + :param generate_ca: Allow resulting cert to be used as ca + :type generate_ca: bool + :returns: x.509 certificate + :rtype: cryptography.x509.Certificate + """ + backend = cryptography.hazmat.backends.default_backend() + # Create x509 artifacts + root_ca_pkey = serialization.load_pem_private_key( + ca_private_key.encode(), + password=ca_private_key_password, + backend=backend) + + new_csr = cryptography.x509.load_pem_x509_csr( + csr.encode(), + backend) + + if ca_cert: + root_ca_cert = cryptography.x509.load_pem_x509_certificate( + ca_cert.encode(), + backend) + issuer_name = root_ca_cert.subject + else: + issuer_name = issuer_name + # Create builder + builder = cryptography.x509.CertificateBuilder() + builder = builder.serial_number( + cryptography.x509.random_serial_number()) + builder = builder.issuer_name(issuer_name) + builder = builder.not_valid_before( + datetime.datetime.today() - datetime.timedelta(1, 0, 0), + ) + builder = builder.not_valid_after( + datetime.datetime.today() + datetime.timedelta(80, 0, 0), + ) + builder = builder.subject_name(new_csr.subject) + builder = builder.public_key(new_csr.public_key()) + + builder = builder.add_extension( + cryptography.x509.BasicConstraints(ca=generate_ca, path_length=None), + critical=True + ) + + # Sign the csr + signer_ca_cert = builder.sign( + private_key=root_ca_pkey, + algorithm=hashes.SHA256(), + backend=backend) + + return signer_ca_cert.public_bytes(encoding=serialization.Encoding.PEM)