Skip to content

HookRegistry._child_registries_cache goes stale after enable_cache/disable_cacheValueError: No context is set #14037

Description

@christopher5106

Describe the bug

HookRegistry._get_child_registries() caches the list of child-module registries (built by walking named_modules()) in self._child_registries_cache on first use and never invalidates it. But register_hook / remove_hook — invoked by enable_cache() / disable_cache() — change which child modules have a _diffusers_hook. When cache_context() is first entered while a model has no stateful block hooks (e.g. the cache is disabled), the cached child-registry list is built without those blocks. A later enable_cache(FirstBlockCacheConfig(...)) adds FirstBlockCache hooks to the blocks, but cache_context()._set_context() still iterates the stale cached list, so the new block-level StateManagers are never given a context. The first cached forward then raises:

ValueError: No context is set. Please set a context before retrieving the state.

This is distinct from #12012 (Wan I2V pipeline not initializing the context, fixed in #12013): here the context is set on the transformer registry, but the staleness in _child_registries_cache prevents it reaching the freshly-registered block hooks. The bug is present on main (register_hook/remove_hook contain no cache invalidation) and was introduced by the _child_registries_cache optimization (docstring references Flux2).

Reproduction

Fully self-contained — no downloads, no GPU. Builds a tiny randomly-initialized FluxTransformer2DModel on CPU (verified with diffusers==0.38.0, torch==2.8.0):

import torch
from diffusers import FirstBlockCacheConfig, FluxTransformer2DModel

torch.manual_seed(0)

heads, head_dim = 2, 16
hidden = heads * head_dim  # 32
model = FluxTransformer2DModel(
    patch_size=1, in_channels=hidden, num_layers=2, num_single_layers=2,
    attention_head_dim=head_dim, num_attention_heads=heads,
    joint_attention_dim=32, pooled_projection_dim=16,
    guidance_embeds=False, axes_dims_rope=(2, 6, 8),  # sums to head_dim (16)
).eval()

B, T_img, T_txt = 1, 8, 4
def make_inputs():
    return dict(
        hidden_states=torch.randn(B, T_img, hidden),
        encoder_hidden_states=torch.randn(B, T_txt, 32),
        pooled_projections=torch.randn(B, 16),
        timestep=torch.tensor([1.0]),
        img_ids=torch.zeros(T_img, 3), txt_ids=torch.zeros(T_txt, 3),
        return_dict=False,
    )

# 1) Enter cache_context BEFORE any cache hooks exist (mirrors a warmup/inference pass that
#    wraps the call in cache_context() while caching is disabled). This populates
#    model._diffusers_hook._child_registries_cache WITHOUT the block registries.
with torch.no_grad(), model.cache_context("cond"):
    model(**make_inputs())

# 2) enable_cache() registers FirstBlockCache hooks on the blocks (new child registries) but
#    does NOT invalidate the parent's _child_registries_cache.
model.enable_cache(FirstBlockCacheConfig(threshold=0.1))

# 3) The stale cache means cache_context()._set_context() never reaches the new block hooks.
with torch.no_grad(), model.cache_context("cond"):
    model(**make_inputs())   # ValueError: No context is set

Output:

[1] warmup pass with no cache hooks: OK
[2] enable_cache: OK
[3] cached pass raised: ValueError: No context is set. Please set a context before retrieving the state.
[workaround] cleared _child_registries_cache
[4] cached pass after workaround: OK  (bug confirmed + workaround fixes it)

The same staleness also bites a plain enable_cache → disable_cache → enable_cache cycle on a long-lived module: the second enable_cache adds hooks whose registries aren't in the cache built during the first cycle.

Root cause

HookRegistry._get_child_registries() memoizes _child_registries_cache and reuses it indefinitely:

def _get_child_registries(self) -> list["HookRegistry"]:
    if not hasattr(self, "_child_registries_cache"):
        self._child_registries_cache = None
    if self._child_registries_cache is not None:
        return self._child_registries_cache          # <-- stale after register/remove_hook
    registries = []
    for module_name, module in unwrap_module(self._module_ref).named_modules():
        ...
    self._child_registries_cache = registries
    return registries

Neither register_hook nor remove_hook resets it, so the cache's invariant ("the set of descendant modules with _diffusers_hook is stable after first build") is violated by the very APIs that mutate that set.

Suggested fix

Invalidate the cache whenever a descendant's _diffusers_hook is created or removed. The simplest correct option is to reset it in register_hook / remove_hook (and propagate to ancestors, or just reset on the registry being mutated and let parents rebuild lazily). E.g.:

def register_hook(self, hook, name):
    ...
    self._child_registries_cache = None      # structural change -> invalidate

def remove_hook(self, name, recurse=True):
    ...
    self._child_registries_cache = None

Since a child registry is created via HookRegistry.check_if_exists_or_initialize(child) and a parent's cache is what goes stale, invalidation likely needs to clear the cache on the whole ancestor chain (or fall back to always walking named_modules() when correctness must hold). Happy to send a PR if you agree on the approach.

Workaround

After enable_cache(...), manually clear the cache before the next cache_context():

reg = getattr(transformer, "_diffusers_hook", None)
if reg is not None and getattr(reg, "_child_registries_cache", None) is not None:
    reg._child_registries_cache = None

System Info

  • diffusers: 0.38.0 (also present on main)
  • transformers: 5.10.2
  • torch: 2.8.0
  • Python: 3.11
  • Platform: Linux, NVIDIA H200

Who can help?

Hook/caching: @a-r-r-o-w @DN6

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions