diff --git a/pagure/api/__init__.py b/pagure/api/__init__.py index fbadc3e..a3137c5 100644 --- a/pagure/api/__init__.py +++ b/pagure/api/__init__.py @@ -113,6 +113,10 @@ def get_authorized_api_project(session, repo, user=None, namespace=None): return repo +def get_request_data(): + return flask.request.form or flask.request.get_json() or {} + + def check_api_acls(acls, optional=False): ''' Checks if the user provided an API token with its request and if this token allows the user to access the endpoint desired. diff --git a/pagure/api/fork.py b/pagure/api/fork.py index 12ccc11..fbda5ef 100644 --- a/pagure/api/fork.py +++ b/pagure/api/fork.py @@ -22,7 +22,7 @@ import pagure.exceptions import pagure.lib import pagure.lib.tasks from pagure.api import (API, api_method, api_login_required, APIERROR, - get_authorized_api_project) + get_authorized_api_project, get_request_data) from pagure.config import config as pagure_config from pagure.utils import is_repo_committer, is_true @@ -371,7 +371,7 @@ def api_pull_request_merge(repo, requestid, username=None, namespace=None): output = {'message': 'Merging queued', 'taskid': task.id} - if flask.request.form.get('wait', True): + if get_request_data().get('wait', True): task.get() output = {'message': 'Changes merged!'} except pagure.exceptions.PagureException as err: @@ -729,7 +729,7 @@ def api_pull_request_add_flag(repo, requestid, username=None, namespace=None): if not request: raise pagure.exceptions.APIError(404, error_code=APIERROR.ENOREQ) - if 'status' in flask.request.form: + if 'status' in get_request_data(): form = pagure.forms.AddPullRequestFlagForm(csrf_enabled=False) else: form = pagure.forms.AddPullRequestFlagFormV1(csrf_enabled=False) @@ -739,7 +739,7 @@ def api_pull_request_add_flag(repo, requestid, username=None, namespace=None): comment = form.comment.data.strip() url = form.url.data.strip() uid = form.uid.data.strip() if form.uid.data else None - if 'status' in flask.request.form: + if 'status' in get_request_data(): status = form.status.data.strip() else: if percent is None: @@ -1001,12 +1001,12 @@ def api_pull_request_create(repo, username=None, namespace=None): if not form.validate_on_submit(): raise pagure.exceptions.APIError( 400, error_code=APIERROR.EINVALIDREQ, errors=form.errors) - branch_to = flask.request.form.get('branch_to') + branch_to = get_request_data().get('branch_to') if not branch_to: raise pagure.exceptions.APIError( 400, error_code=APIERROR.EINVALIDREQ, errors={'branch_to': ['This field is required.']}) - branch_from = flask.request.form.get('branch_from') + branch_from = get_request_data().get('branch_from') if not branch_from: raise pagure.exceptions.APIError( 400, error_code=APIERROR.EINVALIDREQ, diff --git a/pagure/api/issue.py b/pagure/api/issue.py index 4134a8b..14b2f57 100644 --- a/pagure/api/issue.py +++ b/pagure/api/issue.py @@ -21,7 +21,7 @@ import pagure.exceptions import pagure.lib from pagure.api import ( API, api_method, api_login_required, api_login_optional, APIERROR, - get_authorized_api_project + get_authorized_api_project, get_request_data ) from pagure.config import config as pagure_config from pagure.utils import ( @@ -264,11 +264,11 @@ def api_new_issue(repo, username=None, namespace=None): milestone = form.milestone.data or None private = is_true(form.private.data) priority = form.priority.data or None - assignee = flask.request.form.get( + assignee = get_request_data().get( 'assignee', '').strip() or None tags = [ tag.strip() - for tag in flask.request.form.get( + for tag in get_request_data().get( 'tag', '').split(',') if tag.strip()] @@ -1238,7 +1238,7 @@ def api_update_custom_field( 400, error_code=APIERROR.EINVALIDISSUEFIELD) key = fields[field] - value = flask.request.form.get('value') + value = get_request_data().get('value') if value: _check_link_custom_field(key, value) try: @@ -1348,7 +1348,7 @@ def api_update_custom_fields( issue = _get_issue(repo, issueid) _check_ticket_access(issue) - fields = flask.request.form + fields = get_request_data() if not fields: raise pagure.exceptions.APIError( diff --git a/pagure/api/project.py b/pagure/api/project.py index 3f15ee7..6aef797 100644 --- a/pagure/api/project.py +++ b/pagure/api/project.py @@ -24,7 +24,8 @@ import pagure.lib import pagure.lib.git import pagure.utils from pagure.api import (API, api_method, APIERROR, api_login_required, - get_authorized_api_project, api_login_optional) + get_authorized_api_project, api_login_optional, + get_request_data) from pagure.config import config as pagure_config @@ -865,7 +866,7 @@ def api_new_project(): output = {'message': 'Project creation queued', 'taskid': task.id} - if flask.request.form.get('wait', True): + if get_request_data().get('wait', True): result = task.get() project = pagure.lib._get_project( flask.g.session, name=result['repo'], @@ -987,7 +988,7 @@ def api_modify_project(repo, namespace=None): args = flask.request.get_json(force=True, silent=True) or {} retain_access = args.get('retain_access', False) else: - args = flask.request.form + args = get_request_data() retain_access = args.get('retain_access', '').lower() in ['true', '1'] if not args: @@ -1119,7 +1120,7 @@ def api_fork_project(): output = {'message': 'Project forking queued', 'taskid': task.id} - if flask.request.form.get('wait', True): + if get_request_data().get('wait', True): task.get() output = {'message': 'Repo "%s" cloned to "%s/%s"' % (repo.fullname, flask.g.fas_user.username, @@ -1207,7 +1208,7 @@ def api_generate_acls(repo, username=None, namespace=None): json = flask.request.get_json(force=True, silent=True) or {} wait = json.get('wait', False) else: - wait = pagure.utils.is_true(flask.request.form.get('wait')) + wait = pagure.utils.is_true(get_request_data().get('wait')) try: task = pagure.lib.git.generate_gitolite_acls( @@ -1291,7 +1292,7 @@ def api_new_branch(repo, username=None, namespace=None): # returned if it's invalid JSON. args = flask.request.get_json(force=True, silent=True) or {} else: - args = flask.request.form + args = get_request_data() branch = args.get('branch') from_branch = args.get('from_branch') diff --git a/tests/test_pagure_flask_api.py b/tests/test_pagure_flask_api.py index e5bf2c2..0111c7d 100644 --- a/tests/test_pagure_flask_api.py +++ b/tests/test_pagure_flask_api.py @@ -24,6 +24,7 @@ from mock import patch sys.path.insert(0, os.path.join(os.path.dirname( os.path.abspath(__file__)), '..')) +import pagure.api import pagure.flask_app import pagure.lib import tests @@ -32,6 +33,18 @@ import tests class PagureFlaskApitests(tests.SimplePagureTest): """ Tests for flask API controller of pagure """ + def test_api_get_request_data(self): + data = {'foo': 'bar'} + # test_request_context doesn't set flask.g, but some teardown + # functions try to use that, so let's exclude them + self._app.teardown_request_funcs = {} + with self._app.test_request_context('/api/0/version', data=data): + self.assertEqual(pagure.api.get_request_data()['foo'], 'bar') + data = json.dumps(data) + with self._app.test_request_context('/api/0/version', data=data, + content_type='application/json'): + self.assertEqual(pagure.api.get_request_data()['foo'], 'bar') + def test_api_version(self): """ Test the api_version function. """