Describe the bug
HookRegistry._get_child_registries() caches the list of child-module registries (built by walking named_modules()) in self._child_registries_cache on first use and never invalidates it. But register_hook / remove_hook — invoked by enable_cache() / disable_cache() — change which child modules have a _diffusers_hook. When cache_context() is first entered while a model has no stateful block hooks (e.g. the cache is disabled), the cached child-registry list is built without those blocks. A later enable_cache(FirstBlockCacheConfig(...)) adds FirstBlockCache hooks to the blocks, but cache_context()._set_context() still iterates the stale cached list, so the new block-level StateManagers are never given a context. The first cached forward then raises:
ValueError: No context is set. Please set a context before retrieving the state.
This is distinct from #12012 (Wan I2V pipeline not initializing the context, fixed in #12013): here the context is set on the transformer registry, but the staleness in _child_registries_cache prevents it reaching the freshly-registered block hooks. The bug is present on main (register_hook/remove_hook contain no cache invalidation) and was introduced by the _child_registries_cache optimization (docstring references Flux2).
Reproduction
Fully self-contained — no downloads, no GPU. Builds a tiny randomly-initialized FluxTransformer2DModel on CPU (verified with diffusers==0.38.0, torch==2.8.0):
import torch
from diffusers import FirstBlockCacheConfig, FluxTransformer2DModel
torch.manual_seed(0)
heads, head_dim = 2, 16
hidden = heads * head_dim # 32
model = FluxTransformer2DModel(
patch_size=1, in_channels=hidden, num_layers=2, num_single_layers=2,
attention_head_dim=head_dim, num_attention_heads=heads,
joint_attention_dim=32, pooled_projection_dim=16,
guidance_embeds=False, axes_dims_rope=(2, 6, 8), # sums to head_dim (16)
).eval()
B, T_img, T_txt = 1, 8, 4
def make_inputs():
return dict(
hidden_states=torch.randn(B, T_img, hidden),
encoder_hidden_states=torch.randn(B, T_txt, 32),
pooled_projections=torch.randn(B, 16),
timestep=torch.tensor([1.0]),
img_ids=torch.zeros(T_img, 3), txt_ids=torch.zeros(T_txt, 3),
return_dict=False,
)
# 1) Enter cache_context BEFORE any cache hooks exist (mirrors a warmup/inference pass that
# wraps the call in cache_context() while caching is disabled). This populates
# model._diffusers_hook._child_registries_cache WITHOUT the block registries.
with torch.no_grad(), model.cache_context("cond"):
model(**make_inputs())
# 2) enable_cache() registers FirstBlockCache hooks on the blocks (new child registries) but
# does NOT invalidate the parent's _child_registries_cache.
model.enable_cache(FirstBlockCacheConfig(threshold=0.1))
# 3) The stale cache means cache_context()._set_context() never reaches the new block hooks.
with torch.no_grad(), model.cache_context("cond"):
model(**make_inputs()) # ValueError: No context is set
Output:
[1] warmup pass with no cache hooks: OK
[2] enable_cache: OK
[3] cached pass raised: ValueError: No context is set. Please set a context before retrieving the state.
[workaround] cleared _child_registries_cache
[4] cached pass after workaround: OK (bug confirmed + workaround fixes it)
The same staleness also bites a plain enable_cache → disable_cache → enable_cache cycle on a long-lived module: the second enable_cache adds hooks whose registries aren't in the cache built during the first cycle.
Root cause
HookRegistry._get_child_registries() memoizes _child_registries_cache and reuses it indefinitely:
def _get_child_registries(self) -> list["HookRegistry"]:
if not hasattr(self, "_child_registries_cache"):
self._child_registries_cache = None
if self._child_registries_cache is not None:
return self._child_registries_cache # <-- stale after register/remove_hook
registries = []
for module_name, module in unwrap_module(self._module_ref).named_modules():
...
self._child_registries_cache = registries
return registries
Neither register_hook nor remove_hook resets it, so the cache's invariant ("the set of descendant modules with _diffusers_hook is stable after first build") is violated by the very APIs that mutate that set.
Suggested fix
Invalidate the cache whenever a descendant's _diffusers_hook is created or removed. The simplest correct option is to reset it in register_hook / remove_hook (and propagate to ancestors, or just reset on the registry being mutated and let parents rebuild lazily). E.g.:
def register_hook(self, hook, name):
...
self._child_registries_cache = None # structural change -> invalidate
def remove_hook(self, name, recurse=True):
...
self._child_registries_cache = None
Since a child registry is created via HookRegistry.check_if_exists_or_initialize(child) and a parent's cache is what goes stale, invalidation likely needs to clear the cache on the whole ancestor chain (or fall back to always walking named_modules() when correctness must hold). Happy to send a PR if you agree on the approach.
Workaround
After enable_cache(...), manually clear the cache before the next cache_context():
reg = getattr(transformer, "_diffusers_hook", None)
if reg is not None and getattr(reg, "_child_registries_cache", None) is not None:
reg._child_registries_cache = None
System Info
- diffusers: 0.38.0 (also present on
main)
- transformers: 5.10.2
- torch: 2.8.0
- Python: 3.11
- Platform: Linux, NVIDIA H200
Who can help?
Hook/caching: @a-r-r-o-w @DN6
Describe the bug
HookRegistry._get_child_registries()caches the list of child-module registries (built by walkingnamed_modules()) inself._child_registries_cacheon first use and never invalidates it. Butregister_hook/remove_hook— invoked byenable_cache()/disable_cache()— change which child modules have a_diffusers_hook. Whencache_context()is first entered while a model has no stateful block hooks (e.g. the cache is disabled), the cached child-registry list is built without those blocks. A laterenable_cache(FirstBlockCacheConfig(...))addsFirstBlockCachehooks to the blocks, butcache_context()._set_context()still iterates the stale cached list, so the new block-levelStateManagers are never given a context. The first cached forward then raises:This is distinct from #12012 (Wan I2V pipeline not initializing the context, fixed in #12013): here the context is set on the transformer registry, but the staleness in
_child_registries_cacheprevents it reaching the freshly-registered block hooks. The bug is present onmain(register_hook/remove_hookcontain no cache invalidation) and was introduced by the_child_registries_cacheoptimization (docstring references Flux2).Reproduction
Fully self-contained — no downloads, no GPU. Builds a tiny randomly-initialized
FluxTransformer2DModelon CPU (verified withdiffusers==0.38.0,torch==2.8.0):Output:
The same staleness also bites a plain
enable_cache → disable_cache → enable_cachecycle on a long-lived module: the secondenable_cacheadds hooks whose registries aren't in the cache built during the first cycle.Root cause
HookRegistry._get_child_registries()memoizes_child_registries_cacheand reuses it indefinitely:Neither
register_hooknorremove_hookresets it, so the cache's invariant ("the set of descendant modules with_diffusers_hookis stable after first build") is violated by the very APIs that mutate that set.Suggested fix
Invalidate the cache whenever a descendant's
_diffusers_hookis created or removed. The simplest correct option is to reset it inregister_hook/remove_hook(and propagate to ancestors, or just reset on the registry being mutated and let parents rebuild lazily). E.g.:Since a child registry is created via
HookRegistry.check_if_exists_or_initialize(child)and a parent's cache is what goes stale, invalidation likely needs to clear the cache on the whole ancestor chain (or fall back to always walkingnamed_modules()when correctness must hold). Happy to send a PR if you agree on the approach.Workaround
After
enable_cache(...), manually clear the cache before the nextcache_context():System Info
main)Who can help?
Hook/caching: @a-r-r-o-w @DN6