Skip to content

Tomography INR Speed-ups and quantem-cuda API Built-in#246

Draft
cedriclim1 wants to merge 31 commits into
electronmicroscopy:devfrom
cedriclim1:feat/tomography-inr-fixes
Draft

Tomography INR Speed-ups and quantem-cuda API Built-in#246
cedriclim1 wants to merge 31 commits into
electronmicroscopy:devfrom
cedriclim1:feat/tomography-inr-fixes

Conversation

@cedriclim1

Copy link
Copy Markdown
Collaborator

What problem this PR addreseses

List of fixes and performance speed-ups that Claude was able to find. The performance gains require electronmicroscopy/quantem-cuda#2 and a definite blocking PR: #243 to be merged.

1. Non-square tilt series training avialable. I.e.., H≠W tilt stacks.
2. Optimizations on get_volume_tv_loss and interpolate_ms_features_tilted for the KPlanesTilted model.
3. Forward call includes TV loss computation concatenated along with the coordinates for integrating the rays to do one big forward call, rather than multiple predictions from the model.
4. Correct masking for rays that are evaluated outside the grid range. Along the beam direction, some of them were not being correctly masked.
5. Optimizations to transforming the rays, removed torch.tensor wrappings, vectorization of applying the rotation matrices to the rays.
6. SIRT inline alignment fix along with updates the fully-differentiable AD projector. Will be deprecated once FastTomo algorithms are implemented into the codebase.
7. OptimizerMixin fix since the schedulers DO NOT respect the minimum learning rate for multiple optimizers, see tests for examples on how this is rectified. Small fix on to rot_ZXZ rewrapping with torch.tensor again.
8. SIREN models were not respecting winner_initialization and has been rectified.
9. Pose parameters were being reset with .to which is now fixed. KPlanes tighter restrictions on anisotropic resolutions for example [1, 1, 1] would raise an error since it's useless to compute the volume at the same resolution.
10. Specifically just for INR datasets, implementation of the DeviceBatchSampler. The issue was Python would have to build the batches on CPU prior to sending to the GPU for training leading to a workflow that looks like the following:
image
which can be expressed as a single operation:
image
Directly loading the data into the GPU does not help since now the GPU needs to build the batches prior to inference. The above workflow speeds this up substantially.

What should the reviewer(s) do

cedriclim1 and others added 21 commits June 9, 2026 23:00
The per-pixel DataLoader path costs ~43 ms/batch on CPU (8192 Python
__getitem__ calls + collate + H2D copy) while the GPU step itself takes
~20 ms, so single-GPU reconstructions idle the GPU half the time.
DeviceBatchSampler keeps the tilt stack resident on the device and builds
each batch with index arithmetic and two tensor lookups, yielding the same
batch dicts (and drop_last/val-split semantics) as the DataLoader path.
DDP runs keep the DataLoader + DistributedSampler path.

Also disable bf16 autocast in the validation loop to match the training
pass: the so3 pose solve (lu_factor) has no BFloat16 kernel, and a bf16
val loss is not comparable to the fp32 train loss.
Every rank derives the same epoch permutation from seed + epoch (CPU
generator, identical across ranks and reproducible) and takes an
equal-size contiguous shard, with the ragged tail dropped so per-rank
batch counts always match and gradient sync cannot hang. The training
loop's existing sampler.set_epoch call drives reshuffling, exactly like
DistributedSampler; single-process runs auto-advance the epoch instead.

The train/val split now uses a fixed-seed generator: identical across
ranks (no leakage between a rank's train shard and another's val shard)
and stable across save/reload, so resumed runs keep validating on the
same held-out pixels.

Verified with a 2-GPU torchrun run: equal batch counts on both ranks,
finite identical reduced losses, no deadlock.
…raint crash

- TomographyINRDataset.__getitem__ decomposed pixel indices by shape[1] (height)
  instead of shape[2] (width), and __len__ used max(shape)^2 instead of H*W;
  both corrupted the pixel mapping and overcounted the dataset for
  rectangular tilt images.
- ObjectPixelated.apply_soft_constraints built soft_loss as a leaf tensor with
  requires_grad=True and then added the TV loss in-place, which raises
  RuntimeError whenever tv_vol > 0. Build it without requires_grad and add
  out-of-place so the loss backprops through get_tv_loss.
