Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
5c539ba
multiprocessing working for ptycho multi gpu single node, linter erro…
arthurmccray May 19, 2026
d047283
Merge branch 'electronmicroscopy:dev' into multi_gpu_ptycho
arthurmccray May 19, 2026
0e7396a
getting working for DGP
arthurmccray May 21, 2026
f5c0c73
cleaning up multi GPU DGP fixing bugs in lr and opt persistence
arthurmccray May 22, 2026
34a09ce
consistent devices
arthurmccray May 22, 2026
f1417fe
fixing linter errors
arthurmccray May 22, 2026
71840e0
converting ptycho to dataloader while maintaining jupyter notebook mu…
arthurmccray May 23, 2026
caa5c39
move _build_dataloaders to ptychography_base
arthurmccray May 23, 2026
2ef8f9a
cleaning up ptycho_opt
arthurmccray May 26, 2026
107ccd3
adding tests
arthurmccray May 26, 2026
55d7d40
Merge branch 'diffractive_imaging' into multi_gpu_ptycho
arthurmccray May 26, 2026
07d739b
adding iterative ptycho constraint params
arthurmccray May 27, 2026
e5207c8
removing additional constraint flags from reconstruct
arthurmccray May 27, 2026
10d3538
improving constraint params docstrings, parse_dict
arthurmccray May 27, 2026
e3d15ae
changing pure_phase to be unwrapped real values
arthurmccray May 28, 2026
85dfa3f
bugfix of initializing to first devices even if not specified
arthurmccray May 28, 2026
74f52ea
adding TODO for amp/phase tv weight splitting
arthurmccray May 28, 2026
b7d084b
moving hard constraints outside of computational graph
arthurmccray May 29, 2026
0b15305
Merge branch 'diffractive_imaging' into multi_gpu_ptycho
arthurmccray Jun 1, 2026
73c9aee
Merge branch 'multi_gpu_ptycho' into ptycho_constraints
arthurmccray Jun 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 142 additions & 10 deletions src/quantem/core/ml/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,22 @@
import torch
from numpy.typing import NDArray

from quantem.core import config


@dataclass
class BaseContext(ABC):
"""
Constraints should contain a context object that contains all necessary data for the constraints to be applied.
Context object bundling the data a constraint needs to be applied.

Tomography's ``ReconstructionContext`` subclasses this and is passed to its
``apply_soft_constraints(ctx)`` overrides. Ptychography models instead pass
their tensors positionally, so the base ``apply_soft_constraints`` signature
stays ``*args, **kwargs`` to accommodate both domains.
"""

pass

T_ctx = TypeVar("T_ctx", bound=BaseContext)

@dataclass(slots=False)
class Constraints(ABC):
Expand Down Expand Up @@ -56,54 +63,179 @@ def __str__(self) -> str:
)


class BaseConstraints(ABC, Generic[T_ctx]):
def parse_constraint_dict(
namespace: type,
d: dict,
*,
kind: str = "constraint",
) -> Constraints:
"""Dispatch a config dict to one of ``namespace``'s nested ``Constraints`` variants.

``namespace`` is a class with one or more nested ``@dataclass``\\ -decorated
``Constraints`` subclasses. The dict must contain a ``"name"`` or ``"type"`` key
whose value (case-insensitive) matches one variant's ``_name`` field; the
remaining keys are forwarded as constructor kwargs to that variant.

``kind`` is a short human-readable label ("object", "probe", "dataset", ...)
used only in error messages.
"""
d = dict(d)
name = d.pop("name", None) or d.pop("type", None)
if name is None:
raise ValueError(f"Must provide either 'name' or 'type' key for {kind} constraints")
if isinstance(name, type):
name = name.__name__.lower()
elif isinstance(name, str):
name = name.lower()
else:
raise ValueError(f"Unknown {kind} constraint type: {name!r}")

variants: dict[str, type[Constraints]] = {}
for attr in vars(namespace).values():
if isinstance(attr, type) and issubclass(attr, Constraints) and attr is not Constraints:
variant_name = getattr(attr, "_name", None)
if isinstance(variant_name, str):
variants[variant_name.lower()] = attr

if name not in variants:
raise ValueError(
f"Unknown {kind} constraint type: {name!r}; expected one of {sorted(variants)}"
)
return variants[name](**d)


C = TypeVar("C", bound=Constraints)


class BaseConstraints(ABC, Generic[C]):
"""
Base class for constraints.

Generic over a concrete ``Constraints`` subclass so that subclasses (and the
type checker) can see the specific fields available on ``self.constraints``.
Subclasses parameterize like ``BaseConstraints[MyConstraintsType]`` and set
``DEFAULT_CONSTRAINTS`` to an instance of that type.
"""

