#!/usr/bin/env python3

# This file is part of BigBlueButton.
#
# BigBlueButton is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# BigBlueButton is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with BigBlueButton.  If not, see <http://www.gnu.org/licenses/>.

# To install dependencies on Ubuntu:
# apt-get install python3 python3-lxml python3-pyicu

from lxml import etree
from collections import deque
from fractions import Fraction
import io
from icu import Locale, BreakIterator, UnicodeString
import unicodedata
import html
import logging
import json
import sys
import os
import argparse

logging.basicConfig(level=logging.DEBUG)

logger = logging.getLogger(__name__)


def webvtt_timestamp(ms):
    frac_s = int(ms % 1000)
    s = int(ms / 1000 % 60)
    m = int(ms / 1000 / 60 % 60)
    h = int(ms / 1000 / 60 / 60)
    return "{:02}:{:02}:{:02}.{:03}".format(h, m, s, frac_s)


class CaptionLine:
    def __init__(self):
        self.text = ""
        self.start_time = 0
        self.end_time = 0


class Caption:
    def __init__(self, locale):
        self.locale = locale
        self.text = UnicodeString()
        self.timestamps = list()
        self._del_timestamps = list()

    def apply_edit(self, i, j, timestamp, text):
        del_timestamp = None
        if j > i and i < len(self.timestamps):
            if self._del_timestamps[i] is not None:
                del_timestamp = self._del_timestamps[i]
            else:
                del_timestamp = self.timestamps[i]
            self._del_timestamps[i] = del_timestamp
            logger.debug(
                "Removing text %s at %d:%d, del_ts: %d",
                repr(str(self.text[i:j])),
                i,
                j,
                del_timestamp,
            )

        if len(text) > 0:
            logger.debug(
                "Inserting text %s at %d:%d, ts: %d", repr(str(text)), i, j, timestamp
            )

            if i < len(self.timestamps) and timestamp > self.timestamps[i]:
                timestamp = self._del_timestamps[i]
                if timestamp is None:
                    if i > 0:
                        timestamp = self.timestamps[i - 1]
                    else:
                        timestamp = self.timestamps[i]
                logger.debug("Out of order timestamps, using ts: %d", timestamp)

        self._del_timestamps[i:j] = [del_timestamp] * len(text)
        if i < len(self._del_timestamps):
            self._del_timestamps[i] = del_timestamp

        self.text[i:j] = text
        self.timestamps[i:j] = [timestamp] * len(text)

    def apply_record_events(self, events):
        record = False
        ts_offset = 0
        stop_ts = 0
        start_ts = None
        stop_pos = 0
        start_pos = None
        for event in events:
            if event["name"] == "record_status":
                status = event["status"]
                timestamp = event["timestamp"]

                if status and not record:
                    record = True
                    start_ts = timestamp
                    logger.debug("Recording started at ts: %d", start_ts)

                    # Find the position of the first character after recording
                    # started
                    start_pos = stop_pos
                    while (
                        start_pos < len(self.timestamps)
                        and self.timestamps[start_pos] < start_ts
                    ):
                        start_pos += 1

                    logger.debug("Replacing characters %d:%d", stop_pos, start_pos)
                    self.text[stop_pos:start_pos] = "\n"
                    self.timestamps[stop_pos:start_pos] = [stop_ts - ts_offset]

                    start_pos = stop_pos + 1
                    ts_offset += start_ts - stop_ts
                    logger.debug("Timestamp offset now %d", ts_offset)

                    stop_ts = None
                    stop_pos = None

                if not status and record:
                    record = False
                    stop_ts = timestamp
                    logger.debug("Recording stopped at ts: %d", stop_ts)

                    # Find the position of the first character after recording
                    # stopped, and apply ts offsets
                    stop_pos = start_pos
                    while (
                        stop_pos < len(self.timestamps)
                        and self.timestamps[stop_pos] < stop_ts
                    ):
                        self.timestamps[stop_pos] -= ts_offset
                        stop_pos += 1

        if record:
            logger.debug("No recording stop, applying final ts offsets")

            while start_pos < len(self.timestamps):
                self.timestamps[start_pos] -= ts_offset
                start_pos += 1

    @classmethod
    def from_events(cls, events, apply_record_events=True):
        captions = {}

        # Apply all of the caption events to generate the full text
        # with per-character timestamps
        for event in events:
            if event["name"] == "edit_caption_history":
                locale = event["locale"]
                i = event["start_index"]
                j = event["end_index"]
                timestamp = event["timestamp"]
                text = UnicodeString(event["text"])

                caption = captions.get(locale)
                if caption is None:
                    logger.info("Started caption stream for locale '%s'", locale)
                    captions[locale] = caption = cls(locale)

                caption.apply_edit(i, j, timestamp, text)

        if apply_record_events:
            for locale, caption in captions.items():
                logger.info("Applying recording events to locale '%s'", locale)
                caption.apply_record_events(events)

        logger.info("Generated %d caption stream(s)", len(captions))
        return captions

    def split_lines(self, max_length=32):
        lines = list()

        locale = Locale(self.locale)
        logger.debug("Using locale %s for word-wrapping", locale.getDisplayName(locale))

        break_iter = BreakIterator.createLineInstance(locale)
        break_iter.setText(self.text)

        line = CaptionLine()
        line_start = 0
        prev_break = 0
        next_break = break_iter.following(prev_break)

        # Super simple "greedy" line break algorithm.
        while prev_break < len(self.text):
            status = break_iter.getRuleStatus()

            line_end = next_break
            logger.debug("text len: %d, line end: %d", len(self.text), line_end)
            while line_end > line_start and (
                self.text[line_end - 1].isspace()
                or unicodedata.category(self.text[line_end - 1]) in ["Cc", "Mn"]
            ):
                line_end -= 1

            do_break = False
            text_section = unicodedata.normalize(
                "NFC", str(self.text[line_start:line_end])
            )
            timestamps_section = self.timestamps[line_start:next_break]
            start_time = min(timestamps_section)
            end_time = max(timestamps_section)
            if len(text_section) > max_length:
                if prev_break == line_start:
                    # Over-long string. Just chop it into bits
                    next_break = prev_break + max_length
                    continue
                else:
                    next_break = prev_break
                    do_break = True

            else:
                # Status [100,200] indicates a required (hard) line break
                if next_break >= len(self.text) or (status >= 100 and status < 200):
                    line.text = text_section
                    line.start_time = start_time
                    line.end_time = end_time
                    do_break = True

            if do_break:
                logger.debug(
                    "text section %d -> %d (%d): %s",
                    line.start_time,
                    line.end_time,
                    len(line.text),
                    repr(line.text),
                )
                lines.append(line)
                line = CaptionLine()
                line_start = next_break
            else:
                line.text = text_section
                line.start_time = start_time
                line.end_time = end_time

            prev_break = next_break
            next_break = break_iter.following(prev_break)

        return lines

    def write_webvtt(self, f):
        # Write magic
        f.write("WEBVTT\n\n".encode("utf-8"))

        lines = self.split_lines()

        i = 0
        # Don't generate cues for empty lines
        while i < len(lines):
            if len(lines[i].text) == 0:
                del lines[i]
                continue
            i += 1

        i = 0
        while i < len(lines):
            line = lines[i]
            i += 1

            text = line.text
            start_time = line.start_time
            end_time = line.end_time

            # Apply some duration cleanup heuristics to give some reasonable
            # line durations
            duration = end_time - start_time
            # A minimum per-character time for display (up to 3.2s for 32char)
            if duration < 100 * len(text):
                duration = 100 * len(text)
            # Don't show a caption (even a short one) for less than 1s
            if duration < 1000:
                duration = 1000
            # Make lines go away if they've been showing for >16 seconds
            if duration > 16000:
                duration = 16000
            end_time = start_time + duration

            if i < len(lines):
                next_line = lines[i]
                next_start_time = next_line.start_time

                # If this end time would overlap with the next line, ether
                # merge them into a single cue or un-overlap them if there's
                # already multiple lines.
                if end_time > next_start_time:
                    if len(text.splitlines()) < 2:
                        next_line.text = text + "\n" + next_line.text
                        next_line.start_time = start_time
                        continue
                    else:
                        end_time = next_start_time

                # If the next line is close after the current line, make the
                # timestamps continuous so the subtitle doesn't "flicker"
                if next_start_time - end_time < 500:
                    end_time = next_start_time

            f.write(
                "{} --> {}\n".format(
                    webvtt_timestamp(start_time), webvtt_timestamp(end_time)
                ).encode("utf-8")
            )
            f.write(html.escape(text, quote=False).encode("utf-8"))
            f.write("\n\n".encode("utf-8"))

    def caption_desc(self):
        locale = Locale(self.locale)
        return {"locale": self.locale, "localeName": locale.getDisplayName(locale)}


