"""Wavefront-sensor abstractions and common image pre-processing kernels.
This module defines the base class used by pyRTC wavefront-sensor adapters and
includes small image-processing helpers that are hot enough to warrant Numba
acceleration. Hardware-specific sensors subclass ``WavefrontSensor`` and reuse
its SHM publication, dark handling, and optional geometric pre-processing.
"""
import matplotlib.pyplot as plt
import numpy as np
from numba import jit, prange
from pyRTC.logging_utils import get_logger
from pyRTC.Pipeline import ImageSHM, launchComponent
from pyRTC.pyRTCComponent import pyRTCComponent
from pyRTC.utils import setFromConfig
logger = get_logger(__name__)
@jit(nopython=True, nogil=True, cache=True, fastmath=True)
def downsample_int32_image_jit(image, N):
"""
Numba-optimized function to downsample a 2D int32 NumPy array by a factor N, returning int32 output.
Parameters:
- image: 2D NumPy array of int32 with shape (H, W)
- N: int, downsampling factor
Returns:
- downsampled_image: 2D NumPy array of int32 with shape (H//N, W//N)
"""
H, W = image.shape
# Calculate padding sizes if H or W is not divisible by N
pad_H = (-H) % N
pad_W = (-W) % N
# Pad the image if necessary to make dimensions divisible by N
if pad_H > 0 or pad_W > 0:
# Create a new array with zeros
H_padded = H + pad_H
W_padded = W + pad_W
image_padded = np.zeros((H_padded, W_padded), dtype=np.int32)
image_padded[:H, :W] = image
else:
image_padded = image
H_padded, W_padded = H, W
# Initialize the output array
out_H = H_padded // N
out_W = W_padded // N
downsampled_image = np.zeros((out_H, out_W), dtype=np.int32)
# Loop over the output array indices with Numba's parallel loops
for i in range(out_H):
for j in range(out_W):
# Compute the sum over the N x N block
sum_block = 0
for di in range(N):
for dj in range(N):
sum_block += image_padded[i*N + di, j*N + dj]
# Compute the mean
mean_value = sum_block / (N * N)
# Round and cast to int32
downsampled_image[i, j] = np.int32(round(mean_value))
return downsampled_image
@jit(nopython=True, nogil=True, cache=True, fastmath=True, parallel=True)
def rotate_image_jit(image, angle_rad):
"""
Numba-optimized parallel bilinear interpolation rotation.
Parameters:
- image: 2D NumPy array (int32 or float) with shape (H, W)
- angle_rad: float, rotation angle in radians (positive = counter-clockwise)
Returns:
- rotated_image: 2D NumPy array with same shape and dtype as input
"""
h, w = image.shape
cos_angle = np.cos(angle_rad)
sin_angle = np.sin(angle_rad)
# Center of rotation
cx, cy = w / 2.0, h / 2.0
# Output image (same size as input)
rotated = np.zeros_like(image)
for y in prange(h):
for x in range(w):
# Translate to center
x_centered = x - cx
y_centered = y - cy
# Rotate (inverse transformation)
x_orig = x_centered * cos_angle + y_centered * sin_angle + cx
y_orig = -x_centered * sin_angle + y_centered * cos_angle + cy
# Check if the source coordinates are within bounds
if 0 <= x_orig < w-1 and 0 <= y_orig < h-1:
# Bilinear interpolation
x0, x1 = int(np.floor(x_orig)), int(np.ceil(x_orig))
y0, y1 = int(np.floor(y_orig)), int(np.ceil(y_orig))
# Ensure indices are within bounds
if x1 >= w:
x1 = w - 1
if y1 >= h:
y1 = h - 1
# Interpolation weights
wx = x_orig - x0
wy = y_orig - y0
# Bilinear interpolation
val = (image[y0, x0] * (1 - wx) * (1 - wy) +
image[y0, x1] * wx * (1 - wy) +
image[y1, x0] * (1 - wx) * wy +
image[y1, x1] * wx * wy)
rotated[y, x] = val
return rotated
[docs]
class WavefrontSensor(pyRTCComponent):
"""
Base class for cameras that feed the wavefront-sensing pipeline.
The class owns the common control-plane behavior for wavefront-sensor image
sources: configuration, dark subtraction, optional downsampling and
rotation, and publication of both raw and processed frames. Concrete sensor
adapters in ``pyRTC.hardware`` are responsible for talking to vendor SDKs
and filling ``self.data`` before delegating back to the base implementation.
Config
------
name : str
The name of the wavefront sensor. Default "wavefrontSensor"
width : int
The width of the wavefront sensor image. Required.
height : int
The width of the wavefront sensor image. Required.
darkCount : int
Number of dark frames to average. Default 1000.
darkFile : str
Path to the dark frame file. Default, empty string.
Attributes
----------
imageShape : tuple
The shape of the image (width, height).
imageRawDType : data-type
The data type for raw image.
imageDType : data-type
The data type for processed image.
imageRaw : ImageSHM
Shared memory object for raw image.
image : ImageSHM
Shared memory object for processed image.
data : ndarray
Array to store raw image data.
dark : ndarray
Array to store dark frame data.
affinity : int
The affinity configuration.
roiWidth : int
Width of the region of interest.
roiHeight : int
Height of the region of interest.
roiLeft : int
Left coordinate of the region of interest.
roiTop : int
Top coordinate of the region of interest.
exposure : float
Exposure time.
binning : int
Binning factor.
gain : float
Gain setting.
bitDepth : int
Bit depth of the image.
Methods
-------
setRoi(roi)
Sets the region of interest.
setExposure(exposure)
Sets the exposure time.
setBinning(binning)
Sets the binning factor.
setGain(gain)
Sets the gain.
setBitDepth(bitDepth)
Sets the bit depth.
expose()
Writes the current image data to shared memory.
read()
Reads the processed image data from shared memory.
takeDark()
Captures and sets the dark frame.
setDark(dark)
Sets the dark frame.
saveDark(filename='')
Saves the dark frame to a file.
loadDark(filename='')
Loads the dark frame from a file.
plot()
Plots the current image data.
rotateImage(angle_deg)
Rotates the current image data by the specified angle in degrees.
"""
def __init__(self, conf: dict) -> None:
"""
Constructs all the necessary attributes for the wavefront sensor object.
Parameters
----------
conf : dict
Configuration dictionary for the wavefront sensor. Typically it will just be
the "wfs" section of a pyRTC config.
"""
try:
super().__init__(conf)
self.name = setFromConfig(conf, "name", "wavefrontSensor")
self.width = setFromConfig(conf, "width", 1)
self.height = setFromConfig(conf, "height", 1)
self.darkCount = setFromConfig(conf, "darkCount", 1000)
self.darkFile = setFromConfig(conf, "darkFile", "")
self.downsampleFactor = setFromConfig(conf, "downsampleFactor", 0)
self.rotationAngle = setFromConfig(conf, "rotationAngle", 0.0)
self.imageRawShape = [self.width, self.height]
self.imageRawDType = np.uint16
self.imageDType = np.int32
self.imageShape = [self.width, self.height]
if self.downsampleFactor > 0:
self.imageShape[0] = self.imageShape[0] // self.downsampleFactor
self.imageShape[1] = self.imageShape[1] // self.downsampleFactor
self.imageRaw = ImageSHM("wfsRaw", self.imageRawShape, self.imageRawDType, gpuDevice=self.gpuDevice, consumer=False)
self.image = ImageSHM("wfs", self.imageShape, self.imageDType, gpuDevice=self.gpuDevice, consumer=False)
self.data = np.zeros(self.imageShape, dtype=self.imageRawDType)
self.dark = np.zeros(self.imageRawShape, dtype=self.imageDType)
self.loadDark()
self.logger.info(
"Initialized wavefront sensor name=%s raw_shape=%s image_shape=%s downsample=%s rotation=%s",
self.name,
self.imageRawShape,
self.imageShape,
self.downsampleFactor,
self.rotationAngle,
)
except Exception:
logger.exception("Failed to initialize wavefront sensor")
raise
return
[docs]
def setRoi(self, roi):
"""
Sets the region of interest (ROI) for the sensor.
Parameters
----------
roi : tuple
A tuple containing (width, height, left, top) of the ROI.
"""
try:
self.roiWidth = roi[0]
self.roiHeight = roi[1]
self.roiLeft = roi[2]
self.roiTop = roi[3]
self.logger.info("Set ROI width=%s height=%s left=%s top=%s", *roi)
except Exception:
self.logger.exception("Failed to set ROI from %s", roi)
raise
return
[docs]
def setExposure(self, exposure: float) -> None:
"""
Sets the exposure time for the sensor.
Parameters
----------
exposure : float
Exposure time in whatever unit your camera uses.
"""
try:
self.exposure = exposure
self.logger.info("Set exposure to %s", exposure)
except Exception:
self.logger.exception("Failed to set exposure to %s", exposure)
raise
return
[docs]
def setBinning(self, binning: int) -> None:
"""
Sets the binning factor for the sensor.
Parameters
----------
binning : int
Binning factor.
"""
try:
self.binning = binning
self.logger.info("Set binning to %s", binning)
except Exception:
self.logger.exception("Failed to set binning to %s", binning)
raise
return
[docs]
def setGain(self, gain: float) -> None:
"""
Sets the gain for the sensor.
Parameters
----------
gain : float
Gain value.
"""
try:
self.gain = gain
self.logger.info("Set gain to %s", gain)
except Exception:
self.logger.exception("Failed to set gain to %s", gain)
raise
return
[docs]
def setBitDepth(self, bitDepth: int) -> None:
"""
Sets the bit depth for the sensor.
Parameters
----------
bitDepth : int
Bit depth. pyRTC convention is this is the number of bits in the ADC,
e.g., 8, 16, 12, 10.
"""
try:
self.bitDepth = bitDepth
self.logger.info("Set bit depth to %s", bitDepth)
except Exception:
self.logger.exception("Failed to set bit depth to %s", bitDepth)
raise
return
[docs]
def expose(self) -> None:
"""
Writes the current image data to shared memory. Both raw, and dark subtracted.
Parameters
----------
"""
self.imageRaw.write(self.data)
img = self.data.astype(self.imageDType)
# Apply dark subtraction
processed_image = img - self.dark
# Apply downsampling if configured
if self.downsampleFactor > 0:
processed_image = downsample_int32_image_jit(processed_image, self.downsampleFactor)
# Apply rotation if specified
if self.rotationAngle != 0.0:
angle_rad = np.radians(self.rotationAngle)
processed_image = rotate_image_jit(processed_image, angle_rad)
# Write the processed image to shared memory
self.image.write(processed_image)
return
[docs]
def read(self, block = True) -> None:
"""
Reads the dark subtracted image data from shared memory.
Returns
-------
ndarray
Processed image data.
"""
if block:
return self.image.read(RELEASE_GIL = self.RELEASE_GIL)
else:
return self.image.read_noblock()
[docs]
def takeDark(self) -> None:
"""
Captures and sets the dark frame.
"""
try:
if self.darkCount < 1:
raise ValueError("darkCount must be at least 1 to acquire a dark frame")
self.logger.info("Taking dark frame using %s exposures", self.darkCount)
self.setDark(np.zeros_like(self.dark))
dark = np.zeros(self.imageShape, dtype=np.float64)
for _ in range(self.darkCount):
dark += self.read().astype(np.float64)
dark /= self.darkCount
self.setDark(dark)
self.logger.info("Completed dark frame acquisition")
except Exception:
self.logger.exception("Failed to acquire dark frame")
raise
return
[docs]
def setDark(self, dark) -> None:
"""
Sets the dark frame.
Parameters
----------
dark : ndarray
Dark frame data.
"""
try:
self.dark = dark.astype(self.imageDType)
self.logger.info("Updated dark frame")
except Exception:
self.logger.exception("Failed to update dark frame")
raise
return
[docs]
def saveDark(self,filename=''):
"""
Saves the dark frame to a file.
Parameters
----------
filename : str, optional
Filename to save the dark frame to. If not specified, uses the dark file path from the configuration.
"""
try:
if filename == '':
filename = self.darkFile
if filename == '':
raise ValueError("No dark frame filename provided")
np.save(filename, self.dark)
self.logger.info("Saved dark frame to %s", filename)
except Exception:
self.logger.exception("Failed to save dark frame to %s", filename or self.darkFile)
raise
return
[docs]
def loadDark(self,filename=''):
"""
Loads the dark frame from a file.
Parameters
----------
filename : str, optional
Filename to load the dark frame from. If not specified, uses the dark file path from the configuration.
"""
#If no file given, first try dark file
try:
if filename == '':
filename = self.darkFile
if filename == '':
self.dark = np.zeros_like(self.dark)
self.logger.info("No dark frame file configured; using zeros")
else:
self.dark = np.load(filename)
self.logger.info("Loaded dark frame from %s", filename)
except Exception:
self.logger.exception("Failed to load dark frame from %s", filename or self.darkFile)
raise
return
[docs]
def plot(self) -> None:
"""
Plots the current image data.
"""
try:
arr = self.read(block=False)
plt.figure(figsize=(8,8))
plt.imshow(arr, cmap = 'inferno', origin='lower')
plt.colorbar()
plt.show()
self.logger.info("Plotted wavefront sensor image")
except Exception:
self.logger.exception("Failed to plot wavefront sensor image")
raise
return
[docs]
def rotateImage(self, angle_deg: float) -> np.ndarray:
"""
Rotates the current image data by the specified angle.
This method uses a high-performance numba JIT-compiled bilinear interpolation
rotation algorithm that is significantly faster than scipy or opencv implementations
while maintaining good image quality.
Parameters
----------
angle_deg : float
Rotation angle in degrees. Positive values rotate counter-clockwise.
Returns
-------
ndarray
Rotated image data with the same shape and dtype as the original.
Examples
--------
>>> wfs = WavefrontSensor(config)
>>> rotated_img = wfs.rotateImage(45.0) # Rotate 45 degrees counter-clockwise
>>> rotated_img = wfs.rotateImage(-90.0) # Rotate 90 degrees clockwise
"""
# Get the current image data
try:
current_image = self.read(block=False)
angle_rad = np.radians(angle_deg)
rotated_image = rotate_image_jit(current_image, angle_rad)
self.logger.info("Rotated image by %s degrees", angle_deg)
return rotated_image
except Exception:
self.logger.exception("Failed to rotate image by %s degrees", angle_deg)
raise
if __name__ == "__main__":
launchComponent(WavefrontSensor, "wfs", start = True)