# Default constraints are the dataclasses themselves.
DEFAULT_CONSTRAINTS = Constraints()
DEFAULT_CONSTRAINTS: C
_constraints: C

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._soft_constraint_losses = []
self._soft_constraint_loss: dict[str, torch.Tensor | float] = {}
self._iter_constraint_losses: dict[str, float] = {}
self.constraints = self.DEFAULT_CONSTRAINTS.copy()

@property
def soft_constraint_losses(self) -> NDArray[np.float32]:
return np.array(self._soft_constraint_losses, dtype=np.float32)

@property
def constraints(self) -> Constraints:
def soft_constraint_loss(self) -> dict[str, torch.Tensor | float]:
return self._soft_constraint_loss

@property
def constraints(self) -> C:
"""
Constraints for the model.
"""
return self._constraints

@constraints.setter
def constraints(self, constraints: Constraints | dict[str, Any]):
def constraints(self, constraints: C | dict[str, Any]):
"""
Setter for constraints class, can be a Constraints instance or a dictionary.
Dict keys are validated against the active Constraints dataclass's allowed_keys.
"""
if isinstance(constraints, Constraints):
self._constraints = constraints
elif isinstance(constraints, dict):
allowed = self._constraints.allowed_keys
for key, value in constraints.items():
if key not in allowed:
raise KeyError(
f"Invalid constraint key '{key}' for {type(self._constraints).__name__}, "
f"allowed keys are {allowed}"
)
setattr(self._constraints, key, value)
else:
raise ValueError(f"Invalid constraints type: {type(constraints)}")

# --- Required methods tha tneeds to implemented in subclasses ---
def add_constraint(self, key: str, value: Any) -> None:
"""
Set a single constraint field by name, with validation against allowed_keys.
"""
allowed = self._constraints.allowed_keys
if key not in allowed:
raise KeyError(
f"Invalid constraint key '{key}' for {type(self._constraints).__name__}, "
f"allowed keys are {allowed}"
)
setattr(self._constraints, key, value)

# --- helpers for consistent loss logging ---
def _get_zero_loss_tensor(self) -> torch.Tensor:
"""Helper method to create a zero loss tensor with proper device and dtype."""
device = getattr(self, "device", "cpu")
return torch.tensor(0, device=device, dtype=getattr(torch, config.get("dtype_real")))

def reset_soft_constraint_losses(self) -> None:
self._soft_constraint_loss = {}

def add_soft_constraint_loss(self, name: str, value: torch.Tensor | float) -> None:
"""Record a single soft-constraint loss for logging without holding the graph."""
if isinstance(value, torch.Tensor):
val = value.detach()
if val.ndim != 0:
val = val.mean()
self._soft_constraint_loss[name] = val
else:
self._soft_constraint_loss[name] = float(value)

def accumulate_constraint_losses(
self, batch_constraint_losses: dict[str, torch.Tensor | float] | None = None
) -> None:
"""Accumulate constraint losses across batches."""
if batch_constraint_losses is None:
batch_constraint_losses = self.soft_constraint_loss

for loss_name, loss_value in batch_constraint_losses.items():
if isinstance(loss_value, torch.Tensor):
try:
v = loss_value.item()
except Exception:
v = loss_value.detach().mean().item()
else:
v = float(loss_value)
self._iter_constraint_losses[loss_name] = (
self._iter_constraint_losses.get(loss_name, 0.0) + v
)

def get_iter_constraint_losses(self) -> dict[str, float]:
return self._iter_constraint_losses

def reset_iter_constraint_losses(self) -> None:
self._iter_constraint_losses = {}

# --- Required methods that need to be implemented in subclasses ---
@abstractmethod
def apply_hard_constraints(self, pred: torch.Tensor) -> torch.Tensor:
def apply_hard_constraints(self, *args, **kwargs) -> torch.Tensor | None:
"""
Apply hard constraints to the model.

May return a projected tensor (most models) or ``None`` when the
implementation mutates state in place (e.g. ``DatasetConstraints``).
"""
raise NotImplementedError

@abstractmethod
def apply_soft_constraints(self, ctx: T_ctx) -> torch.Tensor:
def apply_soft_constraints(self, *args, **kwargs) -> torch.Tensor:
"""
Apply soft constraints to the model.

Signature is intentionally permissive: ptychography models override with
positional tensors (e.g. ``(obj, mask)``), while tomography models
override with a ``ReconstructionContext`` (``(ctx)``).
"""
raise NotImplementedError
5 changes: 2 additions & 3 deletions src/quantem/core/ml/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, DistributedSampler, random_split

