Blob Blame Raw
# coding=utf-8
"""
 (c) 2015-2018 - Copyright Red Hat Inc

 Authors:
   Slavek Kabrda <bkabrda@redhat.com>

"""

from __future__ import unicode_literals, absolute_import

import os
import sys

from mock import patch

sys.path.insert(
    0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")
)

import pagure.hooks
import pagure.lib.plugins
import tests


class EnabledForAll(pagure.hooks.BaseHook):
    name = "EnabledForAll"

    @classmethod
    def is_enabled_for(cls, project):
        return True


class DisabledForAll(pagure.hooks.BaseHook):
    name = "DisabledForAll"
    # disabled for all is the default


class PagureLibtests_plugins(tests.Modeltests):
    """
    Test the pagure.lib.plugins module
    """

    maxDiff = None

    @patch("pagure.lib.plugins.load")
    def test_plugin_is_enabled_for(self, load):
        """ Test the is_enabled_for method of plugins is properly
        handled by pagure.lib.plugins.get_enabled_plugins.
        """
        tests.create_projects(self.session)
        project = pagure.lib.query._get_project(self.session, "test")

        load.return_value = [EnabledForAll]
        self.assertEqual(
            pagure.lib.plugins.get_enabled_plugins(project),
            [(EnabledForAll, None)],
        )

        load.return_value = [DisabledForAll]
        self.assertEqual(pagure.lib.plugins.get_enabled_plugins(project), [])

    @patch("pagure.lib.plugins.load")
    def test_get_plugin_names(self, load):
        """ Test the get_plugin_names method with plugins that don't
        have backref.
        """
        load.return_value = [EnabledForAll]
        self.assertEqual(pagure.lib.plugins.get_plugin_names(), [])
        self.assertEqual(
            pagure.lib.plugins.get_plugin_names(without_backref=True),
            ["EnabledForAll"],
        )