#!/usr/bin/env python
# -*- coding: utf-8 -*-

""" Verify a v2 format migration stream """

import sys
import struct
import os, os.path
import syslog
import traceback

from xen.streamv2.format import *

fin = None             # Input file/fd
log_to_syslog = False  # Boolean - Log to syslog instead of stdout/err?
verbose = False        # Boolean - Summarise stream contents
quiet = False          # Boolean - Suppress error printing

def info(msg):
    """Info message, routed to appropriate destination"""
    if not quiet and verbose:
        if log_to_syslog:
            for line in msg.split("\n"):
                syslog.syslog(syslog.LOG_INFO, line)
        else:
            print msg

def err(msg):
    """Error message, routed to appropriate destination"""
    if not quiet:
        if log_to_syslog:
            for line in msg.split("\n"):
                syslog.syslog(syslog.LOG_ERR, line)
        print >> sys.stderr, msg

def stream_read(_ = None):
    """Read from input"""
    return fin.read(_)

def rdexact(nr_bytes):
    """Read exactly nr_bytes from fin"""
    _ = stream_read(nr_bytes)
    if len(_) != nr_bytes:
        raise IOError("Stream truncated")
    return _

def unpack_exact(fmt):
    """Unpack a format from fin"""
    sz = struct.calcsize(fmt)
    return struct.unpack(fmt, rdexact(sz))

class StreamError(StandardError):
    """Error with the stream"""
    pass

class RecordError(StandardError):
    """Error with a record in the stream"""
    pass

def skip_xl_header():
    """Skip over an xl header in the stream"""

    hdr = rdexact(32)
    if hdr != "Xen saved domain, xl format\n \0 \r":
        raise StreamError("No xl header")

    _, _, _, optlen = unpack_exact("=IIII")
    _ = rdexact(optlen)

    info("Skipped xl header")

def verify_ihdr():
    """ Verify an image header """

    marker, ident, version, options, res1, res2 = unpack_exact(IHDR_FORMAT)

    if marker != IHDR_MARKER:
        raise StreamError("Bad image marker: Expected 0x%x, got 0x%x"
                          % (IHDR_MARKER, marker))

    if ident != IHDR_IDENT:
        raise StreamError("Bad image id: Expected 0x%x, got 0x%x"
                          % (IHDR_IDENT, ident))

    if version != 2:
        raise StreamError("Unknown image version: Expected 2, got %d"
                          % (version, ))

    if options & IHDR_OPT_RESZ_MASK:
        raise StreamError("Reserved bits set in image options field: 0x%x"
                          % (options & IHDR_OPT_RESZ_MASK))

    if res1 != 0 or res2 != 0:
        raise StreamError("Reserved bits set in image header: 0x%04x:0x%08x"
                          % (res1, res2))

    if ( sys.byteorder == "little" and
         (options & IHDR_OPT_ENDIAN_) != IHDR_OPT_LE ):
        raise StreamError("Stream is not native endianess - unable to validate")

    if options & IHDR_OPT_BE:
        info("Image Header: big endian")
    else:
        info("Image Header: little endian")

def verify_dhdr():
    """ Verify a domain header """

    gtype, page_shift, res1, major, minor = unpack_exact(DHDR_FORMAT)

    if gtype not in dhdr_type_to_str:
        raise StreamError("Unrecognised domain type 0x%x" % (gtype, ))

    if res1 != 0:
        raise StreamError("Reserved bits set in domain header 0x%04x"
                          % (res1, ))

    if page_shift != 12:
        raise StreamError("Page shift expected to be 12.  Got %d"
                          % (page_shift, ))

    if major == 0:
        info("Domain Header: legacy converted %s"
             % (dhdr_type_to_str[gtype], ))
    else:
        info("Domain Header: %s from Xen %d.%d"
             % (dhdr_type_to_str[gtype], major, minor))


def verify_record_end(content):
    """ End record """

    if len(content) != 0:
        raise RecordError("End record with non-zero length")

