transformers - ✅(Solved) Fix Support for sequence-level custom metrics with decoder-only models [1 pull requests, 6 comments, 5 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
huggingface/transformers#44593Fetched 2026-04-08 00:27:29
View on GitHub
Comments
6
Participants
5
Timeline
11
Reactions
0
Author
Timeline (top)
commented ×6subscribed ×2cross-referenced ×1labeled ×1

Root Cause

Retrieves GenerationConfig from model.generation_config

    # Update with defaults because earlier the generation config used to be init
    # with default values. Now we init it with `None` and keep defaults for BC
    gen_config = self.model.generation_config
    default_gen_config = gen_config._get_default_generation_params()
    gen_config.update(**default_gen_config, defaults_only=True)
    # in case the batch is shorter than max length, the output should be padded
    if generated_tokens.shape[-1] < gen_config.max_length:
        generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length)
    elif gen_config.max_new_tokens is not None and generated_tokens.shape[-1] < gen_config.max_new_tokens + 1:
        generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_new_tokens + 1)

Fix Action

Fixed

PR fix notes

PR #44650: Fix Seq2SeqTrainer generation path for decoder-only models

Description (problem / solution / changelog)

Closes #44593

Summary

  • use generation_input_ids/generation_attention_mask when provided for decoder-only models
  • otherwise infer prompt from leading -100 labels and build left-padded prompt batch
  • return completion tokens for decoder-only generation (strip prompt)
  • keep encoder-decoder behavior unchanged

Tests

  • PYTHONPATH=src python -m pytest tests/trainer/test_trainer_seq2seq.py -k Seq2SeqTrainerPredictionStepTester -q -rs

Related #26474, #33396 Follow-up to #32346

Changed files

  • src/transformers/trainer_seq2seq.py (modified, +76/-6)
  • tests/trainer/test_trainer_seq2seq.py (modified, +112/-0)

Code Example

def get_prompt_length_from_labels(labels: torch.Tensor) -> int:
    """
    Retourne le nombre de tokens du prompt (consécutifs -100 au début de labels)
    """
    # Pour tensor 1D
    prompt_len = (labels == -100).cumsum(dim=0).eq(torch.arange(1, len(labels) + 1, device=labels.device)).sum().item()
    return prompt_len


class Seq2SeqTrainerCustom(Seq2SeqTrainer):

    def prediction_step(
            self,
            model: nn.Module,
            inputs: dict[str, torch.Tensor | Any],
            prediction_loss_only: bool,
            ignore_keys: list[str] | None = None,
            **gen_kwargs,
    ) -> tuple[float | None, torch.Tensor | None, torch.Tensor | None]:
        """
        Perform an evaluation step on `model` using `inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (`nn.Module`):
                The model to evaluate.
            inputs (`dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (`bool`):
                Whether or not to return the loss only.
            gen_kwargs:
                Additional `generate` specific kwargs.

        Return:
            tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
            labels (each being optional).
        """

        if not self.args.predict_with_generate or prediction_loss_only:
            return super().prediction_step(
                model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
            )

        has_labels = "labels" in inputs
        inputs = self._prepare_inputs(inputs)

        # Priority (handled in generate):
        # non-`None` gen_kwargs > model.generation_config > default GenerationConfig()
        if len(gen_kwargs) == 0 and hasattr(self, "_gen_kwargs"):
            gen_kwargs = self._gen_kwargs.copy()
        if "num_beams" in gen_kwargs and gen_kwargs["num_beams"] is None:
            gen_kwargs.pop("num_beams")
        if "max_length" in gen_kwargs and gen_kwargs["max_length"] is None:
            gen_kwargs.pop("max_length")

        default_synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self.model)
        gen_kwargs["synced_gpus"] = gen_kwargs.get("synced_gpus", default_synced_gpus)

        generation_inputs = inputs.copy()


        # If the `decoder_input_ids` was created from `labels`, evict the former, so that the model can freely generate
        # (otherwise, it would continue generating from the padded `decoder_input_ids`)
        if (
                "labels" in generation_inputs
                and "decoder_input_ids" in generation_inputs
                and generation_inputs["labels"].shape == generation_inputs["decoder_input_ids"].shape
        ):
            generation_inputs = {
                k: v for k, v in inputs.items() if k not in ("decoder_input_ids", "decoder_attention_mask")
            }

        summon_full_params_context = (
            FullyShardedDataParallel.summon_full_params(self.model)
            if isinstance(self.model, FullyShardedDataParallel)
            else contextlib.nullcontext()
        )

        with summon_full_params_context:
            # generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs)
            # Beginning of fix
            batch_prompt_lens = [get_prompt_length_from_labels(labels) for labels in inputs["labels"]]
    
            all_generated_tokens = []
            for i in range(len(batch_prompt_lens)):
                prompt_len = batch_prompt_lens[i]
                gen_input_ids = inputs["input_ids"][i, :prompt_len].unsqueeze(0)
                gen_attention_mask = inputs["attention_mask"][i, :prompt_len].unsqueeze(0)
                generated_tokens = self.model.generate(
                    input_ids=gen_input_ids,
                    attention_mask=gen_attention_mask,
                    **gen_kwargs
                )
                all_generated_tokens.append(generated_tokens)
        generated_tokens = torch.nn.utils.rnn.pad_sequence(
            [x.squeeze(0) for x in all_generated_tokens],
            batch_first=True,
            padding_value=self.processing_class.pad_token_id
        )
        # End of Fix

        # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
        # TODO: remove this hack when the legacy code that initializes generation_config from a model config is
        # removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183
        if self.model.generation_config._from_model_config:
            self.model.generation_config._from_model_config = False

        # Retrieves GenerationConfig from model.generation_config
        # Update with defaults because earlier the generation config used to be init
        # with default values. Now we init it with `None` and keep defaults for BC
        gen_config = self.model.generation_config
        default_gen_config = gen_config._get_default_generation_params()
        gen_config.update(**default_gen_config, defaults_only=True)
        # in case the batch is shorter than max length, the output should be padded
        if generated_tokens.shape[-1] < gen_config.max_length:
            generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length)
        elif gen_config.max_new_tokens is not None and generated_tokens.shape[-1] < gen_config.max_new_tokens + 1:
            generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_new_tokens + 1)

        with torch.no_grad():
            if has_labels:
                with self.compute_loss_context_manager():
                    outputs = model(**inputs)
                if self.label_smoother is not None:
                    loss = self.label_smoother(outputs, inputs["labels"]).detach().mean()
                else:
                    loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).detach().mean()
            else:
                loss = None

        if self.args.prediction_loss_only:
            return loss, None, None

        if has_labels:
            labels = inputs["labels"]
            if labels.shape[-1] < gen_config.max_length:
                labels = self._pad_tensors_to_max_len(labels, gen_config.max_length)
            elif gen_config.max_new_tokens is not None and labels.shape[-1] < gen_config.max_new_tokens + 1:
                labels = self._pad_tensors_to_max_len(labels, gen_config.max_new_tokens + 1)
        else:
            labels = None

        return loss, generated_tokens, labels
RAW_BUFFERClick to expand / collapse

Feature request

Hi Hugging Face team,

I’m trying to compute custom metrics at the sequence level for a decoder-only Transformer model, but I ran into an issue. The Seq2SeqTrainer class provides the predict_with_generate option, but it is primarily designed for encoder-decoder architectures. As a result, using it with decoder-only models doesn’t fully support sequence-level metric computation out-of-the-box.

Motivation

Sequence-level metrics, such as BLEU, ROUGE, or other task-specific metrics, are useful for monitoring the training of large language models.

Your contribution

I implemented a local fix by subclassing Seq2SeqTrainer and overriding the prediction_step method. The main points of my approach are:

  • Compute the prompt length dynamically from the labels.
  • Mask the prompt and generate sequences per example.
  • Pad generated sequences to make metric computation straightforward.
def get_prompt_length_from_labels(labels: torch.Tensor) -> int:
    """
    Retourne le nombre de tokens du prompt (consécutifs -100 au début de labels)
    """
    # Pour tensor 1D
    prompt_len = (labels == -100).cumsum(dim=0).eq(torch.arange(1, len(labels) + 1, device=labels.device)).sum().item()
    return prompt_len


class Seq2SeqTrainerCustom(Seq2SeqTrainer):

    def prediction_step(
            self,
            model: nn.Module,
            inputs: dict[str, torch.Tensor | Any],
            prediction_loss_only: bool,
            ignore_keys: list[str] | None = None,
            **gen_kwargs,
    ) -> tuple[float | None, torch.Tensor | None, torch.Tensor | None]:
        """
        Perform an evaluation step on `model` using `inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (`nn.Module`):
                The model to evaluate.
            inputs (`dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (`bool`):
                Whether or not to return the loss only.
            gen_kwargs:
                Additional `generate` specific kwargs.

        Return:
            tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
            labels (each being optional).
        """

        if not self.args.predict_with_generate or prediction_loss_only:
            return super().prediction_step(
                model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
            )

        has_labels = "labels" in inputs
        inputs = self._prepare_inputs(inputs)

        # Priority (handled in generate):
        # non-`None` gen_kwargs > model.generation_config > default GenerationConfig()
        if len(gen_kwargs) == 0 and hasattr(self, "_gen_kwargs"):
            gen_kwargs = self._gen_kwargs.copy()
        if "num_beams" in gen_kwargs and gen_kwargs["num_beams"] is None:
            gen_kwargs.pop("num_beams")
        if "max_length" in gen_kwargs and gen_kwargs["max_length"] is None:
            gen_kwargs.pop("max_length")

        default_synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self.model)
        gen_kwargs["synced_gpus"] = gen_kwargs.get("synced_gpus", default_synced_gpus)

        generation_inputs = inputs.copy()


        # If the `decoder_input_ids` was created from `labels`, evict the former, so that the model can freely generate
        # (otherwise, it would continue generating from the padded `decoder_input_ids`)
        if (
                "labels" in generation_inputs
                and "decoder_input_ids" in generation_inputs
                and generation_inputs["labels"].shape == generation_inputs["decoder_input_ids"].shape
        ):
            generation_inputs = {
                k: v for k, v in inputs.items() if k not in ("decoder_input_ids", "decoder_attention_mask")
            }

        summon_full_params_context = (
            FullyShardedDataParallel.summon_full_params(self.model)
            if isinstance(self.model, FullyShardedDataParallel)
            else contextlib.nullcontext()
        )

        with summon_full_params_context:
            # generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs)
            # Beginning of fix
            batch_prompt_lens = [get_prompt_length_from_labels(labels) for labels in inputs["labels"]]
    
            all_generated_tokens = []
            for i in range(len(batch_prompt_lens)):
                prompt_len = batch_prompt_lens[i]
                gen_input_ids = inputs["input_ids"][i, :prompt_len].unsqueeze(0)
                gen_attention_mask = inputs["attention_mask"][i, :prompt_len].unsqueeze(0)
                generated_tokens = self.model.generate(
                    input_ids=gen_input_ids,
                    attention_mask=gen_attention_mask,
                    **gen_kwargs
                )
                all_generated_tokens.append(generated_tokens)
        generated_tokens = torch.nn.utils.rnn.pad_sequence(
            [x.squeeze(0) for x in all_generated_tokens],
            batch_first=True,
            padding_value=self.processing_class.pad_token_id
        )
        # End of Fix

        # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
        # TODO: remove this hack when the legacy code that initializes generation_config from a model config is
        # removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183
        if self.model.generation_config._from_model_config:
            self.model.generation_config._from_model_config = False

        # Retrieves GenerationConfig from model.generation_config
        # Update with defaults because earlier the generation config used to be init
        # with default values. Now we init it with `None` and keep defaults for BC
        gen_config = self.model.generation_config
        default_gen_config = gen_config._get_default_generation_params()
        gen_config.update(**default_gen_config, defaults_only=True)
        # in case the batch is shorter than max length, the output should be padded
        if generated_tokens.shape[-1] < gen_config.max_length:
            generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length)
        elif gen_config.max_new_tokens is not None and generated_tokens.shape[-1] < gen_config.max_new_tokens + 1:
            generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_new_tokens + 1)

        with torch.no_grad():
            if has_labels:
                with self.compute_loss_context_manager():
                    outputs = model(**inputs)
                if self.label_smoother is not None:
                    loss = self.label_smoother(outputs, inputs["labels"]).detach().mean()
                else:
                    loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).detach().mean()
            else:
                loss = None

        if self.args.prediction_loss_only:
            return loss, None, None

        if has_labels:
            labels = inputs["labels"]
            if labels.shape[-1] < gen_config.max_length:
                labels = self._pad_tensors_to_max_len(labels, gen_config.max_length)
            elif gen_config.max_new_tokens is not None and labels.shape[-1] < gen_config.max_new_tokens + 1:
                labels = self._pad_tensors_to_max_len(labels, gen_config.max_new_tokens + 1)
        else:
            labels = None

        return loss, generated_tokens, labels

