diff --git a/progit/hooks/__init__.py b/progit/hooks/__init__.py index 69be6b0..e21b56b 100644 --- a/progit/hooks/__init__.py +++ b/progit/hooks/__init__.py @@ -20,17 +20,20 @@ class RequiredIf(wtforms.validators.Required): has a value. """ - def __init__(self, other_field_name, *args, **kwargs): - self.other_field_name = other_field_name + def __init__(self, fields, *args, **kwargs): + if isinstance(fields, basestring): + fields = [fields] + self.fields = fields super(RequiredIf, self).__init__(*args, **kwargs) def __call__(self, form, field): - other_field = form._fields.get(self.other_field_name) - if other_field is None: - raise Exception( - 'no field named "%s" in form' % self.other_field_name) - if bool(other_field.data): - super(RequiredIf, self).__call__(form, field) + for fieldname in self.fields: + nfield = form._fields.get(fieldname) + if nfield is None: + raise Exception( + 'no field named "%s" in form' % fieldname) + if bool(nfield.data): + super(RequiredIf, self).__call__(form, field) class BaseHook(object):