Tomography INR Speed-ups and quantem-cuda API Built-in#246
Draft
cedriclim1 wants to merge 31 commits into
Draft
Conversation
The per-pixel DataLoader path costs ~43 ms/batch on CPU (8192 Python __getitem__ calls + collate + H2D copy) while the GPU step itself takes ~20 ms, so single-GPU reconstructions idle the GPU half the time. DeviceBatchSampler keeps the tilt stack resident on the device and builds each batch with index arithmetic and two tensor lookups, yielding the same batch dicts (and drop_last/val-split semantics) as the DataLoader path. DDP runs keep the DataLoader + DistributedSampler path. Also disable bf16 autocast in the validation loop to match the training pass: the so3 pose solve (lu_factor) has no BFloat16 kernel, and a bf16 val loss is not comparable to the fp32 train loss.
Every rank derives the same epoch permutation from seed + epoch (CPU generator, identical across ranks and reproducible) and takes an equal-size contiguous shard, with the ragged tail dropped so per-rank batch counts always match and gradient sync cannot hang. The training loop's existing sampler.set_epoch call drives reshuffling, exactly like DistributedSampler; single-process runs auto-advance the epoch instead. The train/val split now uses a fixed-seed generator: identical across ranks (no leakage between a rank's train shard and another's val shard) and stable across save/reload, so resumed runs keep validating on the same held-out pixels. Verified with a 2-GPU torchrun run: equal batch counts on both ranks, finite identical reduced losses, no deadlock.
…raint crash - TomographyINRDataset.__getitem__ decomposed pixel indices by shape[1] (height) instead of shape[2] (width), and __len__ used max(shape)^2 instead of H*W; both corrupted the pixel mapping and overcounted the dataset for rectangular tilt images. - ObjectPixelated.apply_soft_constraints built soft_loss as a leaf tensor with requires_grad=True and then added the TV loss in-place, which raises RuntimeError whenever tv_vol > 0. Build it without requires_grad and add out-of-place so the loss backprops through get_tv_loss.
- interpolate_ms_features_tilted: build the per-plane coordinate pairs with unbind/stack views instead of allocating an index tensor on device every call (a host-to-device copy) and running an (T, 3, B, 3) expand+gather. Bitwise-identical output; ~1.1x on the isolated forward (B=4096, T=8, 3 scales) plus one fewer H2D transfer per model call. - ObjectTensorDecomp.get_volume_tv_loss: evaluate the base points and the three axis-shifted copies in one batched forward (4N points) instead of four sequential model calls. Bitwise-identical loss; 3.1x faster forward and 3.3x with backward (10k samples, KPlanesTILTED backbone).
…e crash - ObjectINR.forward masked densities outside [-1, 1] in x and y only, so tilted rays whose sample points leave the volume along z still picked up extrapolated density and biased the integrated projections at high tilt. The mask now bounds all three axes. - ObjectINR.apply_soft_constraints read ctx.coords.device before the coords-is-None assert, raising AttributeError instead of returning a zero loss when no soft constraints are active and no coords are passed. Fall back to the model device when coords are absent.
- TomographyINRDataset.__getitem__ wrapped each of the three index fields in torch.tensor(), allocating three scalar tensors per item on the dataloader hot path; default_collate builds the same int64 batch tensors from plain ints. 2.9x faster batch loading (12.1 -> 4.2 ms/batch, batch_size=1024, 60x256x256 stack), which dominated the per-batch GPU compute (~0.4 ms). - transform_batch_rays applied the three Euler rotations as nine elementwise passes over the full (B, S) ray tensors. Compose them into one (B, 3, 3) matrix and apply with a single batched matmul: same result to 4e-7 (float32 op reordering), up to 1.15x at large batches and far fewer kernel launches.
- TomographyConventional._reconstruction_epoch wrote the aligned measurement into proj_forward, which radon_torch overwrites immediately after, while the error term keeps reading the original tilt stack -- inline_alignment was a no-op. The aligned image is now persisted into the tilt stack (the tensor the error term reads), and the measurement is transposed to match the slice/detector orientation of the forward projection it is correlated with. - differentiable_rotz/rotx_vectorized raised for more than one angle: the per-slice vmap built affine_grid with a (T, 2, 3) matrix against a slice batch of 1. A rotation about an axis applies the same 2-D transform to every slice along it, so that axis can ride along as grid_sample channels in a single call -- which both fixes multi-angle/per-volume batching and removes the vmap. Scalar-angle outputs are unchanged (verified to 1e-6).
…g in rot_ZXZ - Every SchedulerParams dataclass (Plateau, Exponential, Cyclic, Linear, CosineAnnealing) mutated its own fields inside params(): the first call permanently baked derived values (min_lr, gamma, base/max_lr, total_iters, T_max) into the instance, so a config shared between the object and pose optimizers -- or reused across reconstruct() calls with a different num_iter -- silently kept the first call's values. Derived values are now computed locally; explicitly-set fields still take precedence. - rot_ZXZ re-wrapped all three Euler angles with torch.tensor() whenever any one of them was a non-tensor, copying and detaching tensor angles and silently cutting gradient flow (plus a UserWarning). Each angle is now converted independently, leaving tensors untouched.
…able_tilts setter - Siren/HSiren created and seeded a torch.Generator for the winner initialization but drew the perturbation with torch.randn_like, which ignores generators -- so winner_initialization=72 (used by TomographyLiteINR) silently had no effect on reproducibility and every seed produced the same global-RNG noise. Draw from torch.randn with the seeded generator instead. - TomographyDatasetBase.learnable_tilts had a setter that wrote a private attribute the getter never read, so assignments appeared to succeed while silently doing nothing. The value is derived from the tilt series; the setter is removed so assignment now raises instead of lying.
…s resolutions - TomographyPixDataset.to() and TomographyINRDataset.to() rebuilt the z1/z3/ shift parameters from the initial-value buffers (_z1_angles etc.), which are never updated during training -- so any device move after training, including Tomography.from_file(...).to(device), silently reset every learned pose to zero. Both now go through a shared helper that moves the current parameter values once they exist and only uses the buffers on first materialization. - KPlanes/KPlanesTILTED allocate all three plane grids from one (3[, *T], C, res[1], res[0]) tensor, ignoring res[2]; an anisotropic resolution silently gave the XZ/YZ planes the wrong grid along z. Constructor now raises a clear ValueError instead (isotropic behavior unchanged; supporting true anisotropy needs per-plane parameters, which would change the checkpoint layout).
…-softloss Fix INR dataset pixel indexing for non-square tilt images and TV soft-constraint crash
Speed up KPlanesTILTED feature interpolation and batched volume TV loss
…vice Fix ObjectINR out-of-volume masking along z and soft-constraint device crash
Reduce INR dataloader and ray-transform overhead
Fix no-op SIRT inline alignment and multi-angle rotation operators
…-grads Make scheduler params() pure and fix gradient-detaching angle wrapping in rot_ZXZ
Fix ignored winner-initialization seed in Siren and remove dead learnable_tilts setter
Fix learned pose parameters being reset by to() and refuse anisotropic KPlanes resolutions
Device-resident batch sampling for INR tomography reconstruction
The DeviceBatchSampler parity test compared batch entries against __getitem__ with .to(dtype), but __getitem__ returns plain ints for the index keys since the dataloader overhead reduction. Coerce with torch.as_tensor before comparing.
A tilt stack with more than 95% zero pixels has a zero 95th quantile, so the normalization divided by zero and produced inf/NaN targets that poison every parameter on the first backward. Fall back to the absolute maximum when the quantile is non-positive and raise a clear error for all-zero stacks; same guard for the pretrain dataset. Adds regression tests for both paths.
apply_soft_constraints keyed the TV term on tv_vol alone, so setting tv_plane > 0 with tv_vol = 0 silently dropped the plane TV (the only TV used in the K-Planes paper, Eq. 3). Gate on either weight, and skip each TV term whose weight is zero so plane-only runs avoid the volume TV's three extra forward passes.
…p index tensors in interpolate_ms_features - interpolate_ms_features_cp_tilted rebuilt the (3T, 1, B, 2) sampling grid (reshape/permute plus a zeros_like and stack) inside the per-scale loop even though it only depends on the rotated points; build it once before the loop. - interpolate_ms_features (non-tilted) projected points onto the three planes with list-based advanced indexing, which allocates an index tensor on device (a host-to-device copy) three times per forward; use unbind/stack views, matching the TILTED variant. Both bitwise-identical to the previous implementations; ~1.3x each on the isolated forward (B=4096, 3 scales; CP with T=8).
Hoist loop-invariant sampling grid in CP-TILTED interpolation
- TomographyINRDataset now implements __getitems__: torch's DataLoader calls it with the whole batch of indices, replacing one Python __getitem__ call (plus per-item dict construction and collate stacking) per sample with a few tensor ops. setup_dataloader pairs datasets that define __getitems__ with a passthrough collate_fn (module-level so spawn workers can pickle it); the random_split Subset path delegates the batched form automatically. Batches are identical to default collate (values and dtypes). 39.7x faster batch fetching (4.16 -> 0.105 ms/batch, batch_size=1024, 60x256x256 stack). - The per-epoch loss reduction issued three all_reduce calls and three .item() host syncs; stack the three scalars into one tensor for a single all_reduce and a single .tolist() sync.
…ction - _build_optimizer now passes fused=True to Adam/AdamW when every parameter is on a CUDA device: the whole optimizer step runs in one fused kernel, 2.0x faster than the default foreach path on KPlanesTILTED-sized grids (1.24 -> 0.61 ms/step; ~9% end-to-end per training iteration, where opt.step was 19% of the profile). Same update rule; parameters agree to ~2e-6 after 5 steps (kernel-order float effects). - The epoch losses were re-wrapped in a tensor, all_reduced a second time (idempotent on already rank-averaged values), and synced to host again right after the stacked reduction; the redundant block is removed.
Vectorize INR dataset batch fetching; fused Adam and reduced epoch metrics traffic
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What problem this PR addreseses
List of fixes and performance speed-ups that Claude was able to find. The performance gains require electronmicroscopy/quantem-cuda#2 and a definite blocking PR: #243 to be merged.
1. Non-square tilt series training avialable. I.e.., H≠W tilt stacks.


