#!/usr/bin/env python3

# Copyright (C) 2014-2016 Felix Geyer <debfx@fobos.de>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 or (at your option)
# version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import argparse
import datetime
import dns.exception
import dns.flags
import dns.rdatatype
import dns.rdtypes.ANY.TLSA
import dns.resolver
import hashlib
import re
import socket
import struct
import ssl
import subprocess
import sys


VERSION = "1.1"


class ProtocolError(Exception):
    pass


def nagios_ok(msg: str) -> None:
    print("DANE OK - " + msg)
    sys.exit(0)


def nagios_warning(msg: str) -> None:
    print("DANE WARNING - " + msg)
    sys.exit(1)


def nagios_critical(msg: str) -> None:
    print("DANE CRITICAL - " + msg)
    sys.exit(2)


def nagios_unknown(msg: str) -> None:
    print("DANE UNKONWN - " + msg)
    sys.exit(3)


def create_resolver(dnssec: bool = True,
                    timeout: int = None,
                    nameserver: str = None) -> dns.resolver.Resolver:
    resolver = dns.resolver.Resolver()

    if timeout and timeout != 0:
        resolver.lifetime = timeout

    if nameserver:
        resolver.nameservers = [nameserver]

    if dnssec:
        resolver.edns = 0
        resolver.payload = 1280
        resolver.ednsflags = dns.flags.DO

    return resolver


def check_dns_response_auth(response: dns.resolver.Answer) -> bool:
    if response.response.flags & dns.flags.AD:
        return True
    else:
        return False


def extract_pubkey(cert_binary: bytes) -> bytes:
    # should really be done with a python 3 module that exposes this openssl api

    # extract public key in pem format
    pubkey_pem = subprocess.check_output(["openssl", "x509", "-pubkey", "-inform", "der", "-noout"], input=cert_binary)
    # conver to binary / der format
    pubkey_binary = subprocess.check_output(["openssl", "pkey", "-pubin", "-outform", "der"], input=pubkey_pem)

    return pubkey_binary


def get_tlsa_records(args: argparse.Namespace) -> dns.resolver.Answer:
    resolver = create_resolver(dnssec=args.dnssec, timeout=args.timeout, nameserver=args.nameserver)

    tlsa_domain = "_{}._tcp.{}".format(args.port, args.host)

    try:
        tlsa_records = resolver.query(tlsa_domain, dns.rdatatype.TLSA)
    except dns.resolver.NXDOMAIN:
        nagios_critical("No DNS TLSA record found: {}".format(tlsa_domain))
    except dns.exception.Timeout:
        nagios_unknown("DNS query timeout: {}".format(tlsa_domain))

    if args.dnssec and not check_dns_response_auth(tlsa_records):
        nagios_unknown("DNS query not DNSSEC validated")

    return tlsa_records


def validate_dane(cert_binary: bytes,
                  pkix_valid: bool,
                  tlsa_record: dns.rdtypes.ANY.TLSA.TLSA) -> bool:
    if tlsa_record.usage == 3:
        pass
    elif tlsa_record.usage == 1 and pkix_valid:
        pass
    else:
        # usage=2 unsupported
        return False

    if tlsa_record.selector == 0:
        data = cert_binary
    elif tlsa_record.selector == 1:
        data = extract_pubkey(cert_binary)
    else:
        return False

    if tlsa_record.mtype == 0:
        hashed = data
    elif tlsa_record.mtype == 1:
        hashed = hashlib.sha256(data).digest()
    elif tlsa_record.mtype == 2:
        hashed = hashlib.sha512(data).digest()
    else:
        return False

    return hashed == tlsa_record.cert


def check_cert_expiry(args: argparse.Namespace,
                      cert: dict,
                      days_warning: int,
                      days_critical: int = None) -> datetime.timedelta:
    not_after = datetime.datetime.strptime(cert["notAfter"], "%b %d %H:%M:%S %Y %Z")
    date_diff = not_after - datetime.datetime.now()

    if days_critical:
        if date_diff <= datetime.timedelta(days_critical):
            nagios_critical("{}:{} cert expires in {} days".format(args.host, args.port, date_diff.days))

    if date_diff <= datetime.timedelta(days_warning):
        nagios_warning("{}:{} cert expires in {} days".format(args.host, args.port, date_diff.days))

    return date_diff