- interpolate_ms_features_tilted: build the per-plane coordinate pairs with
  unbind/stack views instead of allocating an index tensor on device every
  call (a host-to-device copy) and running an (T, 3, B, 3) expand+gather.
  Bitwise-identical output; ~1.1x on the isolated forward (B=4096, T=8,
  3 scales) plus one fewer H2D transfer per model call.
- ObjectTensorDecomp.get_volume_tv_loss: evaluate the base points and the
  three axis-shifted copies in one batched forward (4N points) instead of
  four sequential model calls. Bitwise-identical loss; 3.1x faster forward
  and 3.3x with backward (10k samples, KPlanesTILTED backbone).
…e crash

- ObjectINR.forward masked densities outside [-1, 1] in x and y only, so
  tilted rays whose sample points leave the volume along z still picked up
  extrapolated density and biased the integrated projections at high tilt.
  The mask now bounds all three axes.
- ObjectINR.apply_soft_constraints read ctx.coords.device before the
  coords-is-None assert, raising AttributeError instead of returning a zero
  loss when no soft constraints are active and no coords are passed. Fall
  back to the model device when coords are absent.
- TomographyINRDataset.__getitem__ wrapped each of the three index fields in
  torch.tensor(), allocating three scalar tensors per item on the dataloader
  hot path; default_collate builds the same int64 batch tensors from plain
  ints. 2.9x faster batch loading (12.1 -> 4.2 ms/batch, batch_size=1024,
  60x256x256 stack), which dominated the per-batch GPU compute (~0.4 ms).
- transform_batch_rays applied the three Euler rotations as nine elementwise
  passes over the full (B, S) ray tensors. Compose them into one (B, 3, 3)
  matrix and apply with a single batched matmul: same result to 4e-7 (float32
  op reordering), up to 1.15x at large batches and far fewer kernel launches.
- TomographyConventional._reconstruction_epoch wrote the aligned measurement
  into proj_forward, which radon_torch overwrites immediately after, while the
  error term keeps reading the original tilt stack -- inline_alignment was a
  no-op. The aligned image is now persisted into the tilt stack (the tensor
  the error term reads), and the measurement is transposed to match the
  slice/detector orientation of the forward projection it is correlated with.
- differentiable_rotz/rotx_vectorized raised for more than one angle: the
  per-slice vmap built affine_grid with a (T, 2, 3) matrix against a slice
  batch of 1. A rotation about an axis applies the same 2-D transform to every
  slice along it, so that axis can ride along as grid_sample channels in a
  single call -- which both fixes multi-angle/per-volume batching and removes
  the vmap. Scalar-angle outputs are unchanged (verified to 1e-6).
…g in rot_ZXZ

- Every SchedulerParams dataclass (Plateau, Exponential, Cyclic, Linear,
  CosineAnnealing) mutated its own fields inside params(): the first call
  permanently baked derived values (min_lr, gamma, base/max_lr, total_iters,
  T_max) into the instance, so a config shared between the object and pose
  optimizers -- or reused across reconstruct() calls with a different
  num_iter -- silently kept the first call's values. Derived values are now
  computed locally; explicitly-set fields still take precedence.
- rot_ZXZ re-wrapped all three Euler angles with torch.tensor() whenever any
  one of them was a non-tensor, copying and detaching tensor angles and
  silently cutting gradient flow (plus a UserWarning). Each angle is now
  converted independently, leaving tensors untouched.
…able_tilts setter

- Siren/HSiren created and seeded a torch.Generator for the winner
  initialization but drew the perturbation with torch.randn_like, which
  ignores generators -- so winner_initialization=72 (used by TomographyLiteINR)
  silently had no effect on reproducibility and every seed produced the same
  global-RNG noise. Draw from torch.randn with the seeded generator instead.
- TomographyDatasetBase.learnable_tilts had a setter that wrote a private
  attribute the getter never read, so assignments appeared to succeed while
  silently doing nothing. The value is derived from the tilt series; the
  setter is removed so assignment now raises instead of lying.
…s resolutions

- TomographyPixDataset.to() and TomographyINRDataset.to() rebuilt the z1/z3/
  shift parameters from the initial-value buffers (_z1_angles etc.), which are
  never updated during training -- so any device move after training, including
  Tomography.from_file(...).to(device), silently reset every learned pose to
  zero. Both now go through a shared helper that moves the current parameter
  values once they exist and only uses the buffers on first materialization.
- KPlanes/KPlanesTILTED allocate all three plane grids from one
  (3[, *T], C, res[1], res[0]) tensor, ignoring res[2]; an anisotropic
  resolution silently gave the XZ/YZ planes the wrong grid along z. Constructor
  now raises a clear ValueError instead (isotropic behavior unchanged;
  supporting true anisotropy needs per-plane parameters, which would change
  the checkpoint layout).
…-softloss

Fix INR dataset pixel indexing for non-square tilt images and TV soft-constraint crash
Speed up KPlanesTILTED feature interpolation and batched volume TV loss
…vice

Fix ObjectINR out-of-volume masking along z and soft-constraint device crash
Reduce INR dataloader and ray-transform overhead
Fix no-op SIRT inline alignment and multi-angle rotation operators
…-grads

Make scheduler params() pure and fix gradient-detaching angle wrapping in rot_ZXZ
Fix ignored winner-initialization seed in Siren and remove dead learnable_tilts setter
Fix learned pose parameters being reset by to() and refuse anisotropic KPlanes resolutions
Device-resident batch sampling for INR tomography reconstruction
@cedriclim1 cedriclim1 requested a review from arthurmccray June 11, 2026 01:55
cedriclim1 and others added 8 commits June 10, 2026 19:03
The DeviceBatchSampler parity test compared batch entries against
__getitem__ with .to(dtype), but __getitem__ returns plain ints for the
index keys since the dataloader overhead reduction. Coerce with
torch.as_tensor before comparing.
A tilt stack with more than 95% zero pixels has a zero 95th quantile, so
the normalization divided by zero and produced inf/NaN targets that
poison every parameter on the first backward. Fall back to the absolute
maximum when the quantile is non-positive and raise a clear error for
all-zero stacks; same guard for the pretrain dataset. Adds regression
tests for both paths.
apply_soft_constraints keyed the TV term on tv_vol alone, so setting
tv_plane > 0 with tv_vol = 0 silently dropped the plane TV (the only TV
used in the K-Planes paper, Eq. 3). Gate on either weight, and skip each
TV term whose weight is zero so plane-only runs avoid the volume TV's
three extra forward passes.
…p index tensors in interpolate_ms_features

- interpolate_ms_features_cp_tilted rebuilt the (3T, 1, B, 2) sampling grid
  (reshape/permute plus a zeros_like and stack) inside the per-scale loop even
  though it only depends on the rotated points; build it once before the loop.
- interpolate_ms_features (non-tilted) projected points onto the three planes
  with list-based advanced indexing, which allocates an index tensor on device
  (a host-to-device copy) three times per forward; use unbind/stack views,
  matching the TILTED variant.

Both bitwise-identical to the previous implementations; ~1.3x each on the
isolated forward (B=4096, 3 scales; CP with T=8).
Hoist loop-invariant sampling grid in CP-TILTED interpolation
- TomographyINRDataset now implements __getitems__: torch's DataLoader calls
  it with the whole batch of indices, replacing one Python __getitem__ call
  (plus per-item dict construction and collate stacking) per sample with a few
  tensor ops. setup_dataloader pairs datasets that define __getitems__ with a
  passthrough collate_fn (module-level so spawn workers can pickle it); the
  random_split Subset path delegates the batched form automatically. Batches
  are identical to default collate (values and dtypes). 39.7x faster batch
  fetching (4.16 -> 0.105 ms/batch, batch_size=1024, 60x256x256 stack).
- The per-epoch loss reduction issued three all_reduce calls and three .item()
  host syncs; stack the three scalars into one tensor for a single all_reduce
  and a single .tolist() sync.
…ction

- _build_optimizer now passes fused=True to Adam/AdamW when every parameter is
  on a CUDA device: the whole optimizer step runs in one fused kernel, 2.0x
  faster than the default foreach path on KPlanesTILTED-sized grids (1.24 ->
  0.61 ms/step; ~9% end-to-end per training iteration, where opt.step was 19%
  of the profile). Same update rule; parameters agree to ~2e-6 after 5 steps
  (kernel-order float effects).
- The epoch losses were re-wrapped in a tensor, all_reduced a second time
  (idempotent on already rank-averaged values), and synced to host again right
  after the stacked reduction; the redundant block is removed.
Vectorize INR dataset batch fetching; fused Adam and reduced epoch metrics traffic
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant