transformers - 💡(How to fix) Fix MPS OOM error, finetuning T5Gemma2 with Seq2SeqTrainer [1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
huggingface/transformers#45517Fetched 2026-04-20 11:58:52
View on GitHub
Comments
0
Participants
1
Timeline
3
Reactions
0
Author
Participants
Timeline (top)
labeled ×1mentioned ×1subscribed ×1

Code Example

# %%
from huggingface_hub import login

login()

# %%
import os
os.environ["PYTORCH_MPS_LOW_WATERMARK_RATIO"] = "0.9"
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "1.4" # OOM still happens with default value

# %%
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
import evaluate
import torch


# %%
dataset = load_dataset("aaai25withanonymous/MathBridge")
dataset = dataset["train"].train_test_split(test_size=0.01, seed=488373954176)

# %%
MODEL_ID = "google/t5gemma-2-270m-270m"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, backend="torchvision")


# %%
model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL_ID, 
    device_map="auto",
)

# %%
collator = DataCollatorForSeq2Seq(tokenizer, model=model)
metric = evaluate.load("chrf")

# %%
import numpy as np

def postprocess_text(preds, labels):
    preds = [p.lower().strip() for p in preds]
    labels = [[l.lower().strip()] for l in labels]
    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 padding inserted by the data collator before decoding
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return {"chrf": result["score"], "gen_len": result["gen_len"]}

# %%
prefix="Convert the LaTeX equation to spoken English: "
def preprocess_function(examples):
    inputs = [
        prefix + ctx_before + " " + eq + " " + ctx_after
        for ctx_before, eq, ctx_after in zip(
            examples["context_before"],
            examples["equation"],
            examples["context_after"]
        )
    ]
    # Dataset has some trailing punctuation noise, which is removed here.
    # There is also sometimes leading capitalisation noise, which is mitigated by
    # a case-insensitive metric at the training stage.
    targets = [s.lower().rstrip(". ") for s in examples["spoken_English"]]

    model_inputs = tokenizer(inputs)

    # Setup the tokenizer for targets
    labels = tokenizer(text_target=targets)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

preprocess_function(dataset["train"][0:5])

# %%
dataset_tokenized = dataset.map(preprocess_function, batched=True)

# %%
batch_size = 2
model_name = MODEL_ID.split("/")[-1]
args = Seq2SeqTrainingArguments(
    f"MathBridge-{model_name}",
    num_train_epochs=1,
    warmup_steps=0.05,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=64,
    bf16=True,
    eval_strategy="steps",
    eval_steps=500,
    save_strategy="steps",
    save_steps=500,
    save_total_limit=3,
    predict_with_generate=True,
    push_to_hub=True,
    generation_max_length=145,   # tune to your target sequence lengths
)

# %%
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=dataset_tokenized["train"],
    eval_dataset=dataset_tokenized["test"],
    data_collator=collator,
    processing_class=tokenizer,
    compute_metrics=compute_metrics,
)

# %%
trainer.train()

# %%
RAW_BUFFERClick to expand / collapse

System Info

MacOS M3 24gb ram Tahoe, MPS backend, transformers 5.5.4 (with #45516 applied), torch 2.10.0

Who can help?

@SunMarc

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Note that #45516 is necessary to fix a bug that allows Seq2SeqDataCollator to work in this code. Note that you may want to use the Hyeonsieun/MathBridge_1M database instead (23x lower size), but I haven't tested that database.

# %%
from huggingface_hub import login

login()

# %%
import os
os.environ["PYTORCH_MPS_LOW_WATERMARK_RATIO"] = "0.9"
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "1.4" # OOM still happens with default value

# %%
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
import evaluate
import torch


# %%
dataset = load_dataset("aaai25withanonymous/MathBridge")
dataset = dataset["train"].train_test_split(test_size=0.01, seed=488373954176)

# %%
MODEL_ID = "google/t5gemma-2-270m-270m"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, backend="torchvision")


# %%
model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL_ID, 
    device_map="auto",
)

# %%
collator = DataCollatorForSeq2Seq(tokenizer, model=model)
metric = evaluate.load("chrf")

# %%
import numpy as np

