Skip to content

Wan Animate Pipeline#367

Open
csgoogle wants to merge 1 commit intomainfrom
sagarchapara/wananimate-pipeline
Open

Wan Animate Pipeline#367
csgoogle wants to merge 1 commit intomainfrom
sagarchapara/wananimate-pipeline

Conversation

@csgoogle
Copy link
Copy Markdown
Collaborator

@csgoogle csgoogle commented Mar 28, 2026

Wan Animate Pipeline

This CL publishes add the Wan Animate pipepline.

  • Reused the existing Wan attention operator for face encoder cross attention.
  • Swept Flash Attention block-size configurations to identify the best inference setting.

Links

Performance

  • compile_time: 292.73833787906915
  • generation_time: 157.68515427410603

Configuration

  • cp: 8 (v6e8)
  • cfg: 1.0
  • prev_segments: 5
  • resolution: 1280x720
  • fps: 24
  • generated_frames: 77

@github-actions
Copy link
Copy Markdown

@csgoogle csgoogle marked this pull request as ready for review April 6, 2026 16:33
@csgoogle csgoogle requested a review from entrpn as a code owner April 6, 2026 16:33
@csgoogle csgoogle force-pushed the sagarchapara/wananimate-pipeline branch from 67233e9 to e281524 Compare April 13, 2026 08:49
@csgoogle csgoogle force-pushed the sagarchapara/wananimate-pipeline branch from e281524 to 349d080 Compare April 13, 2026 09:10
sigmas = 1.0 - alphas
sigmas = jnp.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
timesteps = (sigmas * self.config.num_train_timesteps).copy().astype(jnp.int64)
sigmas = jnp.linspace(1.0, 1.0 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1]
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

sigmas = self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas)
eps = 1e-6
sigmas = sigmas.at[0].set(jnp.where(jnp.abs(sigmas[0] - 1.0) < eps, sigmas[0] - eps, sigmas[0]))
timesteps = (sigmas * self.config.num_train_timesteps).copy().astype(jnp.int32)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why move from jnp.int64 to jnp.int32?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

timestamps are sufficient to be in int32, so casted to int32 only

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we move the assets to a public GCS path or use an existing hf dataset link?

Comment on lines +960 to +993
noise_pred = animate_transformer_forward_pass(
graphdef,
state,
rest_of_state,
seg_latents,
reference_latents,
pose_latents,
face_seg,
timestep,
prompt_embeds,
image_embeds,
motion_encode_batch_size=motion_encode_batch_size,
)

if do_classifier_free_guidance:
# Blank face pixels (all -1) for the unconditional pass.
face_seg_uncond = face_seg * 0 - 1
noise_uncond = animate_transformer_forward_pass(
graphdef,
state,
rest_of_state,
seg_latents,
reference_latents,
pose_latents,
face_seg_uncond,
timestep,
negative_prompt_embeds,
image_embeds,
motion_encode_batch_size=motion_encode_batch_size,
)
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)

noise_pred = noise_pred.astype(seg_latents.dtype)
seg_latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, seg_latents, return_dict=False)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we batch the cfg and prompt?

Ref: 1 and 2

sigmas = 1.0 - alphas
sigmas = jnp.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
timesteps = (sigmas * self.config.num_train_timesteps).copy().astype(jnp.int64)
sigmas = jnp.linspace(1.0, 1.0 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

f"{_frame_summary('mask', mask_video)}"
)

animate_settings = _get_animate_inference_settings(config)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add lora support?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there is a need for a separate generate script? Can we add this to existing generate_wan.py file?

@Perseus14
Copy link
Copy Markdown
Collaborator

Please resolve conflicts and enable support for diagnostics and profiling as in this PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants