pytorch - 💡(How to fix) Fix Improve static type hints for dataloader with known `collate_fn` [1 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#179707Fetched 2026-04-09 07:50:25
View on GitHub
Comments
1
Participants
1
Timeline
2
Reactions
0
Participants
Timeline (top)
commented ×1renamed ×1

Root Cause

In particular, in code bases where multiple models share a single dataset type, or inversely where multiple datasets must be compatible with one model, this intermediate step would enable one to automatically validate a training loop statically, which is much better than the sometimes grueling work of booting a training server, witnessing it fail because the yielded type is wrong, and then having to modify it and try again (especially with the large model size we nowadays trend towards)

Code Example

DataType = TypeVar("DataType", covariant=True)
Collated = TypeVar("Collated", covariant=True)

def init_dataloader(dataset: Dataset[DataType], collate_fn: Callable[[list[DataType]], Collated], **kwargs: Any) -> Iterable[Collated]:
    return DataLoader(dataset, collate_fn=collate_fn, **kwargs)

---

_collate_fn_t = Callable[[list[_T]], _T_co_res]
class DataLoader(Generic[_T_co, _T_co_res]):
    def __init__(self, ..., collate_fn: _collate_fn_t[_T_co, _T_co_res] | None = None, ...):
        ...
    def __iter__(self) -> _BaseDataLoaderIter[_T_co_res]:
        ...

class _BaseDataLoaderIter(Generic[_T_co_res]):
    def __next__(self) -> _T_co_res:
        ....

---

...
class DataLoader(Generic[_T_co, _T_co_res]):
    @overload
    def __init__(self, ..., collate_fn: _collate_fn_t[_T_co, _T_co_res], ...): ...
    @overload
    def __init__(self: "DataLoader[_T_co, Any]", ..., collate_fn: _collate_fn_t[_T_co, Any] | None = None, ...): ...
...
RAW_BUFFERClick to expand / collapse

🚀 The feature, motivation and pitch

Feature Requested

I want to improve the type hints that are provided to developers when they are training their models, by enabling the IDE to know the types of the data that the Dataloader will yield.

In particular, in code bases where multiple models share a single dataset type, or inversely where multiple datasets must be compatible with one model, this intermediate step would enable one to automatically validate a training loop statically, which is much better than the sometimes grueling work of booting a training server, witnessing it fail because the yielded type is wrong, and then having to modify it and try again (especially with the large model size we nowadays trend towards)

While the code states that there is no way to set a default value if the user doesn't pass in a custom collate_fn, this is actually possible with overloads in a .pyi file.

Alternatives

Had Python had a type intersection, this would've been my first choice for my own projects. Alas, that is not the case.

Otherwise, users might decide to wrap the dataloader initialization and hide the fact that they are handling a dataloader, for example:

DataType = TypeVar("DataType", covariant=True)
Collated = TypeVar("Collated", covariant=True)

def init_dataloader(dataset: Dataset[DataType], collate_fn: Callable[[list[DataType]], Collated], **kwargs: Any) -> Iterable[Collated]:
    return DataLoader(dataset, collate_fn=collate_fn, **kwargs)

Which I hope is of a much lesser interest (since we completely loose the information of the dataloader being a dataloader.

Additional context

Implementation-wise, this feature is quite easy to make: it would involve adding a TypeVar (_T_co_res in the examples below), and using it:

_collate_fn_t = Callable[[list[_T]], _T_co_res]
class DataLoader(Generic[_T_co, _T_co_res]):
    def __init__(self, ..., collate_fn: _collate_fn_t[_T_co, _T_co_res] | None = None, ...):
        ...
    def __iter__(self) -> _BaseDataLoaderIter[_T_co_res]:
        ...

class _BaseDataLoaderIter(Generic[_T_co_res]):
    def __next__(self) -> _T_co_res:
        ....

In the dataloader.pyi file:

...
class DataLoader(Generic[_T_co, _T_co_res]):
    @overload
    def __init__(self, ..., collate_fn: _collate_fn_t[_T_co, _T_co_res], ...): ...
    @overload
    def __init__(self: "DataLoader[_T_co, Any]", ..., collate_fn: _collate_fn_t[_T_co, Any] | None = None, ...): ...
...

Note that you could define some more default overloads in the .pyi for common dataset types (tuple[Tensor, Tensor], etc.), but I think this is out of scope

extent analysis

TL;DR

To improve type hints for the Dataloader, add a TypeVar _T_co_res to the DataLoader class and use it to define the return type of the __iter__ method.

Guidance

  • Define a TypeVar _T_co_res to represent the type of data yielded by the Dataloader.
  • Update the DataLoader class to use the _T_co_res TypeVar, allowing for more accurate type hints.
  • Add overloads to the dataloader.pyi file to handle cases where a custom collate_fn is provided or not.
  • Consider adding additional overloads for common dataset types, but this may be out of scope.

Example

class DataLoader(Generic[_T_co, _T_co_res]):
    def __init__(self, ..., collate_fn: _collate_fn_t[_T_co, _T_co_res] | None = None, ...):
        ...
    def __iter__(self) -> _BaseDataLoaderIter[_T_co_res]:
        ...

Notes

The proposed solution relies on using TypeVars and overloads in the .pyi file to improve type hints. This approach may have limitations, such as requiring additional overloads for common dataset types.

Recommendation

Apply the proposed workaround by adding a TypeVar _T_co_res to the DataLoader class and using it to define the return type of the __iter__ method, as this allows for more accurate type hints and improved static validation of training loops.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING