#!/usr/bin/python3 -sP
# Compare SMB packets from 2 network capture files
#
# Copyright (C) 2019 Aurelien Aptel <aurelien.aptel@gmail.com>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import curses
import smbcmp
import re

# curses pair id
COL_DEFAULT = 0
COL_RED = 1
COL_GREEN = 2
COL_BLACK = 3

def main():
    args = smbcmp.parse_args()

    if args.fileb:
        curses.wrapper(diff_view_main, args)
    else:
        curses.wrapper(single_view_main, args)

def curses_init():
    curses.initscr()
    curses.curs_set(0)
    curses.use_default_colors()

    curses.init_pair(COL_RED, curses.COLOR_RED, -1)
    curses.init_pair(COL_GREEN, curses.COLOR_GREEN, -1)
    curses.init_pair(COL_BLACK, curses.COLOR_BLACK, -1)

def diff_split(w, h, vratio):
    lw = (w-1)//2
    rw = w-1-lw
    th = int((h-1)*vratio)
    bh = h-1-th
    return lw, rw, th, bh

def single_split(w, h, vratio):
    th = int((h-1)*vratio)
    bh = h-1-th
    return th, bh

def single_view_main(stdscr, args):
    curses_init()

    H = curses.LINES
    W = curses.COLS
    th_ratio = smbcmp.CONF['global'].getfloat('vsplit_ratio', .5)
    th, bh = single_split(W, H, th_ratio)

    twin = curses.newwin(th, W, 0, 0)
    tbuf = TraceViewer(twin, smbcmp.strip_packet_no(args.filea))

    bwin = curses.newwin(bh, W, th+1, 0)
    bbuf = BufferViewer(bwin, "")
    resize = False

    empty = tbuf.is_empty()
    if empty:
        tbuf.buf.set_data("No SMB packet in capture")
        bbuf.set_data("No SMB packet")
    else:
        bbuf.set_data(smbcmp.smb_packet(*tbuf.get_packet()))

    while True:
        stdscr.refresh()
        tbuf.refresh()
        bbuf.refresh()

        # y, x, ch, length
        stdscr.hline(th, 0, 0, W)

        if not empty:
            last_packet = tbuf.get_packet()[1]

        k = stdscr.getkey()
        if k == 'q':
            return

        elif k in (smbcmp.KEY['lwin_next'], smbcmp.KEY['rwin_next'], smbcmp.KEY['top_next']):
            tbuf.move(+1)
        elif k in (smbcmp.KEY['lwin_prev'], smbcmp.KEY['rwin_prev'], smbcmp.KEY['top_prev']):
            tbuf.move(-1)

        elif k == smbcmp.KEY['bwin_next']:
            bbuf.move(+1)
        elif k == smbcmp.KEY['bwin_prev']:
            bbuf.move(-1)

        elif k == smbcmp.KEY['vsplit_up']:
            th_ratio += -.1
            resize = True
        elif k == smbcmp.KEY['vsplit_down']:
            th_ratio += +.1
            resize = True

        elif k == 'KEY_RESIZE':
            resize = True

        if resize:
            H, W = stdscr.getmaxyx()
            th, bh = single_split(W, H, th_ratio)
            twin.resize(th, W)
            bwin.resize(bh, W)
            bwin.mvwin(th+1, 0)
            resize = False

        if not empty and tbuf.get_packet()[1] != last_packet:
            bbuf.set_data(smbcmp.smb_packet(*tbuf.get_packet()))