def postprocess_text(preds, labels):
    preds = [p.lower().strip() for p in preds]
    labels = [[l.lower().strip()] for l in labels]
    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 padding inserted by the data collator before decoding
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return {"chrf": result["score"], "gen_len": result["gen_len"]}

# %%
prefix="Convert the LaTeX equation to spoken English: "
def preprocess_function(examples):
    inputs = [
        prefix + ctx_before + " " + eq + " " + ctx_after
        for ctx_before, eq, ctx_after in zip(
            examples["context_before"],
            examples["equation"],
            examples["context_after"]
        )
    ]
    # Dataset has some trailing punctuation noise, which is removed here.
    # There is also sometimes leading capitalisation noise, which is mitigated by
    # a case-insensitive metric at the training stage.
    targets = [s.lower().rstrip(". ") for s in examples["spoken_English"]]

    model_inputs = tokenizer(inputs)

    # Setup the tokenizer for targets
    labels = tokenizer(text_target=targets)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

preprocess_function(dataset["train"][0:5])

# %%
dataset_tokenized = dataset.map(preprocess_function, batched=True)

# %%
batch_size = 2
model_name = MODEL_ID.split("/")[-1]
args = Seq2SeqTrainingArguments(
    f"MathBridge-{model_name}",
    num_train_epochs=1,
    warmup_steps=0.05,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=64,
    bf16=True,
    eval_strategy="steps",
    eval_steps=500,
    save_strategy="steps",
    save_steps=500,
    save_total_limit=3,
    predict_with_generate=True,
    push_to_hub=True,
    generation_max_length=145,   # tune to your target sequence lengths
)

# %%
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=dataset_tokenized["train"],
    eval_dataset=dataset_tokenized["test"],
    data_collator=collator,
    processing_class=tokenizer,
    compute_metrics=compute_metrics,
)

# %%
trainer.train()

# %%

Expected behavior

When training starts, memory usage starts growing. This eventually causes pytorch to crash with e.g.: RuntimeError: MPS backend out of memory (MPS allocated: 3.78 GiB, other allocations: 20.97 GiB, max allowed: 24.86 GiB). Tried to allocate 320.00 MiB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

After OOM, my jupyter notebook doesn't release the memory. Instead, it appears to be compressed instead; look at this activity monitor screenshot:

<img width="805" height="59" alt="Image" src="https://github.com/user-attachments/assets/06fee04d-a3ee-4421-84e9-6cd28173dcd8" />

As a consequence, I can't restart training without restarting the python kernel.

extent analysis

TL;DR

The most likely fix is to adjust the batch size, gradient accumulation steps, or model configuration to reduce memory usage and prevent out-of-memory (OOM) errors.

Guidance

  • Reduce the batch size (per_device_train_batch_size and per_device_eval_batch_size) to decrease memory usage.
  • Decrease the gradient_accumulation_steps to reduce the amount of gradients stored in memory.
  • Consider using a smaller model or adjusting the generation_max_length to reduce the sequence length and subsequent memory usage.
  • Monitor memory usage and adjust these parameters accordingly to find a balance between performance and memory constraints.

Example

No code example is provided as the issue is related to memory management and configuration adjustments rather than a specific code snippet.

Notes

The provided code and configuration may require significant adjustments to run within the available memory constraints. It's essential to monitor memory usage and adjust the parameters accordingly to prevent OOM errors.

Recommendation

Apply a workaround by reducing the batch size and gradient accumulation steps to alleviate memory pressure and prevent OOM errors. This approach allows for continued training while mitigating the memory constraints.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

When training starts, memory usage starts growing. This eventually causes pytorch to crash with e.g.: RuntimeError: MPS backend out of memory (MPS allocated: 3.78 GiB, other allocations: 20.97 GiB, max allowed: 24.86 GiB). Tried to allocate 320.00 MiB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

After OOM, my jupyter notebook doesn't release the memory. Instead, it appears to be compressed instead; look at this activity monitor screenshot:

<img width="805" height="59" alt="Image" src="https://github.com/user-attachments/assets/06fee04d-a3ee-4421-84e9-6cd28173dcd8" />

As a consequence, I can't restart training without restarting the python kernel.

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING