Skip to content

pypty.multislice_core

multislice

Simulate multislice wave propagation using a classic split-step integrator (2nd order precision with respect to slice thickness if beam is optimized).

PARAMETER DESCRIPTION
full_probe

Probe wavefunction with shape [N_batch, y,x, modes]

TYPE: ndarray

this_obj_chopped

Object slices with shape [N_batch, y,x, z, modes].

TYPE: ndarray

num_slices

Number of object slices.

TYPE: int

n_obj_modes

Number of object modes.

TYPE: int

n_probe_modes

Number of probe modes.

TYPE: int

this_distances

Slice thicknesses.

TYPE: ndarray

this_wavelength

Electron wavelength.

TYPE: float

q2

Spatial frequency grids.

TYPE: ndarray

qx

Spatial frequency grids.

TYPE: ndarray

qy

Spatial frequency grids.

TYPE: ndarray

exclude_mask

Mask to exclude undesired frequencies.

TYPE: ndarray

is_single_dist

If True, use the same distance for all slices.

TYPE: bool

this_tan_x

Beam tilts with shape N_batch

TYPE: ndarray

this_tan_y

Beam tilts with shape N_batch

TYPE: ndarray

damping_cutoff_multislice

Damping frequency cutoff.

TYPE: float

smooth_rolloff

Rolloff rate for the damping filter.

TYPE: float

master_propagator_phase_space

Full propagator in Fourier space (optional).

TYPE: ndarray or None

half_master_propagator_phase_space

Half-step propagator (optional).

TYPE: ndarray or None

mask_clean

Clean propagation mask.

TYPE: ndarray

waves_multislice

This array contains interediate exit-waves

TYPE: ndarray

wave

This array contains final exit-wave

TYPE: ndarray

default_float

Numerical types.

TYPE: dtype

default_complex

Numerical types.

TYPE: dtype

RETURNS DESCRIPTION
waves_multislice

Multislice stack of propagated waves.

TYPE: ndarray

wave

Final exit wave.

TYPE: ndarray

multislice_grads

Compute gradients for classic multislice propagation model (object, probe, and tilts).

PARAMETER DESCRIPTION
dLoss_dP_out

Gradient of the loss with respect to the final propagated wave.

TYPE: ndarray

waves_multislice

Intermediate wave stack from the forward multislice pass.

TYPE: ndarray

this_obj_chopped

4D sliced object [batch, y, x, z, modes].

TYPE: ndarray

object_grad

Gradient accumulator for object slices.

TYPE: ndarray

tilts_grad

Accumulator for tilt gradients.

TYPE: ndarray

is_single_dist

If True, slice distances are constant.

TYPE: bool

this_distances

Per-slice thicknesses.

TYPE: ndarray

exclude_mask

Frequency mask for FFT operations.

TYPE: ndarray

this_wavelength

Probe wavelength (Å).

TYPE: float

q2

Spatial frequency grids.

TYPE: ndarray

qx

Spatial frequency grids.

TYPE: ndarray

qy

Spatial frequency grids.

TYPE: ndarray

this_tan_x

Beam tilt values per batch.

TYPE: float

this_tan_y

Beam tilt values per batch.

TYPE: float

num_slices

Number of slices.

TYPE: int

n_obj_modes

Number of object modes.

TYPE: int

tiltind

Index in tilt update array.

TYPE: int

master_propagator_phase_space

Full Fourier propagation kernel.

TYPE: ndarray

this_step_tilts

Whether tilt gradient is updated.

TYPE: int

damping_cutoff_multislice

Damping cutoff for high-frequency noise.

TYPE: float

smooth_rolloff

Width of damping transition.

TYPE: float

tilt_mode

Mode selector for tilt optimization.

TYPE: int

compute_batch

Current batch size.

TYPE: int

mask_clean

FFT domain mask.

TYPE: ndarray

this_step_probe

Whether to compute probe gradient.

TYPE: int

this_step_obj

Whether to compute object gradient.

TYPE: int

this_step_pos_correction

(Unused) Flag for positional corrections.

TYPE: int

masked_pixels_y

Indices for applying gradients to global object.

TYPE: ndarray

masked_pixels_x

Indices for applying gradients to global object.

TYPE: ndarray

default_float

Floating-point type.

TYPE: dtype

default_complex

Complex type.

TYPE: dtype

helper_flag_4

If True, return probe gradient; else return None.

TYPE: bool

RETURNS DESCRIPTION
object_grad

Gradient for object slices.

TYPE: ndarray

interm_probe_grad

Gradient for input probe (if helper_flag_4 is True).

TYPE: ndarray or None

tilts_grad

Updated tilt gradient.

TYPE: ndarray

scatteradd

Adds batched object updates to their respective positions in the full object array. This wrapper is needed to support older CuPy versions.

PARAMETER DESCRIPTION
full

Full object gradient array.

TYPE: ndarray

masky

Index array for the y-axis.

TYPE: ndarray

maskx

Index array for the x-axis.

TYPE: ndarray

chop

Batched gradients to scatter-add.

TYPE: ndarray

RETURNS DESCRIPTION
None