Skip to content

feat(aggregation): Add MoDoWeighting#717

Merged
PierreQuinton merged 6 commits into
SimplexLab:mainfrom
KhusPatel4450:feat/modo-weighting
Jun 3, 2026
Merged

feat(aggregation): Add MoDoWeighting#717
PierreQuinton merged 6 commits into
SimplexLab:mainfrom
KhusPatel4450:feat/modo-weighting

Conversation

@KhusPatel4450
Copy link
Copy Markdown
Contributor

Adds MoDoWeighting from Three-Way Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance (JMLR 2024).

It's a stateful Weighting[PSDMatrix] implementing the λ-update from Algorithm 2:

  • λ_{t+1} = softmax(λ_t − γ·(G·λ_t + ρ·λ_t))

Per the discussion with @PierreQuinton and @ValerianRey on Discord, this follows the official LibMTL implementation which uses softmax rather than the paper's hard simplex projection.

Designed to be composed with autogram.Engine in a two-batch training loop so that MoDo's double-sampling property is preserved (Gramian comes from batch 1; backward uses batch 2).

Test plan

  • Unit tests in tests/unit/aggregation/test_modo.py (12 functions, 72 cases — structural, reset, parameter validation, softmax boundary cases, recurrence verification)
  • Full unit suite passes: 3098 passed, 66 skipped, 33 xfailed
  • ty check passes on _modo.py
  • Sphinx doctest: 97 tests, 0 failures
  • HTML build clean with -W --keep-going -n
    EOF
    )"

@PierreQuinton PierreQuinton added cc: feat Conventional commit type for new features. package: aggregation labels May 29, 2026
Copy link
Copy Markdown
Contributor

@PierreQuinton PierreQuinton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good, it still needs few changes but once this is merge, I think this makes #676 easier to merge.

Comment thread src/torchjd/aggregation/_modo.py Outdated
Comment thread src/torchjd/aggregation/_modo.py Outdated
Comment thread src/torchjd/aggregation/_modo.py Outdated
Comment thread src/torchjd/aggregation/_modo.py Outdated
Comment thread src/torchjd/aggregation/_modo.py Outdated

