Skip to content

feat(aggregation): Add UW#721

Open
ppraneth wants to merge 5 commits into
SimplexLab:mainfrom
ppraneth:scalarization-4
Open

feat(aggregation): Add UW#721
ppraneth wants to merge 5 commits into
SimplexLab:mainfrom
ppraneth:scalarization-4

Conversation

@ppraneth
Copy link
Copy Markdown
Contributor

@ppraneth ppraneth commented Jun 1, 2026

Adds UW, the uncertainty weighting scalarizer from Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics (Kendall et al., CVPR 2018). This is the first stateful, trainable scalarizer, so the PR also moves the Stateful mixin to a shared location so both the aggregation and scalarization packages can use it.

UW

Each value $L_i$ (typically a per-task loss) is assigned a learnable log-variance $s_i = \log \sigma_i^2$, and the values are combined as:

$$\sum_i \left( \frac{1}{2} e^{-s_i} L_i + \frac{1}{2} s_i \right)$$

This is the regression objective (eq. 7 in the paper) after substituting $s_i = \log \sigma_i^2$, which matches the LibMTL implementation.

Following the paper, the model learns the log-variance $s_i$ rather than the variance $\sigma_i^2$ directly. This is numerically more stable (the combination never divides by zero) and keeps $s_i$ unconstrained, since $e^{-s_i}$ is always positive. The $s_i$ are stored as an nn.Parameter, so the scalarizer's parameters must be passed to the optimizer to be learned jointly with the model.

Design notes:

  • shape is given at construction (UW(3) or UW((2, 3))), since the parameter has to exist before the optimizer is built. The shape is validated against the input at call time, like Constant.
  • Log-variances are initialized to 0 (so $\sigma_i^2 = 1$, uniform weights at the start). The paper reports the result is robust to this initialization. LibMTL uses -0.5 instead.
  • Implements reset() (from Stateful), which zeros the log-variances.

Stateful move

Stateful was defined in aggregation/_mixins.py, but it is now needed by scalarization too. It moves to the shared torchjd/_mixins.py. This is backward compatible:

  • torchjd.aggregation.Stateful still works (re-exported).
  • torchjd.Stateful is added as the new top-level path, since the mixin is now cross-cutting.

Internal aggregation modules (_cr_mogm, _gradvac, _nash_mtl) had their imports and docstring cross-references updated. No behavior change for existing aggregators.

Tests

tests/unit/scalarization/test_uw.py covers the value at init, int-vs-tuple shape equivalence, scalar output and gradient flow over all input shapes (0-dim, vector, matrix, higher-dim), gradient flow to log_var, shape validation, reset(), that negative inputs are allowed (unlike GeometricMean), trainability via an optimizer step, and the representations. The full tests/unit suite was run as a regression check since the Stateful move touches the aggregation package.

Question on the shape API

Unlike the stateless scalarizers, UW cannot be shape-agnostic: it holds one learnable log-variance per value, and that parameter has to be created at construction time so it can be handed to the optimizer before training starts.

I went with shape: int | Sequence[int], so UW(3) builds a length-3 vector and UW((2, 3)) builds a 2D grid. Reasons:

  • Consistency with the package. Every other scalarizer accepts any input shape and reduces over all elements. Constant already establishes the "fix the shape at construction, validate at call time" pattern for shape-bound scalarizers, and UW follows it rather than being a 1D-only special case.
  • No cost to the common case. An int just collapses to (n,) internally, so UW(3) is exactly as ergonomic as a num_tasks argument would be, while higher-dim losses still work for free.
  • torch-idiomatic. Accepting int | Sequence[int] for a shape matches how many torch constructors behave.

The alternative is num_tasks: int only (1D vectors of m task losses, matching LibMTL exactly). It is slightly simpler conceptually, since "uncertainty per task" is naturally 1D, but it cannot scalarize higher-dim loss tensors and makes UW inconsistent with the other scalarizers. Happy to switch if you prefer that.

ppraneth added 2 commits June 1, 2026 09:59
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
@ppraneth ppraneth requested review from a team, PierreQuinton and ValerianRey as code owners June 1, 2026 04:52
ppraneth added 2 commits June 1, 2026 10:24
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
Copy link
Copy Markdown
Contributor

@PierreQuinton PierreQuinton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would prefer a usage example of how to co-train parameters of a model and the log_var parameters in the docstring.

Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment thread src/torchjd/aggregation/__init__.py
Comment thread src/torchjd/scalarization/_uw.py Outdated
Signed-off-by: ppraneth <pranethparuchuri@gmail.com>
@ppraneth ppraneth requested a review from PierreQuinton June 1, 2026 06:29
@ppraneth
Copy link
Copy Markdown
Contributor Author

ppraneth commented Jun 1, 2026

Thanks, addressed all four:

  1. Added a co-training doc to the UW docstring (model params + log_var params in one optimizer).
  2. Cross-refs switched to ~torchjd._mixins.Stateful everywhere.
  3. Dropped Stateful from torchjd/aggregation/__init__.py; it is now only torchjd.Stateful. This removes the old torchjd.aggregation.Stateful, so I added a BREAKING changelog note (can add a deprecation instead if you prefer).
  4. Simplified to nn.Parameter(torch.zeros(shape)). Confirmed the int | Sequence[int] annotation still typechecks since torch.zeros has overloads for both.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants