diff --git a/tests/__init__.py b/tests/__init__.py index 521316b..b0d21ad 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -312,10 +312,11 @@ class Modeltests(unittest.TestCase): shutil.rmtree(self.path) self.path = None - def get_csrf(self, url='/new'): + def get_csrf(self, url='/new', output=None): """Retrieve a CSRF token from given URL.""" - output = self.app.get(url) - self.assertEqual(output.status_code, 200) + if output is None: + output = self.app.get(url) + self.assertEqual(output.status_code, 200) return output.data.split( 'name="csrf_token" type="hidden" value="')[1].split('">')[0] diff --git a/tests/test_pagure_flask_ui_fork.py b/tests/test_pagure_flask_ui_fork.py index fd631df..66ce171 100644 --- a/tests/test_pagure_flask_ui_fork.py +++ b/tests/test_pagure_flask_ui_fork.py @@ -293,8 +293,7 @@ class PagureFlaskForktests(tests.Modeltests): output = self.app.get('/test/pull-request/1') self.assertEqual(output.status_code, 200) - csrf_token = output.data.split( - 'name="csrf_token" type="hidden" value="')[1].split('">')[0] + csrf_token = self.get_csrf(output=output) # No CSRF output = self.app.post( @@ -454,8 +453,7 @@ class PagureFlaskForktests(tests.Modeltests): output = self.app.get('/test/pull-request/1') self.assertEqual(output.status_code, 200) - csrf_token = output.data.split( - 'name="csrf_token" type="hidden" value="')[1].split('">')[0] + csrf_token = self.get_csrf(output=output) data = { 'csrf_token': csrf_token, @@ -491,8 +489,7 @@ class PagureFlaskForktests(tests.Modeltests): output = self.app.get('/test/pull-request/1') self.assertEqual(output.status_code, 200) - csrf_token = output.data.split( - 'name="csrf_token" type="hidden" value="')[1].split('">')[0] + csrf_token = self.get_csrf(output=output) data = { 'csrf_token': csrf_token, @@ -525,8 +522,7 @@ class PagureFlaskForktests(tests.Modeltests): output = self.app.get('/test/pull-request/1') self.assertEqual(output.status_code, 200) - csrf_token = output.data.split( - 'name="csrf_token" type="hidden" value="')[1].split('">')[0] + csrf_token = self.get_csrf(output=output) data = { 'csrf_token': csrf_token, @@ -1132,8 +1128,7 @@ index 0000000..2a552bb output = self.app.get('/test/pull-request/1') self.assertEqual(output.status_code, 200) - csrf_token = output.data.split( - 'name="csrf_token" type="hidden" value="')[1].split('">')[0] + csrf_token = self.get_csrf(output=output) data = { 'csrf_token': csrf_token, @@ -1227,8 +1222,7 @@ index 0000000..2a552bb output = self.app.get('/test/pull-request/1') self.assertEqual(output.status_code, 200) - csrf_token = output.data.split( - 'name="csrf_token" type="hidden" value="')[1].split('">')[0] + csrf_token = self.get_csrf(output=output) data = { 'user': 'pingou', @@ -1344,8 +1338,7 @@ index 0000000..2a552bb self.assertEqual(output.status_code, 200) self.assertIn('Create new Project', output.data) - csrf_token = output.data.split( - 'name="csrf_token" type="hidden" value="')[1].split('">')[0] + csrf_token = self.get_csrf(output=output) data = { 'csrf_token': csrf_token, @@ -1424,8 +1417,7 @@ index 0000000..2a552bb '', output.data) - csrf_token = output.data.split( - 'name="csrf_token" type="hidden" value="')[1].split('">')[0] + csrf_token = self.get_csrf(output=output) # Case 1 - Add an initial comment data = { @@ -1492,8 +1484,7 @@ index 0000000..2a552bb '', output.data) - csrf_token = output.data.split( - 'name="csrf_token" type="hidden" value="')[1].split('">')[0] + csrf_token = self.get_csrf(output=output) # Case 1 - Add an initial comment data = { @@ -1550,8 +1541,7 @@ index 0000000..2a552bb 'pull request with
', output.data) output = self.app.get('/test/new_issue') - csrf_token = output.data.split( - 'name="csrf_token" type="hidden" value="')[1].split('">')[0] + csrf_token = self.get_csrf(output=output) data = { 'csrf_token': csrf_token, @@ -1623,8 +1613,7 @@ index 0000000..2a552bb self.assertTrue( output.data.startswith('\n
This look alright but we can do better
', output.data) - csrf_token = output.data.split( - 'name="csrf_token" type="hidden" value="')[1].split('">')[0] + csrf_token = self.get_csrf(output=output) # Invalid comment id data = { @@ -1759,8 +1747,7 @@ index 0000000..2a552bb self.assertTrue( output.data.startswith('\nSettings for test
" in output.data) - csrf_token = output.data.split( - 'name="csrf_token" type="hidden" value="')[1].split('">')[0] + csrf_token = self.get_csrf(output=output) data = {'tag': 'tag1'} @@ -2357,8 +2346,7 @@ class PagureFlaskIssuestests(tests.Modeltests): '