#
# Copyright 2016 Red Hat, Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
#
# Refer to the README and COPYING files for full details of the license
#
from __future__ import absolute_import

"""
Code to perform periodic maintenance and bookkeeping of the VMs.
"""

import logging
import threading

import libvirt

from vdsm import containersconnection
from vdsm import executor
from vdsm import host
from vdsm import libvirtconnection
from vdsm.config import config
from vdsm.virt import migration
from vdsm.virt import sampling
from vdsm.virt import virdomain
from vdsm.virt import vmstatus


# Just a made up number. Maybe should be equal to number of cores?
# TODO: make them tunable through private, unsupported configuration items
_WORKERS = config.getint('sampling', 'periodic_workers')
_TASK_PER_WORKER = config.getint('sampling', 'periodic_task_per_worker')
_TASKS = _WORKERS * _TASK_PER_WORKER
_MAX_WORKERS = config.getint('sampling', 'max_workers')


_operations = []
_executor = None


def _timeout_from(interval):
    """
    Estimate a sensible timeout given a periodic interval.
    """
    return interval / 2.


def start(cif, scheduler):
    global _operations
    global _executor

    _executor = executor.Executor(name="periodic",
                                  workers_count=_WORKERS,
                                  max_tasks=_TASKS,
                                  scheduler=scheduler,
                                  max_workers=_MAX_WORKERS)
    _executor.start()

    def per_vm_operation(func, period):
        disp = VmDispatcher(
            cif.getVMs, _executor, func, _timeout_from(period))
        return Operation(disp, period, scheduler)

    _operations = [
        # Needs dispatching because updating the volume stats needs
        # access to the storage, thus can block.
        per_vm_operation(
            UpdateVolumes,
            config.getint('irs', 'vol_size_sample_interval')),

        # Needs dispatching because it accesses FS and libvirt data.
        # Ignored by new engine, has to be kept for BC sake.
        per_vm_operation(
            NumaInfoMonitor,
            config.getint('vars', 'vm_sample_numa_interval')),

        # Job monitoring need QEMU monitor access.
        per_vm_operation(
            BlockjobMonitor,
            config.getint('vars', 'vm_sample_jobs_interval')),

        # libvirt sampling using bulk stats can block, but unresponsive
        # domains are handled inside VMBulkSampler for performance reasons;
        # thus, does not need dispatching.
        Operation(
            sampling.VMBulkSampler(
                libvirtconnection.get(cif),
                cif.getVMs,
                sampling.stats_cache),
            config.getint('vars', 'vm_sample_interval'),
            scheduler),

        # We do this only until we get high water mark notifications
        # from QEMU. It accesses storage and/or QEMU monitor, so can block,
        # thus we need dispatching.
        per_vm_operation(
            DriveWatermarkMonitor,
            config.getint('vars', 'vm_watermark_interval')),

        Operation(
            sampling.HostMonitor(cif=cif),
            config.getint('vars', 'host_sample_stats_interval'),
            scheduler),

        Operation(
            containersconnection.monitor,
            config.getint('vars', 'vm_sample_interval'),
            scheduler),
    ]

    host.stats.start()

    for op in _operations:
        op.start()


def stop():
    for op in _operations:
        op.stop()

    _executor.stop(wait=False)


class Operation(object):
    """
    Operation runs a callable with a given period until
    someone stops it.
    Operation builds on Schedule and on Executor,
    so that the underlying "func" is called periodically.
    It would be called again even if a former call is blocked.
    """

    _log = logging.getLogger("virt.periodic.Operation")

    def __init__(self, func, period, scheduler, timeout=0, executor=None):
        """
        parameters:

        func: callable, without arguments (task interface).
        period: `func' will be invoked every `period' seconds.
                Please note that timing may not be exact due to
                (OS) scheduling constraings.
        timeout: same meaning of Executor.dispatch
        scheduler: Scheduler instance to use
        executor: Executor instance to use
        """
        self._func = func
        self._period = period
        self._timeout = _timeout_from(period) if timeout == 0 else timeout
        self._scheduler = scheduler
        self._executor = _executor if executor is None else executor
        self._lock = threading.Lock()
        self._running = False
        self._call = None

    def start(self):
        with self._lock:
            if self._running:
                raise AssertionError("Operation already running")
            self._log.debug("starting operation %s", self._func)
            self._running = True
            # we do _dispatch instead of _step here to have some
            # data as soon as possibile
            self._dispatch()

    def stop(self):
        with self._lock:
            if self._running:
                self._log.debug("stopping operation %s", self._func)
                self._running = False
                if self._call:
                    self._call.cancel()
                    self._call = None

    def __call__(self):
        try:
            self._func()
        except Exception:
            self._log.exception("%s operation failed", self._func)

    def _step(self):
        """
        Schedule a next call of `func'.
        """
        self._call = self._scheduler.schedule(self._period,
                                              self._try_to_dispatch)

    def _try_to_dispatch(self):
        """
        Dispatch another Execution, if Operation is still running.
        """
        with self._lock:
            if self._running:
                self._dispatch()

    def _dispatch(self):
        """
        Send `func' to Executor to be run as soon as possible.
        """
        self._call = None
        try:
            self._executor.dispatch(self, self._timeout)
        except executor.TooManyTasks:
            self._log.warning('could not run %s, executor queue full',
                              self._func)
        finally:
            self._step()

    def __repr__(self):
        return '<Operation action=%s at 0x%x>' % (
            self._func, id(self)
        )


class VmDispatcher(object):
    """
    Adapter class. Dispatch an Operation to all VMs, to improve
    isolation among them.
    """

    _log = logging.getLogger("virt.periodic.VmDispatcher")

    def __init__(self, get_vms, executor, create, timeout):
        """
        get_vms: callable which will return a dict which maps
                 vm_ids to vm_instances
        executor: executor.Executor instance
        create: callable to obtain the real callable to
                dispatch, with its timeout
        timeout: per-vm operation timeout, in seconds
                 (fractions allowed).
        """
        self._get_vms = get_vms
        self._executor = executor
        self._create = create
        self._timeout = timeout

    def __call__(self):
        vms = self._get_vms()
        skipped = []

        for vm_id, vm_obj in vms.iteritems():
            try:
                op = self._create(vm_obj)

                if not op.required:
                    continue
                # When dealing with blocked domains, we also want to avoid
                # to pile up jobs that libvirt can't handle and that will
                # eventually clog it.
                # We don't care too much about precise tracking, so it is
                # still OK if occasional misdetection occurs, but we
                # definitely want to avoid known-bad situation and to
                # needlessly overload libvirt.
                if not op.runnable:
                    skipped.append(vm_id)
                    continue

            except Exception:
                # we want to make sure to have VM UUID logged
                self._log.exception("while dispatching %s", op)
            else:
                try:
                    self._executor.dispatch(op, self._timeout)
                except executor.TooManyTasks:
                    skipped.append(vm_id)

        if skipped:
            self._log.warning('could not run %s on %s',
                              self._create, skipped)
        return skipped  # for testing purposes

    def __repr__(self):
        return '<VmDispatcher operation=%s at 0x%x>' % (
            self._create, id(self)
        )


class _RunnableOnVm(object):
    def __init__(self, vm):
        self._vm = vm

    @property
    def required(self):
        # Disable everything until the migration destination VM
        # is fully started, to avoid false positives log spam.
        return self._vm.monitorable

    @property
    def runnable(self):
        return self._vm.isDomainReadyForCommands()

    def __call__(self):
        migrating = self._vm.isMigrating()
        try:
            self._execute()
        except virdomain.NotConnectedError:
            # race on startup:  no worries, let's retry again next cycle.
            # race on shutdown: next cycle won't pick up this VM.
            # both cases: let's reduce the log spam.
            self._vm.log.warning('could not run on %s: domain not connected',
                                 self._vm.id)
        except libvirt.libvirtError as e:
            if self._vm.post_copy != migration.PostCopyPhase.NONE:
                # race on entering post-copy, VM paused now
                return
            if e.get_error_code() in (
                # race on shutdown/migration completion
                libvirt.VIR_ERR_NO_DOMAIN,
            ):
                # known benign cases: migration in progress or completed
                if migrating or self._vm.lastStatus == vmstatus.DOWN:
                    return
            raise

    def _execute(self):
        raise NotImplementedError

    def __repr__(self):
        return '<%s vm=%s at 0x%x>' % (
            self.__class__.__name__, self._vm.id, id(self)
        )


class UpdateVolumes(_RunnableOnVm):

    @property
    def required(self):
        return (super(UpdateVolumes, self).required and
                # Avoid queries from storage during recovery process
                self._vm.driveMonitorEnabled())

    def _execute(self):
        for drive in self._vm.getDiskDevices():
            # TODO: If this blocks (is it actually possible?)
            # we must make sure we don't overwrite good data
            # with stale old data.
            self._vm.updateDriveVolume(drive)


class NumaInfoMonitor(_RunnableOnVm):

    @property
    def required(self):
        return (super(NumaInfoMonitor, self).required and
                self._vm.hasGuestNumaNode)

    @property
    def runnable(self):
        # NUMA operations don't require QEMU monitor access
        # (inspected libvirt sources v1.2.17)
        return True

    def _execute(self):
        self._vm.updateNumaInfo()


class BlockjobMonitor(_RunnableOnVm):

    @property
    def required(self):
        # For performance reasons, we must avoid as much
        # as possible to create per-vm executor tasks, even
        # though they will do nothing but a few checks and exit
        # early, as they do if a VM doesn't have Block Jobs to
        # monitor (most often true).
        return (super(BlockjobMonitor, self).required and self._vm.hasVmJobs)

    def _execute(self):
        self._vm.updateVmJobs()


class DriveWatermarkMonitor(_RunnableOnVm):

    @property
    def required(self):
        return (super(DriveWatermarkMonitor, self).required and
                self._vm.needsDriveMonitoring())

    def _execute(self):
        self._vm.extendDrivesIfNeeded()
