model_wrappers module¶
Adversarial Training¶
Model wrapper classes for improving and monitoring fairness during training using adversaries.
-
class
model_wrappers.
FairModel
(model, output_size, n_groups, n_hidden, layer_width, activation=<function relu>)¶ Wrapper class for models for adversarial training to increase fairness.
model
is a pretrained model for which the fairness must increase.input_size
controls the expected vector input dimension.n_groups
is the number of unique values of the sensitive feature.n_hidden
is the number of hidden layers in the adversarial network.layer_width
is the number of units in each hidden layer.activation
is the activation function used for the hidden layers of the adversarial network.
Examples:
# input, output, and protected/sensitive feature (e.g. race, gender, etc.) X = torch.randn(20, 3) X[:10] = 10 * X[:10] y = torch.sum(torch.Tensor([[1.0, 0.5, 3.14]]) * X, axis=1, keepdims=True) + 1.41 protected = torch.zeros(20, dtype=torch.long) protected[:10] = 0 protected[10:] = 1 # define and pretrain model model = Model() model.fit(X, y) # wrap pretrained model in FairModel adversary and train fm = FairModel(model, output_size=1, n_groups=2, n_hidden=1, layer_width=10) fm.fit(X, y, protected, 0.5, steps=1000) model_predictions, adversary_predictions = fm(X) m = nn.Softmax(dim=1) adversary_predictions = m(adversary_predictions)
-
fit
(x, y, groups, eta, model_loss=<class 'nn.MSELoss'>, adversary_loss=<class 'nn.CrossEntropyLoss'>, optimizer=<class 'optim.Adam'>, steps=100, lr=0.001, verbose=True, grapher=None)¶ Pre-trains the adversarial network then simultaneously trains the wrapped model and the adversarial network.
x
is the input datay
is the true labelgroups
is the group/protected/sensitive attribute for each input sampleeta
is a weighting constant for adversarial trainingmodel_loss
is the loss function used for the non-adversarial modeladversary_loss
is the loss function used for the adversarial modeloptimizer
is the uninitialized optimizer classsteps
is the number of steps used for pre-training the adversary and for the model/adversary simultaneous traininglr
is the learning rate for the trainingverbose
True if output during training desired, False otherwisegrapher
is a function for graphing if verbose is True
-
forward
(x)¶ Forward pass of x through the model and adversarial network.
x
is the input data with shape suitable formodel
Returns: model_prediction, adversary_prediction
model_prediction
with shapemodel(x).shape
adversary_prediction
as logits