feat(aggregation): Add UW#721
Open
ppraneth wants to merge 5 commits into
Open
Conversation
Contributor
PierreQuinton
left a comment
There was a problem hiding this comment.
I think I would prefer a usage example of how to co-train parameters of a model and the log_var parameters in the docstring.
Contributor
Author
|
Thanks, addressed all four:
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds
UW, the uncertainty weighting scalarizer from Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics (Kendall et al., CVPR 2018). This is the first stateful, trainable scalarizer, so the PR also moves theStatefulmixin to a shared location so both the aggregation and scalarization packages can use it.UWEach value$L_i$ (typically a per-task loss) is assigned a learnable log-variance $s_i = \log \sigma_i^2$ , and the values are combined as:
This is the regression objective (eq. 7 in the paper) after substituting$s_i = \log \sigma_i^2$ , which matches the LibMTL implementation.
Following the paper, the model learns the log-variance$s_i$ rather than the variance $\sigma_i^2$ directly. This is numerically more stable (the combination never divides by zero) and keeps $s_i$ unconstrained, since $e^{-s_i}$ is always positive. The $s_i$ are stored as an
nn.Parameter, so the scalarizer's parameters must be passed to the optimizer to be learned jointly with the model.Design notes:
shapeis given at construction (UW(3)orUW((2, 3))), since the parameter has to exist before the optimizer is built. The shape is validated against the input at call time, likeConstant.0(so-0.5instead.reset()(fromStateful), which zeros the log-variances.StatefulmoveStatefulwas defined inaggregation/_mixins.py, but it is now needed by scalarization too. It moves to the sharedtorchjd/_mixins.py. This is backward compatible:torchjd.aggregation.Statefulstill works (re-exported).torchjd.Statefulis added as the new top-level path, since the mixin is now cross-cutting.Internal aggregation modules (
_cr_mogm,_gradvac,_nash_mtl) had their imports and docstring cross-references updated. No behavior change for existing aggregators.Tests
tests/unit/scalarization/test_uw.pycovers the value at init, int-vs-tuple shape equivalence, scalar output and gradient flow over all input shapes (0-dim, vector, matrix, higher-dim), gradient flow tolog_var, shape validation,reset(), that negative inputs are allowed (unlikeGeometricMean), trainability via an optimizer step, and the representations. The fulltests/unitsuite was run as a regression check since theStatefulmove touches the aggregation package.Question on the shape API
Unlike the stateless scalarizers,
UWcannot be shape-agnostic: it holds one learnable log-variance per value, and that parameter has to be created at construction time so it can be handed to the optimizer before training starts.I went with
shape: int | Sequence[int], soUW(3)builds a length-3 vector andUW((2, 3))builds a 2D grid. Reasons:Constantalready establishes the "fix the shape at construction, validate at call time" pattern for shape-bound scalarizers, andUWfollows it rather than being a 1D-only special case.intjust collapses to(n,)internally, soUW(3)is exactly as ergonomic as anum_tasksargument would be, while higher-dim losses still work for free.int | Sequence[int]for a shape matches how many torch constructors behave.The alternative is
num_tasks: intonly (1D vectors of m task losses, matching LibMTL exactly). It is slightly simpler conceptually, since "uncertainty per task" is naturally 1D, but it cannot scalarize higher-dim loss tensors and makesUWinconsistent with the other scalarizers. Happy to switch if you prefer that.