def verify_page_data(content):
    """ Page Data record """
    minsz = struct.calcsize(PAGE_DATA_FORMAT)

    if len(content) <= minsz:
        raise RecordError("PAGE_DATA record must be at least %d bytes long"
                          % (minsz, ))

    count, res1 = struct.unpack(PAGE_DATA_FORMAT, content[:minsz])

    if res1 != 0:
        raise StreamError("Reserved bits set in PAGE_DATA record 0x%04x"
                          % (res1, ))

    pfnsz = count * 8
    if (len(content) - minsz) < pfnsz:
        raise RecordError("PAGE_DATA record must contain a pfn record for "
                          "each count")

    pfns = list(struct.unpack("=%dQ" % (count,), content[minsz:minsz + pfnsz]))

    nr_pages = 0
    for idx, pfn in enumerate(pfns):

        if pfn & PAGE_DATA_PFN_RESZ_MASK:
            raise RecordError("Reserved bits set in pfn[%d]: 0x%016x",
                              idx, pfn & PAGE_DATA_PFN_RESZ_MASK)

        if pfn >> PAGE_DATA_TYPE_SHIFT in (5, 6, 7, 8):
            raise RecordError("Invalid type value in pfn[%d]: 0x%016x",
                              idx, pfn & PAGE_DATA_TYPE_LTAB_MASK)

        # We expect page data for each normal page or pagetable
        if PAGE_DATA_TYPE_NOTAB <= (pfn & PAGE_DATA_TYPE_LTABTYPE_MASK) \
                <= PAGE_DATA_TYPE_L4TAB:
            nr_pages += 1

    pagesz = nr_pages * 4096
    if len(content) != minsz + pfnsz + pagesz:
        raise RecordError("Expected %u + %u + %u, got %u"
                          % (minsz, pfnsz, pagesz, len(content)))


def verify_record_x86_pv_vcpu_generic(content, name):
    """ Generic for all REC_TYPE_x86_pv_vcpu_{basic,extended,xsave,msrs} """
    minsz = struct.calcsize(X86_PV_VCPU_HDR_FORMAT)

    if len(content) <= minsz:
        raise RecordError("X86_PV_VCPU_%s record length must be at least %d"
                          " bytes long" % (name, minsz))

    vcpuid, res1 = struct.unpack(X86_PV_VCPU_HDR_FORMAT, content[:minsz])

    if res1 != 0:
        raise StreamError("Reserved bits set in x86_pv_vcpu_%s record 0x%04x"
                          % (name, res1))

    info("  vcpu%d %s context, %d bytes" % (vcpuid, name, len(content) - minsz))


def verify_x86_pv_info(content):
    """ x86 PV Info record """

    expectedsz = struct.calcsize(X86_PV_INFO_FORMAT)
    if len(content) != expectedsz:
        raise RecordError("x86_pv_info: expected length of %d, got %d"
                          % (expectedsz, len(content)))

    width, levels, res1, res2 = struct.unpack(X86_PV_INFO_FORMAT, content)

    if width not in (4, 8):
        raise RecordError("Expected width of 4 or 8, got %d" % (width, ))

    if levels not in (3, 4):
        raise RecordError("Expected levels of 3 or 4, got %d" % (levels, ))

    if res1 != 0 or res2 != 0:
        raise StreamError("Reserved bits set in X86_PV_INFO: 0x%04x 0x%08x"
                          % (res1, res2))

    bitness = {4:32, 8:64}[width]

    info("  %sbit guest, %d levels of pagetables" % (bitness, levels))

def verify_x86_pv_p2m_frames(content):
    """ x86 PV p2m frames record """

    if len(content) % 8 != 0:
        raise RecordError("Length expected to be a multiple of 8, not %d"
                          % (len(content), ))

    start, end = struct.unpack("=II", content[:8])

    info("  Start pfn 0x%x, End 0x%x" % (start, end))

def verify_record_shared_info(content):
    """ shared info record """

    if len(content) != 4096:
        raise RecordError("Length expected to be 4906 bytes, not %d"
                          % (len(content), ))

def verify_record_tsc_info(content):
    """ tsc info record """

    sz = struct.calcsize(TSC_INFO_FORMAT)

    if len(content) != sz:
        raise RecordError("Length should be %u bytes" % (sz, ))

    mode, khz, nsec, incarn, res1 = struct.unpack(TSC_INFO_FORMAT, content)

    if res1 != 0:
        raise StreamError("Reserved bits set in TSC_INFO: 0x%08x" % (res1, ))

    info("  Mode %u, %u kHz, %u ns, incarnation %d" % (mode, khz, nsec, incarn))

def verify_record_hvm_context(content):
    """ hvm context record """

    if len(content) == 0:
        raise RecordError("Zero length HVM context")

def verify_record_hvm_params(content):
    """ hvm params record """

    sz = struct.calcsize(HVM_PARAMS_FORMAT)

    if len(content) < sz:
        raise RecordError("Length should be at least %u bytes" % (sz, ))

    count, rsvd = struct.unpack(HVM_PARAMS_FORMAT, content[:sz])

    if rsvd != 0:
        raise RecordError("Reserved field not zero (0x%04x)" % (rsvd, ))

    sz += count * struct.calcsize(HVM_PARAMS_ENTRY_FORMAT)

    if len(content) != sz:
        raise RecordError("Length should be %u bytes" % (sz, ))

def verify_toolstack(_):
    """ toolstack record """
    info("  TODO: remove")

def verify_record_verify(content):
    """ verify record """

    if len(content) != 0:
        raise RecordError("Verify record with non-zero length")

