#!/usr/bin/python

from __future__ import print_function

import base64
import re
import sys
from OpenSSL import crypto

# How to use:
# 1. Download the "X509, Base64 encoded" cert from the Certificate
#    Services Manager
# 2. Run "vhostcert import < foo_mit_edu.cer". Save stdout (a base64
#    blob)
# 3. Log in as root to a Scripts server, and run vhostedit foo.mit.edu
# 4. Add entries:
#       scriptsVhostCertificate: space-separated base64 blobs from vhostcert import
#       scriptsVhostCertificateKeyFile: scripts-2048.key
# 5. On each server:
#       /etc/httpd/export-scripts-certs
#       systemctl reload httpd.service
#
# TODO: Make this script do the vhostedit automatically.

def debug_chain(chain):
    for i, c in enumerate(chain):
        print(i, 's:', c.get_subject(), file=sys.stderr)
        print(i, 'i:', c.get_issuer(), file=sys.stderr)
    print(file=sys.stderr)

def pem_to_scripts(data):
    certs = [
        crypto.load_certificate(crypto.FILETYPE_PEM, m.group(0))
        for m in
        re.finditer(
            b'-----BEGIN CERTIFICATE-----\r?\n.+?\r?\n-----END CERTIFICATE-----',
            data, re.DOTALL)
    ]

    # Put the chain in the right order, and delete any self-signed root
    leaf, = [c for c in certs if not any(
        c1.get_issuer() == c.get_subject() for c1 in certs)]
    chain = [leaf]
    count = 1
    while True:
        issuers = [c for c in certs if chain[-1].get_issuer() == c.get_subject()]
        if not issuers:
            break
        issuer, = issuers
        assert issuer not in chain
        count += 1
        if issuer.get_issuer() == issuer.get_subject():
            break
        chain.append(issuer)
    assert count == len(certs)

    debug_chain(chain)

    return b' '.join(base64.b64encode(
        crypto.dump_certificate(crypto.FILETYPE_ASN1, c)) for c in chain)

def scripts_to_pem(data):
    chain = [
        crypto.load_certificate(crypto.FILETYPE_ASN1, base64.b64decode(d))
        for d in data.split(b' ')
    ]

    debug_chain(chain)

    return b''.join(crypto.dump_certificate(crypto.FILETYPE_PEM, c) for c in chain)

def __main__():
    if sys.argv[1:] == ['import']:
        print(pem_to_scripts(sys.stdin.read().encode()).decode())
    elif sys.argv[1:] == ['export']:
        print(scripts_to_pem(sys.stdin.read().encode()).decode(), end='')
    else:
        print('usage: {} {{import|export}}'.format(__file__), file=sys.stderr)

if __name__ == '__main__':
    sys.exit(__main__())