def diff_view_main(stdscr, args):
    curses_init()
    H = curses.LINES
    W = curses.COLS
    th_ratio = smbcmp.CONF['global'].getfloat('vsplit_ratio', .5)
    lw, rw, th, bh = diff_split(W, H, th_ratio)

    # h, w, y, x
    lwin = curses.newwin(th, lw, 0, 0)
    lbuf = TraceViewer(lwin, smbcmp.strip_packet_no(args.filea))

    rwin = curses.newwin(th, rw, 0, lw+1)
    rbuf = TraceViewer(rwin, smbcmp.strip_packet_no(args.fileb))

    pdml = args.mode == 'pdml'
    resize = False
    bwin = curses.newwin(bh, W, th+1, 0)


    empty = lbuf.is_empty() or rbuf.is_empty()
    if lbuf.is_empty():
        lbuf.buf.set_data("No SMB packet in capture")
    if rbuf.is_empty():
        rbuf.buf.set_data("No SMB packet in capture")

    if pdml:
        bbuf = DiffViewer(bwin)
    else:
        bbuf = BufferViewer(bwin, "", diff=True)
        if not empty:
            bbuf.set_data(smbcmp.smb_diff(lbuf.get_packet(), rbuf.get_packet()))

    while True:
        rbuf.refresh()
        lbuf.refresh()
        bbuf.refresh()

        # y, x, ch, length
        stdscr.vline(0, lw, 0, th)
        stdscr.hline(th, 0, 0, W)

        if not empty:
            last_lpacket = lbuf.get_packet()[1]
            last_rpacket = rbuf.get_packet()[1]

        k = stdscr.getkey()
        if k == 'q':
            return

        elif k == smbcmp.KEY['rwin_next']:
            rbuf.move(+1)
        elif k == smbcmp.KEY['rwin_prev']:
            rbuf.move(-1)

        elif k == smbcmp.KEY['lwin_next']:
            lbuf.move(+1)
        elif k == smbcmp.KEY['lwin_prev']:
            lbuf.move(-1)

        elif k == smbcmp.KEY['bwin_next']:
            bbuf.move(+1)
        elif k == smbcmp.KEY['bwin_prev']:
            bbuf.move(-1)

        elif k == smbcmp.KEY['vsplit_up']:
            th_ratio += -.1
            resize = True
        elif k == smbcmp.KEY['vsplit_down']:
            th_ratio += +.1
            resize = True

        elif k == smbcmp.KEY['top_next']:
            lbuf.move(+1)
            rbuf.move(+1)
        elif k == smbcmp.KEY['top_prev']:
            lbuf.move(-1)
            rbuf.move(-1)

        elif k == smbcmp.KEY['toggle_ignore']:
            bbuf.toggle_ignore_field()

        elif k == 'KEY_RESIZE':
            resize = True


        if resize:
            H, W = stdscr.getmaxyx()
            lw, rw, th, bh = diff_split(W, H, th_ratio)
            lwin.resize(th, lw)
            rwin.resize(th, rw)
            rwin.mvwin(0, lw+1)
            bwin.resize(bh, W)
            bwin.mvwin(th+1, 0)
            resize = False

        if not empty and (lbuf.get_packet()[1] != last_lpacket or rbuf.get_packet()[1] != last_rpacket):
            if pdml:
                bbuf.diff_packets(lbuf.get_packet(), rbuf.get_packet())
            else:
                bbuf.set_data(smbcmp.smb_diff(lbuf.get_packet(),
                                       rbuf.get_packet()))

class BufferViewer(object):
    def __init__(self, win, data, hl=True, diff=False, rxhl=None):
        self.win = win
        self.top = 0
        self.cursor = 0
        self.hl = hl
        self.diff = diff
        self.rxhl = None
        if rxhl:
            self.set_rxhl(rxhl)
        self.set_data(data)

    def set_data(self, data):
        self.data = data
        self.lines = self.data.split("\n")
        h = len(self.lines)-1
        self.top = max(0, min(h, self.top))
        self.cursor = max(0, max(self.top, min(h, self.cursor)))
        self.refresh()

    def set_rxhl(self, rxhl):
        self.rxhl = [(re.compile(rx), col) for rx, col in rxhl]

    def height(self):
        return self.win.getmaxyx()[0]

    def width(self):
        return self.win.getmaxyx()[1]

    def is_cursor_at_bot(self):
        return self.cursor - self.top == self.height()-1

    def is_cursor_at_top(self):
        return self.cursor == self.top

    def move(self, direction):
        if direction > 0 and self.cursor < len(self.lines)-1:
            if self.is_cursor_at_bot():
                self.top += 1
            self.cursor += 1
        elif direction < 0 and self.cursor > 0:
            if self.is_cursor_at_top():
                self.top -= 1
            self.cursor -= 1

    def refresh(self):
        self.win.erase()
        for nr, line in enumerate(self.lines[self.top:self.top+self.height()]):
            line = line.replace('\t', '    ')
            line = line[:self.width()-1]
            line = line.ljust(self.width()-1)
            attr = curses.A_NORMAL

            def write(s, a, x=0):
                if self.hl and self.cursor - self.top == nr:
                    a = a|curses.A_REVERSE
                self.win.addstr(nr, x, s, a)

            if self.diff:
                if line[0] == '-':
                    attr = attr | curses.color_pair(COL_RED)
                elif line[0] == '+':
                    attr = attr | curses.color_pair(COL_GREEN)

            write(line, attr)

            if not self.diff and self.rxhl:
                for rx, col in self.rxhl:
                    for m in rx.finditer(line):
                        write(m.group(1), col, x=m.start(1))

        self.win.noutrefresh()


class BufferViewerCSIColors(BufferViewer):
    def getselected(self):
        return self.selected

    def refresh(self):
        self.win.erase()
        lines = []
        for line in self.lines[self.top:self.top+self.height()]:
            line = line.replace('\t', '    ')
            line = line.ljust(self.width()-1)
            lines.append(line)
        text = '\n'.join(lines)

        t_newline = 1
        t_text = 2
        t_esc = 3
        rx = r'(\n)|([^\x1b\n]+)|(\x1b\[.+?m)'

        y = 0
        x = 0
        attr = curses.color_pair(COL_DEFAULT)

        def write(s, a):
            if self.hl and self.cursor - self.top == y and attr!=curses.color_pair(COL_BLACK)|curses.A_BOLD:
                a = a | curses.A_REVERSE
            if x + len(s) >= self.width():
                s = s[:self.width()-x-1]
            self.win.addstr(y, x, s, a)

        for tok in re.finditer(rx, text):
            typ = tok.lastindex
            val = tok.group()
            if typ == t_text:
                write(val, attr)
                x += len(val)
            elif typ == t_newline:
                y += 1
                x = 0
            elif typ == t_esc:
                assert(val[-1] == 'm')
                for num in val[2:-1].split(';'):
                    num = int(num)
                    if num == 0:
                        attr = curses.color_pair(COL_DEFAULT)
                    elif num == 30:
                        attr = curses.color_pair(COL_BLACK)
                    elif num == 31:
                        attr = curses.color_pair(COL_RED)
                    elif num == 32:
                        attr = curses.color_pair(COL_GREEN)
                    elif num == 1:
                        attr |= curses.A_BOLD
                    else:
                        raise Exception("unsupported CSI color")
            else:
                raise Exception("invalid token %d" % typ)
        self.win.noutrefresh()

class DiffViewer:
    def __init__(self, win):
        """Special BufferViewer wrapper for PDML diffs

        Keyword arguments:
        win -- curses window object to display on
        """
        self.buf = BufferViewerCSIColors(win, '')
        self.diff = None

    def diff_packets(self, pkt_a, pkt_b):
        ign = None
        if self.diff:
            ign = self.diff.ignored_fields
        self.diff = smbcmp.PDMLDiff(pkt_a, pkt_b, ign)
        self.out = self.diff.smb_diff()
        self.buf.set_data(self.out.get_text())

    def toggle_ignore_field(self):
        field = self.out.get_item_at_line(self.buf.cursor)
        self.diff.toggle_ignore_field(field.name)
        self.out = self.diff.smb_diff()
        self.buf.set_data(self.out.get_text())

    def move(self, direction):
        self.buf.move(direction)

    def refresh(self):
        self.buf.refresh()


class TraceViewer(object):
    """Browse packet names of a capture file with a cursor

    Keyword arguments:
    win -- curses windows object to display on
    cap -- name of the capture file
    """
    def __init__(self, win, cap):
        self.buf = BufferViewer(win, '')
        self.buf.set_rxhl([
            (r'''(Error)''', curses.color_pair(1) | curses.A_BOLD),
        ])
        self.set_cap(cap)

    def set_cap(self, cap):
        self.cap = cap
        self.pkts = smbcmp.smb_summaries(cap)
        self.nos = sorted(self.pkts.keys())
        lines = []
        for no in self.nos:
            lines.append(self.pkts[no])
        self.buf.set_data('\n'.join(lines))

    def is_empty(self):
        return len(self.nos) == 0

    def get_packet(self):
        return (self.cap, self.nos[self.buf.cursor])

    def move(self, direction):
        self.buf.move(direction)

    def refresh(self):
        self.buf.refresh()

def debug():
    curses.nocbreak()
    curses.echo()
    curses.endwin()
    import pdb
    pdb.set_trace()

if __name__ == '__main__':
    main()
