#!/usr/bin/env python3

"""
Parse a markdown file and pretty-print to the terminal, including colors, glyphs, and text wrapping.
The supported Markdown syntax is given by: https://github.com/Python-Markdown/markdown
"""

import os
import re
import sys
import subprocess
import markdown
import xml.etree.ElementTree as ET
from html import unescape
from textwrap import wrap

from typing import List, Dict, Union, Tuple, Optional, Iterator, TextIO


class AnsiUtils:
    _ansi_re: re.Pattern = re.compile('\x1b\\[[0-9]+m')
    _ansi_codes: Dict[str, str] = {k: '\x1b[{}m'.format(v) for k, v in {
        'reset': 0,
        'bold': 1, 'dark': 2, 'italic': 3, 'underline': 4, 'reverse': 7, 'strikethrough': 9,
        'grey': 30, 'red': 31, 'green': 32, 'yellow': 33, 'blue': 34, 'magenta': 35, 'cyan': 36, 'white': 37,
    }.items()}

    @classmethod
    def str_len(cls, text: Union[str, bytes]) -> int:
        """
        Exclude zero-width ansi codes from string length calculation.
        """
        return len(cls._ansi_re.sub('', text)) if isinstance(text, (str, bytes)) else len(text)

    @classmethod
    def str_prefix(cls, rx_class: str, text: str) -> str:
        return re.match(r'^' + rx_class + '*', text).group(0)  # type: ignore

    @classmethod
    def _line_continuation(cls, text: List[str], indent: int = 0) -> Iterator[str]:
        """
        Continue for example underlines in the next line, but after a possible space indent or prefix.
        """
        active_codes: List[str] = []
        for line in text:
            if len(active_codes):  # from previous line
                indent_width: int = indent + len(cls.str_prefix(' ', line[indent:]))
                line = line[:indent_width] + ''.join(active_codes) + line[indent_width:]
                active_codes.clear()  # will be re-found below

            for match in cls._ansi_re.finditer(line):
                code: str = match.group(0)
                if code == cls._ansi_codes['reset']:
                    active_codes.clear()
                else:
                    active_codes.append(code)

            yield line + cls._ansi_codes['reset'] if len(active_codes) else line

    @classmethod
    def wrap(cls, *args, **kwargs) -> Iterator[str]:
        """
        Wrapper for textwrap.wrap() that skips ansi codes for string lengths and ends/continues codes across lines.
        Uses a similar hacky approach as the `ansiwrap` package, by monkey-patching the textwrap's len().
        """
        from unittest.mock import patch
        with patch('textwrap.len', cls.str_len):
            lines: List[str] = wrap(*args, **kwargs)
        yield from cls._line_continuation(lines, indent=min(len(kwargs.get("initial_indent", "")),
                                                            len(kwargs.get("subsequent_indent", ""))))

    @classmethod
    def colored(cls, text: Optional[str], modifiers: List[str], do_colors: bool = True) -> str:
        """
        Surround the string with ansi codes that correspond to the given modifiers.
        """
        if not text or not do_colors or not modifiers:
            return text if text else ''
        prefix_len: int = len(cls.str_prefix('[ \t\n\v\f\r]', text))
        parts, text = [text[:prefix_len]], text[prefix_len:]
        parts.extend([cls._ansi_codes[_] for _ in modifiers])
        parts.extend([text, cls._ansi_codes['reset']])
        return ''.join(parts)


