RotoGrad

class rotograd.RotoGrad(backbone, heads, rotation_shape, *args, post_shape=(), normalize_losses=False, burn_in_period=20)[source]

Bases: RotateOnly

Implementation of RotoGrad as described in the original paper. [1]

Parameters:
  • backbone (Module) – Shared module.

  • heads (Sequence[Module]) – Task-specific modules.

  • rotation_shape (Union[int, Size]) – Shape of the shared representation to be rotated which, usually, is just the size of the backbone’s output. Passing a shape is useful, for example, if you want to rotate an image with shape width x height.

  • post_shape (optional, default=()) – Shape of the shared representation following the part to be rotated (if any). This part will be kept as it is. This is useful, for example, if you want to rotate only the channels of an image.

  • burn_in_period (optional, default=20) – When back-propagating towards the shared parameters, each task loss is normalized dividing by its initial value, \({L_k(t)}/{L_k(t_0 = 0)}\). This parameter sets a number of iterations after which the denominator will be replaced by the value of the loss at that iteration, that is, \(t_0 = burn\_in\_period\). This is done to overcome problems with losses quickly changing in the first iterations.

  • normalize_losses (optional, default=False) – Whether to use this normalized losses to back-propagate through the task-specific parameters as well.

num_tasks

Number of tasks/heads of the module.

backbone

Shared module.

heads

Sequence with the (rotated) task-specific heads.

rep

Current output of the backbone (aft1er calling forward during training).

References

forward(x)

Forwards the input x through the backbone and all heads, returning a list with all the task predictions. It can be thought as something similar to:

preds = []
z = backbone(x)
for R_i, head in zip(rotations, heads):
    z_i = rotate(R_i, z)
    preds.append(head(z_i))
return preds
Return type:

Sequence[Any]

backward(losses, backbone_loss=None, **kwargs)

Computes the backward computations for the entire model (that is, shared and specific modules). It also computes the gradients for the rotation matrices.

Parameters:
  • losses (Sequence[Tensor]) – Sequence of the task losses from which back-propagate.

  • backbone_loss – Loss exclusive for the backbone (for example, a regularization term).

Return type:

None

property rotation: Sequence[Tensor]

List of rotations matrices, one per task. These are trainable, make sure to call detach().