2. Optimizations on
get_volume_tv_lossandinterpolate_ms_features_tiltedfor theKPlanesTiltedmodel.3. Forward call includes TV loss computation concatenated along with the coordinates for integrating the rays to do one big forward call, rather than multiple predictions from the model.
4. Correct masking for rays that are evaluated outside the grid range. Along the beam direction, some of them were not being correctly masked.
5. Optimizations to transforming the rays, removed
torch.tensorwrappings, vectorization of applying the rotation matrices to the rays.6. SIRT inline alignment fix along with updates the fully-differentiable AD projector. Will be deprecated once
FastTomoalgorithms are implemented into the codebase.7.
OptimizerMixinfix since the schedulers DO NOT respect the minimum learning rate for multiple optimizers, see tests for examples on how this is rectified. Small fix on torot_ZXZrewrapping withtorch.tensoragain.8. SIREN models were not respecting
winner_initializationand has been rectified.9. Pose parameters were being reset with
.towhich is now fixed.KPlanestighter restrictions on anisotropic resolutions for example[1, 1, 1]would raise an error since it's useless to compute the volume at the same resolution.10. Specifically just for INR datasets, implementation of the
DeviceBatchSampler. The issue was Python would have to build the batches on CPU prior to sending to the GPU for training leading to a workflow that looks like the following:which can be expressed as a single operation:
Directly loading the data into the GPU does not help since now the GPU needs to build the batches prior to inference. The above workflow speeds this up substantially.
What should the reviewer(s) do