class EtCliParser:
    """
    HTML pretty-printer that only supports tags and structures that markdown could generate.
    """

    def __init__(self, root: ET.Element, width: int, wrap_pre: bool) -> None:
        self._root: ET.Element = root
        self._width: int = width
        self._wrap_pre: bool = wrap_pre
        self._indent: int = 2
        self._placeholder: str = u"\uFFFD"

    def _wrap(self, line: str, indent: int = 0, pre: bool = False, do_wrap: bool = True,
              prefix: str = "", prefix_cont: str = "") -> Iterator[str]:
        """
        Transform a string into a list of strings by wrapping to the configured width, with the given indentation width
        and line-prefix.
        """
        indentation: str = " " * indent * self._indent
        line = line.replace("\t", " " * self._indent)
        line = line.rstrip() if pre else re.sub(r"[ \t\n\v\f\r]+", " ", line).strip()

        if not line or not do_wrap:
            yield "".join([indentation, prefix, line])
            return

        if pre:
            line_indent: str = AnsiUtils.str_prefix(" ", line)
            prefix_cont += line_indent

        yield from AnsiUtils.wrap(
            line,
            width=self._width,
            initial_indent=indentation + prefix,
            subsequent_indent=indentation + prefix_cont,
            expand_tabs=True, tabsize=self._indent,
            replace_whitespace=True, drop_whitespace=True,
            fix_sentence_endings=False,
            break_long_words=True, break_on_hyphens=not pre,
        )

    def parse(self) -> Iterator[str]:
        """
        Successively parse the root tree and give formatted lines to be printed to the terminal.
        """
        first: bool = True
        for elem in self._root:
            if not first:
                yield ""
            first = False

            if elem.tag in ["h1", "h2", "h3", "h4"]:
                yield from self._parse_h(elem)
            elif elem.tag in ["p"]:
                yield from self._parse_p(elem)
            elif elem.tag in ["ul", "ol"]:
                yield from self._parse_ul(elem)
            elif elem.tag in ["pre"]:
                yield from self._parse_code(elem)
            elif elem.tag in ["hr"]:
                yield from self._parse_hr(elem)
            elif elem.tag in ["blockquote"]:
                yield from self._parse_quote(elem)
            else:
                yield from self._parse_unknown(elem)
                continue

            if elem.tail is not None and elem.tail.strip():
                yield from self._wrap(AnsiUtils.colored(elem.tail, ["reverse"]))

    def _to_html(self, elem: ET.Element) -> str:
        return ET.tostring(elem, encoding="unicode", method="xml")

    def _to_text(self, elem: ET.Element, include_tail: bool) -> str:
        return "".join(["".join(self._decode_text(_) for _ in elem.itertext()),
                        self._decode_text(elem.tail) if include_tail else ""])

    def _decode_text(self, text: Optional[str]) -> str:
        if not text:
            return ""
        text = markdown.util.HTML_PLACEHOLDER_RE.sub(self._placeholder, text)
        text = unescape(text)
        text = re.sub(r'[\x00-\x08\x0b\x0d-\x1f\x7f-\x9f]+', self._placeholder, text)  # control characters but \t\r\n
        return text

    def _single_child(self, elem: ET.Element) -> Optional[ET.Element]:
        child: Optional[ET.Element] = list(elem)[0] if len(elem) == 1 else None
        if child is not None and not elem.text and not child.tail:
            return child  # there is only one child, so we can "flatten" it in some cases
        return None

    def _parse_unknown(self, elem: ET.Element) -> Iterator[str]:
        for line in self._to_html(elem).splitlines(keepends=False):
            yield from self._wrap(AnsiUtils.colored(line, ["reverse"]), pre=True)

    def _parse_p(self, elem: ET.Element) -> Iterator[str]:
        child: Optional[ET.Element] = self._single_child(elem)
        if child is not None and child.tag == "code":
            yield from self._parse_code(child)  # treat a p that only consists of code as pre, e.g. for ```-style blocks
        else:
            yield from self._wrap(self._parse_string_block(elem, include_tail=False))

    def _parse_h(self, elem: ET.Element) -> Iterator[str]:
        depth: int = int(elem.tag[-1])
        yield from self._wrap(AnsiUtils.colored(self._to_text(elem, include_tail=False),
                                                ["bold", "underline"] if depth < 3 else ["underline"]))

    def _parse_hr(self, _: ET.Element) -> Iterator[str]:
        yield "─" * self._width

    def _parse_code(self, elem: ET.Element) -> Iterator[str]:
        for line in self._to_text(elem, include_tail=False).splitlines(keepends=False):
            if self._wrap_pre:
                yield from self._wrap(AnsiUtils.colored(line, ["dark"]), pre=True, prefix="┃ ", prefix_cont="┠ ")
            else:
                yield from self._wrap(AnsiUtils.colored(line, ["dark"]), pre=True, do_wrap=False)

    def _parse_quote(self, elem: ET.Element, depth: int = 0) -> Iterator[str]:
        for child in elem:
            if child.tag == "blockquote":
                yield from self._parse_quote(child, depth + 1)
            elif child.tag == "p":
                for line in self._to_text(child, include_tail=True).splitlines(keepends=False):
                    if self._wrap_pre:
                        prefix: str = "│" * depth
                        yield from self._wrap(AnsiUtils.colored(line, ["dark"]), pre=True,
                                              prefix=prefix + "│ ", prefix_cont=prefix + "├ ")
                    else:
                        prefix = (">" * (depth + 1)) + " "
                        yield from self._wrap(AnsiUtils.colored(line, ["dark"]), pre=True, do_wrap=False,
                                              prefix=prefix, prefix_cont=prefix)
            else:
                yield from self._parse_unknown(child)

    def _parse_ul(self, elem: ET.Element, depth: int = 0) -> Iterator[str]:
        for child in elem:
            if child.tag in ["ul", "ol"]:
                yield from self._parse_ul(child, depth + 1)
            elif child.tag == "li":
                uls: List[ET.Element] = [_ for _ in child if _.tag in ["ul", "ol"]]
                for ul in uls:
                    child.remove(ul)
                yield from self._wrap(self._parse_string_block(child, include_tail=True), indent=depth + 1,
                                      prefix="• ", prefix_cont="  ")
                for ul in uls:
                    yield from self._parse_ul(ul, depth + 1)
            else:
                yield from self._parse_unknown(child)
        if elem.tail and elem.tail.strip():
            yield from self._wrap(self._decode_text(elem.tail), indent=depth + 1,
                                  prefix=self._placeholder + " ", prefix_cont="  ")

    def _parse_string_blocks(self, elems: List[ET.Element], do_colors: bool) -> str:
        return "".join(self._parse_string_block(_, include_tail=True, do_colors=do_colors) for _ in elems)

    def _parse_string_block(self, elem: ET.Element, include_tail: bool, do_colors: bool = True) -> str:
        """
        Helper for an inline text block, which gives a single line (to be wrapped).
        A bit messy, follows no software engineering, and does not yet support recursion into nested inline formats.
        """
        if elem.tag in ["a", "img"]:
            if elem.tag == "a":
                title: str = self._decode_text(elem.attrib.get("title", ""))
                src: str = self._decode_text(elem.attrib.get("href", "#")) or "#"
                text: str = self._decode_text(elem.text) + self._parse_string_blocks(list(elem), do_colors=False)
            else:
                title, src, text = self._decode_text(elem.attrib.get("title")), \
                                   self._decode_text(elem.attrib.get("src", "#")) or "#", \
                                   self._decode_text(elem.attrib.get("alt"))
            if title and not text or title == text:
                title, text = "", title
            if src == text:
                src = ""
            return "".join([
                AnsiUtils.colored(text, ["blue", "underline"], do_colors=do_colors),
                "{}[{}{}]".format(" " if text else "", src, " – " + title if title else "") if src or title else "",
                self._decode_text(elem.tail) if include_tail else "",
            ])
        elif elem.tag in ["em", "strong", "code"]:
            modifiers: Dict[str, List[str]] = {
                "em": ["bold"],
                "strong": ["bold", "underline"],
                "code": ["bold", "dark"],
            }
            text = self._decode_text(elem.text) + self._parse_string_blocks(list(elem), do_colors=False)
            return AnsiUtils.colored(text, modifiers[elem.tag], do_colors=do_colors) + (self._decode_text(elem.tail)
                                                                                        if include_tail else "")
        elif elem.tag in ["br"]:
            return "\n" + self._decode_text(elem.tail) if include_tail else ""
        elif elem.tag in ["p", "li"]:
            return "".join([
                self._decode_text(elem.text),
                self._parse_string_blocks(list(elem), do_colors=do_colors),
                self._decode_text(elem.tail) if include_tail else "",
            ])
        else:
            return AnsiUtils.colored(self._to_html(elem), ["reverse"], do_colors=do_colors)