with torch.no_grad():
grad = gramian @ lambd + self._rho * lambd
lambd = torch.softmax(lambd - self._gamma * grad, dim=-1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So in the end, this is a softmax. @rkhosrowshahi I think this means that moco is essentially just a composition with this weighting, where essentially you give yy_t to it, and then multiply yy_t by the obtained weights. Is that correct? If yes, I think we should change #676 accordingly.

Comment thread src/torchjd/aggregation/_modo.py Outdated
Copy link
Copy Markdown
Contributor

@PierreQuinton PierreQuinton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me this is ready, let's wait for @ValerianRey 's review s still.

Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is equivalent to the paper or the official or libmtl implementation. In all of these, the gramian is computed as J_1 @ J_2^T (I think this aggregator makes 0 sense for IWRM because of that, so we gotta think of it in MTL context). Autojac's gramian computed on losses_1 would be J_1 @ J_1^T though, so I think this PR's usage example is wrong.

See equation 2.9a, step 3 of the algorithm, or line 55 of the libmtl implementation.

I think the only way to add MoDo to torchjd would be with the same implementation but different usage example:

  • user computes J_1 and J_2 using autojac.jac
  • user compute G = J_1 @ J_2^T
  • user computes weights by applying a MoDoWeighting to G
  • user does an extra backward pass with some new losses (losses_3) weighted with the obtained weights.

I think depending on the implementation they either use only losses_1 and losses_2, or they also use losses_3. Idk what's best.

Note that here G is not a gramian, and is not PSD in general. Gotta type MoDoWeighting properly.

What do you think @PierreQuinton @KhusPatel4450 ?

Comment thread CHANGELOG.md Outdated
Comment on lines +13 to +15
- Added `MoDoWeighting` from [Three-Way Trade-Off in Multi-Objective Learning: Optimization,Generalization and Conflict-Avoidance](https://www.jmlr.org/papers/volume25/23-1287/23-1287.pdf) (JMLR 2024). It is a stateful `Weighting` that maintains task weights across calls via a
softmax-projected gradient step on the Gramian, intended to be composed with `autogram.Engine`
in a two-batch training loop.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can remove the description of what MoDo is, or fix it (it's not doable with autogram with the changes I suggest).

Comment thread src/torchjd/aggregation/_modo.py Outdated
from ._weighting_bases import _GramianWeighting


class MoDoWeighting(_GramianWeighting, Stateful, _NonDifferentiable):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't inherit from _GramianWeighting anymore with the changes I propose. Need to either inherit from _MatrixWeighting, or inherit from Weighting[Matrix] and override call to make a more specific docstring (e.g. explaining exactly what the matrix is: J_1 @ J_2^T)

Also need to fix the main docstring accordingly.

Comment thread src/torchjd/aggregation/_modo.py Outdated
Comment on lines +19 to +20
<https://www.jmlr.org/papers/volume25/23-1287/23-1287.pdf>`_ (JMLR 2024), commonly referred
to as MoDo (Multi-Objective gradient with Double sampling).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arguably the whole method (given in the usage example) is called Modo. Can remove this last part. The part explaining the acronym (in parentheses) is nice though.

Comment thread src/torchjd/aggregation/_modo.py Outdated
<https://www.jmlr.org/papers/volume25/23-1287/23-1287.pdf>`_ (JMLR 2024), commonly referred
to as MoDo (Multi-Objective gradient with Double sampling).

Given a Gramian :math:`G`, the weights :math:`\lambda` are updated at each call by a
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a gramian anymore with the changes i propose.

Comment thread src/torchjd/aggregation/_modo.py Outdated
Comment on lines +40 to +44
.. warning::
MoDo's convergence guarantees rely on **double sampling**: the Gramian passed to this
weighting must come from a mini-batch that is independent of the one used for the
subsequent parameter update. The Gramian can be computed efficiently from a batch of
losses using the :class:`~torchjd.autogram.Engine`. See the usage example below.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit wrong, we can't talk about a Gramian there and it should itself come from 2 batches. The last part about autogram should be removed.

Comment thread src/torchjd/aggregation/_modo.py Outdated
Comment on lines +51 to +53
Train a model using MoDo with two independent mini-batches per step. The first batch
drives the :math:`\lambda` update via the Gramian; the second batch drives the parameter
update via the usual backward pass.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would simplify this explanation to make it clear that this is just doing basic MoDo, so that people actually follow this usage example if they wanna reproduce MoDo.

Something like:
Train a model with MoDo.

The role of their paper is to explain what MoDo is. The role of TorchJD is to make it extremely easy to reproduce MoDo, but not to explain it IMO (especially since it's quite complex).

Comment thread src/torchjd/aggregation/_modo.py Outdated
@KhusPatel4450
Copy link
Copy Markdown
Contributor Author

Hello,

I think you are correct, going back and looking at the paper again and looking at equation 2.8 and 2.9, you are right. We originally discussed to use Autogram to compute gramian efficiently, at the time it seemed like the right fit since autogram gives J_1 @ J_1^⊤ and the λ update needs an [m, m] matrix.

The matrix in the λ update is ∇F_{z1}^⊤ ∇F_{z2} which in TorchJD notation (where J is [M, d]) is J_1 @ J_2^⊤. I see now though as to why using J_1 @ J_1^⊤ makes it biased and misses the whole point of double sampling.

So the changes needed are:

  1. MoDoWeighting should take a general Matrix not a PSDMatrix since J_1 @ J_2^⊤ is not symmetric or PSD in general
  2. The usage example needs to be rewritten, user computes J_1 and J_2 via autojac.jac, manually computes G = J_1 @ J_2^⊤, passes G to MoDoWeighting
  3. Remove the autogram usage

One open question on the model update (equation 2.9b): it uses a fresh Z_{t+1} = {z_{t+1,1}, z_{t+1,2}} So technically its 3 seperate batches per step. Looking at the LibMTL implementation, they default to 2 batches and make 3 an opt-in via a three_grad flag. Should we follow the same pattern?

@ValerianRey
Copy link
Copy Markdown
Contributor

Thanks for the quick reply. I agree. Please go ahead.

One open question on the model update (equation 2.9b): it uses a fresh Z_{t+1} = {z_{t+1,1}, z_{t+1,2}} So technically its 3 seperate batches per step. Looking at the LibMTL implementation, they default to 2 batches and make 3 an opt-in via a three_grad flag. Should we follow the same pattern?

Idk, gotta investigate that. I'll look at that soon-ish.

@ValerianRey
Copy link
Copy Markdown
Contributor

For now you could just add the two usage examples.

Later, we can decide to keep only one, or to change the description to say something like: "this is the default behavior of MoDo in the official implementation and in LibMTL" and "this is an alternative implementation that has the advantage of [...]"

@SimplexLab SimplexLab deleted a comment from opencode-agent Bot May 31, 2026
@ValerianRey

This comment has been minimized.

@opencode-agent

This comment has been minimized.

@ValerianRey

This comment has been minimized.

@ValerianRey
Copy link
Copy Markdown
Contributor

My main concern is now fixed, tyvm @KhusPatel4450. I still need to review more in-depth (since there's a lot of code and tests). But I'm pretty sure we'll be able to merge this very soon!

@ValerianRey ValerianRey self-requested a review June 2, 2026 10:13
Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're getting there! I just made a thorough review so I think we can merge whenever everything is fixed.

Comment thread src/torchjd/aggregation/_modo.py Outdated

The following example reproduces basic MoDo using two independent mini-batches per step.

.. code-block:: python
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to use the testcode directive so that this usage example is tested by doctest (in the CI or manually using uv run make -C ./docs doctest), or no block and prefix each line with >>>

Otherwise the example isn't tested by doctest and we have no guarantee that it works and that it keeps working in future updates.

Same comment for the other usage example, and even for any future example you may write.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to have that inside a skill related to documentation for agents?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too simple for a skill but we should definitely have that explained in AGENTS.md and / or contributing.md, or even have a script checking that we don't use any code-block directive. Feel free to make a PR with what you think is best.

Comment thread src/torchjd/aggregation/_modo.py Outdated
params = list(model.parameters())

# loader_1 and loader_2 must yield independent draws of the same size.
for batch_1, batch_2 in zip(loader_1, loader_2):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need the example to be self-sufficient, so we either need to define loader_1 and loader_2 (which is tedious) or just do something like:

inputs = ...
targets = ...

for i in range(len(input) // 2):
    input_1, input_2 = inputs[2*i], inputs[2*i + 1]
    target_1, target_2 = targets[2*i], targets[2*i + 1]
    ...

With the ... filled with the appriopriate code.

Similar comment for the other example.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To better reflect the new usage, I think we should use G = J1 @ J2.T instead of G = J @ J.T in test_reset_restores_first_step_behavior, test_output_lies_on_simplex, test_update_recurrence and test_changing_m_auto_resets.

Similarly, we should use G1 = J1 @ J2.T and G2 = J3 @ J4.T in test_two_consecutive_steps.

Comment thread tests/unit/aggregation/test_modo.py Outdated
Comment on lines +68 to +78
def test_small_gamma_stays_near_uniform() -> None:
"""With a tiny gamma, one step barely moves lambda from the uniform initialisation."""

J = randn_((3, 8))
G = J @ J.T
m = J.shape[0]
W = MoDoWeighting(gamma=1e-8)
uniform = tensor_([1.0 / m] * m)
assert_close(W(G), uniform, atol=1e-6, rtol=1e-6)


Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove this test.

Suggested change
def test_small_gamma_stays_near_uniform() -> None:
"""With a tiny gamma, one step barely moves lambda from the uniform initialisation."""
J = randn_((3, 8))
G = J @ J.T
m = J.shape[0]
W = MoDoWeighting(gamma=1e-8)
uniform = tensor_([1.0 / m] * m)
assert_close(W(G), uniform, atol=1e-6, rtol=1e-6)

Comment thread src/torchjd/aggregation/_modo.py Outdated
optimizer.zero_grad()
"""

def __init__(self, gamma: float = 0.1, rho: float = 0.0) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the official implementation and in LibMTL, the default value of rho is 0.1. It would be better to match this IMO.

Comment thread src/torchjd/aggregation/_modo.py Outdated
G = J_1 @ J_2.T
weights = weighting(G)

losses_2.backward(weights)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we follow Equation 2.9b from the paper, this should be

losses = ((losses_1 + losses_2) / 2.0)
losses.backward(weights)

In the official implementation, it's also what they do, except that they forgot the division by 2.

Comment thread src/torchjd/aggregation/_modo.py Outdated
lambd = cast(Tensor, self._lambda)

grad = matrix @ lambd + self._rho * lambd
lambd = torch.softmax(lambd - self._gamma * grad, dim=-1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there was some confusion on discord when we talked about how to project onto the simplex. We all thought that the official implementation was using a softmax, but it (and LibMTL) actually uses:

    def _projection2simplex(self, y):
        m = len(y)
        sorted_y = torch.sort(y, descending=True)[0]
        tmpsum = 0.0
        tmax_f = (torch.sum(y) - 1.0)/m
        for i in range(m-1):
            tmpsum+= sorted_y[i]
            tmax = (tmpsum - 1)/ (i+1.0)
            if tmax > sorted_y[i+1]:
                tmax_f = tmax
                break
        return torch.max(y - tmax_f, torch.zeros(m).to(y.device))

Should we use this way of projecting @PierreQuinton ?

If we do that, we'll need to say that parts of this file were adapted from the official implementation, add a link to it, and add a notice in NOTICES @KhusPatel4450.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I know what happened now, the code that I was told to read was from Rasa's MoCo.py and that used torch.softmax, but yeah now I see that it uses this.

I personally think we should follow this


.. admonition:: Example (three batches per step)

The following example reproduces basic MoDo using three independent mini-batches per step,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could add that this is the behavior of MoDo in LibMTL and in the official implementation when three_grads is True.


.. admonition:: Example (two batches per step)

The following example reproduces basic MoDo using two independent mini-batches per step.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could add that this is MoDo as described in the paper, and it's the behavior of the official implementation when three_grads is False.

@KhusPatel4450
Copy link
Copy Markdown
Contributor Author

this commit has everything EXCEPT for the projection onto simplex change, I will add a new commit for that once we reach upon a conclusion @ValerianRey

@KhusPatel4450
Copy link
Copy Markdown
Contributor Author

this commit now has the projection onto simplex as discussed on discord

@ValerianRey ValerianRey self-requested a review June 3, 2026 07:07
Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, we can merge! Thanks a lot for the work @KhusPatel4450

@ValerianRey
Copy link
Copy Markdown
Contributor

@PierreQuinton feel free to merge or re-review if you prefer

@PierreQuinton PierreQuinton merged commit 57974a0 into SimplexLab:main Jun 3, 2026
21 checks passed
@PierreQuinton
Copy link
Copy Markdown
Contributor

Congrats on the merge @KhusPatel4450 and thanks again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: feat Conventional commit type for new features. package: aggregation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants