Skip to content

suite2p.registration package#

Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.

compute #

compute(frames)

Compute the bidirectional phase offset between odd and even scan lines.

Estimates the pixel offset between alternating lines that can occur in bidirectional line scanning, using phase correlation along the x-axis.

Parameters:

Name Type Description Default
frames ndarray

Random subsample of frames of shape (n_frames, Ly, Lx).

required

Returns:

Name Type Description
bidiphase int

Bidirectional phase offset in pixels.

Source code in suite2p/registration/bidiphase.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def compute(frames: np.ndarray) -> int:
    """
    Compute the bidirectional phase offset between odd and even scan lines.

    Estimates the pixel offset between alternating lines that can occur in
    bidirectional line scanning, using phase correlation along the x-axis.

    Parameters
    ----------
    frames : np.ndarray
        Random subsample of frames of shape (n_frames, Ly, Lx).

    Returns
    -------
    bidiphase : int
        Bidirectional phase offset in pixels.
    """

    _, Ly, Lx = frames.shape

    # compute phase-correlation between lines in x-direction
    d1 = fft.fft(frames[:, 1::2, :], axis=2)
    d1 /= np.abs(d1) + 1e-5

    d2 = np.conj(fft.fft(frames[:, ::2, :], axis=2))
    d2 /= np.abs(d2) + 1e-5
    d2 = d2[:, :d1.shape[1], :]

    cc = np.real(fft.ifft(d1 * d2, axis=2))
    cc = cc.mean(axis=1).mean(axis=0)
    cc = fft.fftshift(cc)

    bidiphase = -(np.argmax(cc[-10 + Lx // 2:11 + Lx // 2]) - 10)
    return bidiphase

shift #

shift(frames, bidiphase)

Shift odd scan lines by the bidirectional phase offset.

Corrects bidirectional scanning artifacts by shifting every other row (odd lines) along the x-axis by the given pixel offset.

Parameters:

Name Type Description Default
frames ndarray

Frames of shape (n_frames, Ly, Lx). Modified in-place.

required
bidiphase int

Bidirectional phase offset in pixels.

required

Returns:

Name Type Description
frames ndarray

The input frames with odd lines shifted.

Source code in suite2p/registration/bidiphase.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def shift(frames: np.ndarray, bidiphase: int) -> None:
    """
    Shift odd scan lines by the bidirectional phase offset.

    Corrects bidirectional scanning artifacts by shifting every other row
    (odd lines) along the x-axis by the given pixel offset.

    Parameters
    ----------
    frames : np.ndarray
        Frames of shape (n_frames, Ly, Lx). Modified in-place.
    bidiphase : int
        Bidirectional phase offset in pixels.

    Returns
    -------
    frames : np.ndarray
        The input frames with odd lines shifted.
    """
    if bidiphase > 0:
        frames[:, 1::2, bidiphase:] = frames[:, 1::2, :-bidiphase]
    elif bidiphase < 0:
        frames[:, 1::2, :bidiphase] = frames[:, 1::2, -bidiphase:]
    return frames

Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.

get_pc_metrics #

get_pc_metrics(f_reg, yrange=None, xrange=None, settings=default_settings()['registration'], device=torch.device('cpu'))

Compute registration metrics using top PCs of a registered movie.

Subsamples frames from the registered movie, computes PCA to find the top and bottom weighted frames, then registers them to each other. The resulting shift magnitudes indicate registration quality: large shifts suggest residual motion.

Parameters:

Name Type Description Default
f_reg ndarray

Registered movie of shape (n_frames, Ly, Lx).

required
yrange list of int or None

[y_start, y_end] row range to crop the movie. If None, uses the full vertical extent.

None
xrange list of int or None

[x_start, x_end] column range to crop the movie. If None, uses the full horizontal extent.

None
settings dict

Registration settings dictionary containing keys such as "smooth_sigma", "block_size", "maxregshift", "maxregshiftNR", "snr_thresh", "spatial_taper", and optionally "reg_metrics_rs" and "reg_metric_n_pc".

default_settings()['registration']
device device

Torch device (CPU or CUDA) on which to run the PC registration.

device('cpu')

Returns:

Name Type Description
tPC ndarray

Temporal PC weights of shape (n_samples, nPC), describing how each PC varies across the subsampled frames.

regPC ndarray

Average of top and bottom weighted frames for each PC, shape (2, nPC, Ly_crop, Lx_crop) where index 0 is pclow and index 1 is pchigh.

regDX ndarray

Shift metrics of shape (nPC, 4) from pc_register; see pc_register for column definitions.

Source code in suite2p/registration/metrics.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def get_pc_metrics(f_reg, yrange=None, xrange=None, settings=default_settings()["registration"], 
                   device=torch.device("cpu")):

    """
    Compute registration metrics using top PCs of a registered movie.

    Subsamples frames from the registered movie, computes PCA to find the top and
    bottom weighted frames, then registers them to each other. The resulting shift
    magnitudes indicate registration quality: large shifts suggest residual motion.

    Parameters
    ----------
    f_reg : np.ndarray
        Registered movie of shape (n_frames, Ly, Lx).
    yrange : list of int or None
        [y_start, y_end] row range to crop the movie. If None, uses the full
        vertical extent.
    xrange : list of int or None
        [x_start, x_end] column range to crop the movie. If None, uses the full
        horizontal extent.
    settings : dict
        Registration settings dictionary containing keys such as "smooth_sigma",
        "block_size", "maxregshift", "maxregshiftNR", "snr_thresh",
        "spatial_taper", and optionally "reg_metrics_rs" and "reg_metric_n_pc".
    device : torch.device
        Torch device (CPU or CUDA) on which to run the PC registration.

    Returns
    -------
    tPC : np.ndarray
        Temporal PC weights of shape (n_samples, nPC), describing how each PC
        varies across the subsampled frames.
    regPC : np.ndarray
        Average of top and bottom weighted frames for each PC, shape
        (2, nPC, Ly_crop, Lx_crop) where index 0 is pclow and index 1 is pchigh.
    regDX : np.ndarray
        Shift metrics of shape (nPC, 4) from pc_register; see pc_register for
        column definitions.
    """
    n_frames, Ly, Lx = f_reg.shape
    yrange = [0, Ly] if yrange is None else yrange 
    xrange = [0, Lx] if xrange is None else xrange

    # n frames to pick from full movie
    nsamp = 2000 if n_frames < 5000 or Ly > 700 or Lx > 700 else 5000
    nsamp = min(nsamp, n_frames)
    inds = np.linspace(0, n_frames - 1, nsamp).astype("int")
    mov = f_reg[inds][:, yrange[0] : yrange[-1], xrange[0] : xrange[-1]]

    random_state = settings["reg_metrics_rs"] if "reg_metrics_rs" in settings else None
    nPC = settings["reg_metric_n_pc"] if "reg_metric_n_pc" in settings else 30
    pclow, pchigh, sv, tPC = pclowhigh(
        mov, nlowhigh=np.minimum(300, mov.shape[0] // 2), nPC=nPC,
        random_state=random_state)
    pclow = torch.from_numpy(pclow).to(device).float()
    pchigh = torch.from_numpy(pchigh).to(device).float()
    regPC = torch.stack((pclow, pchigh), dim=0).cpu().numpy()
    regDX = pc_register(
        pclow, pchigh, smooth_sigma=settings["smooth_sigma"], block_size=settings["block_size"],
        maxregshift=settings["maxregshift"], maxregshiftNR=settings["maxregshiftNR"], 
        snr_thresh=settings["snr_thresh"], spatial_taper=settings["spatial_taper"])
    return tPC, regPC, regDX

pc_register #

pc_register(pclow, pchigh, smooth_sigma=1.15, block_size=(128, 128), maxregshift=0.25, maxregshiftNR=15, snr_thresh=1.25, spatial_taper=3.45)

Register top and bottom PC averages to each other and compute shift magnitudes.

For each PC, the bottom-weighted average image is used as a reference and the top-weighted average is registered to it using rigid and nonrigid shifts. The resulting shift magnitudes quantify registration quality.

Parameters:

Name Type Description Default
pclow Tensor

Average of bottom-weighted frames for each PC, shape (nPC, Ly, Lx).

required
pchigh Tensor

Average of top-weighted frames for each PC, shape (nPC, Ly, Lx).

required
smooth_sigma float

Standard deviation (in pixels) of the Gaussian smoothing applied to the reference image during registration.

1.15
block_size tuple of int

Block size (Ly_block, Lx_block) used for nonrigid registration.

(128, 128)
maxregshift float

Maximum allowed rigid registration shift as a fraction of the smaller image dimension.

0.25
maxregshiftNR int

Maximum allowed nonrigid registration shift in pixels.

15
snr_thresh float

Signal-to-noise ratio threshold for accepting nonrigid block shifts.

1.25
spatial_taper float

Scalar controlling the slope of the spatial taper mask applied at image borders during registration.

3.45

Returns:

Name Type Description
X ndarray

Shift metrics of shape (nPC, 4) where X[:, 0] is the rigid shift magnitude, X[:, 1] is the mean nonrigid shift magnitude, X[:, 2] is the max nonrigid shift magnitude, and X[:, 3] is the mean combined rigid+nonrigid shift.

Source code in suite2p/registration/metrics.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def pc_register(pclow, pchigh, smooth_sigma=1.15, block_size=(128, 128),
                maxregshift=0.25, maxregshiftNR=15, snr_thresh=1.25,
                spatial_taper=3.45):
    """
    Register top and bottom PC averages to each other and compute shift magnitudes.

    For each PC, the bottom-weighted average image is used as a reference and the
    top-weighted average is registered to it using rigid and nonrigid shifts. The
    resulting shift magnitudes quantify registration quality.

    Parameters
    ----------
    pclow : torch.Tensor
        Average of bottom-weighted frames for each PC, shape (nPC, Ly, Lx).
    pchigh : torch.Tensor
        Average of top-weighted frames for each PC, shape (nPC, Ly, Lx).
    smooth_sigma : float
        Standard deviation (in pixels) of the Gaussian smoothing applied to the
        reference image during registration.
    block_size : tuple of int
        Block size (Ly_block, Lx_block) used for nonrigid registration.
    maxregshift : float
        Maximum allowed rigid registration shift as a fraction of the smaller
        image dimension.
    maxregshiftNR : int
        Maximum allowed nonrigid registration shift in pixels.
    snr_thresh : float
        Signal-to-noise ratio threshold for accepting nonrigid block shifts.
    spatial_taper : float
        Scalar controlling the slope of the spatial taper mask applied at image
        borders during registration.

    Returns
    -------
    X : np.ndarray
        Shift metrics of shape (nPC, 4) where X[:, 0] is the rigid shift magnitude,
        X[:, 1] is the mean nonrigid shift magnitude, X[:, 2] is the max nonrigid
        shift magnitude, and X[:, 3] is the mean combined rigid+nonrigid shift.
    """
    # registration settings
    nPC, Ly, Lx = pclow.shape

    X = np.zeros((nPC, 4))
    for i in range(nPC):
        refImg = pclow[i].cpu().numpy().copy()
        Img = pchigh[i][np.newaxis, :, :]

        refAndMasks = register.compute_filters_and_norm(refImg, norm_frames=True, spatial_smooth=smooth_sigma,
                                           spatial_taper=spatial_taper, 
                                           block_size=block_size,
                                                    device=Img.device)  
        fr_reg = Img.clone()
        offsets = register.compute_shifts(refAndMasks, fr_reg, maxregshift=maxregshift, smooth_sigma_time=0, 
                                          maxregshiftNR=maxregshiftNR, snr_thresh=snr_thresh)
        ymax, xmax, cmax, ymax1, xmax1, cmax1, zest, cmax_all = offsets

        X[i, 0] = ((ymax[0]**2 + xmax[0]**2)**.5).mean().cpu().numpy()
        X[i, 1] = ((ymax1**2 + xmax1**2)**.5).mean().cpu().numpy()
        X[i, 2] = ((ymax1**2 + xmax1**2)**.5).max().cpu().numpy()
        X[i, 3] = (((ymax[0] + ymax1)**2 + (xmax[0] + xmax1)**2)**0.5).mean().cpu().numpy()
    return X

pclowhigh #

pclowhigh(mov, nlowhigh, nPC, random_state)

Compute mean of top and bottom PC weights using sklearn PCA.

Computes nPC principal components of the movie and returns the average frames at the top and bottom of each PC's temporal weights.

Parameters:

Name Type Description Default
mov ndarray

Subsampled movie frames of shape (n_frames, Ly, Lx).

required
nlowhigh int

Number of frames to average at the top and bottom of each PC.

required
nPC int

Number of principal components to compute.

required
random_state int or None

Seed for the PCA random state, used for reproducibility.

required

Returns:

Name Type Description
pclow ndarray

Average of bottom-weighted frames for each PC, shape (nPC, Ly, Lx).

pchigh ndarray

Average of top-weighted frames for each PC, shape (nPC, Ly, Lx).

w ndarray

Singular values from the PCA decomposition, shape (nPC,).

v ndarray

Temporal PC weights of shape (n_frames, nPC), describing how each PC varies across frames.

Source code in suite2p/registration/metrics.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def pclowhigh(mov, nlowhigh, nPC, random_state):
    """
    Compute mean of top and bottom PC weights using sklearn PCA.

    Computes nPC principal components of the movie and returns the average
    frames at the top and bottom of each PC's temporal weights.

    Parameters
    ----------
    mov : np.ndarray
        Subsampled movie frames of shape (n_frames, Ly, Lx).
    nlowhigh : int
        Number of frames to average at the top and bottom of each PC.
    nPC : int
        Number of principal components to compute.
    random_state : int or None
        Seed for the PCA random state, used for reproducibility.

    Returns
    -------
    pclow : np.ndarray
        Average of bottom-weighted frames for each PC, shape (nPC, Ly, Lx).
    pchigh : np.ndarray
        Average of top-weighted frames for each PC, shape (nPC, Ly, Lx).
    w : np.ndarray
        Singular values from the PCA decomposition, shape (nPC,).
    v : np.ndarray
        Temporal PC weights of shape (n_frames, nPC), describing how each PC
        varies across frames.
    """
    nframes, Ly, Lx = mov.shape
    mov = mov.reshape((nframes, -1))
    mov = mov.astype(np.float32)
    mimg = mov.mean(axis=0)
    mov -= mimg
    pca = PCA(n_components=nPC, random_state=random_state).fit(mov.T)
    v = pca.components_.T
    w = pca.singular_values_
    mov += mimg
    mov = np.transpose(np.reshape(mov, (-1, Ly, Lx)), (1, 2, 0))
    pclow = np.zeros((nPC, Ly, Lx), np.float32)
    pchigh = np.zeros((nPC, Ly, Lx), np.float32)
    isort = np.argsort(v, axis=0)
    for i in range(nPC):
        pclow[i] = mov[:, :, isort[:nlowhigh, i]].mean(axis=-1)
        pchigh[i] = mov[:, :, isort[-nlowhigh:, i]].mean(axis=-1)
    return pclow, pchigh, w, v

pclowhigh_torch #

pclowhigh_torch(mov, nlowhigh, nPC, random_state)

Compute mean of top and bottom PC weights using torch SVD.

Computes nPC principal components of the movie via torch SVD and returns the average frames at the top and bottom of each PC's temporal weights.

Parameters:

Name Type Description Default
mov Tensor

Subsampled movie frames of shape (n_frames, Ly, Lx).

required
nlowhigh int

Number of frames to average at the top and bottom of each PC.

required
nPC int

Number of principal components to compute.

required
random_state int or None

Unused, kept for API compatibility with pclowhigh.

required

Returns:

Name Type Description
pclow Tensor

Average of bottom-weighted frames for each PC, shape (nPC, Ly, Lx).

pchigh Tensor

Average of top-weighted frames for each PC, shape (nPC, Ly, Lx).

w Tensor

Singular values from the SVD decomposition.

v Tensor

Temporal PC weights of shape (n_frames, nPC), describing how each PC varies across frames.

Source code in suite2p/registration/metrics.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def pclowhigh_torch(mov, nlowhigh, nPC, random_state):
    """
    Compute mean of top and bottom PC weights using torch SVD.

    Computes nPC principal components of the movie via torch SVD and returns the
    average frames at the top and bottom of each PC's temporal weights.

    Parameters
    ----------
    mov : torch.Tensor
        Subsampled movie frames of shape (n_frames, Ly, Lx).
    nlowhigh : int
        Number of frames to average at the top and bottom of each PC.
    nPC : int
        Number of principal components to compute.
    random_state : int or None
        Unused, kept for API compatibility with pclowhigh.

    Returns
    -------
    pclow : torch.Tensor
        Average of bottom-weighted frames for each PC, shape (nPC, Ly, Lx).
    pchigh : torch.Tensor
        Average of top-weighted frames for each PC, shape (nPC, Ly, Lx).
    w : torch.Tensor
        Singular values from the SVD decomposition.
    v : torch.Tensor
        Temporal PC weights of shape (n_frames, nPC), describing how each PC
        varies across frames.
    """
    nframes, Ly, Lx = mov.shape
    mov = mov.reshape((nframes, -1))
    mimg = mov.mean(axis=0)
    mov -= mimg
    w, v = torch.linalg.svd(mov.T, full_matrices=False)[1:]
    v = v.T
    mov += mimg
    mov = mov.reshape(nframes, Ly, Lx)
    pclow = torch.zeros((nPC, Ly, Lx), dtype=torch.float, device=mov.device)
    pchigh = torch.zeros((nPC, Ly, Lx), dtype=torch.float, device=mov.device)
    isort = v.argsort(axis=0)
    for i in range(nPC):
        pclow[i] = mov[isort[:nlowhigh, i]].mean(axis=0)
        pchigh[i] = mov[isort[-nlowhigh:, i]].mean(axis=0)
    return pclow, pchigh, w, v

Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.

calculate_nblocks #

calculate_nblocks(L, block_size)

Returns block_size and nblocks from dimension length and desired block size

Parameters:

Name Type Description Default
L int

Number of pixels in one dimension in image.

required
block_size int

Block size in pixels.

required

Returns:

Name Type Description
block_size int

min(L, block_size).

nblocks int

Number of blocks to make along dimension.

Source code in suite2p/registration/nonrigid.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def calculate_nblocks(L: int, block_size: int):
    """
    Returns block_size and nblocks from dimension length and desired block size

    Parameters
    ----------
    L: int
        Number of pixels in one dimension in image.
    block_size: int
        Block size in pixels.

    Returns
    -------
    block_size: int
        min(L, block_size).
    nblocks: int
        Number of blocks to make along dimension.
    """
    return (L, 1) if block_size >= L else (block_size,
                                           int(np.ceil(1.5 * L / block_size)))

compute_masks_ref_smooth_fft #

compute_masks_ref_smooth_fft(refImg0, maskSlope, smooth_sigma, yblock, xblock)

Compute per-block taper masks, offsets, and FFT-smoothed reference images for nonrigid phase-correlation registration. This function extracts blocks from a full 2D reference image, applies a spatial taper (window) to each block, computes a per-block constant offset to compensate for masked/background regions, and computes a Gaussian-smoothed version of each block in the frequency domain (complex FFT) for use in phase-correlation based registration.

Parameters:

Name Type Description Default
refImg0 Tensor

2D reference image array of shape (Ly_full, Lx_full). Expected numeric image type (e.g. uint16, float32 or torch tensor). The function will extract sub-blocks using the indices supplied in yblock and xblock.

required
maskSlope float

Scalar parameter controlling the slope of the sigmoid of the spatial taper. Higher values increase tapered region size.

required
smooth_sigma float

Standard deviation (in pixels) of the Gaussian smoothing applied to each block. Smoothing is performed in the frequency domain (via ref_smooth_fft). Typical values are >= 0. A value of 0 should behave as no smoothing (identity).

required
yblock list[ndarray]

List of length (ny * nx) giving the vertical (row) slice for each block. Each element is a 1D integer numpy array [y_start, y_end] specifying the inclusive start (y_start) and exclusive end (y_end) indices of the block along the vertical axis. Blocks are ordered row-major by block-grid row (iy) then column (ix): block_idx = iy * nx + ix.

required
xblock list[ndarray]

List of length (ny * nx) giving the horizontal (column) slice for each block. Each element is a 1D integer numpy array [x_start, x_end] specifying the inclusive start and exclusive end indices along the horizontal axis. Ordering matches yblock (row-major block-grid order).

required

Returns:

Name Type Description
maskMul_block Tensor

Float32 tensor of shape (nb, Ly, Lx). Per-block multiplicative taper masks obtained by multiplying a local block taper.

maskOffset_block Tensor

Float32 tensor of shape (nb, Ly, Lx). Per-block additive offset fields computed as block_mean * (1 - maskMul_block) so that masked regions are filled with the local block mean scaled by the complement of the taper.

cfRefImg_block Tensor(complex64)

Complex32 tensor of shape (nb, Ly, Lx). Frequency-domain (FFT) representation of the Gaussian-smoothed reference blocks (output of ref_smooth_fft). These are intended for use in phase-correlation registration.

Source code in suite2p/registration/nonrigid.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def compute_masks_ref_smooth_fft(refImg0, maskSlope, smooth_sigma,
                                 yblock, xblock):
    """
    Compute per-block taper masks, offsets, and FFT-smoothed reference images for
    nonrigid phase-correlation registration.
    This function extracts blocks from a full 2D reference image, applies a
    spatial taper (window) to each block, computes a per-block constant offset
    to compensate for masked/background regions, and computes a Gaussian-smoothed
    version of each block in the frequency domain (complex FFT) for use in
    phase-correlation based registration.

    Parameters
    ----------
    refImg0 : torch.Tensor
        2D reference image array of shape (Ly_full, Lx_full). Expected numeric
        image type (e.g. uint16, float32 or torch tensor). The function will
        extract sub-blocks using the indices supplied in yblock and xblock.
    maskSlope : float
        Scalar parameter controlling the slope of the sigmoid of the spatial taper. 
        Higher values increase tapered region size.
    smooth_sigma : float
        Standard deviation (in pixels) of the Gaussian smoothing applied to each
        block. Smoothing is performed in the frequency domain (via ref_smooth_fft). 
        Typical values are >= 0. A value of 0 should behave as no
        smoothing (identity).
    yblock : list[numpy.ndarray]
        List of length (ny * nx) giving the vertical (row) slice for each block.
        Each element is a 1D integer numpy array [y_start, y_end] specifying the
        inclusive start (y_start) and exclusive end (y_end) indices of the block
        along the vertical axis. Blocks are ordered row-major by block-grid row
        (iy) then column (ix): block_idx = iy * nx + ix.
    xblock : list[numpy.ndarray]
        List of length (ny * nx) giving the horizontal (column) slice for each
        block. Each element is a 1D integer numpy array [x_start, x_end]
        specifying the inclusive start and exclusive end indices along the
        horizontal axis. Ordering matches yblock (row-major block-grid order).

    Returns
    -------
    maskMul_block : torch.Tensor
        Float32 tensor of shape (nb, Ly, Lx). Per-block multiplicative taper
        masks obtained by multiplying a local block taper.
    maskOffset_block : torch.Tensor
        Float32 tensor of shape (nb, Ly, Lx). Per-block additive offset fields
        computed as block_mean * (1 - maskMul_block) so that masked regions are
        filled with the local block mean scaled by the complement of the taper.
    cfRefImg_block : torch.Tensor (complex64)
        Complex32 tensor of shape (nb, Ly, Lx). Frequency-domain (FFT) representation
        of the Gaussian-smoothed reference blocks (output of ref_smooth_fft). These
        are intended for use in phase-correlation registration.

    """
    nb, Ly, Lx = len(yblock), yblock[0][1] - yblock[0][0], xblock[0][1] - xblock[0][0]
    dims = (nb, Ly, Lx)
    cfRef_dims = dims
    cfRefImg1 = torch.zeros(cfRef_dims, dtype=torch.complex64)

    maskMul = spatial_taper(maskSlope, *refImg0.shape)
    maskMul1 = torch.zeros(dims, dtype=torch.float)
    maskMul1[:] = spatial_taper(2 * smooth_sigma, Ly, Lx)
    maskOffset1 = torch.zeros(dims, dtype=torch.float)
    for yind, xind, maskMul1_n, maskOffset1_n, cfRefImg1_n in zip(
            yblock, xblock, maskMul1, maskOffset1, cfRefImg1):
        ix = np.ix_(
            np.arange(yind[0], yind[-1]).astype("int"),
            np.arange(xind[0], xind[-1]).astype("int"))
        refImg = refImg0[ix]

        # mask params
        maskMul1_n *= maskMul[yind[0] : yind[-1], xind[0] : xind[-1]]
        maskOffset1_n[:] = (refImg.float().mean() * (1. - maskMul1_n))

        # gaussian filter
        cfRefImg1_n[:] = ref_smooth_fft(refImg, smooth_sigma)

    return maskMul1, maskOffset1, cfRefImg1

getSNR #

getSNR(cc, lcorr, lpad)

Compute the signal-to-noise ratio (SNR) of phase-correlation maps. This function estimates the SNR for one or more phase-correlation maps by (1) locating the peak value within the central search region of each map, (2) zeroing a square neighborhood around that peak in a copy of the full map to exclude the main peak energy, and (3) taking the ratio of the peak value to the maximum remaining value in the map (with a small epsilon to avoid division by zero).

Parameters:

Name Type Description Default
cc Tensor

Array of phase-correlation maps with shape (n_maps, H, W). Each spatial dimension is expected to equal (2 * lcorr + 1) + 2 * lpad, i.e. the central searchable region of size (2*lcorr+1) is padded on all sides by lpad pixels. The first axis indexes independent maps (e.g. frames).

required
lcorr int

Half-size of the central correlation search window. The central region searched for the peak is of size (2 * lcorr + 1) x (2 * lcorr + 1).

required
lpad int

Padding width (in pixels) around the central search region. When masking the peak, a square of side length 2 * lpad is zeroed around the detected peak location in the copy of the map to measure the maximum background response.

required

Returns:

Name Type Description
snr ndarray

Array of SNR values, one per input map, with shape (n_maps,). Each entry is the peak value found inside the central region divided by the maximum value remaining in the map after masking the peak neighborhood. Values are finite due to a small numerical epsilon (1e-10) used in the denominator.

Source code in suite2p/registration/nonrigid.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
def getSNR(cc, lcorr, lpad):
    """
    Compute the signal-to-noise ratio (SNR) of phase-correlation maps.
    This function estimates the SNR for one or more phase-correlation maps by
    (1) locating the peak value within the central search region of each map,
    (2) zeroing a square neighborhood around that peak in a copy of the full map
        to exclude the main peak energy, and
    (3) taking the ratio of the peak value to the maximum remaining value in the
        map (with a small epsilon to avoid division by zero).

    Parameters
    ----------
    cc : torch.Tensor
        Array of phase-correlation maps with shape (n_maps, H, W). Each spatial
        dimension is expected to equal (2 * lcorr + 1) + 2 * lpad, i.e. the
        central searchable region of size (2*lcorr+1) is padded on all sides by
        lpad pixels. The first axis indexes independent maps (e.g. frames).
    lcorr : int
        Half-size of the central correlation search window. The central region
        searched for the peak is of size (2 * lcorr + 1) x (2 * lcorr + 1).
    lpad : int
        Padding width (in pixels) around the central search region. When masking
        the peak, a square of side length 2 * lpad is zeroed around the detected
        peak location in the copy of the map to measure the maximum background
        response.

    Returns
    -------
    snr : ndarray
        Array of SNR values, one per input map, with shape (n_maps,). Each entry
        is the peak value found inside the central region divided by the maximum
        value remaining in the map after masking the peak neighborhood. Values
        are finite due to a small numerical epsilon (1e-10) used in the
        denominator.
    """
    cc0 = cc[:, lpad:-lpad, lpad:-lpad].reshape(cc.shape[0], -1)
    # set to 0 all pts +-lpad from ymax,xmax
    cc1 = cc.copy()

    for c1, ymax, xmax in zip(
            cc1,
            *np.unravel_index(cc0.argmax(axis=1), (2 * lcorr + 1, 2 * lcorr + 1))):
        c1[ymax:ymax + 2 * lpad, xmax:xmax + 2 * lpad] = 0

    snr = cc0.max(axis=1) / np.maximum(1e-10, cc1.max(axis=(1, 2)))
    return snr

make_blocks #

make_blocks(Ly, Lx, block_size, lpad=3, subpixel=10)

Compute overlapping registration blocks covering a 2D field of view. This function splits a full-frame image of size (Ly, Lx) into an array of overlapping rectangular blocks to be processed independently for nonrigid registration. Block start positions are computed so that blocks tile the image with (approximately) equal spacing and specified overlap determined by the requested block_size. The function also computes a spatial smoothing matrix (NRsm) over the block grid and an upsampling convolution matrix (Kmat) used for subpixel shift estimation.

Parameters:

Name Type Description Default
Ly int

Number of pixels in the vertical dimension (image height).

required
Lx int

Number of pixels in the horizontal dimension (image width).

required
block_size tuple[int, int]

Block size in pixels as (block_height, block_width).

required
lpad int

Padding in pixels used when constructing the upsampling matrix. Passed to mat_upsample(...). Default is 3.

3
subpixel int

Subpixel upsampling factor. Passed to mat_upsample(...). Default is 10.

10

Returns:

Name Type Description
yblock list[ndarray]

List of length (ny * nx) giving the vertical (row) slice for each block. Each element is a 1D integer numpy array [y_start, y_end] specifying the inclusive start (y_start) and exclusive end (y_end) indices of the block along the vertical axis. Blocks are ordered row-major by block-grid row (iy) then column (ix): block_idx = iy * nx + ix.

xblock list[ndarray]

List of length (ny * nx) giving the horizontal (column) slice for each block. Each element is a 1D integer numpy array [x_start, x_end] specifying the inclusive start and exclusive end indices along the horizontal axis. Ordering matches yblock (row-major block-grid order).

nblocks list[int, int]

Two-element list [ny, nx] with the number of blocks in the vertical and horizontal directions respectively (ny = number of block rows, nx = number of block columns).

block_size tuple[int, int]

Effective block size used, min of input block size and frame size.

NRsm ndarray

2D smoothing kernel matrix defined on the block grid. Shape is (ny, nx). This matrix (derived from kernelD2 over block grid coordinates) is used to smooth or regularize blockwise motion estimates spatially.

Kmat ndarray

Upsampling kriging interpolation matrix returned by mat_upsample(lpad, subpixel). This matrix is used for subpixel shift estimation within +/- lpad pixels.

nup int

Kmat.shape[-1].

Source code in suite2p/registration/nonrigid.py
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def make_blocks(Ly, Lx, block_size, lpad=3, subpixel=10):
    """
    Compute overlapping registration blocks covering a 2D field of view.
    This function splits a full-frame image of size (Ly, Lx) into an array of
    overlapping rectangular blocks to be processed independently for nonrigid
    registration. Block start positions are computed so that blocks tile the image
    with (approximately) equal spacing and specified overlap determined by the
    requested block_size. The function also computes a spatial smoothing matrix
    (NRsm) over the block grid and an upsampling convolution matrix (Kmat) used
    for subpixel shift estimation.

    Parameters
    ----------
    Ly : int
        Number of pixels in the vertical dimension (image height).
    Lx : int
        Number of pixels in the horizontal dimension (image width).
    block_size : tuple[int, int]
        Block size in pixels as (block_height, block_width).
    lpad : int, optional
        Padding in pixels used when constructing the upsampling matrix.
        Passed to mat_upsample(...). Default is 3.
    subpixel : int, optional
        Subpixel upsampling factor. Passed to mat_upsample(...). Default is 10.

    Returns
    -------
    yblock : list[numpy.ndarray]
        List of length (ny * nx) giving the vertical (row) slice for each block.
        Each element is a 1D integer numpy array [y_start, y_end] specifying the
        inclusive start (y_start) and exclusive end (y_end) indices of the block
        along the vertical axis. Blocks are ordered row-major by block-grid row
        (iy) then column (ix): block_idx = iy * nx + ix.
    xblock : list[numpy.ndarray]
        List of length (ny * nx) giving the horizontal (column) slice for each
        block. Each element is a 1D integer numpy array [x_start, x_end]
        specifying the inclusive start and exclusive end indices along the
        horizontal axis. Ordering matches yblock (row-major block-grid order).
    nblocks : list[int, int]
        Two-element list [ny, nx] with the number of blocks in the vertical and
        horizontal directions respectively (ny = number of block rows,
        nx = number of block columns).
    block_size : tuple[int, int]
        Effective block size used, min of input block size and frame size.
    NRsm : numpy.ndarray
        2D smoothing kernel matrix defined on the block grid. Shape is (ny, nx).
        This matrix (derived from kernelD2 over block grid coordinates) is
        used to smooth or regularize blockwise motion estimates spatially.
    Kmat : numpy.ndarray
        Upsampling kriging interpolation matrix returned by mat_upsample(lpad, subpixel).
        This matrix is used for subpixel shift estimation within +/- lpad pixels.
    nup : int
        Kmat.shape[-1].

    """
    block_size = (int(block_size[0]), int(block_size[1]))
    block_size_y, ny = calculate_nblocks(L=Ly, block_size=block_size[0])
    block_size_x, nx = calculate_nblocks(L=Lx, block_size=block_size[1])
    block_size = (block_size_y, block_size_x)

    # todo: could rounding to int here over-represent some pixels over others?
    ystart = np.linspace(0, Ly - block_size[0], ny).astype("int")
    xstart = np.linspace(0, Lx - block_size[1], nx).astype("int")
    yblock = [
        np.array([ystart[iy], ystart[iy] + block_size[0]])
        for iy in range(ny)
        for _ in range(nx)
    ]
    xblock = [
        np.array([xstart[ix], xstart[ix] + block_size[1]])
        for _ in range(ny)
        for ix in range(nx)
    ]

    NRsm = kernelD2(xs=torch.arange(nx), ys=torch.arange(ny)).T.numpy()
    Kmat, nup = mat_upsample(lpad=lpad, subpixel=subpixel)
    return yblock, xblock, [ny, nx], block_size, NRsm, Kmat, nup

phasecorr #

phasecorr(data, blocks, maskMul, maskOffset, cfRefImg, snr_thresh, maxregshiftNR, subpixel=10, lpad=3)

Compute per-block shifts using phase correlation. This function performs a Fourier-domain phase-correlation based registration between each frame and each block in data and a provided (complex) reference image cfRefImg, in blocks. It computes the integer pixel shifts (y, x) that maximize the phase-correlation within a limited search window, defined by maxregshiftNR. The phase-correlations are smoothed across blocks, and these smoothed phase-correlations are used if the block SNR is below snr_thresh. A small neighborhood around each peak is then upsampled via Kriging interpolation using the provided Kmat kernel, and the peak of the upsampled phase-correlation is used to obtain subpixel-level shifts.

Parameters:

Name Type Description Default
data Tensor

Input image sequence, expected shape (nimg, Ly, Lx) where nimg is the number of frames. The tensor may be on CPU or CUDA; it is converted to float and then to complex for the Fourier-domain operations performed by the helper convolve.

required
blocks tuple

Tuple of block descriptors produced by the caller, unpacked in this function as: (yblock, xblock, _, _, NRsm, Kmat, nup)

required
maskMul Tensor

Multiplicative mask applied to data per-block before correlation. Broadcasted over frames.

required
maskOffset Tensor

Additive offset applied after maskMul per-block. Broadcasted over frames.

required
cfRefImg Tensor

Complex-valued reference of shape (Ly, Lx) in the Fourier domain used to compute cross-correlation with each frame.

required
snr_thresh float

SNR threshold used to decide whether to replace a block's raw correlation map with progressively more-smoothed versions computed via NRsm. Lower values make smoothing less likely.

required
maxregshiftNR int

Maximum allowed registration shift (interpreted as pixels and rounded).

required
lpad int

Padding in pixels used when constructing the upsampling matrix. Default is 3.

3
subpixel int

Subpixel upsampling factor. Default is 10.

10

Returns:

Name Type Description
ymax1 Tensor

Tensor of shape (nblocks, N) with the y (row) shift for each frame and block that maximizes the phase-correlation.

xmax1 LongTensor

Tensor of shape (nblocks, N) with the x (row) shift for each frame and block that maximizes the phase-correlation.

cmax1 Tensor

Tensor of shape (nblocks, N) containing the maximum phase-correlation value found for each frame and block.

ccsm ndarray

Phase-correlation maps (potentially smoothed) used for peak selection for each frame and block. Shape: (n_blocks, N, 2lcorr + 2lpad + 1, 2lcorr + 2lpad + 1)

ccb Tensor

Tensor of shape (n_blocks, N, y+x pixels) containing upsampled phase-correlation values for each frame and block.

Source code in suite2p/registration/nonrigid.py
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
def phasecorr(data, blocks, maskMul, maskOffset, cfRefImg, snr_thresh,
              maxregshiftNR, subpixel = 10, lpad = 3):
    """
    Compute per-block shifts using phase correlation.
    This function performs a Fourier-domain phase-correlation based registration between each frame and each block in
    `data` and a provided (complex) reference image `cfRefImg`, in blocks. It computes the integer pixel shifts
    (y, x) that maximize the phase-correlation within a limited search window, defined by `maxregshiftNR`.
    The phase-correlations are smoothed across blocks, and these smoothed phase-correlations are used if the 
    block SNR is below `snr_thresh`. A small neighborhood around each peak is then upsampled via Kriging interpolation 
    using the provided Kmat kernel, and the peak of the upsampled phase-correlation is used to obtain subpixel-level shifts.

    Parameters
    ----------
    data : torch.Tensor
        Input image sequence, expected shape (nimg, Ly, Lx) where nimg is the number of frames.
        The tensor may be on CPU or CUDA; it is converted to float and then to complex for the
        Fourier-domain operations performed by the helper `convolve`.
    blocks : tuple
        Tuple of block descriptors produced by the caller, unpacked in this function as:
            (yblock, xblock, _, _, NRsm, Kmat, nup)
    maskMul : torch.Tensor
        Multiplicative mask applied to `data` per-block before correlation. Broadcasted over frames.
    maskOffset : torch.Tensor
        Additive offset applied after `maskMul` per-block. Broadcasted over frames.
    cfRefImg : torch.Tensor
        Complex-valued reference of shape (Ly, Lx) in the Fourier domain used to compute 
        cross-correlation with each frame.
    snr_thresh : float
        SNR threshold used to decide whether to replace a block's raw correlation map with
        progressively more-smoothed versions computed via NRsm. Lower values make smoothing
        less likely.
    maxregshiftNR : int
        Maximum allowed registration shift (interpreted as pixels and rounded).
    lpad : int, optional
        Padding in pixels used when constructing the upsampling matrix. Default is 3.
    subpixel : int, optional
        Subpixel upsampling factor. Default is 10.

    Returns
    -------
    ymax1 : torch.Tensor
        Tensor of shape (nblocks, N) with the y (row) shift for each frame and block that maximizes the
        phase-correlation. 
    xmax1 : torch.LongTensor
        Tensor of shape (nblocks, N) with the x (row) shift for each frame and block that maximizes the
        phase-correlation. 
    cmax1 : torch.Tensor
        Tensor of shape (nblocks, N) containing the maximum phase-correlation value found for each frame and block.
    ccsm : numpy.ndarray
        Phase-correlation maps (potentially smoothed) used for peak selection for each frame and block. Shape:
            (n_blocks, N, 2*lcorr + 2*lpad + 1, 2*lcorr + 2*lpad + 1)
    ccb : torch.Tensor
        Tensor of shape (n_blocks, N, y+x pixels) containing upsampled phase-correlation values for each frame and block.

    """


    yblock, xblock, _, _, NRsm, Kmat, nup = blocks

    device = data.device

    nimg = data.shape[0]
    ly, lx = cfRefImg.shape[-2:]

    # maximum registration shift allowed
    lcorr = int(
        np.minimum(np.round(maxregshiftNR),
                   np.floor(np.minimum(ly, lx) / 2.) - lpad))
    nb = len(yblock)

    # shifts and corrmax
    Y = torch.zeros((nimg, nb, ly, lx), dtype=torch.int16, device=device)
    for n in range(nb):
        yind, xind = yblock[n], xblock[n]
        Y[:, n] = data[:, yind[0]:yind[-1], xind[0]:xind[-1]]
    Y = (Y.float() * maskMul + maskOffset).type(torch.complex64)
    batch = min(64, Y.shape[1])  #16
    for n in np.arange(0, nb, batch):
        nend = min(Y.shape[1], n + batch)
        Y[:, n:nend] = convolve(mov=Y[:, n:nend], img=cfRefImg[n:nend])

    # calculate ccsm
    lhalf = lcorr + lpad
    cc0 = torch.cat((torch.cat((Y[..., -lhalf:, -lhalf:], Y[..., -lhalf:, :lhalf + 1]), axis=-1),   
                    torch.cat((Y[..., :lhalf + 1, -lhalf:], Y[..., :lhalf + 1, :lhalf + 1]), axis=-1)), axis=-2)
    cc0 = torch.real(cc0)
    cc0 = cc0.permute(1, 0, 2, 3)
    cc0 = cc0.reshape(cc0.shape[0], -1)
    cc0 = cc0.cpu().numpy()

    del Y
    if device.type == "cuda":
        torch.cuda.empty_cache()    
        torch.cuda.synchronize()

    cc2 = [cc0, NRsm @ cc0, NRsm @ NRsm @ cc0]
    cc2 = [
        c2.reshape(nb, nimg, 2 * lcorr + 2 * lpad + 1, 2 * lcorr + 2 * lpad + 1)
        for c2 in cc2
    ]
    ccsm = cc2[0]

    for n in range(nb):
        snr = np.ones(nimg, dtype="float32")
        for j, c2 in enumerate(cc2):
            ism = snr < snr_thresh
            if ism.sum() == 0:
                break
            cc = c2[n, ism, :, :]
            if j > 0:
                ccsm[n, ism, :, :] = cc#.cpu().numpy()
            snr[ism] = getSNR(cc, lcorr, lpad)

    # calculate ymax1, xmax1, cmax1
    mdpt = nup // 2
    ymax1 = np.empty((nimg, nb), "float32")
    cmax1 = np.empty((nimg, nb), "float32")
    xmax1 = np.empty((nimg, nb), "float32")
    ymax = np.empty((nb,), "int32")
    xmax = np.empty((nb,), "int32")

    imax = ccsm[..., lpad:-lpad, lpad:-lpad].reshape(nb, nimg, -1).argmax(axis=-1)
    ymax, xmax = np.unravel_index(imax, (2 * lcorr + 1, 2 * lcorr + 1))
    ccmat = np.empty((nb, nimg, 2 * lpad + 1, 2 * lpad + 1), "float32")
    for t in range(nimg):
        for n in range(nb):
            ym, xm = ymax[n, t], xmax[n, t]
            ccmat[n, t] = ccsm[n, t, ym:ym + 2 * lpad + 1, xm:xm + 2 * lpad + 1]
    ccmat = torch.from_numpy(ccmat.reshape(nb * nimg, -1)).to(device)
    ccb = (ccmat @ Kmat.to(device)).reshape(nb, nimg, -1)
    cmax1, imax1 = ccb.max(axis=-1)
    ymax1, xmax1 = torch.div(imax1, nup, rounding_mode="floor"), imax1 % nup
    ymax1 = (ymax1 - mdpt) / subpixel + torch.from_numpy(ymax).to(device) - lcorr
    xmax1 = (xmax1 - mdpt) / subpixel + torch.from_numpy(xmax).to(device) - lcorr

    return ymax1.T.float(), xmax1.T.float(), cmax1.T, ccsm, ccb

transform_data #

transform_data(data, nblocks, xblock, yblock, ymax1, xmax1, data_ups=None, counts_ups=None)

Apply bilinear interpolation to transform image data using block-wise shifts. This function performs non-rigid image registration by interpolating block-wise shift values across the image and applying the resulting displacement field via the grid_sample function. It handles both standard GPU and Apple Silicon (MPS) devices.

data : torch.Tensor Input image data of shape (nimg, Ly, Lx) where nimg is the number of images, Ly is the height, and Lx is the width. nblocks : tuple of int Number of blocks in (y, x) dimensions for the registration grid. xblock : np.ndarray X-coordinates of block boundaries of length nblocks[0]nblocks[1]. yblock : np.ndarray Y-coordinates of block boundaries of length nblocks[0]nblocks[1]. ymax1 : torch.Tensor Tensor of shape (nblocks, N) with the y (row) shift for each frame and block that maximizes the phase-correlation. xmax1 : torch.Tensor Tensor of shape (nblocks, N) with the x (row) shift for each frame and block that maximizes the phase-correlation.

Returns:

Name Type Description
fr_shift Tensor

Shifted image data of shape (nimg, Ly, Lx) with dtype int16 (short). The input images are warped according to the interpolated displacement field.

Source code in suite2p/registration/nonrigid.py
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
def transform_data(data, nblocks, xblock, yblock, ymax1, xmax1, 
                   data_ups=None, counts_ups=None):
    """
    Apply bilinear interpolation to transform image data using block-wise shifts.
    This function performs non-rigid image registration by interpolating block-wise 
    shift values across the image and applying the resulting displacement field via 
    the `grid_sample` function. It handles both standard GPU and Apple Silicon (MPS) devices.

    data : torch.Tensor
        Input image data of shape (nimg, Ly, Lx) where nimg is the number of images,
        Ly is the height, and Lx is the width.
    nblocks : tuple of int
        Number of blocks in (y, x) dimensions for the registration grid.
    xblock : np.ndarray
        X-coordinates of block boundaries of length nblocks[0]*nblocks[1].
    yblock : np.ndarray
        Y-coordinates of block boundaries of length nblocks[0]*nblocks[1].
    ymax1 : torch.Tensor
        Tensor of shape (nblocks, N) with the y (row) shift for each frame and block that maximizes the
        phase-correlation. 
    xmax1 : torch.Tensor
        Tensor of shape (nblocks, N) with the x (row) shift for each frame and block that maximizes the
        phase-correlation. 

    Returns
    -------
    fr_shift : torch.Tensor
        Shifted image data of shape (nimg, Ly, Lx) with dtype int16 (short).
        The input images are warped according to the interpolated displacement field.
    """
    n_frames, Ly, Lx = data.shape
    #device = torch.device("cuda")
    #data = torch.from_numpy(data).to(device).float()
    device = data.device
    ymax1 = ymax1.reshape(-1, *nblocks)
    xmax1 = xmax1.reshape(-1, *nblocks)
    mshy, mshx = torch.meshgrid(torch.arange(Ly, dtype=torch.float, device=device),
                         torch.arange(Lx, dtype=torch.float, device=device), indexing="ij")
    yb = np.array(yblock[::nblocks[1]]).mean(axis=1).astype("int")
    xb = np.array(xblock[:nblocks[1]]).mean(axis=1).astype("int")
    Lyc, Lxc = int(yb.max() - yb.min()), int(xb.max() - xb.min())
    yxup = F.interpolate(torch.stack((ymax1, xmax1), dim=1), 
                         size=(Lyc, Lxc), mode="bilinear", align_corners=True)
    yxup = F.pad(yxup, (int(xb.min()), Lx - int(xb.max()), 
                        int(yb.min()), Ly - int(yb.max())), mode="replicate")

    if data_ups is not None and counts_ups is not None:
        ups = torch.Tensor([data_ups.shape[0] // Ly, data_ups.shape[1] // Lx]).to(device)
        yxup_round = -1*yxup.clone() + torch.stack((mshy, mshx), dim=0)
        yxup_round = torch.floor(0.5 + (yxup_round * ups.unsqueeze(-1).unsqueeze(-1))).long()
        yxup_round[:,0] = torch.clamp(yxup_round[:,0], min=0, max=Ly*ups[0] - 1)
        yxup_round[:,1] = torch.clamp(yxup_round[:,1], min=0, max=Lx*ups[1] - 1)
        for t in range(n_frames):
            mat = torch.sparse_coo_tensor(indices=yxup_round[t].reshape(2, -1), values=data[t].flatten(), 
                                            size=(int(Ly*ups[0]), int(Lx*ups[1]))).to_dense()
            data_ups += mat 
            mat = torch.sparse_coo_tensor(indices=yxup_round[t].reshape(2, -1), 
                                          values=torch.ones(data[t].numel(), device=device, dtype=torch.long), 
                                          size=(int(Ly*ups[0]), int(Lx*ups[1]))).to_dense()
            counts_ups += mat

            # data_ups[yxup_round[t,0], yxup_round[t,1]] += data[t]
            # counts_ups[yxup_round[t,0], yxup_round[t,1]] += 1

    # rescale for grid_sample
    yxup[:,0] += mshy
    yxup[:,1] += mshx
    yxup /= torch.Tensor([Ly-1, Lx-1]).to(device).unsqueeze(-1).unsqueeze(-1)
    yxup *= 2 
    yxup -= 1
    yxup = yxup.permute(0, 2, 3, 1)

    if device.type == "mps":
        # Manually pad the input tensor with the border values.
        data_padded = F.pad(data.float().unsqueeze(1), (1, 1, 1, 1), mode="replicate")
        # Get the height and width of the original data tensor
        height, width = data.shape[-2:] 
        # Scale the grid to account for the padding. Padded data is now of shape (width + 2) x (height + 2). 
        # Scale_x and scale_y adjust so we exclude the padding. Align_corner is set to true so original image width is width -1. Same for the height.
        scale_x = (width - 1) / (width + 1) 
        scale_y = (height - 1) / (height + 1)
        # Scale the padded image to be within the right coordinates for sampling
        adjusted_yxup = yxup * torch.tensor([[[[scale_x, scale_y]]]]).to(yxup.device)
        # Clamp the grid before subsampling as all coordinate values must lie between [-1,1]. 
        # Sampling should always be along the image (not include padding coordinates, which will exceed [-1,1] range).
        adjusted_yxup = torch.clamp(adjusted_yxup, -1, 1)
        # Perform grid sampling on the padded tensor
        fr_shift = F.grid_sample(
            data_padded,
            adjusted_yxup[:, :, :, [1, 0]],
            mode="bilinear",
            padding_mode="zeros",  # Default or any supported mode
            align_corners=True
        )
    else:
        fr_shift = F.grid_sample(data.float().unsqueeze(1), yxup[:,:,:,[1,0]], 
                             mode="bilinear", padding_mode="border", align_corners=True)


    return fr_shift.squeeze().short()#.cpu().numpy()

Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.

assign_reg_io #

assign_reg_io(f_reg, f_raw=None, f_reg_chan2=None, f_raw_chan2=None, align_by_chan2=False, save_path=None, reg_tif=False, reg_tif_chan2=False)

Assign input/output arrays and tiff directories for registration I/O.

Determines which channel is the alignment source and which is the alternate, based on align_by_chan2. Sets up tiff output directories if requested.

Parameters:

Name Type Description Default
f_reg ndarray or BinaryFile

Registered functional channel frames.

required
f_raw ndarray or BinaryFile or None

Raw functional channel frames, used as input when available.

None
f_reg_chan2 ndarray or BinaryFile or None

Registered second channel frames.

None
f_raw_chan2 ndarray or BinaryFile or None

Raw second channel frames.

None
align_by_chan2 bool

If True, use the second channel as the alignment source.

False
save_path str or None

Base directory for saving registered tiff files.

None
reg_tif bool

If True, save registered functional channel frames as tiffs.

False
reg_tif_chan2 bool

If True, save registered second channel frames as tiffs.

False

Returns:

Name Type Description
f_align_in ndarray or BinaryFile

Input frames for alignment.

f_align_out ndarray or BinaryFile or None

Output destination for aligned frames.

f_alt_in ndarray or BinaryFile or None

Input frames for the alternate channel.

f_alt_out ndarray or BinaryFile or None

Output destination for shifted alternate channel frames.

tif_root_align str or None

Tiff output directory for the alignment channel.

tif_root_alt str or None

Tiff output directory for the alternate channel.

Source code in suite2p/registration/register.py
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
def assign_reg_io(f_reg, f_raw=None, f_reg_chan2=None, 
               f_raw_chan2=None, align_by_chan2=False, 
               save_path=None,
               reg_tif=False, reg_tif_chan2=False):
    """
    Assign input/output arrays and tiff directories for registration I/O.

    Determines which channel is the alignment source and which is the alternate,
    based on align_by_chan2. Sets up tiff output directories if requested.

    Parameters
    ----------
    f_reg : np.ndarray or BinaryFile
        Registered functional channel frames.
    f_raw : np.ndarray or BinaryFile or None
        Raw functional channel frames, used as input when available.
    f_reg_chan2 : np.ndarray or BinaryFile or None
        Registered second channel frames.
    f_raw_chan2 : np.ndarray or BinaryFile or None
        Raw second channel frames.
    align_by_chan2 : bool
        If True, use the second channel as the alignment source.
    save_path : str or None
        Base directory for saving registered tiff files.
    reg_tif : bool
        If True, save registered functional channel frames as tiffs.
    reg_tif_chan2 : bool
        If True, save registered second channel frames as tiffs.

    Returns
    -------
    f_align_in : np.ndarray or BinaryFile
        Input frames for alignment.
    f_align_out : np.ndarray or BinaryFile or None
        Output destination for aligned frames.
    f_alt_in : np.ndarray or BinaryFile or None
        Input frames for the alternate channel.
    f_alt_out : np.ndarray or BinaryFile or None
        Output destination for shifted alternate channel frames.
    tif_root_align : str or None
        Tiff output directory for the alignment channel.
    tif_root_alt : str or None
        Tiff output directory for the alternate channel.
    """
    if f_reg_chan2 is None or not align_by_chan2:
        f_align_in = f_reg if not f_raw else f_raw
        f_alt_in = f_reg_chan2 if not f_raw_chan2 else f_raw_chan2
        f_align_out = f_reg if f_raw else None
        f_alt_out = f_reg_chan2 if f_raw_chan2 else None
    else:
        f_align_in = f_reg_chan2 if not f_raw_chan2 else f_raw_chan2
        f_alt_in = f_reg if not f_raw else f_raw
        f_align_out  = f_reg_chan2 if f_raw_chan2 else None
        f_alt_out = f_reg if f_raw else None

    if f_alt_in is not None:
        if f_align_in.shape[0] != f_alt_in.shape[0]:
            raise ValueError("number of frames in f_align_in and f_alt_in must match")

    tif_root_align, tif_root_alt = None, None
    if save_path:
        if reg_tif:
            tifroot = os.path.join(save_path, "reg_tif")
            os.makedirs(tifroot, exist_ok=True)
            if not align_by_chan2:
                tif_root_align = tifroot
            else:
                tif_root_alt = tifroot
        if reg_tif_chan2:
            tifroot = os.path.join(save_path, "reg_tif_chan2")
            os.makedirs(tifroot, exist_ok=True)
            if align_by_chan2:
                tif_root_align = tifroot
            else:
                tif_root_alt = tifroot

    return f_align_in, f_align_out, f_alt_in, f_alt_out, tif_root_align, tif_root_alt

check_offsets #

check_offsets(yoff, xoff, yoff1, xoff1, n_frames)

Validate that registration offset arrays have the expected number of frames.

Parameters:

Name Type Description Default
yoff ndarray or None

Rigid y offsets of length n_frames.

required
xoff ndarray or None

Rigid x offsets of length n_frames.

required
yoff1 ndarray or None

Nonrigid y offsets of shape (n_frames, n_blocks), or None.

required
xoff1 ndarray or None

Nonrigid x offsets of shape (n_frames, n_blocks), or None.

required
n_frames int

Expected number of frames.

required

Raises:

Type Description
ValueError

If rigid offsets are None or any offset array length does not match n_frames.

Source code in suite2p/registration/register.py
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
def check_offsets(yoff, xoff, yoff1, xoff1, n_frames):
    """
    Validate that registration offset arrays have the expected number of frames.

    Parameters
    ----------
    yoff : np.ndarray or None
        Rigid y offsets of length n_frames.
    xoff : np.ndarray or None
        Rigid x offsets of length n_frames.
    yoff1 : np.ndarray or None
        Nonrigid y offsets of shape (n_frames, n_blocks), or None.
    xoff1 : np.ndarray or None
        Nonrigid x offsets of shape (n_frames, n_blocks), or None.
    n_frames : int
        Expected number of frames.

    Raises
    ------
    ValueError
        If rigid offsets are None or any offset array length does not match
        n_frames.
    """
    if yoff is None or xoff is None:
        raise ValueError("no rigid registration offsets provided")
    elif yoff.shape[0] != n_frames or xoff.shape[0] != n_frames:
        raise ValueError(
            "rigid registration offsets are not the same size as input frames")
    if yoff1 is not None and (yoff1.shape[0] != n_frames or xoff1.shape[0] != n_frames):
        raise ValueError(
                "nonrigid registration offsets are not the same size as input frames")

compute_crop #

compute_crop(xoff, yoff, corrXY, th_badframes, badframes, maxregshift, Ly, Lx)

Determine how much to crop the FOV based on registration motion offsets.

Identifies badframes (frames with large outlier shifts, thresholded by th_badframes) and excludes them when computing valid y and x ranges for cropping the field of view.

Parameters:

Name Type Description Default
xoff ndarray

1-D array of length n_frames with x (column) rigid registration offsets.

required
yoff ndarray

1-D array of length n_frames with y (row) rigid registration offsets.

required
corrXY ndarray

1-D array of length n_frames with phase-correlation values for each frame.

required
th_badframes float

Threshold multiplier for detecting bad frames based on the ratio of shift deviation to correlation quality.

required
badframes ndarray

1-D boolean array of length n_frames with pre-existing bad frame labels.

required
maxregshift float

Maximum allowed registration shift as a fraction of the image dimension. Frames exceeding 95% of this limit are marked as bad.

required
Ly int

Height of a frame in pixels.

required
Lx int

Width of a frame in pixels.

required

Returns:

Name Type Description
badframes ndarray

Updated 1-D boolean array of length n_frames indicating bad frames.

yrange list of int

[ymin, ymax] valid row range after cropping for motion.

xrange list of int

[xmin, xmax] valid column range after cropping for motion.

Source code in suite2p/registration/register.py
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def compute_crop(xoff: int, yoff: int, corrXY, th_badframes, badframes, maxregshift,
                 Ly: int, Lx: int):
    """
    Determine how much to crop the FOV based on registration motion offsets.

    Identifies badframes (frames with large outlier shifts, thresholded by
    th_badframes) and excludes them when computing valid y and x ranges for
    cropping the field of view.

    Parameters
    ----------
    xoff : np.ndarray
        1-D array of length n_frames with x (column) rigid registration offsets.
    yoff : np.ndarray
        1-D array of length n_frames with y (row) rigid registration offsets.
    corrXY : np.ndarray
        1-D array of length n_frames with phase-correlation values for each frame.
    th_badframes : float
        Threshold multiplier for detecting bad frames based on the ratio of shift
        deviation to correlation quality.
    badframes : np.ndarray
        1-D boolean array of length n_frames with pre-existing bad frame labels.
    maxregshift : float
        Maximum allowed registration shift as a fraction of the image dimension.
        Frames exceeding 95% of this limit are marked as bad.
    Ly : int
        Height of a frame in pixels.
    Lx : int
        Width of a frame in pixels.

    Returns
    -------
    badframes : np.ndarray
        Updated 1-D boolean array of length n_frames indicating bad frames.
    yrange : list of int
        [ymin, ymax] valid row range after cropping for motion.
    xrange : list of int
        [xmin, xmax] valid column range after cropping for motion.
    """
    filter_window = min((len(yoff) // 2) * 2 - 1, 101)
    dx = xoff - medfilt(xoff, filter_window)
    dy = yoff - medfilt(yoff, filter_window)
    # offset in x and y (normed by mean offset)
    dxy = (dx**2 + dy**2)**.5
    dxy = dxy / dxy.mean()
    # phase-corr of each frame with reference (normed by median phase-corr)
    cXY = corrXY / medfilt(corrXY, filter_window)
    # exclude frames which have a large deviation and/or low correlation
    px = dxy / np.maximum(0, cXY)
    badframes = np.logical_or(px > th_badframes * 100, badframes)
    badframes = np.logical_or(abs(xoff) > (maxregshift * Lx * 0.95), badframes)
    badframes = np.logical_or(abs(yoff) > (maxregshift * Ly * 0.95), badframes)
    if badframes.mean() < 0.5:
        ymin = np.ceil(np.abs(yoff[np.logical_not(badframes)]).max())
        xmin = np.ceil(np.abs(xoff[np.logical_not(badframes)]).max())
    else:
        warn(
            "WARNING: >50% of frames have large movements, registration likely problematic"
        )
        ymin = np.ceil(np.abs(yoff).max())
        xmin = np.ceil(np.abs(xoff).max())
    ymax = Ly - ymin
    xmax = Lx - xmin
    yrange = [int(ymin), int(ymax)]
    xrange = [int(xmin), int(xmax)]

    return badframes, yrange, xrange

compute_filters_and_norm #

compute_filters_and_norm(refImg, norm_frames=True, spatial_smooth=1.15, spatial_taper=3.45, block_size=(128, 128), lpad=3, subpixel=10, device=torch.device('cuda'))

Compute registration masks, smoothed reference FFTs, and normalization bounds.

Builds rigid and (optionally) nonrigid spatial taper masks, smoothed Fourier-domain reference images, and intensity normalization bounds from the reference image. If refImg is a list (multi-plane), recurses for each plane.

Parameters:

Name Type Description Default
refImg np.ndarray or list of np.ndarray

Reference image of shape (Ly, Lx), or a list of reference images for multi-plane registration.

required
norm_frames bool

If True, clip the reference image to [1st, 99th] percentile and return the clipping bounds.

True
spatial_smooth float

Standard deviation (in pixels) of Gaussian smoothing applied to the reference image in the frequency domain.

1.15
spatial_taper float

Scalar controlling the slope of the sigmoid spatial taper mask at image borders.

3.45
block_size tuple of int or None

Block size (Ly_block, Lx_block) for nonrigid registration. If None, nonrigid masks are not computed.

(128, 128)
lpad int

Number of pixels to pad each nonrigid block.

3
subpixel int

Subpixel accuracy factor for nonrigid block shifts.

10
device device

Torch device to move the masks and reference FFTs to.

device('cuda')

Returns:

Type Description
tuple

If refImg is a single image, returns (maskMul, maskOffset, cfRefImg, maskMulNR, maskOffsetNR, cfRefImgNR, blocks, rmin, rmax). If refImg is a list, returns a list of such tuples.

Source code in suite2p/registration/register.py
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
def compute_filters_and_norm(refImg, norm_frames=True, spatial_smooth=1.15, spatial_taper=3.45,
                             block_size=(128, 128), lpad=3, subpixel=10, device=torch.device("cuda")):
    """
    Compute registration masks, smoothed reference FFTs, and normalization bounds.

    Builds rigid and (optionally) nonrigid spatial taper masks, smoothed
    Fourier-domain reference images, and intensity normalization bounds from the
    reference image. If refImg is a list (multi-plane), recurses for each plane.

    Parameters
    ----------
    refImg : np.ndarray or list of np.ndarray
        Reference image of shape (Ly, Lx), or a list of reference images for
        multi-plane registration.
    norm_frames : bool
        If True, clip the reference image to [1st, 99th] percentile and return
        the clipping bounds.
    spatial_smooth : float
        Standard deviation (in pixels) of Gaussian smoothing applied to the
        reference image in the frequency domain.
    spatial_taper : float
        Scalar controlling the slope of the sigmoid spatial taper mask at image
        borders.
    block_size : tuple of int or None
        Block size (Ly_block, Lx_block) for nonrigid registration. If None,
        nonrigid masks are not computed.
    lpad : int
        Number of pixels to pad each nonrigid block.
    subpixel : int
        Subpixel accuracy factor for nonrigid block shifts.
    device : torch.device
        Torch device to move the masks and reference FFTs to.

    Returns
    -------
    tuple
        If refImg is a single image, returns (maskMul, maskOffset, cfRefImg,
        maskMulNR, maskOffsetNR, cfRefImgNR, blocks, rmin, rmax). If refImg is
        a list, returns a list of such tuples.
    """
    if isinstance(refImg, list):
        refAndMasks_all = []
        for rimg in refImg:

            refAndMasks = compute_filters_and_norm(rimg, norm_frames=norm_frames, 
                                                   spatial_smooth=spatial_smooth, 
                                                   spatial_taper=spatial_taper, 
                                                   lpad=lpad, subpixel=subpixel, 
                                                   block_size=block_size, device=device)

            refAndMasks_all.append(refAndMasks)
        return refAndMasks_all
    else:
        if norm_frames:
            refImg, rmin, rmax = normalize_reference_image(refImg)
        else:
            rmin, rmax = -np.inf, np.inf

        rimg = torch.from_numpy(refImg)
        maskMul, maskOffset, cfRefImg = rigid.compute_masks_ref_smooth_fft(refImg=rimg, maskSlope=spatial_taper,
                                                                 smooth_sigma=spatial_smooth)
        Ly, Lx = refImg.shape
        # MPS backend does not support float64, convert to float32
        if device.type == "mps":
            maskMul, maskOffset = maskMul.to(torch.float32), maskOffset.to(torch.float32)
            cfRefImg = cfRefImg.to(torch.complex64)
        maskMul, maskOffset = maskMul.to(device), maskOffset.to(device)
        cfRefImg = cfRefImg.to(device)
        blocks = []
        if block_size is not None:
            blocks = nonrigid.make_blocks(Ly=Ly, Lx=Lx, block_size=block_size,
                                          lpad=lpad, subpixel=subpixel)
            maskMulNR, maskOffsetNR, cfRefImgNR = nonrigid.compute_masks_ref_smooth_fft(
                refImg0=rimg, maskSlope=spatial_taper, smooth_sigma=spatial_smooth,
                yblock=blocks[0], xblock=blocks[1],
            )
            # MPS backend does not support float64, convert to float32
            if device.type == "mps":
                maskMulNR, maskOffsetNR = maskMulNR.to(torch.float32), maskOffsetNR.to(torch.float32)
                cfRefImgNR = cfRefImgNR.to(torch.complex64)
            maskMulNR, maskOffsetNR = maskMulNR.to(device), maskOffsetNR.to(device)
            cfRefImgNR = cfRefImgNR.to(device)

        else:
            maskMulNR, maskOffsetNR, cfRefImgNR = None, None, None

        return (maskMul, maskOffset, cfRefImg, 
                maskMulNR, maskOffsetNR, cfRefImgNR, 
                blocks,
                rmin, rmax)

compute_reference #

compute_reference(frames, settings=default_settings(), device=torch.device('cuda'))

Compute the reference image by iterative rigid alignment.

Picks an initial reference via pick_initial_reference, then iteratively registers frames to the current reference and updates the reference as the mean of the best-correlated frames.

Parameters:

Name Type Description Default
frames ndarray

Frames of shape (nimg_init, Ly, Lx), dtype int16, used to build the reference image.

required
settings dict

Registration settings dictionary containing keys "batch_size", "smooth_sigma", "spatial_taper", and "maxregshift".

default_settings()
device device

Torch device (CPU or CUDA) on which to run registration.

device('cuda')

Returns:

Name Type Description
refImg ndarray

Reference image of shape (Ly, Lx), dtype int16.

Source code in suite2p/registration/register.py
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def compute_reference(frames, settings=default_settings(), device=torch.device("cuda")):
    """
    Compute the reference image by iterative rigid alignment.

    Picks an initial reference via pick_initial_reference, then iteratively
    registers frames to the current reference and updates the reference as the
    mean of the best-correlated frames.

    Parameters
    ----------
    frames : np.ndarray
        Frames of shape (nimg_init, Ly, Lx), dtype int16, used to build the
        reference image.
    settings : dict
        Registration settings dictionary containing keys "batch_size",
        "smooth_sigma", "spatial_taper", and "maxregshift".
    device : torch.device
        Torch device (CPU or CUDA) on which to run registration.

    Returns
    -------
    refImg : np.ndarray
        Reference image of shape (Ly, Lx), dtype int16.
    """
    fr_reg = torch.from_numpy(frames)
    refImg = pick_initial_reference(fr_reg)

    niter = 8
    batch_size = settings["batch_size"]
    for iter in range(0, niter):
        # rigid registration shifts to reference
        maskMul, maskOffset, cfRefImg = compute_filters_and_norm(refImg, False, settings["smooth_sigma"],
                                           settings["spatial_taper"], block_size=None, device=device)[:3]

        for k in range(0, fr_reg.shape[0], batch_size):
            fr_reg_batch = fr_reg[k:min(k + batch_size, fr_reg.shape[0])].to(device)
            ymax, xmax, cmax = rigid.phasecorr(fr_reg_batch, cfRefImg, maskMul, maskOffset,
                maxregshift=settings["maxregshift"],
                smooth_sigma_time=settings["smooth_sigma_time"])[:3]

            # shift frames to reference
            fr_reg_batch = torch.stack([torch.roll(frame, shifts=(-dy, -dx), dims=(0, 1))
                                for frame, dy, dx in zip(fr_reg_batch, ymax, xmax)], axis=0)
            fr_reg[k:min(k + batch_size, fr_reg.shape[0])] = fr_reg_batch.cpu()

        # frames to average for new reference
        nmax = max(2, int(frames.shape[0] * (1. + iter) / (2 * niter)))
        isort = torch.argsort(-cmax)[:nmax].cpu()
        refImg = fr_reg[isort].double().mean(dim=0)

        # recenter reference image
        if device.type == 'mps':
            # MPS backend currently can not support float64
            dy, dx = -torch.round(ymax[isort].to(torch.float32).mean()).int(), -torch.round(xmax[isort].to(torch.float32).mean()).int()
        else:
            dy, dx = -torch.round(ymax[isort].double().mean()).int(), -torch.round(xmax[isort].double().mean()).int()
        refImg = torch.roll(refImg, shifts=(-dy, -dx), dims=(0, 1))
        refImg = refImg.numpy().astype("int16")

    del fr_reg_batch 
    if device.type == "cuda":
        torch.cuda.empty_cache()    
        torch.cuda.synchronize()

    if device.type == "mps":
        torch.mps.empty_cache()
        torch.mps.synchronize()

    return refImg

compute_shifts #

compute_shifts(refAndMasks, fr_reg, maxregshift=0.1, smooth_sigma_time=0, snr_thresh=1.2, maxregshiftNR=5, nZ=1)

Compute rigid and nonrigid registration shifts for a batch of frames.

Performs rigid phase-correlation registration, then (if nonrigid masks are provided) applies rigid shifts and computes nonrigid block shifts. For multi-plane data (nZ > 1), selects the best z-plane per frame by maximum correlation.

Parameters:

Name Type Description Default
refAndMasks tuple or list of tuple

Registration masks and reference FFTs from compute_filters_and_norm. If nZ > 1, a list of tuples (one per z-plane).

required
fr_reg Tensor

Frames to register, shape (N, Ly, Lx).

required
maxregshift float

Maximum allowed rigid shift as a fraction of the smaller image dimension.

0.1
smooth_sigma_time float

Sigma for temporal smoothing of phase-correlation maps. If <= 0, no temporal smoothing is applied.

0
snr_thresh float

Signal-to-noise ratio threshold for accepting nonrigid block shifts.

1.2
maxregshiftNR int

Maximum allowed nonrigid shift in pixels.

5
nZ int

Number of z-planes. If > 1, performs multi-plane registration.

1

Returns:

Name Type Description
ymax LongTensor

1-D rigid y shifts of length N.

xmax LongTensor

1-D rigid x shifts of length N.

cmax Tensor

1-D rigid correlation values of length N.

ymax1 Tensor or None

Nonrigid y shifts of shape (N, n_blocks), or None if nonrigid is disabled.

xmax1 Tensor or None

Nonrigid x shifts of shape (N, n_blocks), or None if nonrigid is disabled.

cmax1 Tensor or None

Nonrigid correlation values of shape (N, n_blocks), or None.

zest ndarray or None

Best z-plane index per frame of length N (only if nZ > 1), else None.

cmax_all ndarray or None

Correlation values across all z-planes of shape (N, nZ) (only if nZ > 1), else None.

Source code in suite2p/registration/register.py
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
def compute_shifts(refAndMasks, fr_reg, maxregshift=0.1, smooth_sigma_time=0,
                   snr_thresh=1.2, maxregshiftNR=5, nZ=1):
    """
    Compute rigid and nonrigid registration shifts for a batch of frames.

    Performs rigid phase-correlation registration, then (if nonrigid masks are
    provided) applies rigid shifts and computes nonrigid block shifts. For
    multi-plane data (nZ > 1), selects the best z-plane per frame by maximum
    correlation.

    Parameters
    ----------
    refAndMasks : tuple or list of tuple
        Registration masks and reference FFTs from compute_filters_and_norm. If
        nZ > 1, a list of tuples (one per z-plane).
    fr_reg : torch.Tensor
        Frames to register, shape (N, Ly, Lx).
    maxregshift : float
        Maximum allowed rigid shift as a fraction of the smaller image dimension.
    smooth_sigma_time : float
        Sigma for temporal smoothing of phase-correlation maps. If <= 0, no
        temporal smoothing is applied.
    snr_thresh : float
        Signal-to-noise ratio threshold for accepting nonrigid block shifts.
    maxregshiftNR : int
        Maximum allowed nonrigid shift in pixels.
    nZ : int
        Number of z-planes. If > 1, performs multi-plane registration.

    Returns
    -------
    ymax : torch.LongTensor
        1-D rigid y shifts of length N.
    xmax : torch.LongTensor
        1-D rigid x shifts of length N.
    cmax : torch.Tensor
        1-D rigid correlation values of length N.
    ymax1 : torch.Tensor or None
        Nonrigid y shifts of shape (N, n_blocks), or None if nonrigid is disabled.
    xmax1 : torch.Tensor or None
        Nonrigid x shifts of shape (N, n_blocks), or None if nonrigid is disabled.
    cmax1 : torch.Tensor or None
        Nonrigid correlation values of shape (N, n_blocks), or None.
    zest : np.ndarray or None
        Best z-plane index per frame of length N (only if nZ > 1), else None.
    cmax_all : np.ndarray or None
        Correlation values across all z-planes of shape (N, nZ) (only if nZ > 1),
        else None.
    """
    n_fr = fr_reg.shape[0]
    if nZ > 1:
        # find best plane
        offsets_all = []
        for z in range(nZ):
            fr_reg0 = fr_reg.clone()
            offsets0 = compute_shifts(refAndMasks[z], fr_reg0, maxregshift, 
                                      smooth_sigma_time, snr_thresh, 
                                      maxregshiftNR, nZ=1)
            offsets_all.append(offsets0)
        cmax_all = np.array([offsets[2].cpu().numpy() for offsets in offsets_all]).T
        zest = cmax_all.argmax(axis=1)

        nb = refAndMasks[0][3].shape[0] if refAndMasks[0][3] is not None else 0
        device = fr_reg.device
        shapes = [(n_fr,), (n_fr,), (n_fr,), (n_fr, nb), (n_fr, nb), (n_fr, nb)]
        offsets_best = [torch.zeros(shapes[i], device=device, 
                                    dtype=torch.float32 if i > 1 else torch.long) 
                        for i in range(6)]
        for z in range(nZ):
            iz = np.nonzero(zest == z)[0]
            if len(iz) > 0:
                for i, offsets in enumerate(offsets_all[z][:6]):
                    offsets_best[i][iz] = offsets[iz] if offsets is not None else 0

        return *offsets_best[:6], zest, cmax_all

    else:
        (maskMul, maskOffset, cfRefImg, maskMulNR, maskOffsetNR, cfRefImgNR, 
        blocks, rmin, rmax) = refAndMasks
        device = fr_reg.device

        fr_reg = torch.clip(fr_reg, rmin, rmax) if rmin > -np.inf else fr_reg

        # rigid registration
        ymax, xmax, cmax = rigid.phasecorr(fr_reg, cfRefImg, maskMul, maskOffset, 
                                        maxregshift, smooth_sigma_time)[:3]

        # non-rigid registration
        if maskMulNR is not None and maxregshiftNR > 0:     
            # shift torch frames to reference
            fr_reg = torch.stack([torch.roll(frame, shifts=(-dy, -dx), dims=(0, 1))
                                for frame, dy, dx in zip(fr_reg, ymax, xmax)], axis=0)
            ymax1, xmax1, cmax1 = nonrigid.phasecorr(fr_reg, blocks, 
                                                    maskMulNR, maskOffsetNR, cfRefImgNR, 
                                                    snr_thresh, maxregshiftNR)[:3]
        else:    
            ymax1, xmax1, cmax1 = None, None, None

        del fr_reg
        if device.type == "cuda":
            torch.cuda.empty_cache()    

        if device.type == "mps":
            torch.mps.empty_cache()

    return ymax, xmax, cmax, ymax1, xmax1, cmax1, None, None

normalize_reference_image #

normalize_reference_image(refImg)

Clip reference image to [1st, 99th] intensity percentiles.

Parameters:

Name Type Description Default
refImg ndarray

Reference image of shape (Ly, Lx).

required

Returns:

Name Type Description
refImg ndarray

Clipped reference image of shape (Ly, Lx).

rmin int16

1st percentile intensity value used as the lower clip bound.

rmax int16

99th percentile intensity value used as the upper clip bound.

Source code in suite2p/registration/register.py
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
def normalize_reference_image(refImg):
    """
    Clip reference image to [1st, 99th] intensity percentiles.

    Parameters
    ----------
    refImg : np.ndarray
        Reference image of shape (Ly, Lx).

    Returns
    -------
    refImg : np.ndarray
        Clipped reference image of shape (Ly, Lx).
    rmin : np.int16
        1st percentile intensity value used as the lower clip bound.
    rmax : np.int16
        99th percentile intensity value used as the upper clip bound.
    """
    rmin, rmax = np.percentile(refImg, [1, 99]).astype(np.int16)
    refImg = np.clip(refImg, rmin, rmax)
    return refImg, rmin, rmax

pick_initial_reference #

pick_initial_reference(frames)

Compute the initial reference image by finding the most correlated frame.

The seed frame is the frame with the largest mean pairwise correlation with its 20 top correlated frame pairs. The initial reference is the average of that seed frame and its top 20 most correlated frames.

Parameters:

Name Type Description Default
frames Tensor

Input frames of shape (n_frames, Ly, Lx).

required

Returns:

Name Type Description
refImg ndarray

Initial reference image of shape (Ly, Lx), dtype int16.

Source code in suite2p/registration/register.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def pick_initial_reference(frames: torch.Tensor):
    """
    Compute the initial reference image by finding the most correlated frame.

    The seed frame is the frame with the largest mean pairwise correlation with its
    20 top correlated frame pairs. The initial reference is the average of that seed frame and its
    top 20 most correlated frames.

    Parameters
    ----------
    frames : torch.Tensor
        Input frames of shape (n_frames, Ly, Lx).

    Returns
    -------
    refImg : np.ndarray
        Initial reference image of shape (Ly, Lx), dtype int16.
    """
    nimg, Ly, Lx = frames.shape
    fr_z = frames.clone().reshape(nimg, -1).double()
    fr_z -= fr_z.mean(dim=1, keepdim=True)
    cc = fr_z @ fr_z.T 
    ndiag = torch.diag(cc)**0.5
    cc = cc / torch.outer(ndiag, ndiag)
    CCsort = -torch.sort(-cc, dim=1)[0]
    # find frame most correlated to other frames
    bestCC = CCsort[:, 1:20].mean(dim=1) # 1-20 to exclude own frame
    imax = torch.argmax(bestCC)
    # average top 20 frames most correlated to imax
    indsort = torch.argsort(-cc[imax, :])
    refImg = fr_z[indsort[:20]].mean(axis=0).cpu().numpy().astype("int16")
    refImg = refImg.reshape(Ly, Lx)
    return refImg

register_frames #

register_frames(f_align_in, refImg, f_align_out=None, batch_size=100, bidiphase=0, norm_frames=True, smooth_sigma=1.15, spatial_taper=3.45, block_size=(128, 128), nonrigid=True, maxregshift=0.1, smooth_sigma_time=0, snr_thresh=1.2, maxregshiftNR=5, subpixel=10, device=torch.device('cuda'), tif_root=None, apply_shifts=True, upsample_meanImg=False)

Register frames to a reference image using rigid and optionally nonrigid shifts.

Computes registration masks from the reference, then processes frames in batches: computes shifts, applies them, accumulates a mean image, and optionally writes registered frames to f_align_out. Supports multi-plane registration when refImg is a list.

Parameters:

Name Type Description Default
f_align_in ndarray or BinaryFile

Input frames of shape (n_frames, Ly, Lx), supporting slice indexing.

required
refImg np.ndarray or list of np.ndarray

Reference image of shape (Ly, Lx), or a list for multi-plane registration.

required
f_align_out ndarray or BinaryFile or None

Output array for registered frames. If None, registered frames are written back to f_align_in.

None
batch_size int

Number of frames to process per batch.

100
bidiphase int

Bidirectional phase offset in pixels. If non-zero, frames are corrected before registration.

0
norm_frames bool

If True, clip frames to the reference image's [1st, 99th] percentile range.

True
smooth_sigma float

Standard deviation of Gaussian smoothing applied to the reference image.

1.15
spatial_taper float

Slope of the sigmoid spatial taper mask at image borders.

3.45
block_size tuple of int

Block size (Ly_block, Lx_block) for nonrigid registration.

(128, 128)
nonrigid bool

If True, compute nonrigid shifts in addition to rigid shifts.

True
maxregshift float

Maximum rigid shift as a fraction of the smaller image dimension.

0.1
smooth_sigma_time float

Sigma for temporal smoothing of phase-correlation maps.

0
snr_thresh float

SNR threshold for accepting nonrigid block shifts.

1.2
maxregshiftNR int

Maximum nonrigid shift in pixels.

5
device device

Torch device for computation.

device('cuda')
tif_root str or None

If provided, save registered frames as tiffs in this directory.

None
apply_shifts bool

If True, apply computed shifts to frames. If False, only compute shifts.

True
upsample_meanImg bool, int, list, or tuple

Upsampling factor for super-resolution mean image computation. If False or None, no upsampling is performed. If int, same factor is used for both Y and X. If list/tuple of length 2, specifies [Y_factor, X_factor]. The mean image is computed by accumulating registered frames at subpixel locations and normalizing by pixel counts.

False

Returns:

Name Type Description
rmin int16 or list

Lower intensity clip bound(s) from reference normalization.

rmax int16 or list

Upper intensity clip bound(s) from reference normalization.

mean_img ndarray

Mean registered image of shape (Ly, Lx).

offsets_all list

List of [yoff, xoff, corrXY, yoff1, xoff1, corrXY1, zest, cmax_all] concatenated across all batches.

blocks list

Block definitions from nonrigid.make_blocks.

mean_img_ups Tensor or None

Raw upsampled mean image tensor of shape (Lyupsample[0], Lxupsample[1]) before normalization. None if upsample_meanImg is False.

counts_ups Tensor or None

Pixel counts tensor of shape (Lyupsample[0], Lxupsample[1]) indicating how many frames contributed to each upsampled pixel. None if upsample_meanImg is False.

meanImg_ups ndarray or None

Super-resolution mean image of shape (Lyupsample[0], Lxupsample[1]) after Gaussian smoothing and normalization by counts. None if upsample_meanImg is False.

Source code in suite2p/registration/register.py
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
def register_frames(f_align_in, refImg, f_align_out=None, batch_size=100,
                    bidiphase=0, norm_frames=True, smooth_sigma=1.15, spatial_taper=3.45,
                    block_size=(128,128), nonrigid=True, maxregshift=0.1,
                    smooth_sigma_time=0, snr_thresh=1.2, maxregshiftNR=5,
                    subpixel=10, device=torch.device("cuda"), tif_root=None, apply_shifts=True,
                    upsample_meanImg=False):
    """
    Register frames to a reference image using rigid and optionally nonrigid shifts.

    Computes registration masks from the reference, then processes frames in
    batches: computes shifts, applies them, accumulates a mean image, and
    optionally writes registered frames to f_align_out. Supports multi-plane
    registration when refImg is a list.

    Parameters
    ----------
    f_align_in : np.ndarray or BinaryFile
        Input frames of shape (n_frames, Ly, Lx), supporting slice indexing.
    refImg : np.ndarray or list of np.ndarray
        Reference image of shape (Ly, Lx), or a list for multi-plane registration.
    f_align_out : np.ndarray or BinaryFile or None
        Output array for registered frames. If None, registered frames are
        written back to f_align_in.
    batch_size : int
        Number of frames to process per batch.
    bidiphase : int
        Bidirectional phase offset in pixels. If non-zero, frames are corrected
        before registration.
    norm_frames : bool
        If True, clip frames to the reference image's [1st, 99th] percentile range.
    smooth_sigma : float
        Standard deviation of Gaussian smoothing applied to the reference image.
    spatial_taper : float
        Slope of the sigmoid spatial taper mask at image borders.
    block_size : tuple of int
        Block size (Ly_block, Lx_block) for nonrigid registration.
    nonrigid : bool
        If True, compute nonrigid shifts in addition to rigid shifts.
    maxregshift : float
        Maximum rigid shift as a fraction of the smaller image dimension.
    smooth_sigma_time : float
        Sigma for temporal smoothing of phase-correlation maps.
    snr_thresh : float
        SNR threshold for accepting nonrigid block shifts.
    maxregshiftNR : int
        Maximum nonrigid shift in pixels.
    device : torch.device
        Torch device for computation.
    tif_root : str or None
        If provided, save registered frames as tiffs in this directory.
    apply_shifts : bool
        If True, apply computed shifts to frames. If False, only compute shifts.
    upsample_meanImg : bool, int, list, or tuple
        Upsampling factor for super-resolution mean image computation.
        If False or None, no upsampling is performed. If int, same factor is used
        for both Y and X. If list/tuple of length 2, specifies [Y_factor, X_factor].
        The mean image is computed by accumulating registered frames at subpixel
        locations and normalizing by pixel counts.

    Returns
    -------
    rmin : np.int16 or list
        Lower intensity clip bound(s) from reference normalization.
    rmax : np.int16 or list
        Upper intensity clip bound(s) from reference normalization.
    mean_img : np.ndarray
        Mean registered image of shape (Ly, Lx).
    offsets_all : list
        List of [yoff, xoff, corrXY, yoff1, xoff1, corrXY1, zest, cmax_all]
        concatenated across all batches.
    blocks : list
        Block definitions from nonrigid.make_blocks.
    mean_img_ups : torch.Tensor or None
        Raw upsampled mean image tensor of shape (Ly*upsample[0], Lx*upsample[1])
        before normalization. None if upsample_meanImg is False.
    counts_ups : torch.Tensor or None
        Pixel counts tensor of shape (Ly*upsample[0], Lx*upsample[1]) indicating
        how many frames contributed to each upsampled pixel. None if upsample_meanImg is False.
    meanImg_ups : np.ndarray or None
        Super-resolution mean image of shape (Ly*upsample[0], Lx*upsample[1])
        after Gaussian smoothing and normalization by counts. None if upsample_meanImg is False.
    """

    n_frames, Ly, Lx = f_align_in.shape

    if isinstance(refImg, list):
        nZ = len(refImg)
        logger.info(f"List of reference frames len = {nZ}")
    else:
        nZ = 1

    refAndMasks = compute_filters_and_norm(refImg, norm_frames=norm_frames,
                                           spatial_smooth=smooth_sigma,
                                           spatial_taper=spatial_taper,
                                           block_size=block_size if nonrigid else None,
                                           subpixel=subpixel, device=device)
    blocks = refAndMasks[-3] if nZ==1 else refAndMasks[0][-3]
    rmin = refAndMasks[-2] if nZ==1 else [refAndMasks[z][-2] for z in range(nZ)]
    rmax = refAndMasks[-1] if nZ==1 else [refAndMasks[z][-1] for z in range(nZ)]
    ### ------------- register frames to reference image ------------ ###

    mean_img = np.zeros((Ly, Lx), "float32")
    if upsample_meanImg:
        if not isinstance(upsample_meanImg, (np.ndarray, list, tuple)):
            upsample_meanImg = [upsample_meanImg, upsample_meanImg]
        # MPS backend does not support float64
        ups_dtype = torch.float32 if device.type == "mps" else torch.double
        mean_img_ups = torch.zeros((int(Ly*upsample_meanImg[0]), int(Lx*upsample_meanImg[1])), dtype=ups_dtype, device=device)
        counts_ups = torch.zeros((int(Ly*upsample_meanImg[0]), int(Lx*upsample_meanImg[1])), dtype=torch.int, device=device)
    else:
        mean_img_ups, counts_ups, meanImg_ups = None, None, None

    n_batches = int(np.ceil(n_frames / batch_size))
    logger.info(f"Registering {n_frames} frames in {n_batches} batches")
    tqdm_out = TqdmToLogger(logger, level=logging.INFO)
    for n in trange(n_batches, mininterval=10, file=tqdm_out):
        tstart, tend = n * batch_size, min((n+1) * batch_size, n_frames)
        frames = f_align_in[tstart : tend]
        if device.type == "cuda":
            fr_torch = torch.from_numpy(frames).pin_memory().to(device)
        else:
            fr_torch = torch.from_numpy(frames).to(device)
        if bidiphase != 0:
            fr_torch = bidi.shift(fr_torch, bidiphase)

        fr_reg = fr_torch.clone()
        offsets = compute_shifts(refAndMasks, fr_reg, maxregshift=maxregshift, 
                                 smooth_sigma_time=smooth_sigma_time, 
                                 snr_thresh=snr_thresh, maxregshiftNR=maxregshiftNR, 
                                 nZ=nZ)
        ymax, xmax, cmax, ymax1, xmax1, cmax1, zest, cmax_all = offsets

        if apply_shifts:
            frames = shift_frames(fr_torch, ymax, xmax, ymax1, xmax1, blocks, 
                                  mean_img_ups=mean_img_ups, counts_ups=counts_ups, device=device)

        # convert to numpy and concatenate offsets
        ymax, xmax, cmax = ymax.cpu().numpy(), xmax.cpu().numpy(), cmax.cpu().numpy()
        if ymax1 is not None:
            ymax1, xmax1 = ymax1.cpu().numpy(), xmax1.cpu().numpy()
            cmax1 = cmax1.cpu().numpy()
        offsets = [ymax, xmax, cmax, ymax1, xmax1, cmax1, zest, cmax_all]
        offsets_all = ([np.concatenate((offset_all, offset), axis=0) 
                       if offset is not None else None
                       for offset_all, offset in zip(offsets_all, offsets)] 
                        if n > 0 else offsets)

        # make mean image from all registered frames
        mean_img += frames.sum(axis=0) / n_frames

        # save aligned frames to bin file
        if apply_shifts:
            if f_align_out is not None:
                f_align_out[tstart : tend] = frames
            else:
                f_align_in[tstart : tend] = frames

            # save aligned frames to tiffs
            if tif_root:
                fname = os.path.join(tif_root, f"file{n : 05d}.tif")
                save_tiff(mov=frames, fname=fname)

    if upsample_meanImg:
        # apply Gaussian smoothing and normalize by counts
        mimg = mean_img_ups.cpu().numpy()
        cimg = counts_ups.cpu().numpy()
        sig = 1 
        mimg = gaussian_filter(mimg, sig)
        cimg = gaussian_filter(cimg, sig)
        meanImg_ups = mimg / cimg

    return rmin, rmax, mean_img, offsets_all, blocks, mean_img_ups, counts_ups, meanImg_ups

registration_outputs_to_dict #

registration_outputs_to_dict(refImg, rmin, rmax, meanImg, rigid_offsets, nonrigid_offsets, zest, meanImg_chan2, badframes, badframes0, yrange, xrange, bidiphase)

Pack registration results into a dictionary.

Parameters:

Name Type Description Default
refImg ndarray

Reference image of shape (Ly, Lx).

required
rmin int16

Lower intensity clip bound.

required
rmax int16

Upper intensity clip bound.

required
meanImg ndarray

Mean registered image of shape (Ly, Lx).

required
rigid_offsets tuple

Tuple of (yoff, xoff, corrXY) rigid registration offsets.

required
nonrigid_offsets tuple

Tuple of (yoff1, xoff1, corrXY1) nonrigid offsets, elements may be None.

required
zest tuple

Tuple of (zpos, cmax_all) for multi-plane registration, elements may be None.

required
meanImg_chan2 ndarray or None

Mean image of the second channel, shape (Ly, Lx).

required
badframes ndarray

1-D boolean array of detected bad frames.

required
badframes0 ndarray

1-D boolean array of initial bad frames before registration.

required
yrange list of int

[ymin, ymax] valid row range.

required
xrange list of int

[xmin, xmax] valid column range.

required
bidiphase int

Bidirectional phase offset in pixels.

required

Returns:

Name Type Description
reg_outputs dict

Dictionary with keys "refImg", "rmin", "rmax", "yoff", "xoff", "corrXY", "meanImg", "badframes", "badframes0", "yrange", "xrange", "bidiphase", and optionally "yoff1", "xoff1", "corrXY1", "meanImg_chan2", "zpos_registration", "cmax_registration".

Source code in suite2p/registration/register.py
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
def registration_outputs_to_dict(refImg, rmin, rmax, meanImg, rigid_offsets,
                                 nonrigid_offsets, zest, meanImg_chan2,
                                 badframes, badframes0, yrange, xrange, bidiphase):
    """
    Pack registration results into a dictionary.

    Parameters
    ----------
    refImg : np.ndarray
        Reference image of shape (Ly, Lx).
    rmin : np.int16
        Lower intensity clip bound.
    rmax : np.int16
        Upper intensity clip bound.
    meanImg : np.ndarray
        Mean registered image of shape (Ly, Lx).
    rigid_offsets : tuple
        Tuple of (yoff, xoff, corrXY) rigid registration offsets.
    nonrigid_offsets : tuple
        Tuple of (yoff1, xoff1, corrXY1) nonrigid offsets, elements may be None.
    zest : tuple
        Tuple of (zpos, cmax_all) for multi-plane registration, elements may be
        None.
    meanImg_chan2 : np.ndarray or None
        Mean image of the second channel, shape (Ly, Lx).
    badframes : np.ndarray
        1-D boolean array of detected bad frames.
    badframes0 : np.ndarray
        1-D boolean array of initial bad frames before registration.
    yrange : list of int
        [ymin, ymax] valid row range.
    xrange : list of int
        [xmin, xmax] valid column range.
    bidiphase : int
        Bidirectional phase offset in pixels.

    Returns
    -------
    reg_outputs : dict
        Dictionary with keys "refImg", "rmin", "rmax", "yoff", "xoff",
        "corrXY", "meanImg", "badframes", "badframes0", "yrange", "xrange",
        "bidiphase", and optionally "yoff1", "xoff1", "corrXY1",
        "meanImg_chan2", "zpos_registration", "cmax_registration".
    """
    reg_outputs = {}
    # assign reference image and normalizers
    reg_outputs["refImg"] = refImg
    reg_outputs["rmin"], reg_outputs["rmax"] = rmin, rmax
    # assign rigid offsets to reg_outputs
    reg_outputs["yoff"], reg_outputs["xoff"], reg_outputs["corrXY"] = rigid_offsets
    # assign nonrigid offsets to reg_outputs
    if nonrigid_offsets[0] is not None:
        reg_outputs["yoff1"], reg_outputs["xoff1"], reg_outputs["corrXY1"] = nonrigid_offsets
    # assign mean images
    reg_outputs["meanImg"] = meanImg
    if meanImg_chan2 is not None:
        reg_outputs["meanImg_chan2"] = meanImg_chan2
    # assign crop computation and badframes
    reg_outputs["badframes"], reg_outputs["badframes0"] = badframes, badframes0
    reg_outputs["yrange"], reg_outputs["xrange"] = yrange, xrange
    if zest[0] is not None:
        reg_outputs["zpos_registration"] = np.array(zest[0])
        reg_outputs["cmax_registration"] = np.array(zest[1])
    reg_outputs["bidiphase"] = bidiphase
    return reg_outputs

registration_wrapper #

registration_wrapper(f_reg, f_raw=None, f_reg_chan2=None, f_raw_chan2=None, refImg=None, align_by_chan2=False, save_path=None, aspect=1.0, badframes=None, settings=default_settings(), device=torch.device('cuda'))

Main registration function for single- or dual-channel movies.

Computes a reference image (if not provided), estimates bidirectional phase offset, registers the primary channel, optionally performs a two-step registration, applies shifts to an alternate channel if present, and returns all registration outputs as a dictionary.

Parameters:

Name Type Description Default
f_reg ndarray or BinaryFile

Registered functional channel frames of shape (n_frames, Ly, Lx).

required
f_raw ndarray or BinaryFile or None

Raw functional channel frames. If provided, used as the registration input with f_reg as the output destination.

None
f_reg_chan2 ndarray or BinaryFile or None

Registered second channel frames.

None
f_raw_chan2 ndarray or BinaryFile or None

Raw second channel frames.

None
refImg ndarray or None

Reference image of shape (Ly, Lx), dtype int16. If None, a reference is computed from the data.

None
align_by_chan2 bool

If True, use the second channel as the alignment source.

False
save_path str or None

Base directory for saving registered tiff files.

None
aspect float

Pixel aspect ratio used for computing the enhanced mean image.

1.0
badframes ndarray or None

1-D boolean array of pre-existing bad frame labels. If None, initialized to all False.

None
settings dict

Registration settings dictionary from default_settings().

default_settings()
device device

Torch device for computation.

device('cuda')

Returns:

Name Type Description
reg_outputs dict

Dictionary containing registration results with keys: "refImg", "rmin", "rmax", "meanImg", "yoff", "xoff", "corrXY", "yoff1", "xoff1", "corrXY1", "meanImg_chan2", "badframes", "badframes0", "yrange", "xrange", "bidiphase", "meanImgE", and optionally "zpos_registration", "cmax_registration", "meanImg_upsample", "mean_img_ups", and "counts_ups".

Source code in suite2p/registration/register.py
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
def registration_wrapper(f_reg, f_raw=None, f_reg_chan2=None, f_raw_chan2=None,
                        refImg=None, align_by_chan2=False, save_path=None, aspect=1.,
                        badframes=None, settings=default_settings(), device=torch.device("cuda")):
    """
    Main registration function for single- or dual-channel movies.

    Computes a reference image (if not provided), estimates bidirectional phase
    offset, registers the primary channel, optionally performs a two-step
    registration, applies shifts to an alternate channel if present, and returns
    all registration outputs as a dictionary.

    Parameters
    ----------
    f_reg : np.ndarray or BinaryFile
        Registered functional channel frames of shape (n_frames, Ly, Lx).
    f_raw : np.ndarray or BinaryFile or None
        Raw functional channel frames. If provided, used as the registration
        input with f_reg as the output destination.
    f_reg_chan2 : np.ndarray or BinaryFile or None
        Registered second channel frames.
    f_raw_chan2 : np.ndarray or BinaryFile or None
        Raw second channel frames.
    refImg : np.ndarray or None
        Reference image of shape (Ly, Lx), dtype int16. If None, a reference is
        computed from the data.
    align_by_chan2 : bool
        If True, use the second channel as the alignment source.
    save_path : str or None
        Base directory for saving registered tiff files.
    aspect : float
        Pixel aspect ratio used for computing the enhanced mean image.
    badframes : np.ndarray or None
        1-D boolean array of pre-existing bad frame labels. If None, initialized
        to all False.
    settings : dict
        Registration settings dictionary from default_settings().
    device : torch.device
        Torch device for computation.

    Returns
    -------
    reg_outputs : dict
        Dictionary containing registration results with keys: "refImg", "rmin",
        "rmax", "meanImg", "yoff", "xoff", "corrXY", "yoff1", "xoff1",
        "corrXY1", "meanImg_chan2", "badframes", "badframes0", "yrange",
        "xrange", "bidiphase", "meanImgE", and optionally "zpos_registration",
        "cmax_registration", "meanImg_upsample", "mean_img_ups", 
        and "counts_ups".
    """
    out = assign_reg_io(f_reg, f_raw, f_reg_chan2, f_raw_chan2, align_by_chan2,
                        save_path, settings["reg_tif"], settings["reg_tif_chan2"])
    f_align_in, f_align_out, f_alt_in, f_alt_out, tif_root_align, tif_root_alt = out

    nchannels = 2 if f_alt_in is not None else 1
    logger.info(f"registering {nchannels} channels")
    if device.type == "mps":
        logger.warning("MPS device does not support float64, using float32 for registration. "
                       "If you encounter registration issues, try using cuda or cpu instead.")

    ### ----- compute reference image and bidiphase shift -------------- ###
    n_frames, Ly, Lx = f_align_in.shape
    badframes0 = np.zeros(n_frames, "bool") if badframes is None else badframes.copy()

    compute_bidi = settings["do_bidiphase"] and settings["bidiphase"] == 0
    # grab frames
    if refImg is None or compute_bidi:
        ix_frames = np.linspace(0, n_frames, 1 + min(settings["nimg_init"], n_frames), 
                                dtype=int)[:-1]
        frames = f_align_in[ix_frames].copy()

    # compute bidiphase shift
    if compute_bidi:
        bidiphase = bidi.compute(frames)
        logger.info("Estimated bidiphase offset from data: %d pixels" % bidiphase)
        # shift frames for reference image computation
    else:
        bidiphase = settings["bidiphase"]

    if bidiphase != 0 and refImg is None:
        frames = bidi.shift(frames, bidiphase) 

    if refImg is None:
        t0 = time.time()
        refImg = compute_reference(frames, settings=settings, device=device)
        logger.info("Reference frame, %0.2f sec." % (time.time() - t0))
    refImg_orig = refImg.copy()

    for step in range(1 + (settings["two_step_registration"] and f_raw is not None)):
        if step == 1:
            logger.info("starting step 2 of two-step registration")
            logger.info("(making new reference image without badframes)")
            nsamps = min(n_frames, settings["nimg_init"])
            inds = np.linspace(0, n_frames, 1 + nsamps).astype(np.int64)[:-1]
            inds = inds[~np.isin(inds, np.nonzero(badframes)[0])]
            refImg = f_align_out[inds].astype(np.float32).mean(axis=0)
            refImg_orig = refImg.copy()

        ### ----- register frames to reference image -------------- ###
        outputs = register_frames(f_align_in, f_align_out=f_align_out, bidiphase=bidiphase,
                                refImg=refImg, tif_root=tif_root_align,
                                batch_size=settings["batch_size"],
                                norm_frames=settings["norm_frames"], smooth_sigma=settings["smooth_sigma"],
                                spatial_taper=settings["spatial_taper"], block_size=settings["block_size"],
                                nonrigid=settings["nonrigid"],
                                maxregshift=settings["maxregshift"], smooth_sigma_time=settings["smooth_sigma_time"],
                                snr_thresh=settings["snr_thresh"], maxregshiftNR=settings["maxregshiftNR"],
                                subpixel=settings["subpixel"],
                                device=device, upsample_meanImg=settings.get("upsample_meanImg", False))
        rmin, rmax, mean_img, offsets_all, blocks, mean_img_ups, counts_ups, meanImg_ups = outputs
        yoff, xoff, corrXY, yoff1, xoff1, corrXY1, zest, cmax_all = offsets_all

        # compute valid region and timepoints to exclude
        badframes, yrange, xrange = compute_crop(xoff=xoff, yoff=yoff, corrXY=corrXY,
                                             th_badframes=settings["th_badframes"],
                                             badframes=badframes0.copy(),
                                             maxregshift=settings["maxregshift"], Ly=Ly,
                                             Lx=Lx)

    ### ----- register second channel -------------- ###
    if nchannels > 1:
        mean_img_alt = shift_frames_and_write(f_alt_in, f_alt_out, settings["batch_size"], yoff, xoff, yoff1,
                                              xoff1, blocks=blocks, bidiphase=bidiphase,
                                              tif_root=tif_root_align, device=device)
    else:
        mean_img_alt = None

    if device.type == "cuda":
        torch.cuda.empty_cache()

    if device.type == "mps":
        torch.mps.empty_cache()

    meanImg = mean_img if nchannels == 1 or not align_by_chan2 else mean_img_alt
    if nchannels == 2:
        meanImg_chan2 = mean_img_alt if not align_by_chan2 else mean_img
    else:
        meanImg_chan2 = None

    reg_outputs = registration_outputs_to_dict(refImg_orig, rmin, rmax, meanImg,
                                               (yoff, xoff, corrXY),
                                               (yoff1, xoff1, corrXY1),
                                               (zest, cmax_all), meanImg_chan2,
                                               badframes, badframes0,
                                               yrange, xrange, bidiphase, 
                                               )

    # add enhanced mean image
    meanImgE = utils.highpass_mean_image(meanImg.astype("float32"), aspect=aspect)
    reg_outputs["meanImgE"] = meanImgE

    # add upsampled mean image if computed
    if mean_img_ups is not None and counts_ups is not None:
        reg_outputs["meanImg_upsample"] = meanImg_ups
        reg_outputs["mean_img_ups"] = mean_img_ups.cpu().numpy()
        reg_outputs["counts_ups"] = counts_ups.cpu().numpy()

    return reg_outputs

save_tiff #

save_tiff(mov, fname)

Save image stack array to a tiff file.

Parameters:

Name Type Description Default
mov ndarray

Image stack of shape (nimg, Ly, Lx) to save. Values are floored and cast to int16 before writing.

required
fname str

Output tiff file path.

required
Source code in suite2p/registration/register.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def save_tiff(mov: np.ndarray, fname: str) -> None:
    """
    Save image stack array to a tiff file.

    Parameters
    ----------
    mov : np.ndarray
        Image stack of shape (nimg, Ly, Lx) to save. Values are floored and
        cast to int16 before writing.
    fname : str
        Output tiff file path.
    """
    from tifffile import TiffWriter
    with TiffWriter(fname) as tif:
        for frame in np.floor(mov).astype(np.int16):
            tif.write(frame, contiguous=True)

shift_frames #

shift_frames(fr_torch, yoff, xoff, yoff1=None, xoff1=None, blocks=None, mean_img_ups=None, counts_ups=None, device=torch.device('cuda'))

Apply rigid and optionally nonrigid shifts to frames and return as numpy int16.

Parameters:

Name Type Description Default
fr_torch Tensor

Frames to shift, shape (N, Ly, Lx).

required
yoff LongTensor

1-D rigid y shifts of length N.

required
xoff LongTensor

1-D rigid x shifts of length N.

required
yoff1 Tensor or ndarray or None

Nonrigid y shifts of shape (N, n_blocks). If None, only rigid shifts are applied.

None
xoff1 Tensor or ndarray or None

Nonrigid x shifts of shape (N, n_blocks).

None
blocks list or None

Block definitions from nonrigid.make_blocks, used for nonrigid interpolation.

None
device device

Torch device for nonrigid shift tensors.

device('cuda')

Returns:

Name Type Description
frames_out ndarray

Shifted frames of shape (N, Ly, Lx), dtype matching the torch output.

Source code in suite2p/registration/register.py
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
def shift_frames(fr_torch, yoff, xoff, yoff1=None, xoff1=None, blocks=None, 
                 mean_img_ups=None, counts_ups=None, device=torch.device("cuda")):
    """
    Apply rigid and optionally nonrigid shifts to frames and return as numpy int16.

    Parameters
    ----------
    fr_torch : torch.Tensor
        Frames to shift, shape (N, Ly, Lx).
    yoff : torch.LongTensor
        1-D rigid y shifts of length N.
    xoff : torch.LongTensor
        1-D rigid x shifts of length N.
    yoff1 : torch.Tensor or np.ndarray or None
        Nonrigid y shifts of shape (N, n_blocks). If None, only rigid shifts are
        applied.
    xoff1 : torch.Tensor or np.ndarray or None
        Nonrigid x shifts of shape (N, n_blocks).
    blocks : list or None
        Block definitions from nonrigid.make_blocks, used for nonrigid
        interpolation.
    device : torch.device
        Torch device for nonrigid shift tensors.

    Returns
    -------
    frames_out : np.ndarray
        Shifted frames of shape (N, Ly, Lx), dtype matching the torch output.
    """
    fr_torch = torch.stack([torch.roll(frame, shifts=(-dy, -dx), dims=(0, 1))
                               for frame, dy, dx in zip(fr_torch, yoff, xoff)], axis=0)

    if yoff1 is not None:
        if isinstance(yoff1, np.ndarray):
            if fr_torch.device.type == "cuda":
                yoff1 = torch.from_numpy(yoff1).pin_memory().to(device)
                xoff1 = torch.from_numpy(xoff1).pin_memory().to(device)
            elif device.type == "mps":
                # MPS backend does not support float64
                yoff1 = torch.from_numpy(yoff1).to(torch.float32).to(device)
                xoff1 = torch.from_numpy(xoff1).to(torch.float32).to(device)
            else:
                yoff1 = torch.from_numpy(yoff1).to(device)
                xoff1 = torch.from_numpy(xoff1).to(device)

        fr_torch = nonrigid.transform_data(fr_torch, blocks[2], blocks[1], blocks[0], 
                                           yoff1, xoff1, data_ups=mean_img_ups, counts_ups=counts_ups)

    frames_out = np.empty(fr_torch.shape, dtype="int16")
    frames_out = fr_torch.cpu().numpy()

    return frames_out

shift_frames_and_write #

shift_frames_and_write(f_alt_in, f_alt_out=None, batch_size=100, yoff=None, xoff=None, yoff1=None, xoff1=None, blocks=None, bidiphase=0, device=torch.device('cuda'), tif_root=None)

Apply pre-computed registration shifts to an alternate channel and write results.

Applies rigid (and optionally nonrigid) shifts that were computed on the primary channel to the alternate channel frames, in batches. Writes the shifted frames to f_alt_out if provided, otherwise overwrites f_alt_in.

Parameters:

Name Type Description Default
f_alt_in ndarray or BinaryFile

Alternate channel input frames of shape (n_frames, Ly, Lx).

required
f_alt_out ndarray or BinaryFile or None

Output array for shifted frames. If None, writes back to f_alt_in.

None
batch_size int

Number of frames per batch.

100
yoff ndarray

Rigid y offsets of length n_frames.

None
xoff ndarray

Rigid x offsets of length n_frames.

None
yoff1 ndarray or None

Nonrigid y offsets of shape (n_frames, n_blocks).

None
xoff1 ndarray or None

Nonrigid x offsets of shape (n_frames, n_blocks).

None
blocks list or None

Block definitions from nonrigid.make_blocks.

None
bidiphase int

Bidirectional phase offset in pixels.

0
device device

Torch device for computation.

device('cuda')
tif_root str or None

If provided, save shifted frames as tiffs in this directory.

None

Returns:

Name Type Description
mean_img ndarray

Mean image of the shifted alternate channel, shape (Ly, Lx).

Source code in suite2p/registration/register.py
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
def shift_frames_and_write(f_alt_in, f_alt_out=None, batch_size=100, yoff=None, xoff=None, yoff1=None,
                           xoff1=None, blocks=None, bidiphase=0, 
                           device=torch.device("cuda"), tif_root=None):
    """
    Apply pre-computed registration shifts to an alternate channel and write results.

    Applies rigid (and optionally nonrigid) shifts that were computed on the
    primary channel to the alternate channel frames, in batches. Writes the
    shifted frames to f_alt_out if provided, otherwise overwrites f_alt_in.

    Parameters
    ----------
    f_alt_in : np.ndarray or BinaryFile
        Alternate channel input frames of shape (n_frames, Ly, Lx).
    f_alt_out : np.ndarray or BinaryFile or None
        Output array for shifted frames. If None, writes back to f_alt_in.
    batch_size : int
        Number of frames per batch.
    yoff : np.ndarray
        Rigid y offsets of length n_frames.
    xoff : np.ndarray
        Rigid x offsets of length n_frames.
    yoff1 : np.ndarray or None
        Nonrigid y offsets of shape (n_frames, n_blocks).
    xoff1 : np.ndarray or None
        Nonrigid x offsets of shape (n_frames, n_blocks).
    blocks : list or None
        Block definitions from nonrigid.make_blocks.
    bidiphase : int
        Bidirectional phase offset in pixels.
    device : torch.device
        Torch device for computation.
    tif_root : str or None
        If provided, save shifted frames as tiffs in this directory.

    Returns
    -------
    mean_img : np.ndarray
        Mean image of the shifted alternate channel, shape (Ly, Lx).
    """
    n_frames, Ly, Lx = f_alt_in.shape
    check_offsets(yoff, xoff, yoff1, xoff1, n_frames)

    mean_img = np.zeros((Ly, Lx), "float32")
    yoff1k, xoff1k = None, None
    n_batches = int(np.ceil(n_frames / batch_size))
    logger.info(f"Second channel: Shifting {n_frames} frames in {n_batches} batches")
    tqdm_out = TqdmToLogger(logger, level=logging.INFO)
    for n in trange(n_batches, mininterval=10, file=tqdm_out):
        tstart, tend = n * batch_size, min((n+1) * batch_size, n_frames)
        frames = f_alt_in[tstart : tend]
        yoffk, xoffk = yoff[tstart : tend].astype(int), xoff[tstart : tend].astype(int)
        if yoff1 is not None:
            yoff1k, xoff1k = yoff1[tstart : tend], xoff1[tstart : tend]

        if device.type == "cuda":
            fr_torch = torch.from_numpy(frames).pin_memory().to(device)
        else:
            fr_torch = torch.from_numpy(frames).to(device)

        if bidiphase != 0:
            fr_torch = bidi.shift(fr_torch, bidiphase)
        frames = shift_frames(fr_torch, yoffk, xoffk, yoff1k, xoff1k, blocks, device=device)
        mean_img += frames.sum(axis=0) / n_frames

        if f_alt_out is None:
            f_alt_in[tstart : tend] = frames
        else:
            f_alt_out[tstart : tend] = frames

        # save aligned frames to tiffs
        if tif_root:
            fname = os.path.join(tif_root, f"file{n : 05d}.tif")
            save_tiff(mov=frames, fname=fname)

    return mean_img

Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.

compute_masks_ref_smooth_fft #

compute_masks_ref_smooth_fft(refImg, maskSlope, smooth_sigma)

Compute multiplicative and additive masks used for spatial tapering in rigid registration, and smooth with Gaussian.

Parameters:

Name Type Description Default
refImg Tensor

2D reference image of shape (Ly, Lx).

required
maskSlope float

Scalar parameter controlling the slope of the sigmoid of the spatial taper. Higher values increase tapered region size.

required
smooth_sigma float

Standard deviation (in pixels) of the Gaussian smoothing applied to each block. Smoothing is performed in the frequency domain (via ref_smooth_fft). Typical values are >= 0. A value of 0 should behave as no smoothing (identity).

required

Returns:

Name Type Description
maskMul Tensor

Floating-point multiplicative mask of shape (Ly, Lx), intended to smoothly reduce the influence of border pixels during registration.

maskOffset Tensor

Floating-point additive offset mask of shape (Ly, Lx), computed as mean(refImg) * (1.0 - maskMul), setting the border pixels to the mean.

Source code in suite2p/registration/rigid.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def compute_masks_ref_smooth_fft(refImg, maskSlope, smooth_sigma):
    """
    Compute multiplicative and additive masks used for spatial tapering in rigid registration, 
    and smooth with Gaussian.

    Parameters
    ----------
    refImg : torch.Tensor
        2D reference image of shape (Ly, Lx).
    maskSlope : float
        Scalar parameter controlling the slope of the sigmoid of the spatial taper. Higher
        values increase tapered region size.
    smooth_sigma : float
        Standard deviation (in pixels) of the Gaussian smoothing applied to each
        block. Smoothing is performed in the frequency domain (via ref_smooth_fft). Typical values are >= 0. A value of 0 should behave as no
        smoothing (identity).

    Returns
    -------
    maskMul : torch.Tensor
        Floating-point multiplicative mask of shape (Ly, Lx), intended to smoothly
        reduce the influence of border pixels during registration.
    maskOffset : torch.Tensor
        Floating-point additive offset mask of shape (Ly, Lx), computed as
        mean(refImg) * (1.0 - maskMul), setting the border pixels to the mean.
    """
    Ly, Lx = refImg.shape
    maskMul = spatial_taper(maskSlope, Ly, Lx)
    maskOffset = refImg.float().mean() * (1. - maskMul)
    cfRefImg = ref_smooth_fft(refImg=refImg, smooth_sigma=smooth_sigma)
    return maskMul, maskOffset, cfRefImg

phasecorr #

phasecorr(frames, cfRefImg, maskMul, maskOffset, maxregshift, smooth_sigma_time, return_cc=False)

Compute rigid-registration shifts using phase correlation with an optional temporal smoothing. This function performs a Fourier-domain phase-correlation based registration between each frame in frames and a provided (complex) reference image cfRefImg. It computes the integer pixel shifts (y, x) that maximize the phase-correlation within a limited search window, defined by maxregshift.

Parameters:

Name Type Description Default
frames Tensor

Input image sequence, expected shape (N, Ly, Lx) where N is the number of frames. The tensor may be on CPU or CUDA; it is converted to float and then to complex for the Fourier-domain operations performed by the helper convolve.

required
cfRefImg Tensor

Complex-valued reference of shape (Ly, Lx) in the Fourier domain used to compute cross-correlation with each frame

required
maskMul Tensor

Multiplicative mask applied to frames before correlation. Broadcasted over frames.

required
maskOffset Tensor

Additive offset applied after maskMul. Broadcasted over frames.

required
maxregshift float

Maximum allowed registration shift expressed as a fraction of the smaller spatial image dimension. The actual integer search half-window lcorr is computed as min(round(maxregshift * min(Ly, Lx)), floor(min(Ly, Lx) / 2)).

required
smooth_sigma_time float

If > 0, applies temporal smoothing (via helper temporal_smooth) to the phase-correlation maps along the time axis with this sigma before finding maxima. If <= 0, no temporal smoothing is used.

required
return_cc bool, optional (default False)

If True, return the computed local phase-correlation maps as a NumPy array on CPU; otherwise the correlation maps are freed to save memory and None is returned in their place.

False

Returns:

Name Type Description
ymax LongTensor

1-D integer tensor of length N with the y (row) shift for each frame that maximizes the phase-correlation.

xmax LongTensor

1-D integer tensor of length N with the x (column) shift for each frame that maximizes the phase-correlation.

cmax Tensor

1-D tensor of length N containing the maximum phase-correlation value found for each frame.

cc ndarray or None

If return_cc is True, a NumPy array of shape (N, 2lcorr+1, 2lcorr+1) with the real-valued local phase-correlation maps (dtype float32) is returned. If return_cc is False, cc is None.

Source code in suite2p/registration/rigid.py
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def phasecorr(frames, cfRefImg, maskMul, maskOffset, maxregshift, smooth_sigma_time, 
              return_cc=False):
    """
    Compute rigid-registration shifts using phase correlation with an optional temporal smoothing.
    This function performs a Fourier-domain phase-correlation based registration between each frame in
    `frames` and a provided (complex) reference image `cfRefImg`. It computes the integer pixel shifts
    (y, x) that maximize the phase-correlation within a limited search window, defined by `maxregshift`.

    Parameters
    ----------
    frames : torch.Tensor
        Input image sequence, expected shape (N, Ly, Lx) where N is the number of frames.
        The tensor may be on CPU or CUDA; it is converted to float and then to complex for the
        Fourier-domain operations performed by the helper `convolve`.
    cfRefImg : torch.Tensor
        Complex-valued reference of shape (Ly, Lx) in the Fourier domain used to compute 
        cross-correlation with each frame
    maskMul : torch.Tensor
        Multiplicative mask applied to `frames` before correlation. Broadcasted over frames.
    maskOffset : torch.Tensor
        Additive offset applied after `maskMul`. Broadcasted over frames.
    maxregshift : float
        Maximum allowed registration shift expressed as a fraction of the smaller spatial image
        dimension. The actual integer search half-window `lcorr` is computed as
        min(round(maxregshift * min(Ly, Lx)), floor(min(Ly, Lx) / 2)).
    smooth_sigma_time : float
        If > 0, applies temporal smoothing (via helper `temporal_smooth`) to the phase-correlation maps
        along the time axis with this sigma before finding maxima. If <= 0, no temporal smoothing is used.
    return_cc : bool, optional (default False)
        If True, return the computed local phase-correlation maps as a NumPy array on CPU;
        otherwise the correlation maps are freed to save memory and None is returned in their place.

    Returns
    -------
    ymax : torch.LongTensor
        1-D integer tensor of length N with the y (row) shift for each frame that maximizes the
        phase-correlation. 
    xmax : torch.LongTensor
        1-D integer tensor of length N with the x (column) shift for each frame that maximizes the
        phase-correlation. 
    cmax : torch.Tensor
        1-D tensor of length N containing the maximum phase-correlation value found for each frame.
    cc : numpy.ndarray or None
        If `return_cc` is True, a NumPy array of shape (N, 2*lcorr+1, 2*lcorr+1) with the real-valued
        local phase-correlation maps (dtype float32) is returned. If `return_cc` is False, cc is None.
    """

    device = frames.device
    data = (frames.float() * maskMul + maskOffset).type(torch.complex64)
    min_dim = min(data.shape[1], data.shape[2])  # maximum registration shift allowed
    lcorr = int(np.minimum(np.round(maxregshift * min_dim), min_dim // 2))

    data = convolve(data, cfRefImg)
    cc = torch.cat((torch.cat((data[:, -lcorr:, -lcorr:], data[:, -lcorr:, :lcorr + 1]), axis=2),   
                    torch.cat((data[:, :lcorr + 1, -lcorr:], data[:, :lcorr + 1, :lcorr + 1]), axis=2)), axis=1)
    cc = torch.real(cc)

    cc = temporal_smooth(cc, smooth_sigma_time) if smooth_sigma_time > 0 else cc

    imax = torch.stack([torch.argmax(cc[t]) for t in range(data.shape[0])], dim=0)
    ymax, xmax = torch.div(imax, 2 * lcorr + 1, rounding_mode="floor"), imax % (2 * lcorr + 1)
    cmax = cc[torch.arange(len(cc)), ymax, xmax]
    ymax, xmax = ymax - lcorr, xmax - lcorr

    del data
    if return_cc: 
        cc = cc.cpu().numpy()
    else:
        del cc
        cc = None
    if device.type == "cuda":
       torch.cuda.empty_cache()

    return ymax, xmax, cmax, cc

Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.

complex_fft2 #

complex_fft2(img)

Compute the complex conjugate of the 2D FFT of an image.

Parameters:

Name Type Description Default
img Tensor

2D input image of shape (Ly, Lx).

required

Returns:

Name Type Description
cfImg Tensor

Complex conjugate of the 2D FFT, shape (Ly, Lx).

Source code in suite2p/registration/utils.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def complex_fft2(img):
    """
    Compute the complex conjugate of the 2D FFT of an image.

    Parameters
    ----------
    img : torch.Tensor
        2D input image of shape (Ly, Lx).

    Returns
    -------
    cfImg : torch.Tensor
        Complex conjugate of the 2D FFT, shape (Ly, Lx).
    """
    return torch.conj(fft2(img))

convolve #

convolve(mov, img)

Convolve a 3D frame sequence by a 2D image in the Fourier domain using phase-correlation.

Applies FFT to each frame, normalizes by magnitude, multiplies by img, and returns the inverse FFT (real part).

Parameters:

Name Type Description Default
mov Tensor

Input frames of shape (nImg, Ly, Lx).

required
img Tensor

2D complex-valued convolution kernel of shape (Ly, Lx), typically a conjugate FFT of a reference image.

required

Returns:

Name Type Description
convolved_data Tensor

Real-valued convolution result of shape (nImg, Ly, Lx).

Source code in suite2p/registration/utils.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def convolve(mov: np.ndarray, img: np.ndarray) -> np.ndarray:
    """
    Convolve a 3D frame sequence by a 2D image in the Fourier domain using phase-correlation.

    Applies FFT to each frame, normalizes by magnitude, multiplies by `img`, and returns the
    inverse FFT (real part).

    Parameters
    ----------
    mov : torch.Tensor
        Input frames of shape (nImg, Ly, Lx).
    img : torch.Tensor
        2D complex-valued convolution kernel of shape (Ly, Lx), typically a conjugate FFT
        of a reference image.

    Returns
    -------
    convolved_data : torch.Tensor
        Real-valued convolution result of shape (nImg, Ly, Lx).
    """
    mov = fft2(mov)
    mov /= (eps + torch.abs(mov))
    mov *= img
    mov = torch.real(ifft2(mov))
    return mov

gaussian_fft #

gaussian_fft(sig, Ly, Lx)

Compute the real-valued FFT of a 2D isotropic Gaussian kernel for smoothing.

Parameters:

Name Type Description Default
sig float

Standard deviation (in pixels) of the isotropic Gaussian kernel.

required
Ly int

Number of pixels along the y-axis.

required
Lx int

Number of pixels along the x-axis.

required

Returns:

Name Type Description
gaussian_fft Tensor

Real-valued 2D FFT of the Gaussian kernel, shape (Ly, Lx).

Source code in suite2p/registration/utils.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def gaussian_fft(sig, Ly: int, Lx: int):
    """
    Compute the real-valued FFT of a 2D isotropic Gaussian kernel for smoothing.

    Parameters
    ----------
    sig : float
        Standard deviation (in pixels) of the isotropic Gaussian kernel.
    Ly : int
        Number of pixels along the y-axis.
    Lx : int
        Number of pixels along the x-axis.

    Returns
    -------
    gaussian_fft : torch.Tensor
        Real-valued 2D FFT of the Gaussian kernel, shape (Ly, Lx).
    """
    kernel = gaussian_kernel(sig, sig, Ly, Lx)
    return torch.real(fft2(ifftshift(kernel)))

gaussian_kernel #

gaussian_kernel(sigma_y, sigma_x, Ly, Lx, device=torch.device('cpu'))

Generate a normalized 2D Gaussian kernel.

Parameters:

Name Type Description Default
sigma_y float

Standard deviation of the Gaussian along the y-axis.

required
sigma_x float

Standard deviation of the Gaussian along the x-axis.

required
Ly int

Number of pixels along the y-axis.

required
Lx int

Number of pixels along the x-axis.

required
device torch.device, optional (default torch.device("cpu"))

Device on which to create the kernel tensor.

device('cpu')

Returns:

Name Type Description
kernel Tensor

Normalized 2D Gaussian kernel of shape (Ly, Lx), summing to 1.0.

Source code in suite2p/registration/utils.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def gaussian_kernel(sigma_y, sigma_x, Ly, Lx, device=torch.device("cpu")):
    """
    Generate a normalized 2D Gaussian kernel.

    Parameters
    ----------
    sigma_y : float
        Standard deviation of the Gaussian along the y-axis.
    sigma_x : float
        Standard deviation of the Gaussian along the x-axis.
    Ly : int
        Number of pixels along the y-axis.
    Lx : int
        Number of pixels along the x-axis.
    device : torch.device, optional (default torch.device("cpu"))
        Device on which to create the kernel tensor.

    Returns
    -------
    kernel : torch.Tensor
        Normalized 2D Gaussian kernel of shape (Ly, Lx), summing to 1.0.
    """
    y = torch.arange(0, Ly, device=device, dtype=torch.float)
    y -= y.mean()
    x = torch.arange(0, Lx, device=device, dtype=torch.float)
    x -= x.mean()
    ky = torch.exp(-y**2 / (2 * sigma_y**2)) 
    kx = torch.exp(-x**2 / (2 * sigma_x**2))
    kernel = ky[:, None] * kx
    kernel /= kernel.sum()
    return kernel

highpass_mean_image #

highpass_mean_image(I, aspect=1.0)

Compute an enhanced mean image by applying a high-pass Gaussian filter.

Subtracts low-frequency content using a Gaussian kernel (sigma=3 in each axis, scaled by aspect in y), then rescales the result to [0, 1] using the 1st and 99th percentiles.

Parameters:

Name Type Description Default
I ndarray

2D mean image of shape (Ly, Lx).

required
aspect float, optional (default 1.0)

Aspect ratio correction factor. Values != 1.0 scale the Gaussian sigma along the y-axis by this factor.

1.0

Returns:

Name Type Description
img_filt ndarray

High-pass filtered image of shape (Ly, Lx), clipped to [0, 1].

Source code in suite2p/registration/utils.py
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
def highpass_mean_image(I, aspect=1.):
    """
    Compute an enhanced mean image by applying a high-pass Gaussian filter.

    Subtracts low-frequency content using a Gaussian kernel (sigma=3 in each axis,
    scaled by `aspect` in y), then rescales the result to [0, 1] using the 1st and
    99th percentiles.

    Parameters
    ----------
    I : numpy.ndarray
        2D mean image of shape (Ly, Lx).
    aspect : float, optional (default 1.0)
        Aspect ratio correction factor. Values != 1.0 scale the Gaussian sigma along
        the y-axis by this factor.

    Returns
    -------
    img_filt : numpy.ndarray
        High-pass filtered image of shape (Ly, Lx), clipped to [0, 1].
    """
    Ly, Lx = I.shape
    img_filt = cv2.resize(I, (Lx, int(np.round(Ly * aspect))))
    img_filt = transforms.normalize_img(img_filt[..., np.newaxis], sharpen_radius=3)
    img_filt = cv2.resize(img_filt.squeeze(), (Lx, Ly))
    img_filt = np.clip(img_filt, 0, 1)

    return img_filt

kernelD #

kernelD(xs, ys, sigL=0.85)

Compute a Gaussian interpolation kernel between two sets of 2D grid coordinates.

Builds a kernel matrix K where K[i, j] = exp(-d^2 / (2 * sigL^2)) with d being the Euclidean distance between the i-th point in xs x xs and the j-th point in ys x ys. Used for sub-pixel up-sampling in registration.

Parameters:

Name Type Description Default
xs Tensor

1D tensor of grid coordinates for the source points.

required
ys Tensor

1D tensor of grid coordinates for the target points.

required
sigL float, optional (default 0.85)

Smoothing width of the Gaussian kernel. Best results between 0.5 and 1.0.

0.85

Returns:

Name Type Description
K Tensor

Gaussian kernel matrix of shape (len(xs)2, len(ys)2).

Source code in suite2p/registration/utils.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def kernelD(xs: np.ndarray, ys: np.ndarray, sigL: float = 0.85) -> np.ndarray:
    """
    Compute a Gaussian interpolation kernel between two sets of 2D grid coordinates.

    Builds a kernel matrix K where K[i, j] = exp(-d^2 / (2 * sigL^2)) with d being the
    Euclidean distance between the i-th point in `xs` x `xs` and the j-th point in
    `ys` x `ys`. Used for sub-pixel up-sampling in registration.

    Parameters
    ----------
    xs : torch.Tensor
        1D tensor of grid coordinates for the source points.
    ys : torch.Tensor
        1D tensor of grid coordinates for the target points.
    sigL : float, optional (default 0.85)
        Smoothing width of the Gaussian kernel. Best results between 0.5 and 1.0.

    Returns
    -------
    K : torch.Tensor
        Gaussian kernel matrix of shape (len(xs)**2, len(ys)**2).
    """
    xs0, xs1 = torch.meshgrid(xs, xs, indexing="ij")
    ys0, ys1 = torch.meshgrid(ys, ys, indexing="ij")
    dxs = xs0.reshape(-1, 1) - ys0.reshape(1, -1)
    dys = xs1.reshape(-1, 1) - ys1.reshape(1, -1)
    K = torch.exp(-(dxs**2 + dys**2) / (2 * sigL**2))
    return K

kernelD2 #

kernelD2(xs, ys)

Compute a normalized Gaussian kernel matrix from two 1D coordinate tensors.

Builds a 2D meshgrid from xs and ys, computes pairwise Gaussian distances between all flattened grid points, and row-normalizes the result. It is used for smoothing phase-correlation maps across blocks.

Parameters:

Name Type Description Default
xs Tensor

1D tensor of grid coordinates along one axis.

required
ys Tensor

1D tensor of grid coordinates along the other axis.

required

Returns:

Name Type Description
R Tensor

Row-normalized Gaussian kernel matrix of shape (N, N) where N = len(xs) * len(ys).

Source code in suite2p/registration/utils.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
def kernelD2(xs: int, ys: int) -> np.ndarray:
    """
    Compute a normalized Gaussian kernel matrix from two 1D coordinate tensors.

    Builds a 2D meshgrid from `xs` and `ys`, computes pairwise Gaussian distances
    between all flattened grid points, and row-normalizes the result. It is used 
    for smoothing phase-correlation maps across blocks.

    Parameters
    ----------
    xs : torch.Tensor
        1D tensor of grid coordinates along one axis.
    ys : torch.Tensor
        1D tensor of grid coordinates along the other axis.

    Returns
    -------
    R : torch.Tensor
        Row-normalized Gaussian kernel matrix of shape (N, N) where N = len(xs) * len(ys).
    """
    ys, xs = torch.meshgrid(xs, ys, indexing="ij")
    ys = ys.flatten().reshape(1, -1)
    xs = xs.flatten().reshape(1, -1)
    R = torch.exp(-((ys - ys.T)**2 + (xs - xs.T)**2))
    R = R / torch.sum(R, axis=0)
    return R

mat_upsample #

mat_upsample(lpad, subpixel=10, device=torch.device('cpu'))

Build an interpolation matrix for sub-pixel upsampling of correlation peaks.

Constructs a Gaussian interpolation matrix (Kmat) that maps from the original integer grid of size (2*lpad+1) to a finer grid with spacing 1/subpixel, by solving a linear system using kernelD.

Parameters:

Name Type Description Default
lpad int

Half-width of the integer grid. The grid spans from -lpad to +lpad.

required
subpixel int, optional (default 10)

Up-sampling factor. The output grid has spacing 1/subpixel.

10
device torch.device, optional (default torch.device("cpu"))

Device on which to create the grid tensors.

device('cpu')

Returns:

Name Type Description
Kmat Tensor

Interpolation matrix of shape ((2lpad+1)2, nup*2) mapping the original grid to the up-sampled grid.

nup int

Number of points along one axis of the up-sampled grid.

Source code in suite2p/registration/utils.py
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
def mat_upsample(lpad: int, subpixel: int = 10, device=torch.device("cpu")):
    """
    Build an interpolation matrix for sub-pixel upsampling of correlation peaks.

    Constructs a Gaussian interpolation matrix (`Kmat`) that maps from the original
    integer grid of size (2*lpad+1) to a finer grid with spacing 1/subpixel, by solving
    a linear system using `kernelD`.

    Parameters
    ----------
    lpad : int
        Half-width of the integer grid. The grid spans from -lpad to +lpad.
    subpixel : int, optional (default 10)
        Up-sampling factor. The output grid has spacing 1/subpixel.
    device : torch.device, optional (default torch.device("cpu"))
        Device on which to create the grid tensors.

    Returns
    -------
    Kmat : torch.Tensor
        Interpolation matrix of shape ((2*lpad+1)**2, nup**2) mapping the original grid
        to the up-sampled grid.
    nup : int
        Number of points along one axis of the up-sampled grid.
    """
    xs = torch.arange(-lpad, lpad + 1, device=device)
    xs_up = torch.arange(-lpad, lpad + .001, 1. / subpixel, device=device)
    kernel0 = kernelD(xs, xs)
    kernel_up = kernelD(xs, xs_up) 
    Kmat = torch.linalg.solve(kernel0, kernel_up)
    nup = len(xs_up)
    return Kmat, nup

ref_smooth_fft #

ref_smooth_fft(refImg, smooth_sigma=None)

Compute the smoothed, normalized complex-conjugate FFT of a reference image for phase-correlation.

Takes the 2D FFT complex conjugate of refImg, whitens, and multiplies by a Gaussian filter in the frequency domain with standard deviation smooth_sigma.

Parameters:

Name Type Description Default
refImg Tensor

2D reference image of shape (Ly, Lx).

required
smooth_sigma float

Standard deviation (in pixels) of the Gaussian smoothing applied in the frequency domain. If None, no smoothing is applied.

None

Returns:

Name Type Description
cfRefImg Tensor

Complex64 tensor of shape (Ly, Lx) containing the smoothed, whitened complex-conjugate FFT of the reference image.

Source code in suite2p/registration/utils.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
def ref_smooth_fft(refImg: np.ndarray, smooth_sigma=None) -> np.ndarray:
    """
    Compute the smoothed, normalized complex-conjugate FFT of a reference image for phase-correlation.

    Takes the 2D FFT complex conjugate of `refImg`, whitens, and multiplies
    by a Gaussian filter in the frequency domain with standard deviation `smooth_sigma`.

    Parameters
    ----------
    refImg : torch.Tensor
        2D reference image of shape (Ly, Lx).
    smooth_sigma : float, optional
        Standard deviation (in pixels) of the Gaussian smoothing applied in the frequency
        domain. If None, no smoothing is applied.

    Returns
    -------
    cfRefImg : torch.Tensor
        Complex64 tensor of shape (Ly, Lx) containing the smoothed, whitened
        complex-conjugate FFT of the reference image.
    """
    cfRefImg = complex_fft2(img=refImg)
    cfRefImg /= (1e-5 + torch.abs(cfRefImg))
    if smooth_sigma is not None:
        cfRefImg *= gaussian_fft(smooth_sigma, cfRefImg.shape[0], cfRefImg.shape[1])
    return cfRefImg.type(torch.complex64)

spatial_taper #

spatial_taper(sig, Ly, Lx)

Compute a spatial taper mask using a sigmoid function on the image edges.

Creates a 2D multiplicative mask that smoothly reduces values near the borders, controlled by a Gaussian-like sigmoid with standard deviation sig.

Parameters:

Name Type Description Default
sig float

Scalar parameter controlling the slope of the sigmoid taper. Higher values increase the size of the tapered border region.

required
Ly int

Frame height in pixels.

required
Lx int

Frame width in pixels.

required

Returns:

Name Type Description
maskMul Tensor

Floating-point multiplicative mask of shape (Ly, Lx), with values near 1.0 in the center and smoothly decaying to 0.0 at the edges.

Source code in suite2p/registration/utils.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def spatial_taper(sig, Ly, Lx):
    """
    Compute a spatial taper mask using a sigmoid function on the image edges.

    Creates a 2D multiplicative mask that smoothly reduces values near the borders,
    controlled by a Gaussian-like sigmoid with standard deviation `sig`.

    Parameters
    ----------
    sig : float
        Scalar parameter controlling the slope of the sigmoid taper. Higher values
        increase the size of the tapered border region.
    Ly : int
        Frame height in pixels.
    Lx : int
        Frame width in pixels.

    Returns
    -------
    maskMul : torch.Tensor
        Floating-point multiplicative mask of shape (Ly, Lx), with values near 1.0
        in the center and smoothly decaying to 0.0 at the edges.
    """
    y = torch.arange(0, Ly, dtype=torch.double)
    x = torch.arange(0, Lx, dtype=torch.double)
    x = (x - x.mean()).abs()
    y = (y - y.mean()).abs()
    mY = ((Ly - 1) / 2) - 2 * sig
    mX = ((Lx - 1) / 2) - 2 * sig
    maskY = 1. / (1. + torch.exp((y - mY) / sig))
    maskX = 1. / (1. + torch.exp((x - mX) / sig))
    maskMul = maskY[:,None] * maskX
    return maskMul

temporal_smooth #

temporal_smooth(data, sigma)

Apply 1D Gaussian smoothing along the time (first) axis of a 3D array.

TODO: convert to torch

Parameters:

Name Type Description Default
data ndarray

Input data of shape (nimg, Ly, Lx) to be smoothed along axis 0.

required
sigma float

Standard deviation of the Gaussian kernel used for temporal smoothing.

required

Returns:

Name Type Description
smoothed_data ndarray

Temporally smoothed data of shape (nimg, Ly, Lx).

Source code in suite2p/registration/utils.py
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def temporal_smooth(data: np.ndarray, sigma: float) -> np.ndarray:
    """
    Apply 1D Gaussian smoothing along the time (first) axis of a 3D array.

    TODO: convert to torch

    Parameters
    ----------
    data : numpy.ndarray
        Input data of shape (nimg, Ly, Lx) to be smoothed along axis 0.
    sigma : float
        Standard deviation of the Gaussian kernel used for temporal smoothing.

    Returns
    -------
    smoothed_data : numpy.ndarray
        Temporally smoothed data of shape (nimg, Ly, Lx).
    """
    return gaussian_filter1d(data, sigma=sigma, axis=0)

Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.

compute_zpos #

compute_zpos()

Compute z-position estimates from registered frames.

Returns:

Type Description
None

Not yet implemented.

Source code in suite2p/registration/zalign.py
60
61
62
63
64
65
66
67
68
69
def compute_zpos():
    """
    Compute z-position estimates from registered frames.

    Returns
    -------
    None
        Not yet implemented.
    """
    return None

register_to_zstack #

register_to_zstack(f_align_in, refImgs, nonrigid=False, settings=default_settings()['registration'], bidiphase=0, device=torch.device('cuda'))

Register frames to a z-stack of reference images and return the max correlation per z-plane.

Runs register_frames with apply_shifts=False to compute phase-correlation between each frame and every reference image in refImgs, without actually shifting the data.

Parameters:

Name Type Description Default
f_align_in Tensor or ndarray

Input frames of shape (n_frames, Ly, Lx).

required
refImgs Tensor or ndarray

Reference images from the z-stack, passed directly to register_frames as refImg.

required
nonrigid bool, optional (default False)

Whether to use nonrigid registration in addition to rigid registration.

False
settings dict

Registration settings dictionary (from default_settings()["registration"]). Controls batch_size, norm_frames, smooth_sigma, spatial_taper, block_size, maxregshift, smooth_sigma_time, snr_thresh, and maxregshiftNR.

default_settings()['registration']
bidiphase int, optional (default 0)

Bidirectional phase offset to correct for bidirectional scanning artifacts.

0
device torch.device, optional (default torch.device("cuda"))

Device on which to run the registration.

device('cuda')

Returns:

Name Type Description
cmax_all ndarray

Maximum correlation values for each frame across z-planes.

Source code in suite2p/registration/zalign.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def register_to_zstack(f_align_in, refImgs, nonrigid=False, settings=default_settings()["registration"],
                       bidiphase=0, device=torch.device("cuda")):
    """
    Register frames to a z-stack of reference images and return the max correlation per z-plane.

    Runs `register_frames` with `apply_shifts=False` to compute phase-correlation between
    each frame and every reference image in `refImgs`, without actually shifting the data.

    Parameters
    ----------
    f_align_in : torch.Tensor or numpy.ndarray
        Input frames of shape (n_frames, Ly, Lx).
    refImgs : torch.Tensor or numpy.ndarray
        Reference images from the z-stack, passed directly to `register_frames` as `refImg`.
    nonrigid : bool, optional (default False)
        Whether to use nonrigid registration in addition to rigid registration.
    settings : dict, optional
        Registration settings dictionary (from `default_settings()["registration"]`).
        Controls batch_size, norm_frames, smooth_sigma, spatial_taper, block_size,
        maxregshift, smooth_sigma_time, snr_thresh, and maxregshiftNR.
    bidiphase : int, optional (default 0)
        Bidirectional phase offset to correct for bidirectional scanning artifacts.
    device : torch.device, optional (default torch.device("cuda"))
        Device on which to run the registration.

    Returns
    -------
    cmax_all : numpy.ndarray
        Maximum correlation values for each frame across z-planes.
    """
    n_frames, Ly, Lx = f_align_in.shape

    ### ----- register frames to reference image -------------- ###
    outputs = register_frames(f_align_in, f_align_out=None, bidiphase=bidiphase,
                            refImg=refImgs, tif_root=None, 
                            batch_size=settings["batch_size"], 
                            norm_frames=settings["norm_frames"], smooth_sigma=settings["smooth_sigma"], 
                            spatial_taper=settings["spatial_taper"], block_size=settings["block_size"], 
                            nonrigid=nonrigid,
                            maxregshift=settings["maxregshift"], smooth_sigma_time=settings["smooth_sigma_time"],
                                snr_thresh=settings["snr_thresh"], maxregshiftNR=settings["maxregshiftNR"],
                                device=device, apply_shifts=False)
    rmin, rmax, mean_img, offsets_all, blocks = outputs
    yoff, xoff, corrXY, yoff1, xoff1, corrXY1, zest, cmax_all = offsets_all

    return cmax_all