record_verifiers = {
    REC_TYPE_end : verify_record_end,
    REC_TYPE_page_data : verify_page_data,

    REC_TYPE_x86_pv_info: verify_x86_pv_info,
    REC_TYPE_x86_pv_p2m_frames: verify_x86_pv_p2m_frames,

    REC_TYPE_x86_pv_vcpu_basic :
        lambda x: verify_record_x86_pv_vcpu_generic(x, "basic"),
    REC_TYPE_x86_pv_vcpu_extended :
        lambda x: verify_record_x86_pv_vcpu_generic(x, "extended"),
    REC_TYPE_x86_pv_vcpu_xsave :
        lambda x: verify_record_x86_pv_vcpu_generic(x, "xsave"),
    REC_TYPE_x86_pv_vcpu_msrs :
        lambda x: verify_record_x86_pv_vcpu_generic(x, "msrs"),

    REC_TYPE_shared_info: verify_record_shared_info,
    REC_TYPE_tsc_info: verify_record_tsc_info,

    REC_TYPE_hvm_context: verify_record_hvm_context,
    REC_TYPE_hvm_params: verify_record_hvm_params,
    REC_TYPE_toolstack: verify_toolstack,
    REC_TYPE_verify: verify_record_verify,
}

squahsed_page_data_records = 0
def verify_record():
    """ Verify a record """
    global squahsed_page_data_records

    rtype, length = unpack_exact(RH_FORMAT)

    if rtype not in rec_type_to_str:
        raise StreamError("Unrecognised record type %x" % (rtype, ))

    contentsz = (length + 7) & ~7
    content = rdexact(contentsz)

    padding = content[length:]
    if padding != "\x00" * len(padding):
        raise StreamError("Padding containing non0 bytes found")

    if rtype != REC_TYPE_page_data:

        if squahsed_page_data_records > 0:
            info("Squashed %d Page Data records together"
                 % (squahsed_page_data_records, ))
            squahsed_page_data_records = 0

        info("Record: %s, length %d" % (rec_type_to_str[rtype], length))

    else:
        squahsed_page_data_records += 1

    if rtype not in record_verifiers:
        raise RuntimeError("No verification function")
    else:
        record_verifiers[rtype](content[:length])

    return rtype

def read_stream():
    """ Read an entire stream """

    try:
        verify_ihdr()
        verify_dhdr()

        while verify_record() != REC_TYPE_end:
            pass

    except (IOError, StreamError, RecordError):
        err("Stream Error:")
        err(traceback.format_exc())
        return 1

    except StandardError:
        err("Script Error:")
        err(traceback.format_exc())
        err("Please fix me")
        return 2

    return 0

def open_file_or_fd(val, mode, buffering):
    """
    If 'val' looks like a decimal integer, open it as an fd.  If not, try to
    open it as a regular file.
    """

    fd = -1
    try:
        # Does it look like an integer?
        try:
            fd = int(val, 10)
        except ValueError:
            pass

        # Try to open it...
        if fd != -1:
            return os.fdopen(fd, mode, buffering)
        else:
            return open(val, mode, buffering)

    except StandardError, e:
        if fd != -1:
            err("Unable to open fd %d: %s: %s" %
                (fd, e.__class__.__name__, e))
        else:
            err("Unable to open file '%s': %s: %s" %
                (val, __class__.__name__, e))

    raise SystemExit(2)

def main(argv):
    from optparse import OptionParser
    global fin, quiet, verbose

    # Change stdout to be line-buffered.
    sys.stdout = os.fdopen(sys.stdout.fileno(), 'w', 1)

    parser = OptionParser(usage = "%prog [options]",
                          description =
                          "Verify a stream according to the v2 spec")

    # Optional options
    parser.add_option("-i", "--in", dest = "fin", metavar = "<FD or FILE>",
                      default = "0",
                      help = "v2 format stream to verify (defaults to stdin)")
    parser.add_option("-v", "--verbose", action = "store_true", default = False,
                      help = "Summarise stream contents")
    parser.add_option("-q", "--quiet", action = "store_true", default = False,
                      help = "Suppress all logging/errors")
    parser.add_option("-x", "--xl", action = "store_true", default = False,
                      help = ("Is an `xl` header present in the stream?"
                              " (default no)"))
    parser.add_option("--syslog", action = "store_true", default = False,
                      help = "Log to syslog instead of stdout")

    opts, _ = parser.parse_args()

    if opts.syslog:
        global log_to_syslog

        syslog.openlog("verify-stream-v2", syslog.LOG_PID)
        log_to_syslog = True

    verbose = opts.verbose
    quiet = opts.quiet
    fin = open_file_or_fd(opts.fin, "rb", 0)

    if opts.xl:
        skip_xl_header()

    return read_stream()

if __name__ == "__main__":
    try:
        sys.exit(main(sys.argv))
    except SystemExit, e:
        sys.exit(e.code)
    except KeyboardInterrupt:
        sys.exit(2)