class MarkdownCommentPreProcessor(markdown.preprocessors.Preprocessor):  # type: ignore
    """
    The HTML parser preprocessor does not support HTML-style comments found in markdown sources.
    So we strip them from the source beforehand (and even before whitespace/newline normalization).
    """

    def run(self, lines: List[str]) -> List[str]:
        source: str = "\n".join(lines)
        pos: int = 0
        while True:
            start: int = source.find("<!--", pos)
            if start < 0:
                break
            end: int = source.find("-->", start+4)
            if end < 0:
                break
            source = source[:start] + source[end + 3:]
            pos = start
        return source.split("\n")

    @classmethod
    def register(cls, md: markdown.Markdown) -> None:
        md.preprocessors.register(cls(md), cls.__name__, 35)


class MarkdownCodeBlockProcessor(markdown.blockprocessors.BlockProcessor):  # type: ignore
    """
    Support ```-style code blocks that seem to be quite popular nowadays (the 'fenced_code' extension would be rather
    strict and doesn't seem to reliably work anyway).
    Use the same <pre><code> approach such that other parses like EmptyBlockProcessor can recognize it.
    Not properly started or ended blocks are treated by splitting.
    Indented blocks, such as for example in list items, are not supported.
    """

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self._sep: str = "```"

    def test(self, parent: ET.Element, block: str) -> bool:
        return block.startswith(self._sep) or "\n" + self._sep in block

    def _split_list(self, lines: List[str], sep: str, is_prefix: bool = False) -> Tuple[List[str], Optional[List[str]]]:
        for i in range(len(lines)):
            if lines[i] == sep or (is_prefix and lines[i].startswith(sep)):
                return lines[:i], lines[i:]
        return lines, None

    def run(self, parent: ET.Element, blocks: List[str]) -> bool:
        if not blocks[0].startswith(self._sep):  # treat blocks that are started within a block by splitting
            pre_lines, post_lines = self._split_list(blocks.pop(0).splitlines(keepends=False),
                                                     self._sep, is_prefix=True)
            if post_lines is not None:  # should not happen, tested before
                blocks.insert(0, "\n".join(post_lines))
                blocks.insert(0, "\n".join(pre_lines))
            return False  # continue with next processor on blocks[0]

        lang: str = ""
        code_blocks: List[str] = []
        while len(blocks):
            # directly consume subsequent blocks as we cannot recognize them such as classically indented code blocks.
            lines: List[str] = blocks.pop(0).splitlines(keepends=False)

            if not len(code_blocks):  # first block line
                lang = lines.pop(0)[len(self._sep):]
                if lang.endswith(self._sep):  # and one line only...
                    code_blocks.append(lang[:-len(self._sep)])
                    lang = ""
                    if len(lines):
                        blocks.insert(0, "\n".join(lines))
                    break

            lines, rest_lines = self._split_list(lines, self._sep, is_prefix=False)
            code_blocks.append("\n".join(lines))
            if rest_lines is not None:  # treat blocks that are started within a block by splitting
                blocks.insert(0, "\n".join(rest_lines[1:]))
                break  # end found in any case

        pre: ET.Element = ET.SubElement(parent, "pre")
        code: ET.Element = ET.SubElement(pre, "code")
        code.text = markdown.util.AtomicString("\n\n".join(markdown.util.code_escape(block) for block in code_blocks))
        if lang:
            code.set("class", "language-" + lang.lower().replace(" ", "-"))

        return True  # re-start all processors on blocks[0]

    @classmethod
    def register(cls, md: markdown.Markdown) -> None:
        md.parser.blockprocessors.register(cls(md.parser), cls.__name__, 75)


