Skip to content

Upgrading transformers version to 4.51.3 to support recent models#175

Merged
Dornavineeth merged 12 commits into
locuslab:mainfrom
filyp:pr-transformers-4.51.3
Mar 7, 2026
Merged

Upgrading transformers version to 4.51.3 to support recent models#175
Dornavineeth merged 12 commits into
locuslab:mainfrom
filyp:pr-transformers-4.51.3

Conversation

@filyp
Copy link
Copy Markdown
Contributor

@filyp filyp commented Jan 26, 2026

What does this PR do?

  • It upgrades transformers to 4.51.3, to support newer models like gemma3 and qwen3
  • Makes UnlearnTrainer implementation more future proof. (See Upgrading transformers version to support recent models #173)
  • Makes other necessary changes to be compatible with new transformers version
    • Adds num_items_in_batch to compute_loss signature
    • Prevent exceptions from evaluating when eval_dataset=None
    • Uses trainer.processing_class instead of trainer.tokenizer (it's depracated and transformers==5 removes it)

Additionally it:

  • Simplifies the installation of lm_eval (no need to have a special install group in setup.py)
  • Minor fix to the leaderboard doc
  • Makes .gitignore more comprehensive

Related issues: #173 and #155

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Have you gone through the contributions guide?
  • Are your changes documented? Read documentation guidelines here.

Tests

I manually tested the new evaluation, with removed prediction_step from UnlearninTrainer, and it works the same.

When I tested unlearning, the unlearning trajectories are exactly the same when using gradient_accumulation_steps=1. But when it's =/=1, the upgrade changes the scale of the logged training/loss (it's 4x higher.), and subtly changes the unlearning trajectory. This most likely comes from the gradient accumulation fix in transformers=4.46.

image

(Tested unlearning with this command:

python src/train.py --config-name=unlearn.yaml experiment=unlearn/tofu/default eval=tofu_simple question_key=paraphrased_question eval.tofu.batch_size=16 trainer.args.report_to=wandb trainer=NPO task_name=...

Where tofu_simple is a config with fewer eval metrics.)

compute_loss docstring now states:

        Subclass and override for custom behavior. If you are not using `num_items_in_batch` when computing your loss,
        make sure to overwrite `self.model_accepts_loss_kwargs` to `False`. Otherwise, the loss calculating might be slightly inaccurate when performing gradient accumulation.

When I tried setting trainer.model_accepts_loss_kwargs = False it restores the previous scale of the training loss, but it doesn't affect the unlearning trajectory at all, only that logged loss scale.

filyp added 4 commits January 26, 2026 16:15
make UnlearnTrainer implementation more future proof
make other necessary changes to be compatible with new transformers version
fix leaderboard docs
wider gitignore
@filyp
Copy link
Copy Markdown
Contributor Author

filyp commented Feb 8, 2026

I was actually planning to upgrade transformers to 5.0.0, because it has some major MoE optimizations + I'm working on an unlearning method that depends on the MoE implementation, so I'd rather already work on the better implementation.

But I'd rather wait until you merge or at least review this upgrade to 4.51.3, because I don't want to change too many things at once. What do you think?

@molereddy
Copy link
Copy Markdown
Collaborator

Looks great!! Can you clean up the linter errors?

@filyp
Copy link
Copy Markdown
Contributor Author

filyp commented Feb 11, 2026

Ah, right, fixed now.

@molereddy
Copy link
Copy Markdown
Collaborator

Seems the tests are still failing -- maybe a version mismatch? Our instructions for formatting are here. Let me know if you are still stuck!

@filyp
Copy link
Copy Markdown
Contributor Author

filyp commented Feb 18, 2026

Ah I see locally it passed for me because I had a newer ruff version. I reformatted to the one that the github action uses, so should be good now

@Dornavineeth
Copy link
Copy Markdown
Collaborator

Hey @filyp

Really appreciate all the work you put into upgrading the Transformers version. This is a pretty substantial change. I left a few comments; most are minor, except for one issue I called out here that I think we should address before merging.

Once that’s resolved, I’m happy to merge. This upgrade will make it much easier for folks to try the latest models and benefit from the newer bug fixes.

On a side note: I’d love to hear your thoughts on #145 and if it’s easy to incorporate the fix into this PR, that would be great.

Copy link
Copy Markdown
Collaborator

@Dornavineeth Dornavineeth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few changes. Overall, pretty impressive work! Thank you so much.

Comment thread src/trainer/base.py Outdated
self.log(eval_metrics)
return eval_metrics

if eval_dataset is None or eval_dataset == "dummy":
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this hardcoded? Should we remove this?

eval_dataset == "dummy"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not fully happy with that fix, but that's related to this line in train.py:

        eval_dataset=data.get("eval", "dummy"),  # None would trigger Trainer exception

It's just that in the new transformers version, the trainer asserts that eval is not None, if we tell it to do evaluations, which breaks our custom evaluators setup. LMK if you see a better solution.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. Seems reasonable.

Can you just define a variable at the top _EVAL_PLACEHOLDER = "_EVAL_PLACEHOLDER" and use this variable at all places?

Comment thread src/trainer/unlearn/base.py Outdated
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
if len(labels) == 1:
labels = labels[0]
def compute_loss(self, model, inputs, **kwargs):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see why you renamed compute_loss to compute_unlearn_loss, but this could be a breaking change for anyone pulling the update into their forks; it could lead to runtime errors like compute_unlearn_loss is not defined.

I’d prefer to keep the backward compatibility. Can we keep the previous pattern (per the earlier docstrings): have prediction_step and then call super().compute_loss() for the actual loss computation.

NOTE: You need to copy the prediction_step from the transformers 4.51.3

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good point. I committed a 3rd alternative now, take a look at current base.py. It preserves backwards compatibility and also is interoperable with other transformers versions.

I regression tested on both tofu unlearning, and some prediction, and it works the same.

else:
if has_labels or loss_without_labels:
with self.compute_loss_context_manager():
### Call compute_loss of super class since overridden compute_loss is not be applicable to eval_dataset.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
The only change to this function is calling the Trainer's compute_loss, as it's often overridden by unlearning methods, and we want to maintain the Trainer's evaluation setup.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also include these docstrings in the updated code including prediction_step

Comment thread docs/components.md Outdated
...

def compute_loss(self, model, inputs, return_outputs=False):
def compute_unlearn_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to retain it as compute_loss

see https://github.com/locuslab/open-unlearning/pull/175/changes#r2837139317

Comment thread src/train.py Outdated
train_dataset=data.get("train", None),
eval_dataset=data.get("eval", None),
tokenizer=tokenizer,
eval_dataset=data.get("eval", "dummy"), # None would trigger Trainer exception
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the "dummy" and make it
eval_dataset=data.get("eval", None),

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is that using None triggers a Trainer exception (see #175 (comment))

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, I moved this hacky fix into the FinetuneTrainer to hide it more. I don't see any better way to fix it.

Comment thread .gitignore
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like these additions in gitignore are specific to your usecases.
can you revert this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, done

Comment thread README.md Outdated
conda create -n unlearning python=3.11
conda activate unlearning
pip install .[lm_eval]
pip install .
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason you removed this?
I would like to give this as an option instead of keeping it requirements.txt because the lm_eval harness is a more involved build.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see why I didn't notice lm_eval heaviness, I think I was using version 0.4.11 which installs in 20s (I tested it now in a fresh venv), while the older 0.4.8 in 70s. Should I bump it then? Or revert?
(And the reason I moved it is to simplify the installation.)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can bump the version and still keep it optional for people having
pip install .[lm_eval]

Comment thread requirements.txt Outdated
huggingface-hub==0.36.0
transformers==4.51.3
hf-xet==1.2.0
lm-eval==0.4.8
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment thread setup.py
@@ -17,13 +17,10 @@
packages=find_packages(),
install_requires=requirements, # Uses requirements.txt
extras_require={
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@filyp
Copy link
Copy Markdown
Contributor Author

filyp commented Feb 22, 2026

Thanks! I'll go through the review soon. First, about the #145 you mentioned:

It looks pretty complicated, I left a comment there with some alternative simpler fix. But I'd rather keep it separate from this PR, because in my setup I can't test #145 easily. (Unless you want to go with the simple fix, then I can add it, but I still can't test it fully.)

@filyp
Copy link
Copy Markdown
Contributor Author

filyp commented Feb 22, 2026

I also committed now runners/modal_runner.py, which is not related to PR, but people may find it useful for running with less setup required. I also plan to add more runners in the future (slurm, maybe runpod).

Besides that, I addressed all your comments.

@Dornavineeth
Copy link
Copy Markdown
Collaborator

Dornavineeth commented Feb 23, 2026

All good! Left comments for 2 minor nits. Once fixed, I will merge it.

@filyp
Copy link
Copy Markdown
Contributor Author

filyp commented Feb 26, 2026

Hey, somehow I can't see these 2 new comments. Could you link them? (Or maybe you have a "review" started and didn't "submit" the comments, and then only you see them. I remember once having this problem because the UI is quite confusing.)

@filyp
Copy link
Copy Markdown
Contributor Author

filyp commented Mar 2, 2026

Hm, I'm still having problems accessing these comments for some reason :/
(Can you access them when you're not logged in to your github account?)
Maybe simply paste them here?

@molereddy
Copy link
Copy Markdown
Collaborator

From my side, I can see unresolved comments in the "Files Changed" viewer. But only when logged in.
Unsure if there's a permissions issue preventing you from viewing them.
image

@filyp
Copy link
Copy Markdown
Contributor Author

filyp commented Mar 2, 2026

@molereddy yes I see the same ones. I'm understanding that @Dornavineeth added some new ones that we both don't see?

And these unresolved ones are either outdated or I left a question in them.

@molereddy
Copy link
Copy Markdown
Collaborator

Yes, the links @Dornavineeth shard don't open up any comment for me.

Copy link
Copy Markdown
Collaborator

@Dornavineeth Dornavineeth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see the comments.

Comment thread runners/modal_runner.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove this just for the sake of simplicity?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, removed. (Let me know if you'd like me to include that runner someplace else, but ok if not.)

Comment thread src/trainer/unlearn/base.py Outdated


class UnlearnTrainer(FinetuneTrainer):
def prediction_step(self, *args, **kwargs):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just copy the orginal transformers code of prediction_step and make edits just wherever required?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to not accidentally break any other functionalities included in the hf prediction_step.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — copied prediction_step from transformers 4.51.3 with the one change of calling super().compute_loss() instead of self.compute_loss().

(Note though that this implementation could potentially break at some future bump because of these cherrypicked imports at the top. But for the current bump, I regression tested and it's fine.)

Comment thread src/trainer/base.py Outdated
self.log(eval_metrics)
return eval_metrics

if eval_dataset is None or eval_dataset == "dummy":
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. Seems reasonable.

Can you just define a variable at the top _EVAL_PLACEHOLDER = "_EVAL_PLACEHOLDER" and use this variable at all places?

Comment thread README.md Outdated
conda create -n unlearning python=3.11
conda activate unlearning
pip install .[lm_eval]
pip install .
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can bump the version and still keep it optional for people having
pip install .[lm_eval]

@filyp
Copy link
Copy Markdown
Contributor Author

filyp commented Mar 6, 2026

Ok, I applied all these changes. I also regression tested on tofu unlearning and on some code that uses prediction_step and it's unchanged.

So I think it's all done now.

@Dornavineeth
Copy link
Copy Markdown
Collaborator

Great Work. Thank you so much for this PR.

@Dornavineeth Dornavineeth merged commit a456aa2 into locuslab:main Mar 7, 2026
1 check passed
@filyp filyp deleted the pr-transformers-4.51.3 branch March 9, 2026 14:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants