"""Wavefront-corrector abstractions and modal-to-zonal mapping helpers.
This module defines the base class used by pyRTC deformable mirrors and other
corrective devices. It manages command streams, flat handling, actuator masks,
and optional 2D layout views, while leaving hardware transport details to the
concrete adapter subclasses.
"""
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ['NUMBA_NUM_THREADS'] = '1'
os.environ['TBB_NUM_THREADS'] = '1'
import numpy as np
import matplotlib.pyplot as plt
from numba import jit
from pyRTC.logging_utils import get_logger
from pyRTC.Pipeline import ImageSHM, launchComponent
from pyRTC.pyRTCComponent import pyRTCComponent
from pyRTC.utils import gaussian_2d_grid, setFromConfig
logger = get_logger(__name__)
@jit(nopython=True)
def ModaltoZonalWithFlat(correction=np.array([],dtype=np.float32),
M2C=np.array([[]],dtype=np.float32),
flat=np.array([],dtype=np.float32)):
"""Project a modal correction into actuator space and add the flat shape."""
return M2C@correction + flat
[docs]
class WavefrontCorrector(pyRTCComponent):
"""
Base class for deformable mirrors and other wavefront-correction devices.
``WavefrontCorrector`` is responsible for the control-plane machinery around
command generation: SHM output, flat shapes, mode-to-command transforms,
floating actuator handling, and delayed command buffers. Subclasses are left
to implement the device-specific transport in ``sendToHardware``.
Config
------
name : str
Name of the wavefront corrector.
numActuators : int
Number of actuators. Required.
numModes : int
Number of modes. Required.
affinity : str
Affinity setting.
m2cFile : str
Path to the mode-to-command file.
floatingInfluenceRadius : int, optional
Radius for floating influence. Default is 1.
frameDelay : int, optional
Frame delay. Default is 0.
saveFile : str, optional
File to save the shape. Default is "wfcShape.npy".
Attributes
----------
name : str
Name of the wavefront corrector.
numActuators : int
Number of actuators.
numModes : int
Number of modes.
affinity : str
Affinity setting.
m2cFile : str
Path to the mode-to-command file.
correctionVector : ImageSHM
Correction vector.
correctionVector2D : ImageSHM or None
2D correction vector for display.
flat : numpy.ndarray
Initial flat shape.
flatModal : numpy.ndarray
Flat shape in modal basis.
currentShape : numpy.ndarray
Current shape.
actuatorStatus : numpy.ndarray
Status of each actuator.
index_map : numpy.ndarray or None
Index map for actuators.
floatingInfluenceRadius : int
Radius for floating influence.
floatMatrix : numpy.ndarray
Floating actuator matrix.
frameDelay : int
Frame delay.
saveFile : str
File to save the shape.
layout : numpy.ndarray or None
Layout of the actuators.
M2C : numpy.ndarray
Mode-to-command matrix.
f_M2C : numpy.ndarray
Floating mode-to-command matrix.
C2M : numpy.ndarray
Command-to-mode matrix.
currentCorrection : numpy.ndarray
Current correction vector.
shapeBuffer : numpy.ndarray
Buffer for shapes with frame delay.
correctionVector2D_template : numpy.ndarray
Template for the 2D correction vector.
"""
def __init__(self, conf) -> None:
try:
super().__init__(conf)
self.name = conf["name"]
self.numActuators = conf["numActuators"]
self.numModes = conf["numModes"]
self.m2cFile = setFromConfig(conf, "m2cFile", "")
self.correctionVector = ImageSHM("wfc", (self.numModes,), np.float32, gpuDevice=self.gpuDevice, consumer=False)
self.correctionVector2D = None
self.setLayout(None)
self.flat = np.zeros(self.numActuators, dtype=np.float32)
self.flatModal = np.zeros(self.numModes, dtype=self.flat.dtype)
self.currentShape = np.zeros_like(self.flat)
self.flatFile = setFromConfig(conf, "flatFile", "")
self.loadFlat()
self.actuatorStatus = np.array([True] * self.numActuators)
self.index_map = None
self.floatingInfluenceRadius = setFromConfig(conf, "floatingInfluenceRadius", 1)
self.floatMatrix = np.eye(self.numActuators, dtype=self.flat.dtype)
self.setDelay(setFromConfig(conf, "frameDelay", 0))
self.saveFile = setFromConfig(conf, "saveFile", "wfcShape.npy")
self.readM2C()
self.logger.info(
"Initialized wavefront corrector name=%s actuators=%s modes=%s",
self.name,
self.numActuators,
self.numModes,
)
except Exception:
logger.exception("Failed to initialize wavefront corrector")
raise
return
[docs]
def setFlat(self, flat):
"""
Set the flat shape.
Parameters
----------
flat : numpy.ndarray
Flat shape to set.
"""
try:
self.flat = flat.astype(self.flat.dtype)
self.logger.info("Updated flat shape")
except Exception:
self.logger.exception("Failed to update flat shape")
raise
return
[docs]
def loadFlat(self,filename=''):
"""
Loads the Flat from a file.
Parameters
----------
filename : str, optional
Filename to load the dark frame from. If not specified, uses the dark file path from the configuration.
"""
#If no file given, first try dark file
try:
if filename == '':
filename = self.flatFile
if filename == '':
flat = np.zeros_like(self.flat)
self.logger.info("No flat file configured; using zeros")
else:
if '.txt' in filename:
flat = np.genfromtxt(filename)
elif '.npy' in filename:
flat = np.load(filename)
else:
raise ValueError(f"Unsupported flat file format: {filename}")
self.logger.info("Loaded flat from %s", filename)
self.setFlat(flat)
except Exception:
self.logger.exception("Failed to load flat from %s", filename or self.flatFile)
raise
return
[docs]
def setLayout(self, layout):
"""
Set the layout of the actuators.
Parameters
----------
layout : numpy.ndarray or None
Layout of the actuators. Is converted to boolean if not already.
"""
try:
self.layout = layout
if isinstance(self.layout, np.ndarray):
self.layout = self.layout > 0
self.correctionVector2D = ImageSHM("wfc2D", self.layout.shape, np.float32, gpuDevice=self.gpuDevice, consumer=False)
self.correctionVector2D.write(np.zeros(self.layout.shape, dtype=np.float32))
self.correctionVector2D_template = self.correctionVector2D.read_noblock()
self.index_map = np.zeros(self.layout.shape, dtype=int)
self.index_map[self.layout > 0] = np.arange(np.sum(self.layout)).astype(int) + 1
self.logger.info("Configured 2D correction layout shape=%s", self.layout.shape)
else:
self.logger.info("Cleared 2D correction layout")
except Exception:
self.logger.exception("Failed to set wavefront corrector layout")
raise
return
[docs]
def deactivateActuators(self, actuators):
"""
Deactivate specified actuators. Actuators are assumed to be floating
Parameters
----------
actuators : list of int
List of actuator indices to deactivate.
"""
try:
if hasattr(actuators, '__len__') and len(actuators) < 1:
raise Exception("You have provided no actuators")
if not hasattr(actuators, '__len__'):
raise Exception("Actuators given as wrong type, please provide array or list")
if isinstance(self.layout, np.ndarray):
if len(self.layout.shape) != 2:
raise Exception("Layout must be 2 dimensions to float actuators. To remove dead actuators, remove them from the M2C. OR set the layout to be 2D and the floatingInfluenceRadius to a 0")
act_to_float_mask = np.zeros_like(self.index_map)
for act in actuators:
act_to_float_mask[np.where(self.index_map == act + 1)] = 1
self.actuatorStatus[act] = False
for act in actuators:
i, j = np.where(self.index_map == act + 1)
inlfluence_map = gaussian_2d_grid(i, j, self.floatingInfluenceRadius, self.layout.shape[0])
inlfluence_map *= self.layout * (1 - act_to_float_mask)
inlfluence_map /= np.sum(inlfluence_map)
inlfluence_map[inlfluence_map < np.max(inlfluence_map) / 10] = 0
self.floatMatrix[act] = inlfluence_map[self.layout > 0]
self.setM2C(self.M2C)
self.logger.info("Deactivated actuators %s", actuators)
else:
logger.warning("No layout set for DM")
except Exception:
self.logger.exception("Failed to deactivate actuators %s", actuators)
raise
return
[docs]
def reactivateActuators(self, actuators):
"""
Reactivate specified actuators.
Parameters
----------
actuators : list of int
List of actuator indices to reactivate.
"""
try:
for act in actuators:
self.actuatorStatus[act] = True
self.floatMatrix = np.eye(self.numActuators, dtype=self.flat.dtype)
actsToDeactivate = [i for i in range(self.numActuators) if not self.actuatorStatus[i]]
if len(actsToDeactivate) > 0:
self.deactivateActuators(actsToDeactivate)
self.logger.info("Reactivated actuators %s", actuators)
except Exception:
self.logger.exception("Failed to reactivate actuators %s", actuators)
raise
return
[docs]
def setM2C(self, M2C):
"""
Set the mode-to-command matrix. This is the basis for correction.
Parameters
----------
M2C : numpy.ndarray or None
Mode-to-command matrix to set. Axes are [numActuators, numModes]
"""
try:
if not isinstance(M2C, np.ndarray):
self.M2C = np.eye(self.numActuators)[:, :self.numModes]
else:
self.M2C = M2C
self.M2C = self.M2C.astype(self.flat.dtype)
self.f_M2C = self.floatMatrix @ self.M2C
self.C2M = np.linalg.pinv(self.M2C)
self.numModes = self.M2C.shape[1]
self.currentCorrection = np.zeros(self.numModes, dtype=self.flat.dtype)
self.flatModal = self.C2M @ self.flat
self.logger.info("Configured M2C matrix shape=%s", self.M2C.shape)
except Exception:
self.logger.exception("Failed to configure M2C matrix")
raise
[docs]
def setDelay(self,delay):
"""
Sets an artificial frame delay. Used for testing, nominally the delay should always be zero.
Parameters
----------
delay : int
Frame delay to set.
"""
try:
self.frameDelay = delay
self.shapeBuffer = np.zeros((self.frameDelay + 1, *self.currentShape.shape), dtype=self.currentShape.dtype)
for i in range(self.shapeBuffer.shape[0]):
self.shapeBuffer[i] = self.flat.copy()
self.logger.info("Set artificial frame delay to %s", delay)
except Exception:
self.logger.exception("Failed to set frame delay to %s", delay)
raise
return
[docs]
def readM2C(self, filename=''):
"""
Read the mode-to-command matrix from a file.
Parameters
----------
filename : str, optional
File to read the mode-to-command matrix from. If not specified, uses the configured m2cFile.
"""
try:
if filename == '':
filename = self.m2cFile
if '.dat' in filename:
M2C = np.fromfile(filename, dtype=np.float64).reshape(self.numActuators, self.numModes)
elif '.npy' in filename:
M2C = np.load(filename)
else:
self.setM2C(None)
self.logger.info("No M2C file configured; using identity basis")
return
self.setM2C(M2C)
self.logger.info("Loaded M2C matrix from %s", filename)
except Exception:
self.logger.exception("Failed to read M2C matrix from %s", filename or self.m2cFile)
raise
return
[docs]
def sendToHardware(self):
"""
Send the current correction to the hardware. Nominally, this function is overwritten by the
child hardware class and registered to the real-time loop from the config.
"""
#Read a new modal correction in M2C basis
self.currentCorrection = self.correctionVector.read()
#If we added a frame delay
if self.frameDelay > 0:
#Roll back shape buffer by 1
self.shapeBuffer[:-1] = self.shapeBuffer[1:]
#Compute a new shape in zonal basis
self.shapeBuffer[-1] = ModaltoZonalWithFlat(self.currentCorrection,
self.f_M2C,
self.flat)
#Set the current shape
self.currentShape = self.shapeBuffer[0]
else:
self.currentShape = ModaltoZonalWithFlat(self.currentCorrection,
self.f_M2C,
self.flat)
#If we have a 2D SHM instance, update it
if isinstance(self.correctionVector2D, ImageSHM):
self.correctionVector2D_template[self.layout] = self.currentShape - self.flat
self.correctionVector2D.write(self.correctionVector2D_template)
#Overwrite with hardware instructions after this to send to hardware
return
[docs]
def read(self, block = False):
"""
Read the current correction vector.
Returns
-------
numpy.ndarray
Current correction vector.
"""
if block:
return self.correctionVector.read()
return self.correctionVector.read_noblock()
[docs]
def write(self, correction):
"""
Write a new correction.
Parameters
----------
correction : numpy.ndarray
Correction vector to write.
"""
self.currentCorrection = correction
#We assume that sendToHardware is registered to the real-time loop
#And that the WFC is running (i.e. start has been called)
self.correctionVector.write(self.currentCorrection)
return
[docs]
def flatten(self):
"""
Flatten the wavefront corrector.
"""
#Sending a zero correction will be the flat since the correction
#is always assumed to be on top of the flat.
try:
self.write(np.zeros_like(self.currentCorrection))
self.logger.info("Flattened wavefront corrector")
except Exception:
self.logger.exception("Failed to flatten wavefront corrector")
raise
return
[docs]
def push(self, mode, amp):
"""
Push a specific mode with a given amplitude.
Parameters
----------
mode : int
Mode index to push.
amp : float
Amplitude to push the mode with.
"""
try:
corr = np.zeros_like(self.currentCorrection)
corr[int(mode)] = float(amp)
self.write(corr)
self.logger.info("Pushed mode %s with amplitude %s", mode, amp)
except Exception:
self.logger.exception("Failed to push mode %s with amplitude %s", mode, amp)
raise
return
[docs]
def saveShape(self, filename=''):
"""
Save the current shape to a file.
Parameters
----------
filename : str, optional
File to save the shape to. If not specified, uses the configured saveFile.
"""
try:
if filename == '':
filename = self.saveFile
if filename == '':
raise ValueError("No output filename provided for shape save")
np.save(filename, self.currentShape)
self.logger.info("Saved current shape to %s", filename)
except Exception:
self.logger.exception("Failed to save current shape to %s", filename or self.saveFile)
raise
return
[docs]
def plot(self, addFlat=False):
"""
Plot the current correction.
Parameters
----------
removeFlat : bool, optional
If True, removes the flat shape from the current correction before plotting. Default is False.
"""
curCorrection = self.read()
if addFlat:
curCorrection += self.flatModal
if isinstance(self.layout, np.ndarray):
newShape = np.zeros(self.layout.shape)
newShape[self.layout] = self.M2C@curCorrection
else:
newShape = curCorrection
if len(newShape.shape) == 1:
# plt.figure(figsize=(12,5))
plt.plot(newShape)
plt.show()
elif len(newShape.shape) == 2:
# plt.figure(figsize=(10,8))
plt.imshow(newShape, cmap = "inferno", aspect='auto', origin='lower')
plt.colorbar()
plt.show()
return
if __name__ == "__main__":
launchComponent(WavefrontCorrector, "wfc", start = True)