class MarkdownETProcessor(markdown.treeprocessors.Treeprocessor):  # type: ignore
    """
    Registered Markdown processor to capture the element tree before it gets serialized.
    """

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self._root_tag: str = 'div'
        self._et: Optional[ET.Element] = None

    def run(self, root: ET.Element) -> Optional[ET.Element]:
        if root.tag == self._root_tag:
            self._et = root
            return ET.Element(self._root_tag)  # optimization and should prevent in-place ET modifications
        return None

    def _get_root(self) -> Optional[ET.Element]:
        return self._et

    @classmethod
    def convert(cls, text: str) -> Optional[ET.Element]:
        md: markdown.Markdown = markdown.Markdown()
        MarkdownCommentPreProcessor.register(md)
        MarkdownCodeBlockProcessor.register(md)

        processor: MarkdownETProcessor = cls(md)
        md.treeprocessors.register(processor, processor.__class__.__name__, 0)
        md.convert(text)  # discard output, root replaced anyway
        return processor._get_root()


def spawn_pager() -> Optional[subprocess.Popen]:
    pager: List[str] = os.getenv("PAGER", "/usr/bin/less -SR").split(" ")
    try:
        return subprocess.Popen(pager, shell=False, executable=None,
                                cwd=None, restore_signals=True, env=None,
                                bufsize=0, stdin=subprocess.PIPE, close_fds=True,
                                encoding="utf-8", errors="replace")
    except (OSError, ValueError, subprocess.SubprocessError):
        return None


def parse(text: str, columns: int, as_html: bool, wrap_pre: bool, pager: bool) -> bool:
    """
    Pretty-print Markdown text to stdout.
    """

    root: Optional[ET.Element] = MarkdownETProcessor.convert(text)
    if root is None:
        print("Cannot parse markdown", file=sys.stderr, flush=True)
        return False
    if as_html:
        ET.dump(root)
        return True
    parser: EtCliParser = EtCliParser(root, columns, wrap_pre)

    pager_proc: Optional[subprocess.Popen] = None
    try:
        if pager:
            pager_proc = spawn_pager()
            if pager_proc is None:
                print("Cannot spawn pager", file=sys.stderr, flush=True)
                return False

        out_fp: TextIO = pager_proc.stdin if pager_proc is not None else sys.stdout  # type: ignore
        for line in parser.parse():
            try:
                print(line, file=out_fp, flush=True)
            except BrokenPipeError:
                sys.stderr.close()
                return False
        out_fp.close()

        if pager_proc is not None:
            pager_proc.wait()
            pager_proc = None
    finally:
        if pager_proc is not None:
            pager_proc.send_signal(15)  # SIGTERM
            pager_proc.wait()

    return True


def get_term_width(default: int = 0) -> int:
    """
    Try to detect the current terminal's width by $COLUMNS, ioctl(), or given default.
    """

    import fcntl
    import struct

    try:
        return int(os.getenv("COLUMNS", None))  # type: ignore
    except (ValueError, TypeError):
        pass

    for fd in [sys.stdout.fileno(), sys.stderr.fileno()]:  # try stderr in case stdout is piped somewhere
        try:
            ws_st: bytes = fcntl.ioctl(fd, 21523, b"\x00\x00" * 4)  # TIOCGWINSZ
            return struct.unpack("hhhh", ws_st)[1]  # ws_row, ws_col, ws_xpixel, ws_ypixel shorts
        except (OSError, struct.error):
            pass

    return default


def get_input(filename: str) -> Optional[str]:
    try:
        if filename == "-":
            text: str = sys.stdin.read()
        else:
            with open(filename, "r") as fp:
                text = fp.read()
    except OSError as e:
        print(f"Cannot read from '{filename}': {str(e)}", file=sys.stderr, flush=True)
        return None

    try:
        sys.stdin.close()
    except OSError:
        pass

    return text


def main() -> int:
    from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter

    class Formatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter):
        pass
    parser = ArgumentParser(description=__doc__, formatter_class=Formatter)

    parser.add_argument("--columns", "-c", type=int, default=None,
                        help="terminal width, autodetect from $COLUMNS or ioctl() if not given, fallback 80"
                             " – tip: export or set --columns=$COLUMNS to take the shell width when using a pager")
    parser.add_argument("--html", action="store_const", const=True, default=False,
                        help="print parsed HTML output instead, mostly useful for debugging")
    parser.add_argument("--no-wrap-pre", action="store_const", const=True, default=False,
                        help="don't prefix and wrap code/quote blocks"
                        " – useful if you want to copy/paste or are using a pager with horizontal scroll")
    parser.add_argument("--pager", "-p", action="store_const", const=True, default=False,
                        help="don't write to stdout but spawn a pager instead, as given by $PAGER or 'less' by default")
    parser.add_argument("file", metavar="file.md",
                        help="markdown file to print, '-' for stdin")
    args = parser.parse_args()

    try:
        text: Optional[str] = get_input(args.file)
        if text is None:
            return 1
        elif not text.strip():
            return 0
        else:
            return 0 if parse(text,
                              columns=args.columns if args.columns is not None else get_term_width(80),
                              as_html=args.html,
                              wrap_pre=not args.no_wrap_pre,
                              pager=args.pager) else 1
    except KeyboardInterrupt:
        print("Interrupt", file=sys.stderr, flush=True)
        return 1


if __name__ == "__main__":
    sys.exit(main())