def parse_record_status(event, element):
    userId = element.find("userId")
    status = element.find("status")

    event["name"] = "record_status"
    event["user_id"] = userId.text
    event["status"] = status.text == "true"


def parse_caption_edit(event, element):
    locale = element.find("locale")
    text = element.find("text")
    startIndex = element.find("startIndex")
    endIndex = element.find("endIndex")
    localeCode = element.find("localeCode")

    event["name"] = "edit_caption_history"
    event["locale_name"] = locale.text
    if localeCode is not None:
        event["locale"] = localeCode.text
    else:
        # Fallback for missing 'localeCode'
        event["locale"] = "en"
    if text.text is None:
        event["text"] = ""
    else:
        event["text"] = text.text
    event["start_index"] = int(startIndex.text)
    event["end_index"] = int(endIndex.text)


def parse_events(directory="."):
    start_time = None
    have_record_events = False
    events = deque()

    root = etree.parse("{}/events.xml".format(directory))
    for element in root.iter("event"):
        try:
            event = {}

            # Convert timestamps to be in seconds from recording start
            timestamp = int(element.attrib["timestamp"])
            if not start_time:
                start_time = timestamp
            timestamp = timestamp - start_time

            # Only need events from these modules
            if not element.attrib["module"] in ["CAPTION", "PARTICIPANT"]:
                continue

            event["name"] = name = element.attrib["eventname"]
            event["timestamp"] = timestamp

            if name == "RecordStatusEvent":
                parse_record_status(event, element)
                have_record_events = True
            elif name == "EditCaptionHistoryEvent":
                parse_caption_edit(event, element)
                if event["locale"] is None:
                    logger.warn(
                        "Skipping invalid caption event with unset locale. See https://github.com/bigbluebutton/bigbluebutton/issues/19178 for details"
                    )
                    continue
            else:
                logger.debug("Unhandled event: %s", name)
                continue

            events.append(event)
        finally:
            element.clear()

    if not have_record_events:
        # Add a fake record start event to the events list
        event = {
            "name": "record_status",
            "user_id": None,
            "timestamp": 0,
            "status": True,
        }
        events.appendleft(event)

    return events


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Generate WebVTT files from BigBlueButton captions",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "-i",
        "--input",
        metavar="PATH",
        help="input directory with events.xml file",
        default=os.curdir,
    )
    parser.add_argument(
        "-o", "--output", metavar="PATH", help="output directory", default=os.curdir
    )
    args = parser.parse_args()

    rawdir = args.input
    outputdir = args.output

    logger.info("Reading recording events file")
    events = parse_events(rawdir)

    logger.info("Generating caption data from recording events")
    captions = Caption.from_events(events)
    for locale, caption in captions.items():
        filename = os.path.join(outputdir, "caption_{}.vtt".format(locale))
        logger.info("Writing captions for locale %s to %s", locale, filename)
        with open(filename, "wb") as f:
            caption.write_webvtt(f)

    filename = os.path.join(outputdir, "captions.json")
    logger.info("Writing captions index file to %s", filename)

    caption_descs = [caption.caption_desc() for caption in captions.values()]
    with open(filename, "w") as f:
        json.dump(caption_descs, f)
