model_wrappers module¶
Adversarial Training¶
Model wrapper classes for improving and monitoring fairness during training using adversaries.
-
class
model_wrappers.FairModel(model, output_size, n_groups, n_hidden, layer_width, activation=<function relu>)¶ Wrapper class for models for adversarial training to increase fairness.
modelis a pretrained model for which the fairness must increase.input_sizecontrols the expected vector input dimension.n_groupsis the number of unique values of the sensitive feature.n_hiddenis the number of hidden layers in the adversarial network.layer_widthis the number of units in each hidden layer.activationis the activation function used for the hidden layers of the adversarial network.
Examples:
# input, output, and protected/sensitive feature (e.g. race, gender, etc.) X = torch.randn(20, 3) X[:10] = 10 * X[:10] y = torch.sum(torch.Tensor([[1.0, 0.5, 3.14]]) * X, axis=1, keepdims=True) + 1.41 protected = torch.zeros(20, dtype=torch.long) protected[:10] = 0 protected[10:] = 1 # define and pretrain model model = Model() model.fit(X, y) # wrap pretrained model in FairModel adversary and train fm = FairModel(model, output_size=1, n_groups=2, n_hidden=1, layer_width=10) fm.fit(X, y, protected, 0.5, steps=1000) model_predictions, adversary_predictions = fm(X) m = nn.Softmax(dim=1) adversary_predictions = m(adversary_predictions)
-
fit(x, y, groups, eta, model_loss=<class 'nn.MSELoss'>, adversary_loss=<class 'nn.CrossEntropyLoss'>, optimizer=<class 'optim.Adam'>, steps=100, lr=0.001, verbose=True, grapher=None)¶ Pre-trains the adversarial network then simultaneously trains the wrapped model and the adversarial network.
xis the input datayis the true labelgroupsis the group/protected/sensitive attribute for each input sampleetais a weighting constant for adversarial trainingmodel_lossis the loss function used for the non-adversarial modeladversary_lossis the loss function used for the adversarial modeloptimizeris the uninitialized optimizer classstepsis the number of steps used for pre-training the adversary and for the model/adversary simultaneous traininglris the learning rate for the trainingverboseTrue if output during training desired, False otherwisegrapheris a function for graphing if verbose is True
-
forward(x)¶ Forward pass of x through the model and adversarial network.
xis the input data with shape suitable formodel
Returns: model_prediction, adversary_predictionmodel_predictionwith shapemodel(x).shapeadversary_predictionas logits