extent analysis

Fix Plan

Step 1: Create a custom Seq2SeqTrainer class

Create a new file seq2seq_trainer_custom.py and add the following code:

from transformers import Seq2SeqTrainer

class Seq2SeqTrainerCustom(Seq2SeqTrainer):
    def prediction_step(
        self,
        model: nn.Module,
        inputs: dict[str, torch.Tensor | Any],
        prediction_loss_only: bool,
        ignore_keys: list[str] | None = None,
        **gen_kwargs,
    ) -> tuple[float | None, torch.Tensor | None, torch.Tensor | None]:
        # ... (rest of the code remains the same)

Step 2: Use the custom Seq2SeqTrainer class

In your training script, replace the original Seq2SeqTrainer with the custom class:

from seq2seq_trainer_custom import Seq2SeqTrainerCustom

trainer = Seq2SeqTrainerCustom(model, args, train_dataset, eval_dataset, compute_metrics=compute_metrics)

Step 3: Update the get_prompt_length_from_labels function

Update the get_prompt_length_from_labels function to handle 2D tensors:

def get_prompt_length_from_labels(labels: torch.Tensor) -> int:
    if len(labels.shape) == 2:
        prompt_len = (labels == -100).sum(dim=1).nonzero().item()
    else:
        prompt_len = (labels == -100).cumsum(dim=0).eq(torch.arange(1, len(labels) + 1, device=labels.device)).sum().item()
    return prompt_len

Step 4: Update the prediction_step method

Update the prediction_step method to use the get_prompt_length_from_labels function:

def prediction_step(
    self,
    model: nn.Module,

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

transformers - ✅(Solved) Fix Support for sequence-level custom metrics with decoder-only models [1 pull requests, 6 comments, 5 participants]