from quantem.core.ml.dist_utils import worker_init_fn
from quantem.tomography.dataset_models import DatasetModelType


def worker_init_fn(worker_id):
os.environ["CUDA_VISIBLE_DEVICES"] = ""
__all__ = ["DDPMixin", "worker_init_fn"]


class DDPMixin:
Expand Down
102 changes: 102 additions & 0 deletions src/quantem/core/ml/dist_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""
Standalone distributed training utilities for ptychography.

These are kept separate from ddp.py (which imports tomography types) so they
can be used by diffractive_imaging without circular imports.
"""

from __future__ import annotations

import os
from typing import Any

import torch
import torch.distributed as dist


def is_distributed_launch() -> bool:
"""True when launched via torchrun / torch.distributed.launch (RANK env var is set)."""
return "RANK" in os.environ


def init_process_group(
rank: int,
world_size: int,
backend: str = "nccl",
master_addr: str = "127.0.0.1",
master_port: str = "29500",
local_device: int | None = None,
) -> None:
"""Initialize the distributed process group from within an mp.spawn worker.

``local_device`` is the physical CUDA device index this rank should bind to
(e.g. with ``GPU_IDS=[2, 3]``, rank 0 should get ``local_device=2``).
NCCL allocates communicator buffers on the *current* CUDA device at
``init_process_group`` time, so the device must be set *before* that call
or the buffers will land on whichever device was current — typically
``cuda:0``. Falling back to ``rank`` matches PyTorch's
``LOCAL_RANK == device_index`` convention used by ``torchrun`` when each
process maps to a contiguous device starting at 0.
"""
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port
if backend == "nccl":
device_index = local_device if local_device is not None else rank
torch.cuda.set_device(device_index)
dist.init_process_group(
backend=backend,
rank=rank,
world_size=world_size,
)


def get_rank() -> int:
"""Return the current process rank (0 if not in a distributed context)."""
if dist.is_available() and dist.is_initialized():
return dist.get_rank()
return 0


def get_world_size() -> int:
"""Return the world size (1 if not in a distributed context)."""
if dist.is_available() and dist.is_initialized():
return dist.get_world_size()
return 1


def all_reduce_params(*params: torch.Tensor, op: Any = dist.ReduceOp.AVG) -> None:
"""Average the .grad tensors of the given parameters across all ranks in-place."""
for p in params:
if p.grad is not None:
_ = dist.all_reduce(p.grad, op=op)


def broadcast_params(*params: torch.Tensor, src: int = 0) -> None:
"""Broadcast .data of each parameter from rank src to all other ranks."""
for p in params:
_ = dist.broadcast(p.data, src=src)


def worker_init_fn(worker_id: int) -> None:
"""Hide CUDA from DataLoader workers so they only touch CPU-resident tensors."""
os.environ["CUDA_VISIBLE_DEVICES"] = ""


def spawn_distributed_workers(
worker_fn, devices: list[int], *worker_args, start_method: str = "forkserver"
) -> None:
"""Launch one worker per device via torch.multiprocessing.start_processes.

worker_fn must be a module-level callable with signature
(rank, world_size, *worker_args) — matches the mp.start_processes contract,
which passes rank as the first arg automatically.
"""
import torch.multiprocessing as mp

mp.start_processes( # type: ignore
worker_fn,
args=(len(devices), *worker_args),
nprocs=len(devices),
join=True,
start_method=start_method,
)
8 changes: 7 additions & 1 deletion src/quantem/diffractive_imaging/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
from quantem.diffractive_imaging.dataset_models import (
PtychoDatasetConstraintParams as PtychoDatasetConstraintParams,
PtychoDatasetConstraintsType as PtychoDatasetConstraintsType,
PtychographyDatasetRaster as PtychographyDatasetRaster,
)
from quantem.diffractive_imaging.detector_models import DetectorPixelated as DetectorPixelated
from quantem.diffractive_imaging.object_models import (
ObjectDIP as ObjectDIP,
ObjectPixelated as ObjectPixelated,
PtychoObjConstraintParams as PtychoObjConstraintParams,
PtychoObjConstraintsType as PtychoObjConstraintsType,
)
from quantem.diffractive_imaging.probe_models import (
ProbeDIP as ProbeDIP,
ProbePixelated as ProbePixelated,
ProbeParametric as ProbeParametric,
ProbePixelated as ProbePixelated,
PtychoProbeConstraintParams as PtychoProbeConstraintParams,
PtychoProbeConstraintsType as PtychoProbeConstraintsType,
)
from quantem.diffractive_imaging.ptychography import Ptychography as Ptychography
from quantem.diffractive_imaging.ptychography_lite import (
Expand Down
Loading