RotoGradNorm

class rotograd.RotoGradNorm(backbone, heads, rotation_shape, *args, alpha, post_shape=(), normalize_losses=False, burn_in_period=20)[source]

Bases: RotoGrad

Implementation of RotoGrad as described in the original paper, [1] combined with GradNorm [2] to homogeneize both the direction and magnitude of the task gradients.

Parameters:
  • backbone (Module) – Shared module.

  • heads (Sequence[Module]) – Task-specific modules.

  • rotation_shape (Union[int, Size]) – Shape of the shared representation to be rotated which, usually, is just the size of the backbone’s output. Passing a shape is useful, for example, if you want to rotate an image with shape width x height.

  • alpha (float) – \(\alpha\) hyper-parameter as described in GradNorm, [2] used to compute the reference direction.

  • post_shape (optional, default=()) – Shape of the shared representation following the part to be rotated (if any). This part will be kept as it is. This is useful, for example, if you want to rotate only the channels of an image.

  • burn_in_period (optional, default=20) – When back-propagating towards the shared parameters, each task loss is normalized dividing by its initial value, \({L_k(t)}/{L_k(t_0 = 0)}\). This parameter sets a number of iterations after which the denominator will be replaced by the value of the loss at that iteration, that is, \(t_0 = burn\_in\_period\). This is done to overcome problems with losses quickly changing in the first iterations.

  • normalize_losses (optional, default=False) – Whether to use this normalized losses to back-propagate through the task-specific parameters as well.

  • TODO

num_tasks

Number of tasks/heads of the module.

backbone

Shared module.

heads

Sequence with the (rotated) task-specific heads.

rep

Current output of the backbone (after calling forward during training).

References

forward(x)

Forwards the input x through the backbone and all heads, returning a list with all the task predictions. It can be thought as something similar to:

preds = []
z = backbone(x)
for R_i, head in zip(rotations, heads):
    z_i = rotate(R_i, z)
    preds.append(head(z_i))
return preds
Return type:

Sequence[Any]

backward(losses, backbone_loss=None, **kwargs)

Computes the backward computations for the entire model (that is, shared and specific modules). It also computes the gradients for the rotation matrices.

Parameters:
  • losses (Sequence[Tensor]) – Sequence of the task losses from which back-propagate.

  • backbone_loss – Loss exclusive for the backbone (for example, a regularization term).

Return type:

None

property rotation: Sequence[Tensor]

List of rotations matrices, one per task. These are trainable, make sure to call detach().

property weight: Sequence[Tensor]

List of task weights, one per task. These are trainable, make sure to call detach().