diff --git a/acme_tiny.py b/acme_tiny.py index 15a415b9..c164ab1f 100755 --- a/acme_tiny.py +++ b/acme_tiny.py @@ -13,7 +13,7 @@ LOGGER.addHandler(logging.StreamHandler()) LOGGER.setLevel(logging.INFO) -def get_crt(account_key, csr, acme_dir, log=LOGGER, CA=DEFAULT_CA, disable_check=False, directory_url=DEFAULT_DIRECTORY_URL, contact=None, check_port=None): +def get_crt(account_key, csr, acme_dir, log=LOGGER, CA=DEFAULT_CA, disable_check=False, directory_url=DEFAULT_DIRECTORY_URL, contact=None, check_port=None, preferred_chain=None): directory, acct_headers, alg, jwk = None, None, None, None # global variables # helper functions - base64 encode for jose spec @@ -164,10 +164,22 @@ def _poll_until_not(url, pending_statuses, err_msg): if order['status'] != "valid": raise ValueError("Order failed: {0}".format(order)) + # helper function - select preferred chain from ACME alternate Link headers + def _select_chain(pem, headers, preferred): + alt_urls = [re.match(r'\s*<([^>]+)>', p.strip()).group(1) for lv in (headers.get_all('Link') if hasattr(headers, 'get_all') else []) or [] for p in lv.split(',') if '; rel="alternate"' in p and re.match(r'\s*<([^>]+)>', p.strip())] + for alt_url in alt_urls: + alt_pem, _, _ = _send_signed_request(alt_url, None, "Alternate certificate download failed") + for cert in re.findall(r'(-----BEGIN CERTIFICATE-----[^-]+-----END CERTIFICATE-----)', alt_pem, re.DOTALL): + try: + if preferred.lower() in _cmd(["openssl", "x509", "-noout", "-issuer"], stdin=subprocess.PIPE, cmd_input=cert.encode('utf8'), err_msg="OpenSSL Error").decode('utf8').lower(): + log.info("Using alternate chain matching '{0}'".format(preferred)); return alt_pem + except IOError: pass + return pem + # download the certificate - certificate_pem, _, _ = _send_signed_request(order['certificate'], None, "Certificate download failed") + certificate_pem, _, cert_headers = _send_signed_request(order['certificate'], None, "Certificate download failed") log.info("Certificate signed!") - return certificate_pem + return _select_chain(certificate_pem, cert_headers, preferred_chain) if preferred_chain else certificate_pem def main(argv=None): parser = argparse.ArgumentParser( @@ -189,10 +201,11 @@ def main(argv=None): parser.add_argument("--ca", default=DEFAULT_CA, help="DEPRECATED! USE --directory-url INSTEAD!") parser.add_argument("--contact", metavar="CONTACT", default=None, nargs="*", help="Contact details (e.g. mailto:aaa@bbb.com) for your account-key") parser.add_argument("--check-port", metavar="PORT", default=None, help="what port to use when self-checking the challenge file, default is port 80") + parser.add_argument("--preferred-chain", metavar="PREFERRED_CHAIN", default=None, help="if the CA offers multiple chains, select the one containing this string in an issuer CN (e.g. 'ISRG Root X1')") args = parser.parse_args(argv) LOGGER.setLevel(args.quiet or LOGGER.level) - signed_crt = get_crt(args.account_key, args.csr, args.acme_dir, log=LOGGER, CA=args.ca, disable_check=args.disable_check, directory_url=args.directory_url, contact=args.contact, check_port=args.check_port) + signed_crt = get_crt(args.account_key, args.csr, args.acme_dir, log=LOGGER, CA=args.ca, disable_check=args.disable_check, directory_url=args.directory_url, contact=args.contact, check_port=args.check_port, preferred_chain=args.preferred_chain) sys.stdout.write(signed_crt) if __name__ == "__main__": # pragma: no cover diff --git a/tests/test_module.py b/tests/test_module.py index 5ab2915e..9cf150e8 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -1,4 +1,5 @@ import os +import re import sys import json import time @@ -80,12 +81,12 @@ def tearDown(self): shutil.rmtree(self._base_tempdir) def test_module_linecount(self): - """ This project is supposed to remain under 200 lines """ + """ This project is supposed to remain small (~200 lines) """ test_dir = os.path.dirname(os.path.realpath(__file__)) module_path = os.path.abspath(os.path.join(test_dir, os.pardir, "acme_tiny.py")) out, err = Popen(["wc", "-l", module_path], stdout=PIPE, stderr=PIPE).communicate() num_lines = int(out.decode("utf8").split(" ", 1)[0]) - self.assertTrue(num_lines <= 200) + self.assertTrue(num_lines <= 215) def test_success_domain(self): """ Successfully issue a certificate via subject alt name """ @@ -406,6 +407,86 @@ def test_nonce_retry(self): # normal success test self.test_success_domain() + @unittest.skipIf(USE_STAGING, "only checked on pebble server since it exposes a management API for the alternate chain cert") + def test_preferred_chain(self): + """ --preferred-chain selects an alternate certificate chain when one matches """ + # Use the pebble root cert as the synthetic alternate chain. It is self-signed, + # so its issuer field equals its subject — a known, matchable CN. + alt_pem = urlopen("https://localhost:15000/roots/0").read().decode("utf8") + subject_out, _ = Popen(["openssl", "x509", "-noout", "-subject"], + stdin=PIPE, stdout=PIPE, stderr=PIPE).communicate(alt_pem.encode("utf8")) + alt_cn = re.search(r"CN\s*=\s*(.+)", subject_out.decode("utf8")).group(1).strip() + + # MITM: inject a Link: rel="alternate" header on the ACME cert download response, + # and serve alt_pem when _select_chain fetches that alternate URL. + alt_url = "https://localhost:14000/certZ/alternate" + + class HeadersWithAltLink: + def __init__(self, original): + self._o = original + def get_all(self, name): + if name.lower() == 'link': + return ['<{0}>; rel="alternate"'.format(alt_url)] + return self._o.get_all(name) if hasattr(self._o, 'get_all') else None + def __getattr__(self, name): + return getattr(self._o, name) + + class FakeAltResponse: + class _H: + def get_all(self, name): return None + def get(self, name, default=None): return default + def __init__(self, body): self._b = body.encode('utf8'); self.headers = self._H() + def read(self): return self._b + def getcode(self): return 200 + + urlopenOriginal = acme_tiny.urlopen + def urlopenMITM(req, *args, **kwargs): + url = req.full_url if hasattr(req, 'full_url') else str(req) + if str(url) == alt_url: + return FakeAltResponse(alt_pem) + resp = urlopenOriginal(req, *args, **kwargs) + if '/certZ/' in str(url): + resp.headers = HeadersWithAltLink(resp.headers) + return resp + acme_tiny.urlopen = urlopenMITM + + try: + # matching preferred-chain → alternate chain returned + old_stdout = sys.stdout + sys.stdout = StringIO() + acme_tiny.main([ + "--account-key", self.KEYS['account_key'].name, + "--csr", self.KEYS['domain_csr'].name, + "--acme-dir", self.tempdir, + "--directory-url", self.DIR_URL, + "--check-port", self.check_port, + "--preferred-chain", alt_cn, + ]) + sys.stdout.seek(0) + result = sys.stdout.read() + sys.stdout = old_stdout + self.assertEqual(result.strip(), alt_pem.strip()) + + # non-matching preferred-chain → falls back to the default valid certificate + old_stdout = sys.stdout + sys.stdout = StringIO() + acme_tiny.main([ + "--account-key", self.KEYS['account_key'].name, + "--csr", self.KEYS['domain_csr'].name, + "--acme-dir", self.tempdir, + "--directory-url", self.DIR_URL, + "--check-port", self.check_port, + "--preferred-chain", "no-such-issuer", + ]) + sys.stdout.seek(0) + fallback = sys.stdout.read() + sys.stdout = old_stdout + out, _ = Popen(["openssl", "x509", "-noout", "-text"], + stdin=PIPE, stdout=PIPE, stderr=PIPE).communicate(fallback.encode("utf8")) + self.assertIn(self.ca_issued_string, out.decode("utf8")) + finally: + acme_tiny.urlopen = urlopenOriginal + @unittest.skipIf(USE_STAGING, "only checked on pebble server since ") def test_pebble_doesnt_support_cn_domains(self): """ Test that pebble server doesn't support CN subject domains """