def connect_to_host(connect_host: str,
                    connect_port: int,
                    args: argparse.Namespace,
                    check_cert: bool) -> ssl.SSLSocket:
    if args.timeout == 0:
        socket.setdefaulttimeout(None)
    else:
        socket.setdefaulttimeout(args.timeout)

    context = ssl.create_default_context()
    if not check_cert:
        context.check_hostname = False
        context.verify_mode = ssl.CERT_NONE

    # TODO: use our resolver
    try:
        sock = socket.create_connection((connect_host, connect_port))
    except OSError as e:
        nagios_unknown("Can't establish a connection to {}:{}:\n{}"
                       .format(connect_host, connect_port, str(e)))

    peername = sock.getpeername()[0]
    if ":" in peername:
        peername = "[{}]".format(peername)

    try:
        if args.starttls == "smtp":
            connect_smtp(sock)
        elif args.starttls == "ftp":
            connect_ftp(sock)
        elif args.starttls == "imap":
            connect_imap(sock)
        elif args.starttls == "xmpp":
            connect_xmpp(sock, args.host)
        elif args.starttls == "quassel":
            connect_quassel(sock, args.host)
    except (ProtocolError, OSError) as e:
        nagios_unknown("Failed to initiate STARTTLS to {} ({}:{}):\n{}"
                       .format(connect_host, peername, connect_port, str(e)))

    try:
        return context.wrap_socket(sock, server_hostname=args.host)
    except (ProtocolError, ssl.SSLError, OSError) as e:
        if isinstance(e, ssl.SSLError) and e.reason == "CERTIFICATE_VERIFY_FAILED":
            # pass exceptions concerning certificate validation to the caller
            raise
        else:
            nagios_unknown("Can't establish a TLS connection to {} ({}:{}):\n{}"
                           .format(connect_host, peername, connect_port, str(e)))


def smtp_read_response(sock: socket.socket) -> str:
    response = ""
    line_end = False
    while not line_end:
        line = sock.recv(1024)
        response += line.decode("ASCII")
        if len(line) < 4 or line[3] != "-":
            line_end = True
    return response


def connect_smtp(sock: socket.socket) -> None:
    smtp_read_response(sock)
    sock.sendall(b"EHLO openssl.client.net\r\n")

    response = smtp_read_response(sock)
    if "STARTTLS" not in response:
        raise ProtocolError("Server doesn't support STARTTLS")

    sock.sendall(b"STARTTLS\r\n")

    smtp_read_response(sock)


def connect_ftp(sock: socket.socket) -> None:
    buf = sock.recv(1024)

    sock.sendall(b"AUTH TLS\r\n")

    response = sock.recv(1024).decode("ASCII").strip("\r\n\t ")
    if not re.search(r"^2\d\d", response):
        raise ProtocolError("Server doesn't support STARTTLS ({})".format(response))


def connect_imap(sock: socket.socket) -> None:
    sock.recv(1024)
    sock.sendall(b". CAPABILITY\n")

    response = smtp_read_response(sock)
    if "STARTTLS" not in response:
        raise ProtocolError("Server doesn't support STARTTLS")

    sock.sendall(b". STARTTLS\n")

    smtp_read_response(sock)


def connect_xmpp(sock: socket.socket, host: str) -> None:
    sock.sendall("<stream:stream xmlns:stream='http://etherx.jabber.org/streams' "
                 "xmlns='jabber:client' to='{}' version='1.0'>".format(host).encode("ASCII"))

    buf = sock.recv(1024)
    if "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'" not in buf.decode("ASCII"):
        raise ProtocolError("Server doesn't support STARTTLS")

    sock.sendall(b"<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>")

    buf = sock.recv(1024)
    if "<proceed" not in buf.decode("ASCII"):
        raise ProtocolError("Server doesn't support STARTTLS")


def connect_quassel(sock: socket.socket, host: str) -> None:
    MAGIC = 0x42b33f00
    FEATURE_ENCRYPTION = 0x1
    PROTOCOL_DATAGRAM = 0x2
    PROTOCOL_END = (0x1 << 31)

    sock.sendall(struct.pack("!I", MAGIC | FEATURE_ENCRYPTION))
    sock.sendall(struct.pack("!I", PROTOCOL_DATAGRAM | PROTOCOL_END))

    try:
        response = struct.unpack("!I", sock.recv(4))[0]
    except struct.error:
        raise ProtocolError("No valid response from server")

    protocol_type = (response & 0xff)
    connection_features = (response >> 24)

    if not (protocol_type & PROTOCOL_DATAGRAM):
        raise ProtocolError("Server doesn't support the protocol")

    if not (connection_features & FEATURE_ENCRYPTION):
        raise ProtocolError("Server doesn't support TLS")


