RotoGradNorm
- class rotograd.RotoGradNorm(backbone, heads, rotation_shape, *args, alpha, post_shape=(), normalize_losses=False, burn_in_period=20)[source]
Bases:
RotoGrad
Implementation of RotoGrad as described in the original paper, [1] combined with GradNorm [2] to homogeneize both the direction and magnitude of the task gradients.
- Parameters:
backbone (
Module
) – Shared module.heads (
Sequence
[Module
]) – Task-specific modules.rotation_shape (
Union
[int
,Size
]) – Shape of the shared representation to be rotated which, usually, is just the size of the backbone’s output. Passing a shape is useful, for example, if you want to rotate an image with shape width x height.alpha (
float
) – \(\alpha\) hyper-parameter as described in GradNorm, [2] used to compute the reference direction.post_shape (optional, default=()) – Shape of the shared representation following the part to be rotated (if any). This part will be kept as it is. This is useful, for example, if you want to rotate only the channels of an image.
burn_in_period (optional, default=20) – When back-propagating towards the shared parameters, each task loss is normalized dividing by its initial value, \({L_k(t)}/{L_k(t_0 = 0)}\). This parameter sets a number of iterations after which the denominator will be replaced by the value of the loss at that iteration, that is, \(t_0 = burn\_in\_period\). This is done to overcome problems with losses quickly changing in the first iterations.
normalize_losses (optional, default=False) – Whether to use this normalized losses to back-propagate through the task-specific parameters as well.
TODO –
- num_tasks
Number of tasks/heads of the module.
- backbone
Shared module.
- heads
Sequence with the (rotated) task-specific heads.
- rep
Current output of the backbone (after calling forward during training).
References
- forward(x)
Forwards the input x through the backbone and all heads, returning a list with all the task predictions. It can be thought as something similar to:
preds = [] z = backbone(x) for R_i, head in zip(rotations, heads): z_i = rotate(R_i, z) preds.append(head(z_i)) return preds
- Return type:
Sequence
[Any
]
- backward(losses, backbone_loss=None, **kwargs)
Computes the backward computations for the entire model (that is, shared and specific modules). It also computes the gradients for the rotation matrices.
- Parameters:
losses (
Sequence
[Tensor
]) – Sequence of the task losses from which back-propagate.backbone_loss – Loss exclusive for the backbone (for example, a regularization term).
- Return type:
None
- property rotation: Sequence[Tensor]
List of rotations matrices, one per task. These are trainable, make sure to call detach().
- property weight: Sequence[Tensor]
List of task weights, one per task. These are trainable, make sure to call detach().