diff --git a/src/quantem/core/ml/constraints.py b/src/quantem/core/ml/constraints.py index 590da4cb..0204de9c 100644 --- a/src/quantem/core/ml/constraints.py +++ b/src/quantem/core/ml/constraints.py @@ -7,15 +7,22 @@ import torch from numpy.typing import NDArray +from quantem.core import config + @dataclass class BaseContext(ABC): """ - Constraints should contain a context object that contains all necessary data for the constraints to be applied. + Context object bundling the data a constraint needs to be applied. + + Tomography's ``ReconstructionContext`` subclasses this and is passed to its + ``apply_soft_constraints(ctx)`` overrides. Ptychography models instead pass + their tensors positionally, so the base ``apply_soft_constraints`` signature + stays ``*args, **kwargs`` to accommodate both domains. """ + pass -T_ctx = TypeVar("T_ctx", bound=BaseContext) @dataclass(slots=False) class Constraints(ABC): @@ -56,17 +63,68 @@ def __str__(self) -> str: ) -class BaseConstraints(ABC, Generic[T_ctx]): +def parse_constraint_dict( + namespace: type, + d: dict, + *, + kind: str = "constraint", +) -> Constraints: + """Dispatch a config dict to one of ``namespace``'s nested ``Constraints`` variants. + + ``namespace`` is a class with one or more nested ``@dataclass``\\ -decorated + ``Constraints`` subclasses. The dict must contain a ``"name"`` or ``"type"`` key + whose value (case-insensitive) matches one variant's ``_name`` field; the + remaining keys are forwarded as constructor kwargs to that variant. + + ``kind`` is a short human-readable label ("object", "probe", "dataset", ...) + used only in error messages. + """ + d = dict(d) + name = d.pop("name", None) or d.pop("type", None) + if name is None: + raise ValueError(f"Must provide either 'name' or 'type' key for {kind} constraints") + if isinstance(name, type): + name = name.__name__.lower() + elif isinstance(name, str): + name = name.lower() + else: + raise ValueError(f"Unknown {kind} constraint type: {name!r}") + + variants: dict[str, type[Constraints]] = {} + for attr in vars(namespace).values(): + if isinstance(attr, type) and issubclass(attr, Constraints) and attr is not Constraints: + variant_name = getattr(attr, "_name", None) + if isinstance(variant_name, str): + variants[variant_name.lower()] = attr + + if name not in variants: + raise ValueError( + f"Unknown {kind} constraint type: {name!r}; expected one of {sorted(variants)}" + ) + return variants[name](**d) + + +C = TypeVar("C", bound=Constraints) + + +class BaseConstraints(ABC, Generic[C]): """ Base class for constraints. + + Generic over a concrete ``Constraints`` subclass so that subclasses (and the + type checker) can see the specific fields available on ``self.constraints``. + Subclasses parameterize like ``BaseConstraints[MyConstraintsType]`` and set + ``DEFAULT_CONSTRAINTS`` to an instance of that type. """ - # Default constraints are the dataclasses themselves. - DEFAULT_CONSTRAINTS = Constraints() + DEFAULT_CONSTRAINTS: C + _constraints: C def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._soft_constraint_losses = [] + self._soft_constraint_loss: dict[str, torch.Tensor | float] = {} + self._iter_constraint_losses: dict[str, float] = {} self.constraints = self.DEFAULT_CONSTRAINTS.copy() @property @@ -74,36 +132,110 @@ def soft_constraint_losses(self) -> NDArray[np.float32]: return np.array(self._soft_constraint_losses, dtype=np.float32) @property - def constraints(self) -> Constraints: + def soft_constraint_loss(self) -> dict[str, torch.Tensor | float]: + return self._soft_constraint_loss + + @property + def constraints(self) -> C: """ Constraints for the model. """ return self._constraints @constraints.setter - def constraints(self, constraints: Constraints | dict[str, Any]): + def constraints(self, constraints: C | dict[str, Any]): """ Setter for constraints class, can be a Constraints instance or a dictionary. + Dict keys are validated against the active Constraints dataclass's allowed_keys. """ if isinstance(constraints, Constraints): self._constraints = constraints elif isinstance(constraints, dict): + allowed = self._constraints.allowed_keys for key, value in constraints.items(): + if key not in allowed: + raise KeyError( + f"Invalid constraint key '{key}' for {type(self._constraints).__name__}, " + f"allowed keys are {allowed}" + ) setattr(self._constraints, key, value) else: raise ValueError(f"Invalid constraints type: {type(constraints)}") - # --- Required methods tha tneeds to implemented in subclasses --- + def add_constraint(self, key: str, value: Any) -> None: + """ + Set a single constraint field by name, with validation against allowed_keys. + """ + allowed = self._constraints.allowed_keys + if key not in allowed: + raise KeyError( + f"Invalid constraint key '{key}' for {type(self._constraints).__name__}, " + f"allowed keys are {allowed}" + ) + setattr(self._constraints, key, value) + + # --- helpers for consistent loss logging --- + def _get_zero_loss_tensor(self) -> torch.Tensor: + """Helper method to create a zero loss tensor with proper device and dtype.""" + device = getattr(self, "device", "cpu") + return torch.tensor(0, device=device, dtype=getattr(torch, config.get("dtype_real"))) + + def reset_soft_constraint_losses(self) -> None: + self._soft_constraint_loss = {} + + def add_soft_constraint_loss(self, name: str, value: torch.Tensor | float) -> None: + """Record a single soft-constraint loss for logging without holding the graph.""" + if isinstance(value, torch.Tensor): + val = value.detach() + if val.ndim != 0: + val = val.mean() + self._soft_constraint_loss[name] = val + else: + self._soft_constraint_loss[name] = float(value) + + def accumulate_constraint_losses( + self, batch_constraint_losses: dict[str, torch.Tensor | float] | None = None + ) -> None: + """Accumulate constraint losses across batches.""" + if batch_constraint_losses is None: + batch_constraint_losses = self.soft_constraint_loss + + for loss_name, loss_value in batch_constraint_losses.items(): + if isinstance(loss_value, torch.Tensor): + try: + v = loss_value.item() + except Exception: + v = loss_value.detach().mean().item() + else: + v = float(loss_value) + self._iter_constraint_losses[loss_name] = ( + self._iter_constraint_losses.get(loss_name, 0.0) + v + ) + + def get_iter_constraint_losses(self) -> dict[str, float]: + return self._iter_constraint_losses + + def reset_iter_constraint_losses(self) -> None: + self._iter_constraint_losses = {} + + # --- Required methods that need to be implemented in subclasses --- @abstractmethod - def apply_hard_constraints(self, pred: torch.Tensor) -> torch.Tensor: + def apply_hard_constraints(self, *args, **kwargs) -> torch.Tensor | None: """ Apply hard constraints to the model. + + May return a projected tensor (most models) or ``None`` when the + implementation mutates state in place (e.g. ``DatasetConstraints``). """ raise NotImplementedError @abstractmethod - def apply_soft_constraints(self, ctx: T_ctx) -> torch.Tensor: + def apply_soft_constraints(self, *args, **kwargs) -> torch.Tensor: """ Apply soft constraints to the model. + + Signature is intentionally permissive: ptychography models override with + positional tensors (e.g. ``(obj, mask)``), while tomography models + override with a ``ReconstructionContext`` (``(ctx)``). """ raise NotImplementedError diff --git a/src/quantem/core/ml/ddp.py b/src/quantem/core/ml/ddp.py index dedd90dd..9cc6e0f2 100644 --- a/src/quantem/core/ml/ddp.py +++ b/src/quantem/core/ml/ddp.py @@ -5,11 +5,10 @@ import torch.nn as nn from torch.utils.data import DataLoader, Dataset, DistributedSampler, random_split +from quantem.core.ml.dist_utils import worker_init_fn from quantem.tomography.dataset_models import DatasetModelType - -def worker_init_fn(worker_id): - os.environ["CUDA_VISIBLE_DEVICES"] = "" +__all__ = ["DDPMixin", "worker_init_fn"] class DDPMixin: diff --git a/src/quantem/core/ml/dist_utils.py b/src/quantem/core/ml/dist_utils.py new file mode 100644 index 00000000..450c4fea --- /dev/null +++ b/src/quantem/core/ml/dist_utils.py @@ -0,0 +1,102 @@ +""" +Standalone distributed training utilities for ptychography. + +These are kept separate from ddp.py (which imports tomography types) so they +can be used by diffractive_imaging without circular imports. +""" + +from __future__ import annotations + +import os +from typing import Any + +import torch +import torch.distributed as dist + + +def is_distributed_launch() -> bool: + """True when launched via torchrun / torch.distributed.launch (RANK env var is set).""" + return "RANK" in os.environ + + +def init_process_group( + rank: int, + world_size: int, + backend: str = "nccl", + master_addr: str = "127.0.0.1", + master_port: str = "29500", + local_device: int | None = None, +) -> None: + """Initialize the distributed process group from within an mp.spawn worker. + + ``local_device`` is the physical CUDA device index this rank should bind to + (e.g. with ``GPU_IDS=[2, 3]``, rank 0 should get ``local_device=2``). + NCCL allocates communicator buffers on the *current* CUDA device at + ``init_process_group`` time, so the device must be set *before* that call + or the buffers will land on whichever device was current — typically + ``cuda:0``. Falling back to ``rank`` matches PyTorch's + ``LOCAL_RANK == device_index`` convention used by ``torchrun`` when each + process maps to a contiguous device starting at 0. + """ + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = master_port + if backend == "nccl": + device_index = local_device if local_device is not None else rank + torch.cuda.set_device(device_index) + dist.init_process_group( + backend=backend, + rank=rank, + world_size=world_size, + ) + + +def get_rank() -> int: + """Return the current process rank (0 if not in a distributed context).""" + if dist.is_available() and dist.is_initialized(): + return dist.get_rank() + return 0 + + +def get_world_size() -> int: + """Return the world size (1 if not in a distributed context).""" + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size() + return 1 + + +def all_reduce_params(*params: torch.Tensor, op: Any = dist.ReduceOp.AVG) -> None: + """Average the .grad tensors of the given parameters across all ranks in-place.""" + for p in params: + if p.grad is not None: + _ = dist.all_reduce(p.grad, op=op) + + +def broadcast_params(*params: torch.Tensor, src: int = 0) -> None: + """Broadcast .data of each parameter from rank src to all other ranks.""" + for p in params: + _ = dist.broadcast(p.data, src=src) + + +def worker_init_fn(worker_id: int) -> None: + """Hide CUDA from DataLoader workers so they only touch CPU-resident tensors.""" + os.environ["CUDA_VISIBLE_DEVICES"] = "" + + +def spawn_distributed_workers( + worker_fn, devices: list[int], *worker_args, start_method: str = "forkserver" +) -> None: + """Launch one worker per device via torch.multiprocessing.start_processes. + + worker_fn must be a module-level callable with signature + (rank, world_size, *worker_args) — matches the mp.start_processes contract, + which passes rank as the first arg automatically. + """ + import torch.multiprocessing as mp + + mp.start_processes( # type: ignore + worker_fn, + args=(len(devices), *worker_args), + nprocs=len(devices), + join=True, + start_method=start_method, + ) diff --git a/src/quantem/diffractive_imaging/__init__.py b/src/quantem/diffractive_imaging/__init__.py index 2c26de60..9db66f11 100644 --- a/src/quantem/diffractive_imaging/__init__.py +++ b/src/quantem/diffractive_imaging/__init__.py @@ -1,15 +1,21 @@ from quantem.diffractive_imaging.dataset_models import ( + PtychoDatasetConstraintParams as PtychoDatasetConstraintParams, + PtychoDatasetConstraintsType as PtychoDatasetConstraintsType, PtychographyDatasetRaster as PtychographyDatasetRaster, ) from quantem.diffractive_imaging.detector_models import DetectorPixelated as DetectorPixelated from quantem.diffractive_imaging.object_models import ( ObjectDIP as ObjectDIP, ObjectPixelated as ObjectPixelated, + PtychoObjConstraintParams as PtychoObjConstraintParams, + PtychoObjConstraintsType as PtychoObjConstraintsType, ) from quantem.diffractive_imaging.probe_models import ( ProbeDIP as ProbeDIP, - ProbePixelated as ProbePixelated, ProbeParametric as ProbeParametric, + ProbePixelated as ProbePixelated, + PtychoProbeConstraintParams as PtychoProbeConstraintParams, + PtychoProbeConstraintsType as PtychoProbeConstraintsType, ) from quantem.diffractive_imaging.ptychography import Ptychography as Ptychography from quantem.diffractive_imaging.ptychography_lite import ( diff --git a/src/quantem/diffractive_imaging/constraints.py b/src/quantem/diffractive_imaging/constraints.py deleted file mode 100644 index e907690e..00000000 --- a/src/quantem/diffractive_imaging/constraints.py +++ /dev/null @@ -1,99 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any - -import torch - -from quantem.core import config - - -class BaseConstraints(ABC): - """Base class for constraint management with common functionality.""" - - # Subclasses should define their own DEFAULT_CONSTRAINTS - DEFAULT_CONSTRAINTS: dict[str, Any] = {} - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._soft_constraint_loss = {} - self._constraints = self.DEFAULT_CONSTRAINTS.copy() - self._iter_constraint_losses = {} - - @property - def constraints(self) -> dict[str, Any]: - return self._constraints - - @constraints.setter - def constraints(self, c: dict[str, Any]): - allowed_keys = self.DEFAULT_CONSTRAINTS.keys() - constraint_type = self.__class__.__name__.lower().replace("constraints", "") - - for key, value in c.items(): - if key not in allowed_keys: - raise KeyError( - f"Invalid {constraint_type} constraint key '{key}', allowed keys are {list(allowed_keys)}" - ) - self._constraints[key] = value - - @property - def soft_constraint_loss(self) -> dict[str, torch.Tensor | float]: - return self._soft_constraint_loss - - def add_constraint(self, key: str, value: Any): - allowed_keys = self.DEFAULT_CONSTRAINTS.keys() - constraint_type = self.__class__.__name__.lower().replace("constraints", "") - - if key not in allowed_keys: - raise KeyError( - f"Invalid {constraint_type} constraint key '{key}', allowed keys are {list(allowed_keys)}" - ) - self._constraints[key] = value - - @abstractmethod - def apply_soft_constraints(self, *args, **kwargs) -> torch.Tensor: - """Apply soft constraints and return total constraint loss.""" - pass - - def _get_zero_loss_tensor(self) -> torch.Tensor: - """Helper method to create a zero loss tensor with proper device and dtype.""" - device = getattr(self, "device", "cpu") - return torch.tensor(0, device=device, dtype=getattr(torch, config.get("dtype_real"))) - - # --- helpers for consistent loss logging --- - def reset_soft_constraint_losses(self) -> None: - self._soft_constraint_loss = {} - - def add_soft_constraint_loss(self, name: str, value: torch.Tensor | float) -> None: - """Record a single soft-constraint loss for logging without holding the graph.""" - if isinstance(value, torch.Tensor): - val = value.detach() - if val.ndim != 0: - val = val.mean() - self._soft_constraint_loss[name] = val - else: - self._soft_constraint_loss[name] = float(value) - - def accumulate_constraint_losses( - self, batch_constraint_losses: dict[str, torch.Tensor | float] | None = None - ) -> None: - """Accumulate constraint losses across batches.""" - if batch_constraint_losses is None: - batch_constraint_losses = self.soft_constraint_loss - - for loss_name, loss_value in batch_constraint_losses.items(): - if isinstance(loss_value, torch.Tensor): - try: - v = loss_value.item() - except Exception: - print("loss value not singular: ", loss_value) # TODO remove - v = loss_value.detach().mean().item() - else: - v = float(loss_value) - self._iter_constraint_losses[loss_name] = ( - self._iter_constraint_losses.get(loss_name, 0.0) + v - ) - - def get_iter_constraint_losses(self) -> dict[str, float]: - return getattr(self, "_iter_constraint_losses", {}) # TODO clean this up - - def reset_iter_constraint_losses(self) -> None: - self._iter_constraint_losses = {} diff --git a/src/quantem/diffractive_imaging/dataset_models.py b/src/quantem/diffractive_imaging/dataset_models.py index 751c9245..daa66952 100644 --- a/src/quantem/diffractive_imaging/dataset_models.py +++ b/src/quantem/diffractive_imaging/dataset_models.py @@ -1,18 +1,20 @@ import warnings from abc import abstractmethod -from dataclasses import replace +from dataclasses import dataclass, replace from pathlib import Path -from typing import Any, Literal, Self +from typing import Any, Literal, Self, cast import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn +import torch.utils.data from quantem.core import config from quantem.core.datastructures.dataset3d import Dataset3d from quantem.core.datastructures.dataset4dstem import Dataset4dstem from quantem.core.io.serialize import AutoSerialize +from quantem.core.ml.constraints import BaseConstraints, Constraints, parse_constraint_dict from quantem.core.ml.optimizer_mixin import OptimizerMixin, OptimizerParams from quantem.core.utils.utils import electron_wavelength_angstrom, tqdmnd from quantem.core.utils.validators import ( @@ -22,7 +24,6 @@ validate_tensor, ) from quantem.core.visualization import show_2d -from quantem.diffractive_imaging.constraints import BaseConstraints from quantem.diffractive_imaging.ptycho_utils import AffineTransform, fit_origin, shift_array """ @@ -30,7 +31,78 @@ """ -class PtychographyDatasetBase(AutoSerialize, OptimizerMixin, torch.nn.Module): +class PtychoDatasetConstraintParams: + """ + Namespace class for ptychography dataset constraint dataclasses. + + Tab-complete on ``PtychoDatasetConstraintParams`` in a notebook to discover the + available variants. Tab-complete inside a variant's constructor to see every + constraint field with its default value. + + Variants + -------- + Raster + Constraints for ``PtychographyDatasetRaster`` (descan TV penalty, descan + zero-out, scan position clipping and centering). + """ + + @dataclass + class Raster(Constraints): + """Constraints for raster-scan ptychography datasets (``PtychographyDatasetRaster``). + + Attributes + ---------- + descan_shifts_constant : bool, default ``False`` + Forces all descan shifts to zero after each update. Useful when you + want to keep the descan optimizer in the parameter group but freeze + its effect. + center_scan_positions : bool, default ``False`` + Shifts all scan positions uniformly so their mean sits at the object + center after each update. Prevents the reconstruction from + translating during long runs with ``lr_scan_positions > 0``. + clip_scan_positions : bool, default ``True`` + Clamps scan positions to lie within ``[0, obj_shape - 1]`` after each + update. On by default to prevent positions from drifting off the + padded object during refinement. + descan_tv_weight : float, default ``0.0`` + Soft penalty. Weight on the total-variation of the descan-shift + sequence (x and y averaged). Encourages smoothly varying descan, but + only contributes when ``learn_descan`` is on and the dataset has a + descan optimizer attached. + """ + + # hard constraints + descan_shifts_constant: bool = False + center_scan_positions: bool = False + clip_scan_positions: bool = True + # soft constraints + descan_tv_weight: float = 0.0 + _name: str = "raster" + + soft_constraint_keys = ["descan_tv_weight"] + hard_constraint_keys = [ + "descan_shifts_constant", + "center_scan_positions", + "clip_scan_positions", + ] + + @classmethod + def parse_dict(cls, d: dict) -> "PtychoDatasetConstraintsType": + """Instantiate the appropriate variant from a config dict. + + The dict must contain a ``'name'`` or ``'type'`` key (case-insensitive), + with value ``'raster'``. All other keys are forwarded as keyword + arguments to the chosen dataclass. + """ + return cast(PtychoDatasetConstraintsType, parse_constraint_dict(cls, d, kind="dataset")) + + +PtychoDatasetConstraintsType = PtychoDatasetConstraintParams.Raster + + +class PtychographyDatasetBase( + AutoSerialize, OptimizerMixin, torch.nn.Module, torch.utils.data.Dataset +): _token = object() _patch_indices: torch.Tensor @@ -64,6 +136,11 @@ def __init__( self.dset = dset self.verbose = verbose + # target_residency controls where loss targets live: + # "device" (default) — targets are kept resident on the compute device (current behavior) + # "cpu" — targets live in CPU RAM and are streamed to the device per-batch, + # enabling datasets larger than a single GPU's VRAM + self._target_residency: Literal["device", "cpu"] = "device" self._preprocessed = False self._preprocessing_params = {} # for serialization and reloading self._com_rotation_rad = 0 # default @@ -87,7 +164,9 @@ def __init__( self._initial_scan_positions_px = torch.zeros_like(self._scan_positions_px) self._initial_descan_shifts = torch.zeros_like(self._descan_shifts) - self.register_buffer("_targets", torch.zeros(self.num_gpts, *self.roi_shape)) + # _targets is a plain attribute (NOT a registered buffer) so that its device can be + # managed explicitly per target_residency; AutoSerialize does not serialize it either way. + self._targets = torch.zeros(self.num_gpts, *self.roi_shape) self.register_buffer( "_patch_indices", torch.zeros(self.num_gpts, *self.roi_shape, dtype=torch.int32) ) @@ -143,6 +222,15 @@ def to(self, *args, **kwargs): """Move all relevant tensors to a different device.""" # Call parent's to() method to handle PyTorch's internal device management super().to(*args, **kwargs) + # _targets is a plain attribute, so nn.Module.to() does not move it; do so explicitly + # unless residency is "cpu" (in which case targets intentionally stay on CPU and are + # streamed to the device per-batch). + if ( + getattr(self, "target_residency", "device") != "cpu" + and getattr(self, "_targets", None) is not None + ): + # After super().to(), self.device reflects the new device (it reads off a Parameter). + self._targets = self._targets.to(self.device) # Reconnect optimizer to parameters on the new device self.reconnect_optimizer_to_parameters() return self @@ -253,22 +341,37 @@ def targets(self) -> torch.Tensor: raise ValueError("dset must be preprocessed before targets can be accessed") return self._targets + @property + def target_residency(self) -> Literal["device", "cpu"]: + """Where the loss targets live: ``"device"`` (resident, fastest) or + ``"cpu"`` (streamed per-batch, enables datasets larger than VRAM).""" + return self._target_residency + + @target_residency.setter + def target_residency(self, value: str) -> None: + if value not in ("device", "cpu"): + raise ValueError(f"target_residency must be 'device' or 'cpu', got {value!r}") + self._target_residency = value + def _set_targets( self, loss_type: Literal[ "l2_amplitude", "l1_amplitude", "l2_intensity", "l1_intensity", "poisson" ], ): + # When residency is "cpu", build targets on CPU so they can be streamed per-batch + # (and read by DataLoader workers); otherwise keep them resident on the compute device. + target_device = "cpu" if self.target_residency == "cpu" else self.device if "amplitude" in loss_type: if self.learn_descan and self.has_optimizer(): - self._targets = self.amplitudes.clone().to(self.device) + self._targets = self.amplitudes.clone().to(target_device) else: - self._targets = self.centered_amplitudes.clone().to(self.device) + self._targets = self.centered_amplitudes.clone().to(target_device) elif "intensity" in loss_type or loss_type == "poisson": if self.learn_descan and self.has_optimizer(): - self._targets = self.intensities.clone().to(self.device) + self._targets = self.intensities.clone().to(target_device) else: - self._targets = self.centered_intensities.clone().to(self.device) + self._targets = self.centered_intensities.clone().to(target_device) else: raise ValueError(f"Unknown loss type {loss_type}") @@ -278,6 +381,21 @@ def patch_indices(self) -> torch.Tensor: # endregion --- buffers --- + # region --- torch.utils.data.Dataset interface --- + def __len__(self) -> int: + return self.num_gpts + + def __getitem__(self, idx: int) -> dict[str, Any]: + """Return one sample for the DataLoader. + + The target is returned as-is on whatever device target_residency dictates (CPU when + residency is "cpu" so DataLoader workers can read it). The integer index is collated + into a LongTensor by the default collate_fn. + """ + return {"index": idx, "target": self._targets[idx]} + + # endregion --- torch.utils.data.Dataset interface --- + # region --- explicit properties (have setters) --- @property def dset(self) -> Dataset3d: @@ -567,7 +685,7 @@ def _set_patch_indices(self, obj_padding_px: np.ndarray | tuple) -> None: patch_indices_list.append(patch_indices_chunk) self._patch_indices = torch.cat(patch_indices_list, dim=0) - self._last_patch_positions_px = self.scan_positions_px.clone() + self._last_patch_positions_px = self.scan_positions_px.detach().clone() def patch_indices_need_update(self) -> bool: """ @@ -584,24 +702,19 @@ def reset(self) -> None: # endregion --- class methods --- -class DatasetConstraints(BaseConstraints, PtychographyDatasetBase): - DEFAULT_CONSTRAINTS = { - "descan_tv_weight": 0.0, - "descan_shifts_constant": False, - "center_scan_positions": False, - "clip_scan_positions": True, - } +class DatasetConstraints( + BaseConstraints[PtychoDatasetConstraintParams.Raster], PtychographyDatasetBase +): + DEFAULT_CONSTRAINTS: PtychoDatasetConstraintParams.Raster = ( + PtychoDatasetConstraintParams.Raster() + ) def apply_soft_constraints(self, descan_shifts: torch.Tensor) -> torch.Tensor: self.reset_soft_constraint_losses() loss = self._get_zero_loss_tensor() - if ( - self.constraints.get("descan_tv_weight", 0) > 0 - and self.learn_descan - and self.has_optimizer() - ): - tv_loss = self.get_descan_tv_loss(descan_shifts, self.constraints["descan_tv_weight"]) + if self.constraints.descan_tv_weight > 0 and self.learn_descan and self.has_optimizer(): + tv_loss = self.get_descan_tv_loss(descan_shifts, self.constraints.descan_tv_weight) loss = loss + tv_loss self.add_soft_constraint_loss("descan_tv_weight", tv_loss) @@ -621,23 +734,17 @@ def apply_descan_constraints( self, descan: torch.Tensor, ) -> torch.Tensor: - if self.constraints["descan_shifts_constant"]: + if self.constraints.descan_shifts_constant: descan = torch.zeros_like(descan) return descan def apply_hard_constraints(self, obj_padding_px: np.ndarray | tuple) -> None: - # could clip positions here if needed positions = self.scan_positions_px obj_shape = torch.tensor(self._obj_shape_full_2d(obj_padding_px), device=positions.device) - if self.constraints.get( - "clip_scan_positions", self.DEFAULT_CONSTRAINTS["clip_scan_positions"] - ): + if self.constraints.clip_scan_positions: positions = torch.clamp(positions, min=torch.zeros_like(obj_shape), max=obj_shape - 1) - if self.constraints.get( - "center_scan_positions", self.DEFAULT_CONSTRAINTS["center_scan_positions"] - ): - # shift all positions uniformly so that the mean position is at the center of the object + if self.constraints.center_scan_positions: positions = positions - positions.mean(dim=0, keepdim=True) positions = positions + obj_shape / 2 @@ -667,12 +774,15 @@ def __init__( self.scan_sampling = dset.sampling[:2] self.scan_units = dset.units[:2] self.gpts = dset.shape[:2] - self.intensities_4d = dset.array.copy() + # TODO remove after Dataset torch migration complete + dset_numpy = dset.array if dset.array is not None else dset.tensor.cpu().numpy() + + self.intensities_4d = dset_numpy.copy() # convert to dataset3d - shp = dset.array.shape + shp = dset_numpy.shape dset3d = Dataset3d.from_array( - array=dset.array.reshape((shp[0] * shp[1], shp[2], shp[3])), + array=dset_numpy.reshape((shp[0] * shp[1], shp[2], shp[3])), name=dset.name, origin=[0, *dset.origin[2:]], sampling=[0, *dset.sampling[2:]], @@ -1014,7 +1124,11 @@ def preprocess( ), modify_in_place=True, ) - self.intensities_4d = self.dset.array.reshape( + # TODO remove after Dataset torch migration complete + dset_numpy = ( + self.dset.array if self.dset.array is not None else self.dset.tensor.cpu().numpy() + ) + self.intensities_4d = dset_numpy.reshape( (*self.gpts, *padded_diffraction_intensities_shape) ) self.detector_mask = torch.nn.functional.pad( diff --git a/src/quantem/diffractive_imaging/object_models.py b/src/quantem/diffractive_imaging/object_models.py index 0c5531de..140ca585 100644 --- a/src/quantem/diffractive_imaging/object_models.py +++ b/src/quantem/diffractive_imaging/object_models.py @@ -1,6 +1,7 @@ import math from abc import abstractmethod from copy import deepcopy +from dataclasses import dataclass from typing import Callable, Literal, Self, Sequence, cast from warnings import warn @@ -13,6 +14,7 @@ from quantem.core import config from quantem.core.io.serialize import AutoSerialize from quantem.core.ml.blocks import reset_weights +from quantem.core.ml.constraints import BaseConstraints, Constraints, parse_constraint_dict from quantem.core.ml.loss_functions import get_loss_module from quantem.core.ml.optimizer_mixin import ( OptimizerMixin, @@ -27,17 +29,153 @@ ) from quantem.core.visualization import show_2d from quantem.core.visualization.custom_normalizations import CustomNormalization -from quantem.diffractive_imaging.constraints import BaseConstraints from quantem.diffractive_imaging.ptycho_utils import sum_patches object_type = Literal["potential", "pure_phase", "complex"] + +class PtychoObjConstraintParams: + """ + Namespace class for ptychography object constraint dataclasses. + + Tab-complete on ``PtychoObjConstraintParams`` in a notebook to discover the + available variants. Tab-complete inside a variant's constructor to see every + constraint field with its default value. + + Variants + -------- + Raster + Constraints for grid-based object representations (``ObjectPixelated`` and + ``ObjectDIP`` share this set today). + INR + Placeholder for the upcoming implicit-neural-representation object. + + Examples + -------- + >>> PtychoObjConstraintParams.Raster(tv_weight_z=5.0, identical_slices=True) + >>> PtychoObjConstraintParams.parse_dict({"name": "raster", "positivity": False}) + """ + + @dataclass + class Raster(Constraints): + """Constraints for grid-based ptychography object models (``ObjectPixelated``, + ``ObjectDIP``). + + Fields are applied each iteration in two flavors: **hard** constraints + project / filter the object after the optimizer step; **soft** constraints + add a penalty term to the training loss. + + Attributes + ---------- + positivity : bool, default ``True`` + Clamps the object to be non-negative after each update. + Only consulted when ``obj_type="potential"``; for ``"complex"`` / + ``"pure_phase"`` the amplitude is clamped to ``[0, 1]`` (or fixed to 1) + regardless of this flag. + fix_potential_baseline : bool, default ``False`` + ``obj_type="potential"`` only. Subtracts an offset from the object so + background regions sit at zero. If an FOV mask is set the offset is + the mean of the background (``mask < 0.5 * mask.max()``); otherwise + it's ``obj.min()``. + fix_potential_baseline_factor : float, default ``1.0`` + Scales the baseline offset. Values ``<1`` relax the anchoring + (subtract less of the background); ``>1`` over-correct. + identical_slices : bool, default ``False`` + Multislice (``num_slices > 1``) only. Replaces every slice with the + mean across slices, forcing an effectively 2D object. + apply_fov_mask : bool, default ``False`` + Multiplies the object by the precomputed FOV mask after each update. + Useful when the scan does not cover the full padded object area. + gaussian_sigma : float | None, default ``None`` + Standard deviation (in pixels) of a 2D Gaussian blur applied to each + slice after each update. Smoothing prior; ``None`` disables. + butterworth_order : int, default ``4`` + Order of the Butterworth filter used by ``q_lowpass`` / ``q_highpass``. + q_lowpass : float | None, default ``None`` + Lowpass cutoff in inverse Angstroms. Fourier components above this + spatial frequency are suppressed via a Butterworth filter. + q_highpass : float | None, default ``None`` + Highpass cutoff in inverse Angstroms. Components below this frequency + are suppressed; typically used to remove a slowly varying background. + tv_weight_z : float, default ``0.0`` + Soft penalty. Weight on the depth-axis total-variation term in the + loss. Multislice (``num_slices > 1``) only. + tv_weight_xy : float, default ``0.0`` + Soft penalty. Weight on the in-plane total-variation term; + encourages piecewise-smooth regions while preserving edges. + surface_zero_weight : float, default ``0.0`` + Soft penalty pulling the first and last slices toward zero. Useful + for thick samples embedded in vacuum. Multislice only and requires + ``num_slices >= 3``. + """ + + # hard constraints + positivity: bool = True + fix_potential_baseline: bool = False + fix_potential_baseline_factor: float = 1.0 + identical_slices: bool = False + apply_fov_mask: bool = False + # filtering (treated as hard, applied post-update) + gaussian_sigma: float | None = None # pixels + butterworth_order: int = 4 + q_lowpass: float | None = None # A^-1 + q_highpass: float | None = None # A^-1 + # soft constraints + tv_weight_z: float = 0.0 + tv_weight_xy: float = 0.0 + surface_zero_weight: float = 0.0 + _name: str = "raster" + + soft_constraint_keys = ["tv_weight_z", "tv_weight_xy", "surface_zero_weight"] + hard_constraint_keys = [ + "positivity", + "fix_potential_baseline", + "fix_potential_baseline_factor", + "identical_slices", + "apply_fov_mask", + "gaussian_sigma", + "butterworth_order", + "q_lowpass", + "q_highpass", + ] + + @dataclass + class INR(Constraints): + """Placeholder for the upcoming ``ObjectINR`` variant. + + INR-specific constraints (e.g. sparsity / TV penalties evaluated at + sampled coordinates) will land here when the model is implemented. + Until then this exists so ``parse_dict`` accepts ``"inr"`` and downstream + code can pattern-match on the variant. + """ + + _name: str = "inr" + + soft_constraint_keys = [] + hard_constraint_keys = [] + + @classmethod + def parse_dict(cls, d: dict) -> "PtychoObjConstraintsType": + """Instantiate the appropriate variant from a config dict. + + The dict must contain a ``'name'`` or ``'type'`` key (case-insensitive), + with value ``'raster'`` or ``'inr'``. All other keys are forwarded as + keyword arguments to the chosen dataclass. + """ + return cast(PtychoObjConstraintsType, parse_constraint_dict(cls, d, kind="object")) + + +PtychoObjConstraintsType = PtychoObjConstraintParams.Raster | PtychoObjConstraintParams.INR + """ -Currently all object models.obj are complex valued for "complex" or "pure_phase" object types, -and real valued for "potential" object types. This could be changed to be always complex valued, -(after applying constraints) as currently the real-valued potential is made complex in get_obj_patches, -which will not be used for implicit NNs, which leads to an inconsistency. Leaving for now as I'm not -sure if this would lead to other issues, so a bit of testing will be needed. +Object representation by obj_type: +- "complex" : _obj is complex (amplitude * exp(1j * phase)) +- "pure_phase" : _obj is a real, unwrapped phase array +- "potential" : _obj is a real potential array + +The forward boundary (`_get_obj_patches`) wraps real `_obj` to `exp(1j * _obj)` for +both pure_phase and potential, so the rest of the forward model never has to +branch on obj_type. """ @@ -89,10 +227,9 @@ def shape_2d(self) -> tuple[int, int]: @property def dtype(self) -> "torch.dtype": - if self.obj_type == "potential": - return getattr(torch, config.get("dtype_real")) - else: + if self.obj_type == "complex": return getattr(torch, config.get("dtype_complex")) + return getattr(torch, config.get("dtype_real")) @property def device(self) -> str: @@ -251,6 +388,15 @@ def _propagate_array( return propagated def _get_obj_patches(self, obj_array, patch_indices): + """Forward boundary: wrap real obj to ``exp(1j * obj)`` and gather patches. + + ``obj_array`` may be complex (``obj_type="complex"``) or real (``"pure_phase"``, + ``"potential"``). Real inputs are wrapped to the complex transmission + function ``exp(1j * obj_array)`` here, so the rest of the forward model + never has to branch on ``obj_type``. ``patch_indices`` is a + ``(num_gpts, Hroi, Wroi)`` int tensor of flattened-index lookups into the + 2D padded object. + """ if not obj_array.is_complex(): # potential or pure_phase DIP -> float obj_array2 = torch.exp(1.0j * obj_array) else: @@ -271,88 +417,103 @@ def backward(self, *args, **kwargs): ) -class ObjectConstraints(BaseConstraints, ObjectBase): - DEFAULT_CONSTRAINTS = { - "positivity": True, - "fix_potential_baseline": False, - "fix_potential_baseline_factor": 1.0, - "identical_slices": False, - "apply_fov_mask": False, - "tv_weight_z": 0, - "tv_weight_xy": 0, - "surface_zero_weight": 0, - "gaussian_sigma": None, # pixels - "butterworth_order": 4, - "q_lowpass": None, # A^-1 - "q_highpass": None, # A^-1 - } +class ObjectConstraints(BaseConstraints[PtychoObjConstraintParams.Raster], ObjectBase): + DEFAULT_CONSTRAINTS: PtychoObjConstraintParams.Raster = PtychoObjConstraintParams.Raster() def apply_hard_constraints( - self, obj: torch.Tensor, mask: torch.Tensor | None = None + self, raw: torch.Tensor, mask: torch.Tensor | None = None ) -> torch.Tensor: - if self.obj_type in ["complex", "pure_phase"]: + """ + Apply hard constraints: range clamping and filtering. All hard constaints are applied in + place with torch.no_grad(). + """ + c = self.constraints + with torch.no_grad(): if self.obj_type == "complex": - amp = torch.clamp(torch.abs(obj), 0.0, 1.0) - else: - amp = 1.0 - phase = obj.angle() - obj.angle().mean() - if mask is not None and self.constraints["apply_fov_mask"]: - obj2 = amp * mask * torch.exp(1.0j * phase * mask) - else: - obj2 = amp * torch.exp(1.0j * phase) - else: # potential - if self.constraints["fix_potential_baseline"]: - if mask is not None: - background = mask < 0.5 * mask.max() - if background.any(): - offset = obj[background].mean() - else: - offset = obj.min() + constrained = self._apply_hard_complex(raw, c) + elif self.obj_type == "pure_phase": + constrained = self._apply_hard_pure_phase(raw, c) + else: # potential + constrained = self._apply_hard_potential(raw, c, mask) + constrained = self._apply_shared_hard(constrained, c, mask) + return raw + (constrained - raw).detach() + + def _apply_hard_complex( + self, obj: torch.Tensor, c: PtychoObjConstraintParams.Raster + ) -> torch.Tensor: + amp = torch.clamp(torch.abs(obj), 0.0, 1.0) + phase = obj.angle() - obj.angle().mean() + return amp * torch.exp(1.0j * phase) + + def _apply_hard_pure_phase( + self, obj: torch.Tensor, c: PtychoObjConstraintParams.Raster + ) -> torch.Tensor: + # phase stored directly as a real tensor; recenter to zero mean + return obj - obj.mean() + + def _apply_hard_potential( + self, + obj: torch.Tensor, + c: PtychoObjConstraintParams.Raster, + mask: torch.Tensor | None, + ) -> torch.Tensor: + if c.fix_potential_baseline: + if mask is not None: + background = mask < 0.5 * mask.max() + if background.any(): + offset = obj[background].mean() else: offset = obj.min() - offset = offset.detach() - offset *= self.constraints["fix_potential_baseline_factor"] else: - offset = 0 + offset = obj.min() + offset = offset.detach() + offset = offset * c.fix_potential_baseline_factor + else: + offset = 0 - if self.constraints.get("positivity", True): - obj2 = torch.clamp(obj - offset, min=0.0) - else: - obj2 = obj - offset + if c.positivity: + return torch.clamp(obj - offset, min=0.0) + return obj - offset - if self.constraints["apply_fov_mask"] and mask is not None: - obj2 *= mask + def _apply_shared_hard( + self, + obj: torch.Tensor, + c: PtychoObjConstraintParams.Raster, + mask: torch.Tensor | None, + ) -> torch.Tensor: + if c.apply_fov_mask and mask is not None: + obj = obj * mask - # want backwards compatibility for gaussian_sigma and q_lowpass/q_highpass, so use get - if self.constraints.get("gaussian_sigma") is not None: - obj2 = self.gaussian_blur_2d(obj2, sigma=self.constraints["gaussian_sigma"]) + if c.gaussian_sigma is not None: + obj = self.gaussian_blur_2d(obj, sigma=c.gaussian_sigma) - if any([self.constraints["q_lowpass"], self.constraints["q_highpass"]]): - obj2 = self.butterworth_constraint( - obj2, - sampling=self.sampling, - ) - if self.num_slices > 1: - if self.constraints["identical_slices"]: - with torch.no_grad(): - obj2[:] = torch.mean(obj2, dim=0, keepdim=True) + if any([c.q_lowpass, c.q_highpass]): + obj = self.butterworth_constraint(obj, sampling=self.sampling) - return obj2 + if self.num_slices > 1 and c.identical_slices: + # In-place mutation is safe because apply_hard_constraints is + # always called under outer torch.no_grad (see its docstring). + obj[:] = torch.mean(obj, dim=0, keepdim=True) + return obj def apply_soft_constraints( self, obj: torch.Tensor, mask: torch.Tensor | None = None ) -> torch.Tensor: + """Sum of the per-iteration soft penalties. + + Returns a scalar tensor that is added to the data-fidelity loss before + ``backward()``. Individual contributions are also recorded via + ``add_soft_constraint_loss`` for logging. + """ # reset recorded losses each call self.reset_soft_constraint_losses() - tv_loss = self.get_tv_loss( - obj, - ) + tv_loss = self.get_tv_loss(obj) self.add_soft_constraint_loss("tv_loss", tv_loss) surface_zero_loss = self.get_surface_zero_loss( obj, - weight=self.constraints["surface_zero_weight"], + weight=self.constraints.surface_zero_weight, ) self.add_soft_constraint_loss("surface_zero_loss", surface_zero_loss) self.accumulate_constraint_losses() @@ -361,42 +522,66 @@ def apply_soft_constraints( def get_tv_loss( self, array: torch.Tensor, weights: None | tuple[float, float] = None ) -> torch.Tensor: + """Total-variation soft penalty on the object. + + ``weights`` is a ``(z_weight, xy_weight)`` tuple. When ``None``, defaults + to ``(self.constraints.tv_weight_z, self.constraints.tv_weight_xy)``. A single + scalar is broadcast to both axes. The z weight is zeroed for + ``num_slices == 1``. + """ loss = self._get_zero_loss_tensor() + w = self._resolve_tv_weights(weights) + if not any(w): + return loss + + if self.obj_type == "complex": + return self._tv_complex(array, w) + # pure_phase and potential are both real tensors; phase wrapping is gone. + return self._calc_tv_loss(array, w) + + def _resolve_tv_weights( + self, weights: None | tuple[float, float] | float | int + ) -> tuple[float, float]: if weights is None: - w = ( - self.constraints["tv_weight_z"], - self.constraints["tv_weight_xy"], + w: tuple[float, float] = ( + self.constraints.tv_weight_z, + self.constraints.tv_weight_xy, ) elif isinstance(weights, (float, int)): - if weights == 0: - return loss - w = (weights, weights) + w = (float(weights), float(weights)) else: if len(weights) != 2: raise ValueError(f"weights must be a tuple of length 2, got {weights}") - w = weights - - if not any(w): - return loss - + w = (float(weights[0]), float(weights[1])) if self.num_slices == 1: - w = (0, w[1]) - - if array.is_complex(): - ph = array.angle() - warn( - "calculating TV loss for phase, need to check phase wrapping. Easiest fix is scalar phase array." - ) - loss = loss + self._calc_tv_loss(ph, w) - amp = array.abs() - if self.obj_type == "complex": - loss = loss + self._calc_tv_loss(amp, w) - else: - loss = loss + self._calc_tv_loss(array, w) + w = (0.0, w[1]) + return w + def _tv_complex(self, array: torch.Tensor, w: tuple[float, float]) -> torch.Tensor: + # complex objects carry information in both amplitude and phase. We + # still extract phase via angle() here, so the wrap warning stays — + # but only for obj_type == "complex". + loss = self._get_zero_loss_tensor() + ph = array.angle() + warn( + "calculating TV loss for phase of complex object, " + "phase wrapping may distort the gradient. Consider obj_type='pure_phase'." + ) + # TODO: amp and phase share `w` here. Consider splitting `tv_weight_xy` + # into separate amp/phase weights on PtychoObjConstraintParams.Raster + # so users can tune them independently for obj_type="complex". + loss = loss + self._calc_tv_loss(ph, w) + amp = array.abs() + loss = loss + self._calc_tv_loss(amp, w) return loss def _calc_tv_loss(self, array: torch.Tensor, weight: tuple[float, float]) -> torch.Tensor: + """Mean-|diff| TV on a real array. ``weight = (w_z, w_xy)``. + + For a 3D ``(slices, H, W)`` array, dim 0 uses ``w_z`` and dims 1+2 use + ``w_xy``. The result is averaged over the number of axes that actually + contributed (i.e. had a non-zero weight). + """ loss = self._get_zero_loss_tensor() calc_dim = 0 for dim in range(array.ndim): @@ -414,34 +599,48 @@ def _calc_tv_loss(self, array: torch.Tensor, weight: tuple[float, float]) -> tor def get_surface_zero_loss( self, array: torch.Tensor, weight: float | int = 0.0 ) -> torch.Tensor: + """Penalize the first and last slices to be near vacuum. + + Real ``pure_phase`` / ``potential`` arrays: penalizes ``|array[0]|`` and + ``|array[-1]|`` directly. ``complex`` arrays pull amplitude toward 1 + and phase toward its mean (see ``_surface_zero_complex``). A no-op for + single- or double-slice objects (``array.shape[0] < 3``). + """ loss = self._get_zero_loss_tensor() - if weight == 0: - return loss - if array.shape[0] < 3: + if weight == 0 or array.shape[0] < 3: return loss - if array.is_complex(): - ph = array.angle().abs() - if self.obj_type == "complex": - amp = array.abs() - loss = loss + weight * (torch.mean(1.0 - amp[0]) + torch.mean(1.0 - amp[-1])) - warn("calculating surface zero loss for phase, need to check phase wrapping.") - loss = loss + weight * ( - torch.mean(torch.abs(ph[0] - ph[0].mean())) - + torch.mean(torch.abs(ph[-1] - ph[-1].mean())) - ) - else: - loss = loss + weight * ( - torch.mean(torch.abs(array[0])) + torch.mean(torch.abs(array[-1])) - ) + if self.obj_type == "complex": + return self._surface_zero_complex(array, weight) + # pure_phase and potential: real array, penalize first/last slice magnitude + return loss + weight * (torch.mean(torch.abs(array[0])) + torch.mean(torch.abs(array[-1]))) + + def _surface_zero_complex(self, array: torch.Tensor, weight: float | int) -> torch.Tensor: + # complex: pull amp toward 1 (vacuum) at the surfaces, and phase toward its mean + loss = self._get_zero_loss_tensor() + amp = array.abs() + loss = loss + weight * (torch.mean(1.0 - amp[0]) + torch.mean(1.0 - amp[-1])) + ph = array.angle().abs() + warn( + "calculating surface zero loss for phase of complex object, " + "phase wrapping may distort the gradient. Consider obj_type='pure_phase'." + ) + loss = loss + weight * ( + torch.mean(torch.abs(ph[0] - ph[0].mean())) + + torch.mean(torch.abs(ph[-1] - ph[-1].mean())) + ) return loss def gaussian_blur_2d(self, tensor, sigma=1.0): - """ - Apply Gaussian blur along dimensions 2 and 3 of a 3D tensor. - - Args: - tensor: Can be real or complex - sigma: Standard deviation for Gaussian kernel + """Separable 2D Gaussian blur over the last two dimensions. + + Parameters + ---------- + tensor : torch.Tensor + Real or complex, shape ``(slices, H, W)``. Complex inputs are + filtered as independent real/imag channels. + sigma : float + Standard deviation of the Gaussian kernel, in pixels. The + kernel size is ``2 * ceil(3 * sigma) + 1``. """ kernel_size = int(2 * math.ceil(3 * sigma) + 1) if kernel_size % 2 == 0: @@ -484,14 +683,26 @@ def butterworth_constraint( tensor: torch.Tensor, sampling: tuple[float, float], ) -> torch.Tensor: + """Apply a Fourier-domain Butterworth low/high-pass to each 2D slice. + + Reads ``q_lowpass``, ``q_highpass``, and ``butterworth_order`` off + ``self.constraints``. The DC component is subtracted before filtering + and added back so the mean is preserved. + + Parameters + ---------- + tensor : torch.Tensor + Shape ``(slices, H, W)``. May be real or complex. Real inputs are + re-cast to real after the FFT round-trip when ``obj_type != "complex"``. + sampling : tuple[float, float] + ``(dy, dx)`` real-space sampling in Ångström per pixel. Sets the + inverse-Å scale of the Butterworth response; ``q_lowpass`` and + ``q_highpass`` are in *inverse Ångström (cycles / Å). """ - Butterworth filter used for low/high-pass filtering. - """ - - q_lowpass = self.constraints["q_lowpass"] - q_highpass = self.constraints["q_highpass"] - butterworth_order = self.constraints["butterworth_order"] + q_lowpass = self.constraints.q_lowpass + q_highpass = self.constraints.q_highpass + butterworth_order = self.constraints.butterworth_order qx = torch.fft.fftfreq(tensor.shape[-2], sampling[0], device=tensor.device) qy = torch.fft.fftfreq(tensor.shape[-1], sampling[1], device=tensor.device) @@ -515,8 +726,9 @@ def butterworth_constraint( tensor = tensor + tensor_mean - # Take real part for potential tensorects - if self.obj_type == "potential": + # FFT-based filter returns complex even for real inputs; cast back to real + # for any non-complex object type (pure_phase, potential). + if self.obj_type != "complex": tensor = tensor.real return tensor @@ -652,20 +864,26 @@ def _initialize_obj( return init_shape = tuple(int(x) for x in shape) if self._initialize_mode == "uniform": - if self.obj_type in ["complex", "pure_phase"]: + if self.obj_type == "complex": + # amp=1, phase=0 -> complex ones arr = torch.ones(init_shape) * torch.exp(1.0j * torch.zeros(init_shape)) else: + # pure_phase (phase=0) and potential start as real zeros arr = torch.zeros(init_shape) elif self._initialize_mode == "random": ph = ( torch.randn(init_shape, dtype=torch.float32, generator=self._rng_torch) - 0.5 ) * 1e-6 - if self.obj_type == "potential": - arr = ph - else: + if self.obj_type == "complex": arr = torch.exp(1.0j * ph) + else: + # pure_phase stores phase directly; potential stores real values + arr = ph elif self._initialize_mode == "array": arr = self._initial_obj + if self.obj_type == "pure_phase" and arr.is_complex(): + # Convert legacy complex initial_obj (amp*exp(1j*phase)) to bare phase + arr = arr.angle() else: raise ValueError(f"Invalid initialize mode: {self._initialize_mode}") @@ -816,7 +1034,7 @@ def from_pixelated( else: model_dtype = "real" - if pixelated.obj_type == "pure_phase" and model_dtype == "real": + if pixelated.obj_type == "complex" and model_dtype == "real": obj = pixelated.obj.angle().clone().detach() else: obj = pixelated.obj.clone().detach() @@ -835,9 +1053,6 @@ def from_pixelated( return obj_model - # TODO add a from_params that sets the model input and target from params, - # will need to specify a shape as well, at least before pre-training (so just set here) - @property def num_slices(self) -> int: return self._num_slices @@ -913,21 +1128,6 @@ def model_input(self, input_tensor: torch.Tensor | np.ndarray): self._model_input = input_tensor.type(self.dtype).to(self.device) - # def _generate_model_input(self, mode: Literal["random", "zeros", "ones"]) -> None: - # input_shape = (1, *self.shape) - # # could support for 3D CNN models, single channel 2D with identical slices - # if mode == "random": - # inp = torch.randn( - # input_shape, device=self.device, dtype=self.dtype, generator=self._rng_torch - # ) - # elif mode == "zeros": - # inp = torch.zeros(input_shape, device=self.device, dtype=self.dtype) - # elif mode == "ones": - # inp = torch.ones(input_shape, device=self.device, dtype=self.dtype) - # else: - # raise ValueError(f"Invalid mode: {mode} | must be one of: 'random', 'zeros', 'ones'") - # self._model_input = inp - @property def pretrain_target(self) -> torch.Tensor: """get the pretrain target""" @@ -976,16 +1176,12 @@ def pretrain_lrs(self) -> np.ndarray: @property def obj(self): """get the full object""" - obj = self.model(self._model_input)[0] - if self.obj_type == "pure_phase" and "complex" not in str(self.dtype): - # using a real-valued model for a pure-phase (complex) object - obj = torch.ones_like(obj) * torch.exp(1j * obj) + raw = self.model(self._model_input)[0] # TODO -- single channel 2D with identical slices, view as 3D num_slices - return self.apply_hard_constraints(obj, mask=self.mask) + return self.apply_hard_constraints(raw, mask=self.mask) @property def _obj(self): - # TODO -- single channel 2D with identical slices, view as 3D num_slices?? return self.model(self._model_input)[0] def forward(self, patch_indices: torch.Tensor): @@ -1056,11 +1252,12 @@ def pretrain( loss_fn: Callable | str = "l2", apply_constraints: bool = False, show: bool = True, - device: str | None = None, # allow overwriting of device + device: str | int | None = None, normalize_object_plotting: bool = True, ): if device is not None: - self.to(device) + dev, _ = config.validate_device(device) + self.to(dev) if optimizer_params is not None: self.set_optimizer(optimizer_params) @@ -1135,8 +1332,6 @@ def _pretrain( if apply_constraints: output = self.apply_hard_constraints(self.model(model_input)[0]) - if self.obj_type == "pure_phase": - output = output.angle() else: output = self.model(model_input)[0] loss: torch.Tensor = loss_fn(output, self.pretrain_target) @@ -1242,7 +1437,7 @@ def visualize_pretrain( ], cmap="magma", cbar=True, - norm=[norm_angle, norm_angle, norm_abs, norm_abs], # type:ignore + norm=[norm_angle, norm_angle, norm_abs, norm_abs], # type:ignore ) else: norm = None diff --git a/src/quantem/diffractive_imaging/optimize_hyperparameters.py b/src/quantem/diffractive_imaging/optimize_hyperparameters.py index eb73c257..28155a07 100644 --- a/src/quantem/diffractive_imaging/optimize_hyperparameters.py +++ b/src/quantem/diffractive_imaging/optimize_hyperparameters.py @@ -867,7 +867,8 @@ def _plot_grid_objects(self, results, param_names, figsize): if recon_obj.obj_type == "potential": obj = np.abs(obj).sum(0) elif recon_obj.obj_type == "pure_phase": - obj = np.angle(obj).sum(0) + # pure_phase obj_cropped is a real phase array — plot directly + obj = obj.sum(0) else: obj = np.angle(obj).sum(0) diff --git a/src/quantem/diffractive_imaging/probe_models.py b/src/quantem/diffractive_imaging/probe_models.py index af638b85..9cae9f7b 100644 --- a/src/quantem/diffractive_imaging/probe_models.py +++ b/src/quantem/diffractive_imaging/probe_models.py @@ -1,6 +1,7 @@ from abc import abstractmethod from copy import deepcopy -from typing import Any, Callable, Self, Union +from dataclasses import dataclass +from typing import Any, Callable, Self, Union, cast from warnings import warn import matplotlib.pyplot as plt @@ -14,6 +15,7 @@ from quantem.core.datastructures import Dataset2d, Dataset4dstem from quantem.core.io.serialize import AutoSerialize from quantem.core.ml.blocks import reset_weights +from quantem.core.ml.constraints import BaseConstraints, Constraints, parse_constraint_dict from quantem.core.ml.loss_functions import get_loss_module from quantem.core.ml.optimizer_mixin import ( OptimizerMixin, @@ -36,7 +38,6 @@ POLAR_SYMBOLS, real_space_probe, ) -from quantem.diffractive_imaging.constraints import BaseConstraints from quantem.diffractive_imaging.ptycho_utils import ( fourier_shift_expand, shift_array, @@ -45,6 +46,87 @@ DeviceType = Union[str, torch.device, int] +class PtychoProbeConstraintParams: + """ + Namespace class for ptychography probe constraint dataclasses. + + Tab-complete on ``PtychoProbeConstraintParams`` in a notebook to discover the + available variants. Tab-complete inside a variant's constructor to see every + constraint field with its default value. + + Variants + -------- + Raster + Constraints for grid-based probe representations (``ProbePixelated`` and + ``ProbeDIP`` share this set today). + Parametric + Placeholder for parametric probe models, where Gram-Schmidt orthogonalization + and pixel-domain TV are moot. + """ + + @dataclass + class Raster(Constraints): + """Constraints for grid-based ptychography probe models (``ProbePixelated``, + ``ProbeDIP``). + + Attributes + ---------- + orthogonalize_probe : bool, default ``True`` + Mixed-state probe (``num_probes > 1``) only. After each update applies + Gram-Schmidt orthogonalization across the probe stack and then sorts + the resulting probes by total intensity (descending). For + ``num_probes == 1`` this is effectively a renormalization no-op. + center_probe : bool, default ``False`` + Shifts the probe's intensity center-of-mass back to the array center + via a Fourier shift after each update. Useful when probe drift + competes with scan-position refinement; if both move freely the + reconstruction can wander while still fitting the diffraction data. + tv_weight : float, default ``0.0`` + Soft penalty. Weight on the in-plane total-variation of the (complex) + probe; encourages smooth probe magnitude / phase. + """ + + # hard constraints + orthogonalize_probe: bool = True + center_probe: bool = False + # soft constraints + tv_weight: float = 0.0 + _name: str = "raster" + + soft_constraint_keys = ["tv_weight"] + hard_constraint_keys = ["orthogonalize_probe", "center_probe"] + + @dataclass + class Parametric(Constraints): + """Placeholder for parametric probe constraints (``ProbeParametric``). + + Parametric probes are pure functions of aberration / aperture coefficients, + so pixel-domain projections like ``orthogonalize_probe`` and ``tv_weight`` + don't apply. Parametric-specific fields (e.g. bounds on individual + aberration coefficients) will land here when needed. + """ + + _name: str = "parametric" + + soft_constraint_keys = [] + hard_constraint_keys = [] + + @classmethod + def parse_dict(cls, d: dict) -> "PtychoProbeConstraintsType": + """Instantiate the appropriate variant from a config dict. + + The dict must contain a ``'name'`` or ``'type'`` key (case-insensitive), + with value ``'raster'`` or ``'parametric'``. All other keys are forwarded + as keyword arguments to the chosen dataclass. + """ + return cast(PtychoProbeConstraintsType, parse_constraint_dict(cls, d, kind="probe")) + + +PtychoProbeConstraintsType = ( + PtychoProbeConstraintParams.Raster | PtychoProbeConstraintParams.Parametric +) + + class ProbeBase(nn.Module, RNGMixin, OptimizerMixin, AutoSerialize): DEFAULT_PROBE_PARAMS = { "energy": None, @@ -427,13 +509,8 @@ def _compute_propagator_arrays( return propagators -class ProbeConstraints(BaseConstraints, ProbeBase): - DEFAULT_CONSTRAINTS = { - # "fix_probe": False, - "orthogonalize_probe": True, - "center_probe": False, - "tv_weight": 0.0, - } +class ProbeConstraints(BaseConstraints[PtychoProbeConstraintParams.Raster], ProbeBase): + DEFAULT_CONSTRAINTS: PtychoProbeConstraintParams.Raster = PtychoProbeConstraintParams.Raster() def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -441,8 +518,8 @@ def __init__(self, *args, **kwargs): def apply_soft_constraints(self, probe: torch.Tensor) -> torch.Tensor: self.reset_soft_constraint_losses() loss = self._get_zero_loss_tensor() - if self.constraints["tv_weight"]: - loss_tv = self._probe_tv_constraint(probe, self.constraints["tv_weight"]) + if self.constraints.tv_weight: + loss_tv = self._probe_tv_constraint(probe, self.constraints.tv_weight) self.add_soft_constraint_loss("tv_loss", loss_tv) loss = loss + loss_tv @@ -450,11 +527,9 @@ def apply_soft_constraints(self, probe: torch.Tensor) -> torch.Tensor: return loss def apply_hard_constraints(self, probe: torch.Tensor) -> torch.Tensor: - # if self.constraints["fix_probe"]: - # return self.initial_probe - if self.constraints["orthogonalize_probe"]: + if self.constraints.orthogonalize_probe: probe = self._probe_orthogonalization_constraint(probe) - if self.constraints["center_probe"]: + if self.constraints.center_probe: probe = self._probe_center_of_mass_constraint(probe) return probe @@ -594,6 +669,8 @@ def from_array( ): if isinstance(probe_array, np.ndarray): probe_array = torch.tensor(probe_array, dtype=dtype, device=device) + else: + probe_array = probe_array.to(dtype=dtype, device=device) if probe_array.ndim == 3: if num_probes is None: num_probes = probe_array.shape[0] @@ -603,9 +680,7 @@ def from_array( ) else: num_probes = 1 if num_probes is None else num_probes - probe_array = torch.tensor(probe_array, dtype=dtype, device=device) - # probe_array = torch.tile(probe_array, (num_probes, 1, 1)) - probe_array = torch.cat([probe_array] * num_probes, dim=0) + probe_array = torch.stack([probe_array] * num_probes, dim=0) probe_model = cls( num_probes=num_probes, @@ -1364,10 +1439,11 @@ def pretrain( loss_fn: Callable | str = "l2", apply_constraints: bool = False, show: bool = True, - device: str | None = None, # allow overwriting of device + device: str | int | None = None, ): if device is not None: - self.to(device) + dev, _ = config.validate_device(device) + self.to(dev) if optimizer_params is not None: self.set_optimizer(optimizer_params) diff --git a/src/quantem/diffractive_imaging/ptycho_utils.py b/src/quantem/diffractive_imaging/ptycho_utils.py index 023963ff..f3c09726 100644 --- a/src/quantem/diffractive_imaging/ptycho_utils.py +++ b/src/quantem/diffractive_imaging/ptycho_utils.py @@ -117,6 +117,51 @@ def val_len(self) -> int: return int(ceil(len(self.val_indices) / self.batch_size)) if self.has_validation else 0 +def compute_train_val_split( + num: int, + val_ratio: float, + val_mode: Literal["grid", "random"], + rng: np.random.Generator, +) -> tuple[np.ndarray, np.ndarray]: + """Compute the train/validation index split. + + Returns ``(train_indices, val_indices)`` as int numpy arrays. ``val_mode="grid"`` + selects every k-th index (with ``k = round(1/val_ratio)``, inverted when + ``val_ratio > 0.5``); ``"random"`` selects a seeded ``rng.permutation`` slice. + """ + indices = np.arange(num) + if val_ratio < 0 or val_ratio >= 1: + val_ratio = 0.0 + n_val = int(round(len(indices) * val_ratio)) + if n_val <= 0: + return indices, np.asarray([], dtype=int) + + if val_mode == "random": + # Random unique selection for validation + perm = rng.permutation(indices) + val_indices = perm[:n_val] + train_indices = np.setdiff1d(indices, val_indices, assume_unique=False) + else: # grid/regular selection: every k-th index + if val_ratio <= 0.5: + k = max(1, int(round(1.0 / val_ratio))) + invert = False + else: + k = max(1, int(round(1.0 / (1.0 - val_ratio)))) + invert = True + + grid_sel = indices[::k] + if len(grid_sel) > n_val: + grid_sel = grid_sel[:n_val] + if invert: + train_indices = grid_sel + val_indices = np.setdiff1d(indices, grid_sel, assume_unique=False) + else: + val_indices = grid_sel + train_indices = np.setdiff1d(indices, val_indices, assume_unique=False) + + return np.asarray(train_indices, dtype=int), np.asarray(val_indices, dtype=int) + + @overload def fourier_shift_expand( array: np.ndarray, positions: np.ndarray, expand_dim: bool = True @@ -136,7 +181,7 @@ def fourier_shift_expand( if af.is_complex(array): return shifted_array else: - return shifted_array.real # type:ignore ## will be numeric so this should be safe + return shifted_array.real # type:ignore ## will be numeric so this should be safe @overload diff --git a/src/quantem/diffractive_imaging/ptychography.py b/src/quantem/diffractive_imaging/ptychography.py index 199e8588..4d9345e7 100644 --- a/src/quantem/diffractive_imaging/ptychography.py +++ b/src/quantem/diffractive_imaging/ptychography.py @@ -1,38 +1,100 @@ import contextlib import copy import gc +import os import tempfile from pathlib import Path -from typing import TYPE_CHECKING, Literal, Self, Sequence, cast +from typing import Any, Literal, Self, Sequence, cast from warnings import warn import numpy as np +import torch +import torch.distributed as dist from tqdm.auto import tqdm -from quantem.core import config from quantem.core.io.serialize import load as autoserialize_load +from quantem.core.ml.dist_utils import ( + init_process_group, + is_distributed_launch, + spawn_distributed_workers, +) from quantem.diffractive_imaging.dataset_models import DatasetModelType from quantem.diffractive_imaging.detector_models import DetectorModelType from quantem.diffractive_imaging.logger_ptychography import LoggerPtychography from quantem.diffractive_imaging.object_models import ObjectModelType, ObjectPixelated from quantem.diffractive_imaging.probe_models import ProbeModelType, ProbeParametric -from quantem.diffractive_imaging.ptycho_utils import SimpleBatcher +from quantem.diffractive_imaging.ptycho_utils import compute_train_val_split from quantem.diffractive_imaging.ptychography_base import PtychographyBase from quantem.diffractive_imaging.ptychography_opt import PtychographyOpt from quantem.diffractive_imaging.ptychography_visualizations import PtychographyVisualizations -if TYPE_CHECKING: - import torch -else: - if config.get("has_torch"): - import torch +def _ddp_ptycho_worker( + rank: int, + world_size: int, + ptycho_path: str, + devices: list[int], + recon_kwargs: dict[str, Any], + result_path: str, +) -> None: + """Module-level worker for mp.start_processes — must live at module scope to be picklable. -class Ptychography(PtychographyOpt, PtychographyVisualizations, PtychographyBase): + Receives a file path rather than the Ptychography object directly so that no + large tensors cross the process boundary via pickle (which triggers PyTorch's + shared-memory tensor mechanism and fails in some Linux environments). + """ + device_id = devices[rank] + # Bind the CUDA device BEFORE init_process_group so NCCL allocates its + # communicator buffers on the correct GPU. Without this, NCCL grabs cuda:0 + # at init time, stranding small per-rank buffers on GPUs the user didn't + # ask for. + init_process_group( + rank, + world_size, + backend="nccl" if torch.cuda.is_available() else "gloo", + local_device=device_id if torch.cuda.is_available() else None, + ) + + # mmap=True so all workers share one memory-mapped RAM copy of the (potentially large, + # CPU-resident) state instead of each duplicating it. + ptycho = torch.load(ptycho_path, map_location="cpu", weights_only=False, mmap=True) + ptycho.to(f"cuda:{device_id}" if torch.cuda.is_available() else "cpu") + + if dist.is_available() and dist.is_initialized(): + ptycho._broadcast_parameters(src=0) + + ptycho._reconstruct_inner(**recon_kwargs, _dist_rank=rank, _dist_world_size=world_size) + + if rank == 0: + obj_opt = ptycho.optimizers.get("object") + probe_opt = ptycho.optimizers.get("probe") + torch.save( + { + "obj_state": {k: v.cpu() for k, v in ptycho.obj_model.state_dict().items()}, + "probe_state": {k: v.cpu() for k, v in ptycho.probe_model.state_dict().items()}, + "obj_optimizer_params": ptycho.obj_model._optimizer_params, + "probe_optimizer_params": ptycho.probe_model._optimizer_params, + "obj_optimizer_state": obj_opt.state_dict() if obj_opt is not None else None, + "probe_optimizer_state": probe_opt.state_dict() if probe_opt is not None else None, + "iter_losses": ptycho._iter_losses, + "iter_val_losses": ptycho._iter_val_losses, + "iter_lrs": ptycho._iter_lrs, + "iter_recon_types": ptycho._iter_recon_types, + }, + result_path, + ) + + dist.destroy_process_group() + + +class Ptychography(PtychographyOpt, PtychographyVisualizations, PtychographyBase): # pyright: ignore[reportUnsafeMultipleInheritance] """ A class for performing phase retrieval using the Ptychography algorithm. """ + _autograd: bool = True + _dataset_metadata: "dict[str, Any] | None" = None + @classmethod def from_models( cls, @@ -124,7 +186,7 @@ def _reset_iter_constraints(self) -> None: def _soft_constraints(self) -> torch.Tensor: """Calculate soft constraints by calling apply_soft_constraints on each model.""" - total_loss = torch.tensor(0, device=self.device, dtype=self._dtype_real) + total_loss = torch.tensor(0, device=self._single_device, dtype=self._dtype_real) obj_loss = self.obj_model.apply_soft_constraints( self.obj_model.obj, mask=self.obj_model.mask @@ -147,30 +209,124 @@ def reconstruct( self, num_iters: int = 0, reset: bool = False, - optimizer_params: dict | None = None, - scheduler_params: dict | None = None, - constraints: dict = {}, + optimizer_params: dict[str, Any] | None = None, + scheduler_params: dict[str, Any] | None = None, + constraints: dict[str, Any] | None = None, batch_size: int | None = None, store_snapshots: bool | None = None, store_snapshots_every: int | None = None, - device: Literal["cpu", "gpu"] | None = None, + device: Literal["cpu", "gpu"] | int | list[int] | None = None, autograd: bool = True, loss_type: Literal[ "l2_amplitude", "l1_amplitude", "l2_intensity", "l1_intensity", "poisson" ] = "l2_amplitude", + num_workers: int = 0, ) -> Self: - """ - reason for having a single reconstruct() is so that updating things like constraints - or recon_types only happens in one place, reason for having separate reoconstruction_ - methods would be to simplify the flags for this and not have to include all + """Run iterative ptychography reconstruction. + ``device`` accepts: + - ``None`` — keep current device + - ``"cpu"`` / ``"gpu"`` — existing string form + - ``int`` — specific GPU index, e.g. ``device=2`` → cuda:2 + - ``list[int]`` — multi-GPU, e.g. ``device=[0,1,2,3]`` + + ``constraints`` is a dict keyed by ``"object"``, ``"probe"``, ``"dataset"`` + (any subset). Each leaf may be: + + - a ``Constraints`` dataclass instance (e.g. ``PtychoObjConstraintParams.Raster(...)``), + which replaces that model's constraint state wholesale, or + - a plain ``dict`` of field-name -> value, which does a per-key partial update + on the existing constraint state. + + Multi-GPU (``device`` is a list) launches worker processes via ``mp.spawn`` when called + from a notebook, or uses the existing distributed process group when launched with + ``torchrun``. Only autograd mode is supported for multi-GPU in this release. """ - # TODO maybe make an "process args" method that handles things like: - # mode, store_iterations, device, self._check_preprocessed() - if device is not None: - self.to(device) - self.batch_size = batch_size + + # Determine effective device list: explicit arg takes priority, else fall back to stored. + devices_to_use = ( + device if isinstance(device, list) else getattr(self, "_multi_gpu_devices", None) + ) + + # Route to multi-GPU path + if isinstance(devices_to_use, list) and not is_distributed_launch(): + if not autograd: + raise ValueError("Multi-GPU reconstruction requires autograd=True.") + return self._spawn_reconstruct( + devices=devices_to_use, + num_iters=num_iters, + reset=reset, + optimizer_params=optimizer_params, + scheduler_params=scheduler_params, + constraints=constraints, + batch_size=batch_size, + store_snapshots=store_snapshots, + store_snapshots_every=store_snapshots_every, + autograd=autograd, + loss_type=loss_type, + num_workers=num_workers, + ) + + # Handle torchrun distributed launch (RANK env var present) + if is_distributed_launch(): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ.get("LOCAL_RANK", rank)) + if not torch.distributed.is_initialized(): + # Bind the device BEFORE init_process_group so NCCL allocates + # its communicator buffers on the correct GPU. + if torch.cuda.is_available(): + torch.cuda.set_device(local_rank) + torch.distributed.init_process_group( + backend="nccl" if torch.cuda.is_available() else "gloo", + init_method="env://", + ) + dev = f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu" + self.to(dev) + self._broadcast_parameters(src=0) + else: + rank, world_size = 0, 1 + if device is not None and not isinstance(device, list): + self.to(device) + + return self._reconstruct_inner( + num_iters=num_iters, + reset=reset, + optimizer_params=optimizer_params, + scheduler_params=scheduler_params, + constraints=constraints, + batch_size=batch_size, + store_snapshots=store_snapshots, + store_snapshots_every=store_snapshots_every, + autograd=autograd, + loss_type=loss_type, + num_workers=num_workers, + _dist_rank=rank, + _dist_world_size=world_size, + ) + + def _reconstruct_inner( + self, + num_iters: int = 0, + reset: bool = False, + optimizer_params: dict[str, Any] | None = None, + scheduler_params: dict[str, Any] | None = None, + constraints: dict[str, Any] | None = None, + batch_size: int | None = None, + store_snapshots: bool | None = None, + store_snapshots_every: int | None = None, + autograd: bool = True, + loss_type: Literal[ + "l2_amplitude", "l1_amplitude", "l2_intensity", "l1_intensity", "poisson" + ] = "l2_amplitude", + num_workers: int = 0, + _dist_rank: int = 0, + _dist_world_size: int = 1, + ) -> Self: + """Core reconstruction loop. Called by reconstruct() for all launch modes.""" + if batch_size is not None: + self.batch_size = batch_size self.store_snapshot_every = store_snapshots_every if store_snapshots_every is not None and store_snapshots is None: self.store_snapshots = True @@ -179,7 +335,8 @@ def reconstruct( if reset: self.reset_recon() - self.constraints = constraints + if constraints: + self.constraints = constraints new_scheduler = reset if optimizer_params is not None: @@ -196,22 +353,37 @@ def reconstruct( self.dset._set_targets(loss_type) self.compute_propagator_arrays() # required to avoid issue if stopped learning probe tilt - batcher = SimpleBatcher( + + # Compute the global scan count once — needed to keep loss scale consistent across world + global_n = self.dset.num_gpts + + train_indices, val_indices = compute_train_val_split( self.dset.num_gpts, - self.batch_size, - rng=self.rng, - val_ratio=self.val_ratio, - val_mode=self.val_mode, + self.val_ratio, + self.val_mode, + self.rng, + ) + train_loader, train_sampler, val_loader = self._build_dataloaders( + train_indices, + val_indices, + world_size=_dist_world_size, + rank=_dist_rank, + num_workers=num_workers, ) - pbar = tqdm(range(num_iters), disable=not self.verbose) + + pbar = tqdm(range(num_iters), disable=not self.verbose or _dist_rank != 0) for a0 in pbar: + if _dist_world_size > 1 and train_sampler is not None: + train_sampler.set_epoch(a0) consistency_loss = 0.0 total_loss = 0.0 self._reset_iter_constraints() - for batch_indices in batcher: + for batch in train_loader: self.zero_grad_all() + batch_indices = batch["index"].to(self._single_device) + targets = batch["target"].to(self._single_device, non_blocking=True) patch_indices, _positions_px, positions_px_fractional, descan_shifts = ( self.dset.forward(batch_indices, self.obj_padding_px) ) @@ -225,7 +397,9 @@ def reconstruct( batch_consistency_loss, targets = self.error_estimate( pred_intensities, batch_indices, + targets=targets, loss_type=loss_type, + global_n=global_n, ) batch_soft_constraint_loss = self._soft_constraints() @@ -240,21 +414,33 @@ def reconstruct( patch_indices, targets, ) + if _dist_world_size > 1: + self._all_reduce_gradients() self.step_optimizers() consistency_loss += batch_consistency_loss.item() total_loss += batch_loss.item() - num_batches = len(batcher) + num_batches = len(train_loader) total_loss = total_loss / num_batches consistency_loss = consistency_loss / num_batches + # Average loss across ranks so rank-0 reports the global mean + if _dist_world_size > 1: + loss_t = torch.tensor( + [total_loss, consistency_loss], device=self._single_device, dtype=torch.float64 + ) + dist.all_reduce(loss_t, op=dist.ReduceOp.AVG) + total_loss, consistency_loss = loss_t[0].item(), loss_t[1].item() + # Validation pass (no gradient, no optimizer steps) val_loss = None - if batcher.has_validation: + if val_loader is not None: val_consistency_loss = 0.0 val_batches = 0 with torch.no_grad(): - for batch_indices in batcher.iter_val(): + for batch in val_loader: + batch_indices = batch["index"].to(self._single_device) + targets = batch["target"].to(self._single_device, non_blocking=True) patch_indices, _positions_px, positions_px_fractional, descan_shifts = ( self.dset.forward(batch_indices, self.obj_padding_px) ) @@ -265,23 +451,36 @@ def reconstruct( ) pred_intensities = self.detector_model.forward(overlap) batch_val_loss, _ = self.error_estimate( - pred_intensities, batch_indices, loss_type=loss_type + pred_intensities, + batch_indices, + targets=targets, + loss_type=loss_type, + global_n=global_n, ) val_consistency_loss += batch_val_loss.item() val_batches += 1 if val_batches > 0: val_loss = val_consistency_loss / val_batches - self._iter_val_losses.append(val_loss) + # Average the val loss across ranks so rank-0 records the global mean + if _dist_world_size > 1: + val_t = torch.tensor( + val_loss, device=self._single_device, dtype=torch.float64 + ) + dist.all_reduce(val_t, op=dist.ReduceOp.AVG) + val_loss = val_t.item() + if _dist_rank == 0: + self._iter_val_losses.append(val_loss) - self._record_iter(total_loss) # TODO record val loss as well + if _dist_rank == 0: + self._record_iter(total_loss) # TODO record val loss as well # Step schedulers with current loss self.step_schedulers(total_loss) - if self.store_snapshots and (a0 % self.store_snapshot_every) == 0: + if _dist_rank == 0 and self.store_snapshots and (a0 % self.store_snapshot_every) == 0: self._store_current_iter_snapshot() - if self.logger is not None: + if _dist_rank == 0 and self.logger is not None: self.logger.log_iter( self.obj_model, self.probe_model, @@ -292,21 +491,120 @@ def reconstruct( self._get_current_lrs(), ) - if val_loss is not None: - pbar.set_description( - f"Iter {a0 + 1}/{num_iters}, Loss: {total_loss:.3e}, Val: {val_loss:.3e}" - ) - else: - pbar.set_description(f"Iter {a0 + 1}/{num_iters}, Loss: {total_loss:.3e}") + if _dist_rank == 0: + if val_loss is not None: + pbar.set_description( + f"Iter {a0 + 1}/{num_iters}, Loss: {total_loss:.3e}, Val: {val_loss:.3e}" + ) + else: + pbar.set_description(f"Iter {a0 + 1}/{num_iters}, Loss: {total_loss:.3e}") gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() if hasattr(torch, "mps") and torch.backends.mps.is_available(): torch.mps.empty_cache() gc.collect() return self + def _spawn_reconstruct(self, devices: list[int], **recon_kwargs) -> Self: + """Notebook multi-GPU: spawn one worker process per device via forkserver. + + State is saved to a temp file so that no tensors cross the process boundary + via pickle. PyTorch's ForkingPickler automatically moves all CPU tensors to + shared memory when pickling for multiprocessing, which fails on some Linux + systems (EINVAL from ftruncate). Passing only a file path (a plain string) + avoids that mechanism entirely. + """ + restore_device = f"cuda:{devices[0]}" if torch.cuda.is_available() else "cpu" + # Persist batch_size on the main process so it carries into the saved file and + # is remembered on future calls that omit batch_size. + bs = recon_kwargs.get("batch_size") + if bs is not None: + self.batch_size = bs + self.to("cpu") + + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir_path = Path(tmpdir) + ptycho_path = str(tmpdir_path / "ptycho_state.pt") + result_path = str(tmpdir_path / "result.pt") + + torch.save(self, ptycho_path, pickle_protocol=4) + + # forkserver: workers fork from a clean pre-started server (no inherited + # CUDA, no Jupyter FDs). Only plain Python scalars/strings cross the + # process boundary, so tensor pickling is never triggered. + spawn_distributed_workers( + _ddp_ptycho_worker, + devices, + ptycho_path, + devices, + recon_kwargs, + result_path, + ) + result = torch.load(result_path, map_location="cpu", weights_only=False) + + # --- model weights --- + self.obj_model.load_state_dict(result["obj_state"]) + self.probe_model.load_state_dict(result["probe_state"]) + self.to(restore_device) + + # --- restore optimizer params (worker may have set/changed them) so that future + # spawns (e.g. reset=True without optimizer_params) can re-init the optimizer --- + for model, key in ( + (self.obj_model, "obj_optimizer_params"), + (self.probe_model, "probe_optimizer_params"), + ): + saved = result.get(key) + if saved is not None: + model._optimizer_params = saved + + # Re-create optimizers on the restored device (main process never ran set_optimizers). + # set_optimizers() skips models whose _optimizer_params is NoneOptimizer. + self.set_optimizers() + + # --- optimizer states (params and device must be set before loading) --- + for name, key in (("object", "obj_optimizer_state"), ("probe", "probe_optimizer_state")): + opt_state = result.get(key) + opt = self.optimizers.get(name) + if opt_state is not None and opt is not None: + opt.load_state_dict(opt_state) + # State tensors were saved on CPU; move them to restore_device + for state in opt.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.to(restore_device) + + # --- iteration tracking --- + # When reset=True the worker ran reset_recon() internally, so its lists start from 0. + # When reset=False the worker inherited existing history, so its lists are [old...new...]. + # n_before lets us take only the genuinely new tail in both cases. + n_before = len(self._iter_losses) + is_reset = recon_kwargs.get("reset", False) + + if is_reset: + self._iter_losses.clear() + self._iter_val_losses.clear() + self._iter_lrs.clear() + self._iter_recon_types.clear() + self._iter_losses.extend(result["iter_losses"]) + self._iter_val_losses.extend(result["iter_val_losses"]) + for k, v in result.get("iter_lrs", {}).items(): + self._iter_lrs[k] = list(v) + self._iter_recon_types.extend(result.get("iter_recon_types", [])) + else: + self._iter_losses.extend(result["iter_losses"][n_before:]) + self._iter_val_losses.extend(result["iter_val_losses"][n_before:]) + for k, v in result.get("iter_lrs", {}).items(): + if k not in self._iter_lrs: + self._iter_lrs[k] = [] + self._iter_lrs[k].extend(list(v)[n_before:]) + self._iter_recon_types.extend(result.get("iter_recon_types", [])[n_before:]) + + self._multi_gpu_devices = devices + return self + def _get_current_lrs(self) -> dict[str, float]: return { param_name: optimizer.param_groups[0]["lr"] @@ -329,12 +627,14 @@ def backward( # scaling pixelated ad gradients to closer match analytic if isinstance(self.obj_model, ObjectPixelated): obj_grad_scale = self.dset.upsample_factor**2 / 2 # factor of 2 from l2 grad - self.obj_model._obj.grad.mul_(obj_grad_scale) # type:ignore + if self.obj_model._obj.grad is not None: + self.obj_model._obj.grad.mul_(obj_grad_scale) if isinstance(self.probe_model, ProbeParametric): probe_grad_scale = np.sqrt(self.probe_model._mean_diffraction_intensity) for par in self.probe_model.params: - par.grad.mul_(probe_grad_scale) # type:ignore + if par.grad is not None: + par.grad.mul_(probe_grad_scale) else: gradient = self.gradient_step(amplitudes, overlap) @@ -457,7 +757,8 @@ def save( # Add other common skips for ptychography objects skips = skip - current_device = self.device + _dev = self.device + current_device: str = f"cuda:{_dev[0]}" if isinstance(_dev, list) else _dev self.to("cpu") if self.verbose and verbose: @@ -474,8 +775,8 @@ def save( self.to(current_device) # TODO figure out why this isn't working for DDIP sometimes? # Clean up temporary metadata - if not save_raw_data and hasattr(self, "_dataset_metadata"): - delattr(self, "_dataset_metadata") + if not save_raw_data and self._dataset_metadata is not None: + self._dataset_metadata = None @classmethod def from_file( @@ -517,7 +818,7 @@ def from_file( # If no dataset was provided, try to reload it from saved metadata if dset is None and auto_reload_dataset and not hasattr(ptycho, "dset"): - if hasattr(ptycho, "_dataset_metadata") and ptycho._dataset_metadata: + if ptycho._dataset_metadata is not None: metadata = ptycho._dataset_metadata file_path = metadata.get("file_path") @@ -560,13 +861,13 @@ def from_file( elif dset is not None: dset._set_initial_scan_positions_px(ptycho.obj_padding_px) dset._set_patch_indices(ptycho.obj_padding_px) - if hasattr(ptycho, "_dataset_metadata") and ptycho._dataset_metadata: + if ptycho._dataset_metadata is not None: metadata = ptycho._dataset_metadata # preserve learned scan positions and descan shifts if "learned_scan_positions_px" in metadata: - dset.scan_positions_px.data = metadata["learned_scan_positions_px"] + dset.scan_positions_px.data = metadata["learned_scan_positions_px"] # type: ignore[assignment] if "learned_descan_shifts" in metadata: - dset.descan_shifts.data = metadata["learned_descan_shifts"] + dset.descan_shifts.data = metadata["learned_descan_shifts"] # type: ignore[assignment] # check if dset was attached to ptycho object if dset is not None: diff --git a/src/quantem/diffractive_imaging/ptychography_base.py b/src/quantem/diffractive_imaging/ptychography_base.py index 665819d9..07071b84 100644 --- a/src/quantem/diffractive_imaging/ptychography_base.py +++ b/src/quantem/diffractive_imaging/ptychography_base.py @@ -4,9 +4,13 @@ import numpy as np import scipy.ndimage as ndi import torch +import torch.distributed as dist +from torch.utils.data import DataLoader, DistributedSampler from quantem.core import config from quantem.core.io.serialize import AutoSerialize +from quantem.core.ml.constraints import Constraints +from quantem.core.ml.dist_utils import all_reduce_params, worker_init_fn from quantem.core.utils.rng import RNGMixin from quantem.core.utils.utils import ( electron_wavelength_angstrom, @@ -93,12 +97,23 @@ def __init__( # TODO prevent direct instantiation raise RuntimeError("the quantEM Ptychography module requires torch to be installed.") super().__init__() + + # Pre-initialize private attributes so type checker sees them in __init__ + self._verbose: int = 0 + self._logger: LoggerPtychography | None = None + self._batch_size: int = 1 + self._dset: DatasetModelType = dset + self._obj_model: ObjectModelType = obj_model + self._probe_model: ProbeModelType = probe_model + self._detector_model: DetectorModelType = detector_model + self.verbose = verbose self.dset = dset self.device = device self.rng = rng # initializing default attributes + self._multi_gpu_devices: list[int] | None = None self._preprocessed: bool = False self._obj_padding_force_power2_level: int = 3 self._store_snapshots: bool = False @@ -106,7 +121,7 @@ def __init__( # TODO prevent direct instantiation self._iter_losses: list[float] = [] self._iter_val_losses: list[float] = [] self._iter_recon_types: list[str] = [] - self._iter_lrs: dict[str, list] = {} # LRs/step_sizes across iterations + self._iter_lrs: dict[str, list[float]] = {} # LRs/step_sizes across iterations self._snapshots: list[Snapshot] = [] self._obj_padding_px = np.array([0, 0]) self.obj_fov_mask = torch.ones(self.dset._obj_shape_full_2d(self.obj_padding_px).shape) @@ -127,7 +142,7 @@ def __init__( # TODO prevent direct instantiation self.detector_model = detector_model self.compute_propagator_arrays() self.logger = logger - self.to(self.device) + self.to(self._single_device) # region --- preprocessing --- ## hopefully will be able to remove some of thes preprocessing flags, @@ -170,7 +185,7 @@ def preprocess( self.roi_shape, self.reciprocal_sampling, self.dset.mean_diffraction_intensity, - device=self.device, + device=self._single_device, ) # change obj_padding_px and whatever else needs to be changed @@ -205,7 +220,7 @@ def _get_probe_overlap(self, max_batch_size: int | None = None) -> np.ndarray: batch_size = num_dps if max_batch_size is None else int(max_batch_size) probe_overlap = torch.zeros( - tuple(self.obj_shape_full[-2:]), dtype=self._dtype_real, device=self.device + tuple(self.obj_shape_full[-2:]), dtype=self._dtype_real, device=self._single_device ) for start, end in generate_batches(num_dps, max_batch=batch_size): probe_overlap += sum_patches( @@ -296,7 +311,7 @@ def slice_thicknesses(self) -> np.ndarray: return self._to_numpy(slice_thick) @slice_thicknesses.setter - def slice_thicknesses(self, val: float | Sequence | None) -> None: + def slice_thicknesses(self, val: float | Sequence[float] | None) -> None: self._obj_model.slice_thicknesses = val if hasattr(self, "_propagators"): # propagators already set, update with new slices self.compute_propagator_arrays() @@ -311,8 +326,14 @@ def verbose(self, v: bool | int | float) -> None: @property def obj(self) -> np.ndarray: + """Object array in its native representation per ``obj_type``: + + - ``"complex"`` → complex ndarray (amp * exp(1j*phase)); phase recentered. + - ``"pure_phase"`` → real ndarray of phase values. + - ``"potential"`` → real ndarray of potential values. + """ obj = self._to_numpy(self.obj_model.obj) - if self.obj_type in ["pure_phase", "complex"]: + if self.obj_type == "complex": ph = np.angle(obj) obj = np.abs(obj) * np.exp(1j * (ph - ph.mean())) return obj @@ -482,11 +503,10 @@ def get_snapshot_by_iter( if cropped: snp2 = snp.copy() cropped_obj = self._crop_rotate_obj_fov(snp2["obj"]) - # same logic as self.obj_cropped - if self.obj_type == "pure_phase": - ph = np.angle(cropped_obj) - cropped_obj = np.exp(1j * (ph - ph.mean())) - if self.obj_type in ["pure_phase", "complex"]: + # same logic as self.obj_cropped: only re-center for complex (which + # carries phase inside a complex tensor); pure_phase and potential + # are already real and recentered upstream. + if self.obj_type == "complex": ph = np.angle(cropped_obj) cropped_obj = np.abs(cropped_obj) * np.exp(1j * (ph - ph.mean())) snp2["obj"] = cropped_obj @@ -506,7 +526,7 @@ def obj_model(self, model: ObjectModelType | type): raise TypeError(f"obj_model must be a ObjectModelType, got {type(model)}") # Set object shape - model.to(self.device) + model.to(self._single_device) self._obj_model = cast(ObjectModelType, model) @property @@ -527,12 +547,12 @@ def probe_model(self, model: ProbeModelType | type): self.roi_shape, self.reciprocal_sampling, self.dset.mean_diffraction_intensity, - device=self.device, + device=self._single_device, ) else: # will be set in ptycho.preprocess after dset is preprocessed pass - self._probe_model.to(self.device) + self._probe_model.to(self._single_device) @property def constraints(self) -> dict[str, Any]: @@ -548,7 +568,12 @@ def constraints(self) -> dict[str, Any]: @constraints.setter def constraints(self, c: dict[str, Any]): - """Set constraints by forwarding to individual models.""" + """Set constraints by forwarding to individual models. + + Each leaf value may be either a plain ``dict`` (validated per-key against + the model's constraint dataclass) or a ``Constraints`` dataclass instance + (assigned wholesale to the model). + """ constraint_handlers = { "object": self.obj_model, "probe": self.probe_model, @@ -556,9 +581,16 @@ def constraints(self, c: dict[str, Any]): } for key, value in c.items(): - if key in constraint_handlers and isinstance(value, dict): - for subkey, subvalue in value.items(): - constraint_handlers[key].add_constraint(subkey, subvalue) + if key in constraint_handlers: + if isinstance(value, Constraints): + constraint_handlers[key].constraints = value + elif isinstance(value, dict): + constraint_handlers[key].constraints = value + else: + raise TypeError( + f"Constraints for '{key}' must be a dict or Constraints dataclass, " + f"got {type(value).__name__}" + ) elif key == "detector" and isinstance(value, dict): warn("Detector constraints not implemented, skipping") else: @@ -595,12 +627,13 @@ def logger(self, logger: LoggerPtychography | None): # region --- implicit class properties --- @property - def device(self) -> str: - """This should be of form 'cuda:X' or 'cpu', as defined by quantem.config""" + def device(self) -> str | list[int]: + """Returns the active device: 'cuda:X'/'cpu' for single-GPU, or [gpu_ids] for multi-GPU.""" + if self._multi_gpu_devices is not None: + return self._multi_gpu_devices if hasattr(self, "_device"): return self._device - else: - return config.get("device") + return config.get("device") @device.setter def device(self, device: str | int | None): @@ -613,6 +646,11 @@ def device(self, device: str | int | None): except AttributeError: pass + @property + def _single_device(self) -> str: + """Single-device string for internal tensor operations. Always str, never a list.""" + return self._device if hasattr(self, "_device") else str(config.get("device")) + @property def _obj_dtype(self) -> "torch.dtype": return self.obj_model.dtype @@ -628,8 +666,17 @@ def _dtype_complex(self) -> "torch.dtype": @property def obj_cropped(self) -> np.ndarray: + """Cropped + FOV-rotated object, in its native representation. + + - ``obj_type="complex"`` → complex array (amp * exp(1j*phase)); phase is + recentered to zero mean here as a defensive duplicate of + ``ObjectConstraints._apply_hard_complex``. + - ``obj_type="pure_phase"`` → real array of phase values (already + recentered upstream by ``_apply_hard_pure_phase``). + - ``obj_type="potential"`` → real array of potential values. + """ cropped = self._crop_rotate_obj_fov(self.obj, padding=self.obj_padding_px) - if self.obj_type in ["pure_phase", "complex"]: + if self.obj_type == "complex": ph = np.angle(cropped) cropped = np.abs(cropped) * np.exp(1j * (ph - ph.mean())) return cropped @@ -770,13 +817,13 @@ def _to_torch( raise TypeError(f"dtype should be string or torch.dtype, got {type(dtype)} {dtype}") if isinstance(array, np.ndarray): - t = torch.tensor(array.copy(), device=self.device, dtype=dt) + t = torch.tensor(array.copy(), device=self._single_device, dtype=dt) elif isinstance(array, torch.Tensor): - t = array.to(self.device) + t = array.to(self._single_device) if dt is not None: t = t.type(dt) elif isinstance(array, (list, tuple)): - t = torch.tensor(array, device=self.device, dtype=dt) + t = torch.tensor(array, device=self._single_device, dtype=dt) else: raise TypeError(f"arr should be ndarray or Tensor, got {type(array)}") return t @@ -876,10 +923,57 @@ def get_probe_intensities( intensities = np.abs(probe) ** 2 return intensities.sum(axis=(-2, -1)) / intensities.sum() + def _broadcast_parameters(self, src: int = 0) -> None: + """Broadcast obj, probe, and dataset parameters from rank src to all other ranks. + + Uses .parameters() so it works for both pixelated and DIP/INR models. The dataset's + learnable params (scan positions / descan shifts) must also be broadcast: with the + DistributedSampler partitioning scan positions, the full position params are replicated + on every rank, so they must start identical and stay synchronized. + """ + for p in self.obj_model.parameters(): + dist.broadcast(p.data, src=src) + for p in self.probe_model.parameters(): + dist.broadcast(p.data, src=src) + for group in self.dset.get_optimization_parameters().values(): + for p in group: + buf = p.data.contiguous() + dist.broadcast(buf, src=src) + p.data.copy_(buf) + + def _all_reduce_gradients(self) -> None: + """Average obj, probe, and dataset gradients across all ranks (call after backward, + before step). + + Uses .parameters() so it works for both pixelated and DIP/INR models. The dataset's + learnable params are included because each scan position's gradient is nonzero on + exactly one rank, so they must be reduced (AVG) to stay consistent across ranks. + """ + dset_params = [ + p for group in self.dset.get_optimization_parameters().values() for p in group + ] + params = [ + p + for p in ( + list(self.obj_model.parameters()) + + list(self.probe_model.parameters()) + + dset_params + ) + if p.grad is not None + ] + if params: + for p in params: + if p.grad is not None and not p.grad.is_contiguous(): + p.grad = p.grad.contiguous() + all_reduce_params(*params) + def to(self, device: str | int | torch.device): dev, _id = config.validate_device(device) - if dev != self.device: - self._device = dev + self._device = dev + self._multi_gpu_devices = None + # Sync each sub-model's own device tracker so their reset() uses the correct device + self.obj_model.device = dev + self.probe_model.device = dev self.obj_model.to(dev) self.probe_model.to(dev) self.dset.to(dev) @@ -887,6 +981,82 @@ def to(self, device: str | int | torch.device): self._propagators = self._to_torch(self._propagators) self._rng_to_device(dev) + def _build_dataloaders( + self, + train_indices: np.ndarray, + val_indices: np.ndarray, + world_size: int, + rank: int, + num_workers: int, + ) -> "tuple[DataLoader, DistributedSampler | None, DataLoader | None]": + """Build train + (optional) val DataLoaders for both single- and multi-GPU paths. + + Mirrors the shape of ``DDPMixin.setup_dataloader`` but adapted to ptycho's device + contract (``str | list[int]``) and ptycho's precomputed ``val_mode`` index split. + ``world_size > 1`` uses ``DistributedSampler`` over a ``Subset``; ``world_size == 1`` + uses ``shuffle=True`` with a seeded ``torch.Generator`` for run-to-run determinism. + ``__getitem__`` returns ``{"index": idx, ...}`` for the original dataset index, and + ``Subset[i]`` calls ``dataset[indices[i]]``, so ``batch["index"]`` is the original + dataset index under either branch. + """ + pin_memory = self.dset.target_residency == "cpu" and str(self._single_device).startswith( + "cuda" + ) + loader_kwargs: dict[str, Any] = { + "batch_size": self.batch_size, + "num_workers": num_workers, + "pin_memory": pin_memory, + "drop_last": False, + } + if num_workers > 0: + loader_kwargs.update( + multiprocessing_context="spawn", + persistent_workers=True, + worker_init_fn=worker_init_fn, + ) + + train_subset = torch.utils.data.Subset(self.dset, train_indices.tolist()) + val_subset = ( + torch.utils.data.Subset(self.dset, val_indices.tolist()) + if len(val_indices) > 0 + else None + ) + + if world_size > 1: + train_sampler = DistributedSampler( + train_subset, + num_replicas=world_size, + rank=rank, + shuffle=True, + seed=int(self.rng.integers(0, 2**31 - 1)), + drop_last=False, + ) + train_loader = DataLoader(train_subset, sampler=train_sampler, **loader_kwargs) + if val_subset is not None: + val_sampler = DistributedSampler( + val_subset, + num_replicas=world_size, + rank=rank, + shuffle=False, + drop_last=False, + ) + val_loader = DataLoader(val_subset, sampler=val_sampler, **loader_kwargs) + else: + val_loader = None + else: + train_sampler = None + shuffle_gen = torch.Generator().manual_seed(int(self.rng.integers(0, 2**31 - 1))) + train_loader = DataLoader( + train_subset, shuffle=True, generator=shuffle_gen, **loader_kwargs + ) + val_loader = ( + DataLoader(val_subset, shuffle=False, **loader_kwargs) + if val_subset is not None + else None + ) + + return train_loader, train_sampler, val_loader + # endregion # region --- ptychography foRcard model --- @@ -911,21 +1081,23 @@ def error_estimate( self, pred_intensities: torch.Tensor, batch_indices: np.ndarray, + targets: torch.Tensor, loss_type: Literal[ "l2_amplitude", "l1_amplitude", "l2_intensity", "l1_intensity", "poisson" ] = "l2_amplitude", + global_n: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - targets = self.dset.targets[batch_indices] if "amplitude" in loss_type: preds = torch.sqrt(pred_intensities + 1e-9) # add eps to avoid diverging gradients else: preds = pred_intensities diff = preds * self.dset.detector_mask - targets * self.dset.detector_mask + n = global_n if global_n is not None else self.dset.num_gpts if "l1" in loss_type: - error = torch.sum(torch.abs(diff)) / (diff.shape[0] / self.dset.num_gpts) + error = torch.sum(torch.abs(diff)) / (diff.shape[0] / n) elif "l2" in loss_type: - error = torch.sum(torch.abs(diff) ** 2) / (diff.shape[0] / self.dset.num_gpts) + error = torch.sum(torch.abs(diff) ** 2) / (diff.shape[0] / n) elif loss_type == "poisson": error = torch.sum(preds - targets * torch.log(preds + 1e-6)) else: diff --git a/src/quantem/diffractive_imaging/ptychography_lite.py b/src/quantem/diffractive_imaging/ptychography_lite.py index 3cdb4a02..0bca7a49 100644 --- a/src/quantem/diffractive_imaging/ptychography_lite.py +++ b/src/quantem/diffractive_imaging/ptychography_lite.py @@ -32,21 +32,21 @@ def from_dataset( *, # object settings num_slices: int = 1, - slice_thicknesses: float | Sequence | None = None, + slice_thicknesses: float | Sequence[float] | None = None, obj_type: Literal["complex", "pure_phase", "potential"] = "complex", # probe settings num_probes: int = 1, energy: float | None = None, defocus: float | None = None, semiangle_cutoff: float | None = None, - polar_parameters: dict | None = None, + polar_parameters: dict[str, Any] | None = None, middle_focus: bool = False, vacuum_probe_intensity: np.ndarray | Dataset4dstem | None = None, initial_probe_weights: list[float] | np.ndarray | None = None, # preprocessing obj_padding_px: tuple[int, int] = (0, 0), # logging/device - log_dir: os.PathLike | str | None = None, + log_dir: os.PathLike[str] | str | None = None, log_prefix: str = "", log_images_every: int = 10, log_probe_images: bool = False, @@ -166,7 +166,7 @@ def from_dataset( ) return ptycho - def reconstruct( # type:ignore could do overloads but this is simpler... + def reconstruct( # pyright: ignore[reportIncompatibleMethodOverride] self, num_iters: int = 0, reset: bool = False, @@ -178,9 +178,9 @@ def reconstruct( # type:ignore could do overloads but this is simpler... scheduler_type: Literal["exp", "cyclic", "plateau", "none"] = "none", scheduler_factor: float = 0.5, new_optimizers: bool = False, # not sure what the default should be - constraints: dict = {}, # TODO add constraints flags + constraints: dict[str, Any] | None = None, store_iterations_every: int | None = None, - device: Literal["cpu", "gpu"] | None = None, + device: "Literal['cpu', 'gpu'] | int | list[int] | None" = None, verbose: int | bool = True, ) -> Self: self.verbose = verbose @@ -202,6 +202,8 @@ def reconstruct( # type:ignore could do overloads but this is simpler... if not needs_dataset_optimizer and "dataset" in self.optimizers: self.remove_optimizer("dataset") + opt_params: dict[str, Any] | None + scheduler_params: dict[str, Any] | None if setup_new_optimizers or (needs_dataset_optimizer and "dataset" not in self.optimizers): opt_params = { "object": { @@ -237,8 +239,6 @@ def reconstruct( # type:ignore could do overloads but this is simpler... opt_params = None scheduler_params = None - constraints = constraints # placeholder for constraints flags - return super().reconstruct( num_iters=num_iters, reset=reset, @@ -283,7 +283,7 @@ def from_file( upgraded = cls._recursive_load_from_path(path) return upgraded # type: ignore[return-value] - return base # type: ignore[return-value] + return base # pyright: ignore[reportReturnType] class PtychoLiteDIP(Ptychography): @@ -305,9 +305,9 @@ def from_ptycholite( normalize_object_plotting: bool = True, # model settings cnn_num_layers: int = 3, - final_activation: str | Callable = nn.Identity(), + final_activation: "str | Callable[..., Any]" = nn.Identity(), # logging/device - log_dir: os.PathLike | str | None = None, + log_dir: os.PathLike[str] | str | None = None, log_prefix: str = "", log_images_every: int = 10, log_probe_images: bool = False, @@ -407,7 +407,7 @@ def from_ptycholite( ) return ptycho - def reconstruct( # type:ignore could do overloads but this is simpler... + def reconstruct( # pyright: ignore[reportIncompatibleMethodOverride] self, num_iters: int = 0, reset: bool = False, @@ -419,9 +419,9 @@ def reconstruct( # type:ignore could do overloads but this is simpler... scheduler_type: Literal["exp", "cyclic", "plateau", "none"] = "none", scheduler_factor: float = 0.5, new_optimizers: bool = False, # not sure what the default should be - constraints: dict = {}, # TODO add constraints flags + constraints: dict[str, Any] | None = None, store_iterations_every: int | None = None, - device: Literal["cpu", "gpu"] | None = None, + device: Literal["cpu", "gpu"] | int | list[int] | None = None, verbose: int | bool = True, ) -> Self: self.verbose = verbose @@ -443,6 +443,8 @@ def reconstruct( # type:ignore could do overloads but this is simpler... if not needs_dataset_optimizer and "dataset" in self.optimizers: self.remove_optimizer("dataset") + opt_params: dict[str, Any] | None + scheduler_params: dict[str, Any] | None if setup_new_optimizers or (needs_dataset_optimizer and "dataset" not in self.optimizers): opt_params = { "object": { @@ -478,8 +480,6 @@ def reconstruct( # type:ignore could do overloads but this is simpler... opt_params = None scheduler_params = None - constraints = constraints # placeholder for constraints flags - return super().reconstruct( num_iters=num_iters, reset=reset, diff --git a/src/quantem/diffractive_imaging/ptychography_opt.py b/src/quantem/diffractive_imaging/ptychography_opt.py index ff1b3a2c..a22456f6 100644 --- a/src/quantem/diffractive_imaging/ptychography_opt.py +++ b/src/quantem/diffractive_imaging/ptychography_opt.py @@ -1,8 +1,9 @@ from dataclasses import replace -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from quantem.core import config from quantem.core.ml.optimizer_mixin import ( + OptimizerMixin, OptimizerParams, OptimizerParamsType, SchedulerParams, @@ -19,7 +20,11 @@ class PtychographyOpt(PtychographyBase): """ - A class for performing phase retrieval using the Ptychography algorithm. + Optimizer/scheduler dispatch layer for `Ptychography`. + + Each optimizable component (`object`, `probe`, `dataset`) lives on its own + `OptimizerMixin`-equipped model. The methods here are thin façades that fan a + single dict (`{key: params}`) out to the three models via the `_models` dict. """ OPTIMIZABLE_VALS = ["object", "probe", "dataset"] @@ -28,6 +33,25 @@ class PtychographyOpt(PtychographyBase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + @property + def _models(self) -> dict[str, OptimizerMixin]: + """Maps each optimization key to the model that owns its parameters. + + Not cached: `obj_model`, `probe_model`, and `dset` can be reassigned + (e.g. by `from_ptychography`), and we must always see the current binding. + """ + return { + "object": self.obj_model, + "probe": self.probe_model, + "dataset": self.dset, + } + + def _check_key(self, key: str) -> None: + if key not in self.OPTIMIZABLE_VALS: + raise ValueError( + f"key to be optimized, {key}, not in allowed keys: {self.OPTIMIZABLE_VALS}" + ) + def _get_default_lr(self, key: str) -> float: """Get default learning rate for a given optimization key.""" if key == "object": @@ -39,26 +63,22 @@ def _get_default_lr(self, key: str) -> float: else: raise ValueError(f"Unknown optimization key: {key}") - # region --- explicit properties and setters --- + # region --- optimizer params --- @property def optimizer_params(self) -> dict[str, OptimizerParamsType | dict[str, OptimizerParamsType]]: return { - key: params - for key, params in [ - ("object", self.obj_model.optimizer_params), - ("probe", self.probe_model.optimizer_params), - ("dataset", self.dset.optimizer_params), - ] - if not isinstance(params, OptimizerParams.NoneOptimizer) + key: model.optimizer_params + for key, model in self._models.items() + if not isinstance(model.optimizer_params, OptimizerParams.NoneOptimizer) } @optimizer_params.setter - def optimizer_params(self, d: dict) -> None: + def optimizer_params(self, d: dict[str, Any] | list[str] | tuple[str, ...]) -> None: """ Takes a dictionary mapping optimizable keys to either an ``OptimizerParamsType`` dataclass or a plain dict (with optional ``"name"``/``"type"`` and ``"lr"`` - keys). Missing ``"name"`` / ``"lr"`` are filled from ``DEFAULT_OPTIMIZER_TYPE`` + keys). Missing ``"name"`` / ``"lr"`` are filled from ``DEFAULT_OPTIMIZER_TYPE`` and ``_get_default_lr`` respectively. Examples @@ -67,10 +87,9 @@ def optimizer_params(self, d: dict) -> None: >>> ptycho.optimizer_params = {"object": {"name": "adam", "lr": 5e-3}} >>> ptycho.optimizer_params = ["object", "probe"] # use all defaults """ - if isinstance(d, (tuple, list)): - d = {k: {} for k in d} + _d: dict[str, Any] = {k: {} for k in d} if isinstance(d, (list, tuple)) else d - for k, v in d.items(): + for k, v in _d.items(): if isinstance(v, OptimizerParamsType): pass # already a dataclass, pass through elif isinstance(v, dict): @@ -84,51 +103,44 @@ def optimizer_params(self, d: dict) -> None: else: raise TypeError(f"Expected OptimizerParamsType or dict for key '{k}', got {type(v)}") - if k == "object": - self.obj_model.optimizer_params = v - elif k == "probe": - self.probe_model.optimizer_params = v - elif k == "dataset": - self.dset.optimizer_params = v - else: - raise ValueError( - f"key to be optimized, {k}, not in allowed keys: {self.OPTIMIZABLE_VALS}" - ) + self._models[k].optimizer_params = v # type: ignore[assignment] + + # endregion --- optimizer params --- + + # region --- optimizers --- @property def optimizers(self) -> dict[str, "torch.optim.Optimizer"]: - """Get optimizers from all models.""" - optimizers = {} - if self.obj_model.has_optimizer(): - optimizers["object"] = self.obj_model.optimizer - if self.probe_model.has_optimizer(): - optimizers["probe"] = self.probe_model.optimizer - if self.dset.has_optimizer(): - optimizers["dataset"] = self.dset.optimizer - return optimizers - - def set_optimizers(self): - """Set optimizers for each model.""" + """Active optimizers, keyed by optimization key.""" + return { + key: model.optimizer # type: ignore[reportIncompatibleMethodOverride] + for key, model in self._models.items() + if model.has_optimizer() + } + + def set_optimizers(self) -> None: + """(Re)create an optimizer on each model whose params are not `NoneOptimizer`.""" for key, params in self.optimizer_params.items(): - if key == "object": - self.obj_model.set_optimizer(params) - elif key == "probe": - self.probe_model.set_optimizer(params) - elif key == "dataset": - self.dset.set_optimizer(params) - else: - raise ValueError( - f"key to be optimized, {key}, not in allowed keys: {self.OPTIMIZABLE_VALS}" - ) + self._models[key].set_optimizer(params) def remove_optimizer(self, key: str) -> None: - """Remove optimizer from a specific model.""" - if key == "object": - self.obj_model.remove_optimizer() - elif key == "probe": - self.probe_model.remove_optimizer() - elif key == "dataset": - self.dset.remove_optimizer() + """Tear down the optimizer on the model for `key`.""" + self._check_key(key) + self._models[key].remove_optimizer() + + def step_optimizers(self) -> None: + for model in self._models.values(): + if model.has_optimizer(): + model.step_optimizer() + + def zero_grad_all(self) -> None: + for model in self._models.values(): + if model.has_optimizer(): + model.zero_optimizer_grad() + + # endregion --- optimizers --- + + # region --- schedulers --- @property def scheduler_params(self) -> dict[str, SchedulerParamsType]: @@ -140,7 +152,7 @@ def scheduler_params(self) -> dict[str, SchedulerParamsType]: } @scheduler_params.setter - def scheduler_params(self, d: dict) -> None: + def scheduler_params(self, d: dict[str, Any] | list[str] | tuple[str, ...]) -> None: """ Takes a dictionary mapping optimizable keys to either a ``SchedulerParamsType`` dataclass or a plain dict. Keys not present in ``d`` are set to @@ -151,76 +163,30 @@ def scheduler_params(self, d: dict) -> None: >>> ptycho.scheduler_params = {"object": SchedulerParams.Plateau(factor=0.5)} >>> ptycho.scheduler_params = {"object": {"name": "plateau", "factor": 0.5}} """ + _d: dict[str, Any] = {k: {} for k in d} if isinstance(d, (list, tuple)) else dict(d) for key in self.OPTIMIZABLE_VALS: - if key not in d: - d[key] = SchedulerParams.NoneScheduler() - for k, v in d.items(): - if k == "object": - self.obj_model.scheduler_params = v - elif k == "probe": - self.probe_model.scheduler_params = v - elif k == "dataset": - self.dset.scheduler_params = v - else: - raise ValueError( - f"key to be optimized, {k}, not in allowed keys: {self.OPTIMIZABLE_VALS}" - ) + _d.setdefault(key, SchedulerParams.NoneScheduler()) + for k, v in _d.items(): + self._check_key(k) + self._models[k].scheduler_params = v @property - def schedulers(self) -> dict[str, "torch.optim.lr_scheduler._LRScheduler"]: - """Get schedulers from all models.""" - schedulers = {} - if self.obj_model.scheduler is not None: - schedulers["object"] = self.obj_model.scheduler - if self.probe_model.scheduler is not None: - schedulers["probe"] = self.probe_model.scheduler - if self.dset.scheduler is not None: - schedulers["dataset"] = self.dset.scheduler - return schedulers + def schedulers(self) -> dict[str, "torch.optim.lr_scheduler.LRScheduler"]: + return { + key: model.scheduler + for key, model in self._models.items() + if model.scheduler is not None + } def set_schedulers(self, params: dict[str, SchedulerParamsType], num_iter: int | None = None): """Set schedulers for each model.""" for key, scheduler_params in params.items(): - if key not in self.OPTIMIZABLE_VALS: - raise ValueError( - f"key to be optimized, {key}, not in allowed keys: {self.OPTIMIZABLE_VALS}" - ) - - if key == "object": - self.obj_model.set_scheduler(scheduler_params, num_iter) - elif key == "probe": - self.probe_model.set_scheduler(scheduler_params, num_iter) - elif key == "dataset": - self.dset.set_scheduler(scheduler_params, num_iter) - - def step_optimizers(self): - """Step all active optimizers.""" - for key in self.optimizer_params.keys(): - if key == "object" and self.obj_model.has_optimizer(): - self.obj_model.step_optimizer() - elif key == "probe" and self.probe_model.has_optimizer(): - self.probe_model.step_optimizer() - elif key == "dataset" and self.dset.has_optimizer(): - self.dset.step_optimizer() - - def zero_grad_all(self): - """Zero gradients for all active optimizers.""" - for key in self.optimizer_params.keys(): - if key == "object" and self.obj_model.has_optimizer(): - self.obj_model.zero_optimizer_grad() - elif key == "probe" and self.probe_model.has_optimizer(): - self.probe_model.zero_optimizer_grad() - elif key == "dataset" and self.dset.has_optimizer(): - self.dset.zero_optimizer_grad() - - def step_schedulers(self, loss: float | None = None): - """Step all active schedulers.""" - for key in self.scheduler_params.keys(): - if key == "object" and self.obj_model.scheduler is not None: - self.obj_model.step_scheduler(loss) - elif key == "probe" and self.probe_model.scheduler is not None: - self.probe_model.step_scheduler(loss) - elif key == "dataset" and self.dset.scheduler is not None: - self.dset.step_scheduler(loss) - - # endregion --- explicit properties and setters --- + self._check_key(key) + self._models[key].set_scheduler(scheduler_params, num_iter) + + def step_schedulers(self, loss: float | None = None) -> None: + for model in self._models.values(): + if model.scheduler is not None: + model.step_scheduler(loss) + + # endregion --- schedulers --- diff --git a/src/quantem/diffractive_imaging/ptychography_visualizations.py b/src/quantem/diffractive_imaging/ptychography_visualizations.py index 3470a5f3..f1e7010d 100644 --- a/src/quantem/diffractive_imaging/ptychography_visualizations.py +++ b/src/quantem/diffractive_imaging/ptychography_visualizations.py @@ -70,12 +70,16 @@ def show_obj( ims = [] titles = [] cmaps = [] + # obj_np dtype depends on obj_type: + # potential -> real values to plot directly + # pure_phase -> real phase to plot directly (no np.angle wrap) + # complex -> complex; extract amp & phase via np.abs / np.angle if self.obj_type == "potential": ims.append(np.abs(obj_np).sum(0)) titles.append(t + "Potential") cmaps.append(ph_cmap) elif self.obj_type == "pure_phase": - ims.append(np.angle(obj_np).sum(0)) + ims.append(obj_np.sum(0)) titles.append(t + "Pure Phase") cmaps.append(ph_cmap) else: @@ -145,8 +149,14 @@ def show_obj_fft( tukey(obj_np.shape[-2], tukey_alpha)[:, None] * tukey(obj_np.shape[-1], tukey_alpha)[None, :] ) + # Build the complex transmission function and apply the spatial window: + # potential -> real values, plotted in real space + # pure_phase -> real phase; transmission = exp(1j*phase) + # complex -> already complex transmission if self.obj_type == "potential": windowed_obj = obj_np.sum(0) * window_2d + elif self.obj_type == "pure_phase": + windowed_obj = np.exp(1j * obj_np.sum(0)) * window_2d else: windowed_obj = ( np.abs(obj_np).sum(0) @@ -398,7 +408,7 @@ def show_obj_slices( objs_flat = [np.abs(obj[i]) for i in range(len(obj))] titles_flat = [f"Potential {t_parts[i]}" for i in range(len(obj))] elif self.obj_type == "pure_phase": - objs_flat = [np.angle(obj[i]) for i in range(len(obj))] + objs_flat = [obj[i] for i in range(len(obj))] titles_flat = [f"Pure Phase {t_parts[i]}" for i in range(len(obj))] else: objs_flat = [np.angle(obj[i]) for i in range(len(obj))] @@ -690,7 +700,7 @@ def _show_object_iters_only( all_titles.append(title_prefix + "Potential") all_cmaps.append(ph_cmap) elif self.obj_type == "pure_phase": - all_images.append(np.angle(obj).sum(0)) + all_images.append(obj.sum(0)) all_titles.append(title_prefix + "Phase") all_cmaps.append(ph_cmap) else: # complex @@ -805,7 +815,7 @@ def _show_object_and_probe_iters( row_titles.append(f"Iter {iteration} Potential") row_cmaps.append(ph_cmap) elif self.obj_type == "pure_phase": - row_images.append(np.angle(obj).sum(0)) + row_images.append(obj.sum(0)) row_titles.append(f"Iter {iteration} Phase") row_cmaps.append(ph_cmap) else: # complex diff --git a/tests/diffractive_imaging/test_constraints.py b/tests/diffractive_imaging/test_constraints.py new file mode 100644 index 00000000..eeb7e1a0 --- /dev/null +++ b/tests/diffractive_imaging/test_constraints.py @@ -0,0 +1,289 @@ +"""Tests for the ptychography constraint dataclass API.""" + +import warnings + +import numpy as np +import pytest +import torch + +from quantem.core.datastructures import Dataset4dstem +from quantem.diffractive_imaging import ( + DetectorPixelated, + ObjectPixelated, + ProbePixelated, + PtychoDatasetConstraintParams, + Ptychography, + PtychographyDatasetRaster, + PtychoObjConstraintParams, + PtychoProbeConstraintParams, +) + +N_SCAN = 8 +N_DET = 16 +PROBE_ENERGY = 80e3 +PROBE_SEMIANGLE = 20 +PROBE_DEFOCUS = 100 + + +@pytest.fixture +def ptycho(): + rng = np.random.default_rng(42) + array = rng.random((N_SCAN, N_SCAN, N_DET, N_DET)).astype(np.float32) + dset = Dataset4dstem.from_array( + array, + name="test", + sampling=[1.0, 1.0, 0.05, 0.05], + units=["A", "A", "A^-1", "A^-1"], + ) + pdset = PtychographyDatasetRaster.from_dataset4dstem(dset) + pdset.preprocess(com_fit_function="constant", plot_rotation=False, plot_com=False) + obj = ObjectPixelated.from_uniform(obj_type="pure_phase", num_slices=1) + probe = ProbePixelated.from_params( + probe_params={ + "energy": PROBE_ENERGY, + "defocus": PROBE_DEFOCUS, + "semiangle_cutoff": PROBE_SEMIANGLE, + } + ) + p = Ptychography.from_models( + dset=pdset, + obj_model=obj, + probe_model=probe, + detector_model=DetectorPixelated(), + verbose=False, + rng=42, + ) + p.preprocess(obj_padding_px=(4, 4)) + return p + + +# --- parse_dict tests --------------------------------------------------------- + + +class TestParseDict: + def test_object_raster_by_name(self): + c = PtychoObjConstraintParams.parse_dict({"name": "raster", "tv_weight_z": 5.0}) + assert isinstance(c, PtychoObjConstraintParams.Raster) + assert c.tv_weight_z == 5.0 + assert c.positivity is True # default preserved + + def test_object_inr_by_type(self): + c = PtychoObjConstraintParams.parse_dict({"type": "inr"}) + assert isinstance(c, PtychoObjConstraintParams.INR) + + def test_object_unknown_raises(self): + with pytest.raises(ValueError, match="Unknown object constraint type"): + PtychoObjConstraintParams.parse_dict({"name": "nope"}) + + def test_object_missing_name_raises(self): + with pytest.raises(ValueError, match="Must provide either 'name' or 'type'"): + PtychoObjConstraintParams.parse_dict({"tv_weight_z": 5.0}) + + def test_probe_raster_with_fields(self): + c = PtychoProbeConstraintParams.parse_dict( + {"name": "raster", "center_probe": True, "tv_weight": 0.1} + ) + assert isinstance(c, PtychoProbeConstraintParams.Raster) + assert c.center_probe is True + assert c.tv_weight == 0.1 + + def test_dataset_raster_default(self): + c = PtychoDatasetConstraintParams.parse_dict({"name": "raster"}) + assert isinstance(c, PtychoDatasetConstraintParams.Raster) + assert c.clip_scan_positions is True # default preserved + + +# --- Constraint typo catching ------------------------------------------------- + + +class TestTypoCatching: + def test_setting_unknown_field_via_dict_raises(self, ptycho): + with pytest.raises(KeyError, match="Invalid constraint key"): + ptycho.obj_model.constraints = {"not_a_real_field": True} + + def test_add_constraint_unknown_key_raises(self, ptycho): + with pytest.raises(KeyError, match="Invalid constraint key"): + ptycho.obj_model.add_constraint("not_a_real_field", True) + + +# --- Round-trip: pass dataclass via reconstruct(), read back through getter --- + + +class TestRoundtrip: + def test_obj_constraints_dataclass(self, ptycho): + obj_c = PtychoObjConstraintParams.Raster(tv_weight_z=2.5, identical_slices=True) + ptycho.constraints = {"object": obj_c} + assert ptycho.obj_model.constraints is obj_c + assert ptycho.obj_model.constraints.tv_weight_z == 2.5 + assert ptycho.obj_model.constraints.identical_slices is True + + def test_probe_constraints_dataclass(self, ptycho): + probe_c = PtychoProbeConstraintParams.Raster(center_probe=True, tv_weight=0.05) + ptycho.constraints = {"probe": probe_c} + assert ptycho.probe_model.constraints is probe_c + + def test_dataset_constraints_dataclass(self, ptycho): + dset_c = PtychoDatasetConstraintParams.Raster(descan_tv_weight=0.01) + ptycho.constraints = {"dataset": dset_c} + assert ptycho.dset.constraints is dset_c + + def test_dict_form_still_works(self, ptycho): + """Backward compatibility: nested-dict form sets individual fields.""" + ptycho.constraints = { + "object": {"tv_weight_z": 3.0, "positivity": False}, + "probe": {"tv_weight": 0.02}, + } + assert ptycho.obj_model.constraints.tv_weight_z == 3.0 + assert ptycho.obj_model.constraints.positivity is False + assert ptycho.probe_model.constraints.tv_weight == 0.02 + + +# --- Reconstruct() with constraints= ------------------------------------------ + + +class TestReconstructKwargs: + def test_dataclass_leaf_applied(self, ptycho): + from quantem.core.ml import OptimizerParams + + obj_c = PtychoObjConstraintParams.Raster(tv_weight_z=1.5) + ptycho.reconstruct( + num_iters=1, + reset=True, + optimizer_params={"object": OptimizerParams.Adam(lr=1e-2)}, + constraints={"object": obj_c}, + batch_size=4, + device="cpu", + ) + assert ptycho.obj_model.constraints.tv_weight_z == 1.5 + + def test_dict_leaf_partial_update(self, ptycho): + from quantem.core.ml import OptimizerParams + + ptycho.reconstruct( + num_iters=1, + reset=True, + optimizer_params={"object": OptimizerParams.Adam(lr=1e-2)}, + constraints={"object": {"surface_zero_weight": 0.7}}, + batch_size=4, + device="cpu", + ) + assert ptycho.obj_model.constraints.surface_zero_weight == 0.7 + # other fields keep their defaults + assert ptycho.obj_model.constraints.positivity is True + + def test_mixed_dataclass_and_dict_leaves(self, ptycho): + from quantem.core.ml import OptimizerParams + + ptycho.reconstruct( + num_iters=1, + reset=True, + optimizer_params={"object": OptimizerParams.Adam(lr=1e-2)}, + constraints={ + "object": PtychoObjConstraintParams.Raster(tv_weight_xy=0.4), + "probe": {"center_probe": True}, + }, + batch_size=4, + device="cpu", + ) + assert ptycho.obj_model.constraints.tv_weight_xy == 0.4 + assert ptycho.probe_model.constraints.center_probe is True + + +# --- Real-valued pure_phase representation ----------------------------------- + + +class TestPurePhaseRealValued: + def test_pure_phase_pixelated_obj_is_real(self): + obj = ObjectPixelated.from_uniform(obj_type="pure_phase", num_slices=1) + obj._initialize_obj((1, 16, 16), sampling=(0.1, 0.1)) + assert not obj._obj.is_complex(), f"pure_phase _obj should be real, got {obj._obj.dtype}" + + def test_complex_pixelated_obj_is_complex(self): + obj = ObjectPixelated.from_uniform(obj_type="complex", num_slices=1) + obj._initialize_obj((1, 16, 16), sampling=(0.1, 0.1)) + assert obj._obj.is_complex() + + def test_potential_pixelated_obj_is_real(self): + obj = ObjectPixelated.from_uniform(obj_type="potential", num_slices=1) + obj._initialize_obj((1, 16, 16), sampling=(0.1, 0.1)) + assert not obj._obj.is_complex() + + def test_pure_phase_tv_emits_no_phase_warning(self): + obj = ObjectPixelated.from_uniform(obj_type="pure_phase", num_slices=1) + obj._initialize_obj((1, 16, 16), sampling=(0.1, 0.1)) + obj.constraints.tv_weight_xy = 0.1 + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + obj.get_tv_loss(obj._obj) + phase_warnings = [w for w in caught if "phase wrapping" in str(w.message)] + assert not phase_warnings, ( + f"pure_phase should not emit phase-wrap warning, got {phase_warnings}" + ) + + def test_complex_tv_still_emits_phase_warning(self): + obj = ObjectPixelated.from_uniform(obj_type="complex", num_slices=1) + obj._initialize_obj((1, 16, 16), sampling=(0.1, 0.1)) + obj.constraints.tv_weight_xy = 0.1 + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + obj.get_tv_loss(obj._obj) + assert any("phase wrapping" in str(w.message) for w in caught), ( + "complex obj_type should still emit phase-wrap warning" + ) + + def test_pure_phase_apply_hard_constraints_stays_real(self): + obj = ObjectPixelated.from_uniform(obj_type="pure_phase", num_slices=1) + obj._initialize_obj((1, 16, 16), sampling=(0.1, 0.1)) + out = obj.apply_hard_constraints(obj._obj) + assert not out.is_complex() + + +# --- FOV-mask single application --------------------------------------------- + + +class TestFovMaskSingleApplication: + def _make_obj(self, obj_type) -> ObjectPixelated: + obj = ObjectPixelated.from_uniform(obj_type=obj_type, num_slices=1) + obj._initialize_obj((1, 16, 16), sampling=(0.1, 0.1)) + obj.constraints.apply_fov_mask = True + # Force a non-trivial _obj so masking is observable + if obj_type == "complex": + obj._obj = torch.nn.Parameter( + torch.ones(1, 16, 16, dtype=torch.complex64) * (0.5 + 0.3j) + ) + else: + obj._obj = torch.nn.Parameter(torch.full((1, 16, 16), 0.7)) + return obj + + @pytest.mark.parametrize("obj_type", ["pure_phase", "complex", "potential"]) + def test_mask_applied_once(self, obj_type): + obj = self._make_obj(obj_type) + # Half-mask: ones on the left, zeros on the right; if mask is applied + # twice the masked region squares the multiplication (no observable + # difference for 0/1 masks), so use a non-binary mask. + mask = torch.full((1, 16, 16), 0.5) + obj._mask = mask + out = obj.apply_hard_constraints(obj._obj, mask=mask) + # Verify nothing crashed and shape is preserved. + assert out.shape == obj._obj.shape + # If mask had been applied twice, |out| would scale by 0.5**2 = 0.25 + # of the unmasked value; once it scales by 0.5. We compare to the + # per-obj-type expected post-constraint value. + if obj_type == "pure_phase": + # phase recentered to zero mean, then *= 0.5 mask + expected_mag = 0.0 # phase=constant -> recenter to 0 -> *0.5 = 0 + elif obj_type == "potential": + # positivity clamp keeps 0.7, * 0.5 -> 0.35 (one application) + expected_mag = 0.35 + else: # complex + # amp clamp keeps 0.5+0.3j, * 0.5 -> magnitude 0.5 * |0.5+0.3j| + expected_mag = 0.5 * abs(0.5 + 0.3j) + # Sample the magnitude in the masked region + if out.is_complex(): + sampled = out.abs().mean().item() + else: + sampled = out.abs().mean().item() + assert abs(sampled - expected_mag) < 1e-4, ( + f"{obj_type}: expected mag ~{expected_mag}, got {sampled} " + f"(would be {expected_mag * 0.5} if mask were applied twice)" + ) diff --git a/tests/diffractive_imaging/test_multi_gpu.py b/tests/diffractive_imaging/test_multi_gpu.py new file mode 100644 index 00000000..1fb9b4dc --- /dev/null +++ b/tests/diffractive_imaging/test_multi_gpu.py @@ -0,0 +1,241 @@ +"""Multi-GPU state-management tests for iterative ptychography. + +Ported from the standalone verify_multi_gpu.py script. Tests are grouped by +scenario (one test per related cluster of assertions) so that the spawn +overhead is paid once per scenario rather than once per assertion. + +All tests are marked ``slow`` and skipped when fewer than 2 CUDA devices are +available. Run with: ``uv run pytest tests/diffractive_imaging/test_multi_gpu.py --runslow``. +""" + +import numpy as np +import pytest +import torch + +# Helpers must live at module scope so forkserver-spawned DataLoader / DDP workers +# can pickle and re-import them. + +N_SCAN = 8 +N_DET = 32 +PROBE_ENERGY = 80e3 +PROBE_SEMIANGLE = 20 +PROBE_DEFOCUS = 100 +N_ITERS = 4 +GPU_IDS = [0, 1] +DEVICE_0 = "cuda:0" + + +def _make_dataset(): + from quantem.core.datastructures import Dataset4dstem + + rng = np.random.default_rng(42) + array = rng.random((N_SCAN, N_SCAN, N_DET, N_DET)).astype(np.float32) + return Dataset4dstem.from_array( + array, + name="test", + sampling=[1.0, 1.0, 0.05, 0.05], + units=["A", "A", "A^-1", "A^-1"], + ) + + +def _make_ptycho(): + from quantem.diffractive_imaging import ( + DetectorPixelated, + ObjectPixelated, + ProbePixelated, + Ptychography, + PtychographyDatasetRaster, + ) + + pdset = PtychographyDatasetRaster.from_dataset4dstem(_make_dataset()) + pdset.preprocess(com_fit_function="constant", plot_rotation=False, plot_com=False) + obj = ObjectPixelated.from_uniform(obj_type="pure_phase", num_slices=1) + probe = ProbePixelated.from_params( + probe_params={ + "energy": PROBE_ENERGY, + "defocus": PROBE_DEFOCUS, + "semiangle_cutoff": PROBE_SEMIANGLE, + } + ) + ptycho = Ptychography.from_models( + dset=pdset, + obj_model=obj, + probe_model=probe, + detector_model=DetectorPixelated(), + verbose=False, + rng=42, + ) + ptycho.preprocess(obj_padding_px=(4, 4)) + return ptycho + + +def _make_dip_ptycho(): + from quantem.core.ml import OptimizerParams + from quantem.diffractive_imaging import PtychoLite, PtychoLiteDIP + + base = _make_ptycho() + base.reconstruct( + num_iters=5, + reset=True, + optimizer_params={ + "object": OptimizerParams.Adam(lr=1e-2), + "probe": OptimizerParams.Adam(lr=1e-2), + }, + batch_size=16, + device=0, + ) + lite = PtychoLite.from_models( + dset=base.dset, + obj_model=base.obj_model, + probe_model=base.probe_model, + detector_model=base.detector_model, + verbose=False, + rng=42, + ) + lite.preprocess(obj_padding_px=(4, 4)) + return PtychoLiteDIP.from_ptycholite(lite, device="cpu", pretrain_iters=None) + + +def _opt(): + from quantem.core.ml import OptimizerParams + + return { + "object": OptimizerParams.Adam(lr=1e-2), + "probe": OptimizerParams.Adam(lr=1e-2), + } + + +# Module-level marks: all tests in this file are slow and require >= 2 GPUs. +pytestmark = [ + pytest.mark.slow, + pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="requires >= 2 CUDA devices", + ), +] + + +class TestSingleGPUDevicePersistence: + """device= argument is sticky across follow-up reconstruct() calls.""" + + def test_int_device_persists(self): + p = _make_ptycho() + p.reconstruct( + num_iters=N_ITERS, reset=True, optimizer_params=_opt(), batch_size=16, device=0 + ) + assert p.device == DEVICE_0 + assert p.obj_model._obj.device.type == "cuda" + + # follow-up call without device= keeps the previous device + p.reconstruct(num_iters=N_ITERS, batch_size=16) + assert p.device == DEVICE_0 + assert p.obj_model._obj.device.type == "cuda" + + # reset=True must not reset device tracking + p.reconstruct(num_iters=N_ITERS, reset=True, batch_size=16) + assert p.device == DEVICE_0 + + +class TestMultiGPUDeviceRestoration: + """device=[…] is stored and restored across spawn boundaries.""" + + def test_gpu_list_persists(self): + p = _make_ptycho() + p.reconstruct( + num_iters=N_ITERS, reset=True, optimizer_params=_opt(), batch_size=16, device=GPU_IDS + ) + assert p.device == GPU_IDS + assert p.obj_model._obj.device.type == "cuda" + + p.reconstruct(num_iters=N_ITERS, batch_size=16) + assert p.device == GPU_IDS + + +class TestMultiGPULossTracking: + """_iter_losses extends correctly across spawn(reset=False)/reset=True.""" + + def test_losses_length_lifecycle(self): + p = _make_ptycho() + p.reconstruct( + num_iters=N_ITERS, reset=True, optimizer_params=_opt(), batch_size=16, device=GPU_IDS + ) + assert len(p._iter_losses) == N_ITERS + + p.reconstruct(num_iters=N_ITERS, reset=False, batch_size=16, device=GPU_IDS) + assert len(p._iter_losses) == 2 * N_ITERS, "continuation must not double-count" + + p.reconstruct(num_iters=N_ITERS, reset=True, batch_size=16, device=GPU_IDS) + assert len(p._iter_losses) == N_ITERS + + +class TestMultiGPULRTracking: + """iter_lrs extends correctly across spawn boundaries.""" + + def test_iter_lrs_lifecycle(self): + p = _make_ptycho() + p.reconstruct( + num_iters=N_ITERS, reset=True, optimizer_params=_opt(), batch_size=16, device=GPU_IDS + ) + assert "object" in p.iter_lrs + assert len(p.iter_lrs["object"]) == N_ITERS + + p.reconstruct(num_iters=N_ITERS, reset=False, batch_size=16, device=GPU_IDS) + assert len(p.iter_lrs["object"]) == 2 * N_ITERS + + p.reconstruct(num_iters=N_ITERS, reset=True, batch_size=16, device=GPU_IDS) + assert len(p.iter_lrs["object"]) == N_ITERS + + +class TestMultiGPUOptimizerState: + """Adam state survives the save/restore around the spawn worker.""" + + def test_adam_state_restored_on_device(self): + p = _make_ptycho() + p.reconstruct( + num_iters=N_ITERS, reset=True, optimizer_params=_opt(), batch_size=16, device=GPU_IDS + ) + obj_opt = p.optimizers.get("object") + assert obj_opt is not None + assert len(obj_opt.state) > 0 + first = next(iter(obj_opt.state.values())) + assert "exp_avg" in first, "Adam moments missing" + assert first["exp_avg"].device.type == "cuda" + + +class TestDIPMultiGPU: + """DIP path mirrors the pixelated multi-GPU contract.""" + + def test_dip_device_and_loss_lifecycle(self): + d = _make_dip_ptycho() + d.reconstruct( + num_iters=N_ITERS, + reset=True, + lr_obj=1e-3, + lr_probe=1e-3, + batch_size=16, + device=GPU_IDS, + ) + assert d.device == GPU_IDS + assert len(d._iter_losses) == N_ITERS + assert "object" in d.iter_lrs + assert len(d.iter_lrs["object"]) == N_ITERS + + d.reconstruct( + num_iters=N_ITERS, + reset=False, + lr_obj=1e-3, + lr_probe=1e-3, + batch_size=16, + device=GPU_IDS, + ) + assert len(d._iter_losses) == 2 * N_ITERS, "DIP continuation must not double-count" + + d.reconstruct( + num_iters=N_ITERS, + reset=True, + lr_obj=1e-3, + lr_probe=1e-3, + batch_size=16, + device=GPU_IDS, + ) + assert len(d._iter_losses) == N_ITERS diff --git a/tests/diffractive_imaging/test_ptychography.py b/tests/diffractive_imaging/test_ptychography.py index 1434a8ee..6d9c16ce 100644 --- a/tests/diffractive_imaging/test_ptychography.py +++ b/tests/diffractive_imaging/test_ptychography.py @@ -1,5 +1,6 @@ """ -Tests for ptychography gradient equivalence between autograd and analytical methods +Tests for ptychography gradient equivalence between autograd and analytical methods, +plus property-style tests for state management and serialization. """ import numpy as np @@ -8,6 +9,8 @@ from quantem.core import config from quantem.core.datastructures.dataset4dstem import Dataset4dstem +from quantem.core.io.serialize import load as autoserialize_load +from quantem.core.ml import OptimizerParams from quantem.core.utils.utils import electron_wavelength_angstrom from quantem.diffractive_imaging.dataset_models import PtychographyDatasetRaster from quantem.diffractive_imaging.detector_models import DetectorPixelated @@ -237,7 +240,7 @@ def test_single_probe_gradients(self, single_probe_ptycho_model): } ptycho.reconstruct( - num_iter=1, + num_iters=1, reset=True, autograd=True, constraints=constraints, @@ -249,7 +252,7 @@ def test_single_probe_gradients(self, single_probe_ptycho_model): grads_probe_ad = ptycho.probe_model._probe.grad.clone().detach().cpu().numpy() ptycho.reconstruct( - num_iter=1, + num_iters=1, reset=True, autograd=False, constraints=constraints, @@ -311,7 +314,7 @@ def test_mixed_probe_gradients(self, mixed_probe_ptycho_model): } ptycho.reconstruct( - num_iter=1, + num_iters=1, reset=True, autograd=True, constraints=constraints, @@ -323,7 +326,7 @@ def test_mixed_probe_gradients(self, mixed_probe_ptycho_model): grads_probe_ad = ptycho.probe_model._probe.grad.clone().detach().cpu().numpy() ptycho.reconstruct( - num_iter=1, + num_iters=1, reset=True, autograd=False, constraints=constraints, @@ -360,3 +363,112 @@ def test_mixed_probe_gradients(self, mixed_probe_ptycho_model): assert ssim_obj_abs > 0.99 # type: ignore assert ssim_probe_angle > 0.7 # type: ignore + + +class TestTargetResidency: + """Property + serialization behavior for the streaming-target knob.""" + + def test_default_is_device(self, ptycho_dataset): + assert ptycho_dataset.target_residency == "device" + + def test_setter_accepts_valid(self, ptycho_dataset): + ptycho_dataset.target_residency = "cpu" + assert ptycho_dataset.target_residency == "cpu" + ptycho_dataset.target_residency = "device" + assert ptycho_dataset.target_residency == "device" + + @pytest.mark.parametrize("bad", ["gpu", "GPU", "CPU", "", "cuda", "Device"]) + def test_setter_rejects_invalid(self, ptycho_dataset, bad): + with pytest.raises(ValueError, match="target_residency"): + ptycho_dataset.target_residency = bad + # value should be unchanged after a rejected set + assert ptycho_dataset.target_residency == "device" + + def test_save_load_roundtrip(self, ptycho_dataset, tmp_path): + ptycho_dataset.target_residency = "cpu" + path = tmp_path / "pdset.zip" + ptycho_dataset.save(str(path)) + reloaded = autoserialize_load(str(path)) + assert reloaded.target_residency == "cpu" + + +@pytest.mark.slow +class TestPtychographySaveLoadRoundtrip: + """Reconstruct → save → load → continue training preserves training state. + + The 0.3 threshold reflects that, on this synthetic ducky-style dataset with the + analytical probe already in hand, a well-formed reconstruction should drive the + loss down by at least 70% in 20 iterations on the right configuration. The bar + is deliberately strict — if you tune optimizer settings and this fires, the + config probably regressed. + """ + + NUM_ITERS = 50 # enough headroom for the strict 0.3 threshold at lr=5e-3 + + @pytest.fixture + def trained_ptycho(self, single_probe_ptycho_model): + ptycho = single_probe_ptycho_model + ptycho.reconstruct( + num_iters=self.NUM_ITERS, + reset=True, + optimizer_params={ + "object": OptimizerParams.Adam(lr=5e-3), + "probe": OptimizerParams.Adam(lr=5e-3), + }, + batch_size=N**2, + device=config.get_device(), + ) + return ptycho + + def test_iter_losses_preserved(self, trained_ptycho, tmp_path): + path = tmp_path / "ptycho.zip" + trained_ptycho.save(str(path), save_raw_data=True) + reloaded = autoserialize_load(str(path)) + np.testing.assert_array_equal(reloaded._iter_losses, trained_ptycho._iter_losses) + + def test_scan_positions_preserved(self, trained_ptycho, tmp_path): + path = tmp_path / "ptycho.zip" + trained_ptycho.save(str(path), save_raw_data=True) + reloaded = autoserialize_load(str(path)) + original = trained_ptycho.dset.scan_positions_px.detach().cpu().numpy() + new = reloaded.dset.scan_positions_px.detach().cpu().numpy() + np.testing.assert_allclose(new, original, rtol=0, atol=0) + + def test_object_preserved(self, trained_ptycho, tmp_path): + path = tmp_path / "ptycho.zip" + trained_ptycho.save(str(path), save_raw_data=True) + reloaded = autoserialize_load(str(path)) + original = trained_ptycho.obj_model._obj.detach().cpu().numpy() + new = reloaded.obj_model._obj.detach().cpu().numpy() + np.testing.assert_allclose(new, original, rtol=0, atol=0) + + def test_loss_decreases_below_strict_threshold(self, trained_ptycho): + losses = trained_ptycho._iter_losses + assert losses[-1] < 0.3 * losses[0], ( + f"loss should drop below 30% of initial in {self.NUM_ITERS} iters: " + f"initial={losses[0]:.3e}, final={losses[-1]:.3e}, " + f"ratio={losses[-1] / losses[0]:.2f}" + ) + + def test_continue_training_after_reload(self, trained_ptycho, tmp_path): + """Reload a trained ptycho, continue training, and verify the loss keeps + decreasing — confirms optimizer state and parameter bindings survive the + save/load roundtrip end-to-end.""" + path = tmp_path / "ptycho.zip" + trained_ptycho.save(str(path), save_raw_data=True) + reloaded = autoserialize_load(str(path)) + loss_after_reload = reloaded._iter_losses[-1] + + n_continue = 10 + reloaded.reconstruct( + num_iters=n_continue, + reset=False, + batch_size=N**2, + device=config.get_device(), + ) + assert len(reloaded._iter_losses) == self.NUM_ITERS + n_continue, ( + "continuation must not reset history" + ) + assert reloaded._iter_losses[-1] <= loss_after_reload, ( + "loss must not regress after reload — optimizer state likely lost" + )