def main() -> None:
    parser = argparse.ArgumentParser(description=""" Nagios/Icinga plugin for checking DANE/TLSA records.
                                     It compares the DANE/TLSA record against the TLS certificate provided
                                     by a service.""")

    parser.add_argument("--host", "-H", dest="host", required=True, help="Hostname to check.")
    parser.add_argument("--port", "-p", type=int, required=True, help="TCP port to check.")
    parser.add_argument("--connect-host", "--ip", "-I", dest="connect_host",
                        help="Connect to this host instead of --host.")
    parser.add_argument("--connect-port", dest="connect_port", help="Connect to this port instead of --port.")
    parser.add_argument("--starttls",
                        choices=["smtp", "ftp", "imap", "xmpp", "quassel"],
                        help="Send the protocol-specific messages to enable TLS.")
    parser = argparse.ArgumentParser(description=""" Nagios/Icinga plugin for checking DANE/TLSA records.
                                     It compares the DANE/TLSA record against the TLS certificate provided
                                     by a service.""")

    parser.add_argument("--host", "-H", dest="host", required=True, help="Hostname to check.")
    parser.add_argument("--port", "-p", type=int, required=True, help="TCP port to check.")
    parser.add_argument("--connect-host", "--ip", "-I", dest="connect_host",
                        help="Connect to this host instead of --host.")
    parser.add_argument("--connect-port", dest="connect_port", help="Connect to this port instead of --port.")
    parser.add_argument("--starttls",
                        choices=["smtp", "ftp", "imap", "xmpp", "quassel"],
                        help="Send the protocol-specific messages to enable TLS.")
    parser.add_argument("--check-pkix", action="store_true",
                        help="Additionally perform traditional checks on the certificate "
                             "(ca trust path, hostname, expiry).")
    parser.add_argument("--min-days-valid", help="Minimum number of days a certificate has to be valid. "
                                                 "Format: INTEGER[,INTEGER]. "
                                                 "1st is #days for warning, 2nd is critical.")
    parser.add_argument("--no-dnssec", dest="dnssec", action="store_false",
                        help="Continue even when DNS replies aren't DNSSEC authenticated.")
    parser.add_argument("--nameserver", help="Use a custom nameserver.")
    parser.add_argument("--timeout", type=int, default=10, help="Network timeout in sec. Default: 10")
    parser.add_argument("--version", action="version", version="%(prog)s " + VERSION)
    args = parser.parse_args()

    pyver = sys.version_info
    if pyver[0] < 3 or (pyver[0] == 3 and pyver[1] < 4):
        nagios_unknown("check_dane requires Python >= 3.4")

    if args.port < 1 or args.port > 65535:
        nagios_unknown("Invalid port")

    if args.min_days_valid and not re.search(r"^\d+(,\d+)?$", args.min_days_valid):
        nagios_unknown("--check-cert-expire takes INTEGER[,INTEGER] as arguments")

    if args.timeout < 0:
        nagios_unknown("Invalid timeout argument")

    if args.connect_host:
        connect_host = args.connect_host
    else:
        connect_host = args.host

    if args.connect_port:
        connect_port = args.connect_port
    else:
        connect_port = args.port

    tlsa_records = get_tlsa_records(args)

    has_usage1_tlsa = False
    for record in tlsa_records:
        if record.usage == 1:
            has_usage1_tlsa = True
            break

    initial_check_pkix = (args.check_pkix or has_usage1_tlsa)

    try:
        # validate against PKIX if manually requested or we've found a usage=1 tlsa record
        ssl_sock = connect_to_host(connect_host, connect_port, args, initial_check_pkix)
        pkix_valid = initial_check_pkix
    except (ssl.CertificateError, ssl.SSLError) as e:
        if args.check_pkix:
            nagios_critical(str(e))
        else:
            ssl_sock = connect_to_host(connect_host, connect_port, args, False)
            pkix_valid = False
            pkix_error = str(e)

    cert_binary = ssl_sock.getpeercert(binary_form=True)
    cert_dict = ssl_sock.getpeercert()
    ssl_sock.close()

    dane_valid_cert = False

    for tlsa in tlsa_records:
        if validate_dane(cert_binary, pkix_valid, tlsa):
            dane_valid_cert = True
            break

    if not dane_valid_cert:
        # test if it would match if it were pkix_valid
        additional_msg = ""
        for tlsa in tlsa_records:
            if validate_dane(cert_binary, True, tlsa):
                additional_msg = "\nIt matches a TLSA usage=1 record but fails PKIX validation:\n" + pkix_error
                break

        nagios_critical("Certificate doesn't match TLSA record" + additional_msg)

    if pkix_valid and args.min_days_valid:
        days_parts = args.min_days_valid.split(",")

        if len(days_parts) == 2:
            timedelta_valid = check_cert_expiry(args, cert_dict, int(days_parts[0]), int(days_parts[1]))
        else:
            timedelta_valid = check_cert_expiry(args, cert_dict, int(days_parts[0]))

        expire_str = ", expires in {} days".format(timedelta_valid.days)
    else:
        expire_str = ""

    message = "{}:{} cert matches TLSA record".format(args.host, args.port)
    if not args.dnssec:
        message += " (DNSSEC not validated)"
    message += expire_str

    nagios_ok(message)


if __name__ == "__main__":
    main()


    message = "{}:{} cert matches TLSA record".format(args.host, args.port)
    if not args.dnssec:
        message += " (DNSSEC not validated)"
    message += expire_str

    nagios_ok(message)


if __name__ == "__main__":
    main()

# kate: space-indent on; indent-width 4;