Source code for pyRTC.utils

"""General utility helpers shared across pyRTC.

The utilities in this module cover several small but widely used concerns:
configuration validation, file-path helpers, timing helpers, dtype encoding,
basic numerical helpers, and lightweight socket/process convenience functions.

They are kept here because they are broadly reusable across components and do
not belong to a single subsystem.
"""

import yaml
import sys
import select
import os 
from astropy.io import fits
import numpy as np
import psutil
from scipy.ndimage import median_filter, gaussian_filter
import socket
from datetime import datetime
import time 
import logging
from typing import Any, Iterable, Mapping, Optional

from pyRTC.logging_utils import get_logger


logger = get_logger(__name__)

NP_DATA_TYPES = [
    np.int8, np.int16, np.int32, np.int64,
    np.uint8, np.uint16, np.uint32, np.uint64,
    np.float16, np.float32, np.float64, #np.float128,  # np.float128 availability depends on the system
    np.complex64, np.complex128, #np.complex256,       # np.complex256 availability depends on the system
    np.bool_,
    np.object_,
    np.bytes_, np.str_,
    np.datetime64, np.timedelta64
]


class ConfigValidationError(ValueError):
    """Raised when a component configuration does not meet pyRTC expectations."""
    pass


def _require_mapping(conf: Any, component: str) -> Mapping[str, Any]:
    if not isinstance(conf, Mapping):
        raise ConfigValidationError(f"{component}: config must be a mapping/dict, got {type(conf).__name__}")
    return conf


def _validate_optional_numeric(conf: Mapping[str, Any], key: str, component: str, minimum: Optional[float] = None):
    if key not in conf:
        return
    value = conf[key]
    if not isinstance(value, (int, float)):
        raise ConfigValidationError(f"{component}: '{key}' must be numeric, got {type(value).__name__}")
    if minimum is not None and value < minimum:
        raise ConfigValidationError(f"{component}: '{key}' must be >= {minimum}, got {value}")


def validate_wfs_config(conf: Any) -> None:
    component = "wfs"
    conf = _require_mapping(conf, component)

    _validate_optional_numeric(conf, "width", component, minimum=1)
    _validate_optional_numeric(conf, "height", component, minimum=1)
    _validate_optional_numeric(conf, "darkCount", component, minimum=0)
    _validate_optional_numeric(conf, "downsampleFactor", component, minimum=0)
    _validate_optional_numeric(conf, "rotationAngle", component)


def validate_wfc_config(conf: Any) -> None:
    component = "wfc"
    conf = _require_mapping(conf, component)

    required = ["name", "numActuators", "numModes"]
    missing = [key for key in required if key not in conf]
    if missing:
        missing_str = ", ".join(missing)
        raise ConfigValidationError(f"{component}: missing required config key(s): {missing_str}")

    if not isinstance(conf["name"], str) or not conf["name"].strip():
        raise ConfigValidationError(f"{component}: 'name' must be a non-empty string")

    if not isinstance(conf["numActuators"], int) or conf["numActuators"] <= 0:
        raise ConfigValidationError(f"{component}: 'numActuators' must be a positive int, got {conf['numActuators']}")

    if not isinstance(conf["numModes"], int) or conf["numModes"] <= 0:
        raise ConfigValidationError(f"{component}: 'numModes' must be a positive int, got {conf['numModes']}")

    _validate_optional_numeric(conf, "floatingInfluenceRadius", component, minimum=0)
    _validate_optional_numeric(conf, "frameDelay", component, minimum=0)


def validate_loop_config(conf: Any) -> None:
    component = "loop"
    conf = _require_mapping(conf, component)

    _validate_optional_numeric(conf, "numDroppedModes", component, minimum=0)
    _validate_optional_numeric(conf, "gain", component)
    _validate_optional_numeric(conf, "leakyGain", component)
    _validate_optional_numeric(conf, "hardwareDelay", component, minimum=0)
    _validate_optional_numeric(conf, "pokeAmp", component, minimum=0)
    _validate_optional_numeric(conf, "numItersIM", component, minimum=1)
    _validate_optional_numeric(conf, "delay", component, minimum=0)
    _validate_optional_numeric(conf, "pGain", component)
    _validate_optional_numeric(conf, "iGain", component)
    _validate_optional_numeric(conf, "dGain", component)
    _validate_optional_numeric(conf, "derivativeFilter", component)

    for key in ["controlLimits", "integralLimits", "absoluteLimits"]:
        if key not in conf:
            continue
        value = conf[key]
        if not isinstance(value, (list, tuple)) or len(value) != 2:
            raise ConfigValidationError(f"{component}: '{key}' must be a list/tuple of length 2")


def validate_component_config(conf: Any, mro_names: Iterable[str]) -> None:
    """Dispatch configuration validation based on the component class hierarchy.

    Parameters
    ----------
    conf : Any
        Candidate configuration object for one component.
    mro_names : Iterable[str]
        Class names from the component's method-resolution order. The helper
        uses these names to decide which specialized validators should run.
    """
    _require_mapping(conf, "component")

    mro_name_set = set(mro_names)
    if "Loop" in mro_name_set:
        validate_loop_config(conf)
    if "WavefrontSensor" in mro_name_set:
        validate_wfs_config(conf)
    if "WavefrontCorrector" in mro_name_set:
        validate_wfc_config(conf)


def precise_delay(microseconds):
    target_time = time.perf_counter() + microseconds / 1_000_000
    while np.float64(time.perf_counter()) < target_time:
        pass

# Function to measure execution time
def measure_execution_time(f, args, numIters=10):
    """Measure repeated execution-time statistics for a callable.

    The return value is tailored to the repository's lightweight performance
    smoke checks: median, interquartile range, and approximate low/high bounds.
    """
   
    #init once
    f(*args)

    # Measure time
    exTimes = np.empty(numIters)
    for i in range(numIters):
        start_time = time.time()
        f(*args)
        end_time = time.time()
        exTimes[i] = (end_time - start_time)

    sorted_times = np.sort(exTimes)

    def _percentile_from_sorted(sorted_arr, pct):
        if sorted_arr.size == 0:
            return np.float64(0.0)
        if sorted_arr.size == 1:
            return np.float64(sorted_arr[0])
        rank = (pct / 100.0) * (sorted_arr.size - 1)
        low = int(np.floor(rank))
        high = int(np.ceil(rank))
        if low == high:
            return np.float64(sorted_arr[low])
        weight = rank - low
        return np.float64(sorted_arr[low] * (1.0 - weight) + sorted_arr[high] * weight)

    median = _percentile_from_sorted(sorted_times, 50.0)
    q1 = _percentile_from_sorted(sorted_times, 25.0)
    q3 = _percentile_from_sorted(sorted_times, 75.0)
    iqr = q3 - q1
    CI_1 = _percentile_from_sorted(sorted_times, 0.5)
    CI_99 = _percentile_from_sorted(sorted_times, 99.5)

    return median, iqr, CI_1, CI_99

def change_directory(directory):
    try:
        os.chdir(directory)
        logger.info("Successfully changed the current directory to %s", os.getcwd())
    except FileNotFoundError:
        logger.error("The directory '%s' does not exist", directory)
    except PermissionError:
        logger.error("Permission denied to access the directory '%s'", directory)
    except Exception as e:
        logger.exception("Unexpected error while changing directory: %s", e)
    return

def add_to_path(directory):
    # Check if the directory exists
    if not os.path.isdir(directory):
        logger.error("The directory '%s' does not exist", directory)
        return

    # Add the directory to the PATH environment variable
    current_path = os.environ.get('PATH', '')
    if directory not in current_path:
        new_path = f"{directory}:{current_path}"
        os.environ['PATH'] = new_path
        logger.info("Directory '%s' added to PATH", directory)
    else:
        logger.info("Directory '%s' is already in PATH", directory)

    return

def powerLawOG(numModes, k):
    return (1- (np.arange(numModes)/numModes)**k)


def append_to_file(filename, data, dtype=np.float32):
    """
    Append a numpy array to a binary file on disk.

    Parameters:
    filename : str
        The name of the file to which data will be appended.
    data : numpy array
        The numpy array to append to the file.
    dtype : data-type, optional
        The desired data-type for the array. Default is np.float32.
    """
    if os.path.exists(filename):
        # If the file exists, append to it
        with open(filename, 'ab') as f:
            data.tofile(f)
    else:
        # If the file does not exist, create it and write the initial data
        with open(filename, 'wb') as f:
            data.tofile(f)

def generate_circular_aperture_mask(N, R, ratio):
    """
    Generates a binary mask of size NxN with a circular aperture of radius R and a central obscuration of radius r.
    
    Parameters:
    N (int): The size of the mask (NxN).
    R (float): The radius of the outer circular aperture.
    ratio (float): The ratio of the inner obscuration radius to the outer radius (r/R).

    Returns:
    numpy.ndarray: Binary mask with the circular aperture.
    """
    r = R * ratio
    x = np.linspace(-N/2, N/2, N)
    xx, yy = np.meshgrid(x,x)
    mask = (xx**2 + yy**2 <= R**2) 
    if r > 0:
        mask &= (xx**2 + yy**2 >= r**2)
    return mask.astype(bool)

def load_data(filename, dtype=None):
    if filename.endswith('.npy'):
        data = np.load(filename)
    elif filename.endswith('.fits'):
        with fits.open(filename) as hdul:
            data = hdul[0].data
    else:
        raise ValueError("Unsupported file format. Please provide a .npy or .fits file.")
    
    if dtype is not None:
        return data.astype(dtype)
    return data

def generate_filepath(base_dir='.', prefix='file', extension='.dat'):
    """
    Generate a file path based on the current date and time.

    Parameters:
    base_dir : str
        The base directory where the file will be saved.
    prefix : str
        The prefix for the file name.
    extension : str
        The file extension.

    Returns:
    str
        The generated file path.
    """
    # Get the current date and time
    current_time = datetime.now()

    # Format the date and time
    timestamp = current_time.strftime('%Y%m%d_%H%M%S')

    # Construct the file name
    filename = f"{prefix}_{timestamp}{extension}"

    # Construct the full file path
    filepath = os.path.join(base_dir, filename)

    return filepath

def get_tmp_filepath(file_path, uniqueStr = 'tmp'):
    """
    Append '_tmp' to the filename part of the given file path, before the file extension.

    :param file_path: str, the original file path
    :return: str, modified file path with '_tmp' before the extension
    """
    # Split the file path into directory path and filename
    dir_path, filename = os.path.split(file_path)

    # Split the filename into name and extension
    file_name, file_ext = os.path.splitext(filename)

    # Add '_tmp' to the filename
    new_filename = f"{file_name}_{uniqueStr}{file_ext}"

    # Construct the new full path
    new_file_path = os.path.join(dir_path, new_filename)

    return new_file_path

def centroid(array):
    arr = np.asarray(array, dtype=np.float64)
    total = np.add.reduce(arr.ravel(), dtype=np.float64) + 1e-4
    y_indices, x_indices = np.indices(arr.shape, dtype=np.float64)
    x_weighted = np.add.reduce((x_indices * arr).ravel(), dtype=np.float64)
    y_weighted = np.add.reduce((y_indices * arr).ravel(), dtype=np.float64)
    return np.array([x_weighted / total, y_weighted / total], dtype=np.float64)

def add_to_buffer(buffer, vec):
    buffer[:-1] = buffer[1:]
    buffer[-1] = vec
    return

def next_power_of_two(n):
    # Handle case for non-positive input
    if n <= 0:
        return 1

    power = 1
    while power <= n:
        power *= 2
    return power


def adjusted_cosine_similarity(a, b):
    dot_product = np.dot(a, b)
    norm_a = np.linalg.norm(a)
    norm_b = np.linalg.norm(b)
    if norm_a == 0 or norm_b == 0:
        return 0
    cosine_similarity = dot_product / (norm_a * norm_b)
    magnitude_similarity = min(norm_a, norm_b) / max(norm_a, norm_b)
    return cosine_similarity * magnitude_similarity


def robust_variance(data):
    median = np.median(data)
    deviations = np.abs(data - median)
    mad = np.median(deviations)
    return (mad / 0.6745) ** 2

def cosine_similarity(v1, v2):
    # Calculate the magnitudes of the vectors
    mag_v1 = np.linalg.norm(v1)
    mag_v2 = np.linalg.norm(v2)

    # Calculate the dot product of vectors
    dot_product = np.dot(v1, v2)

    if mag_v1 == 0 or mag_v2 == 0:
        return 0

    return dot_product / (mag_v1 * mag_v2)

def angle_between_vectors(v1, v2):

    # Calculate the cosine of the angle
    return np.abs(np.arccos(cosine_similarity(v1, v2)))

def compute_fwhm_dark_subtracted_image(image):
    # Filter to keep only negative values
    negative_pixels = image[image < 1]
    
    # Compute the histogram of negative values
    # Adjust bins and range as necessary for your specific image
    hist, bins = np.histogram(negative_pixels, bins=np.arange(np.min(negative_pixels), 1)+0.5)
    # Since the distribution is symmetric, we can mirror the histogram to get the full distribution
    hist_full = np.concatenate((hist[::-1], hist))
    
    # Compute the bin centers from the bin edges
    bin_centers = (bins[:-1] + bins[1:]) / 2
    bin_centers_full = np.concatenate((-bin_centers[::-1], bin_centers))

    # Find the maximum value (mode of the distribution)
    peak_value = np.max(hist_full)
    half_max = peak_value / 2
    
    # Find the points where the histogram crosses the half maximum
    cross_points = np.where(np.diff((hist_full > half_max).astype(int)))[0]
    
    # Assuming the distribution is sufficiently smooth and has a single peak,
    # the FWHM is the distance between the first and last crossing points
    fwhm_value = np.abs(bin_centers_full[cross_points[-1]] - bin_centers_full[cross_points[0]])
    
    return fwhm_value

def clean_image_for_strehl(img, median_filter_size = 3, gaussian_sigma = 1):
    corrected_img = np.asarray(img)

    if median_filter_size is not None and median_filter_size > 1:
        corrected_img = median_filter(
            corrected_img,
            size=median_filter_size,
            output=None,
            mode='reflect',
            cval=0.0,
            origin=0,
        )

    if gaussian_sigma is not None and gaussian_sigma > 0:
        corrected_img = gaussian_filter(
            corrected_img,
            sigma=gaussian_sigma,
            order=0,
            output=None,
            mode='reflect',
            cval=0.0,
            truncate=4.0,
        )

    return corrected_img

def gaussian_2d_grid(i, j, sigma, grid_size):
    i = int(np.asarray(i).reshape(-1)[0])
    j = int(np.asarray(j).reshape(-1)[0])
    sigma = float(np.asarray(sigma).reshape(-1)[0])
    grid_size = int(np.asarray(grid_size).reshape(-1)[0])

    grid = np.zeros((grid_size, grid_size))
    if sigma == 0:
        return grid
    for x in range(grid_size):
        for y in range(grid_size):
            if x == i and y == j:
                continue  # Skip the center point as its value should be 0
            else:
                # Compute the Gaussian value
                grid[x, y] = np.exp(-((x - i)**2 + (y - j)**2) / (2 * sigma**2))
    
    grid /= np.sum(grid)

    return grid

def set_affinity(affinity):
    # Unsupported by MacOS
    if isinstance(affinity, int) or isinstance(affinity, float):
        affinity = [int(affinity),]
    elif isinstance(affinity, np.ndarray):
        affinity = list(affinity)
    else:
        return -1
    if sys.platform != 'darwin':
        psutil.Process(os.getpid()).cpu_affinity(affinity)
    return



[docs] def setFromConfig(conf, name, default): """Return a config value or a typed default. When a default is provided, this helper asserts that any override found in the configuration matches the default's type. That makes many YAML mistakes fail early during component startup instead of surfacing later. """ if name in conf.keys(): val = conf[name] else: val = default debugStr = f"There is a type mismatch between the default value for config variable {name} and the given value: {type(val).__name__} != {type(default).__name__}" if default is not None: assert isinstance(val, type(default)), debugStr return val
def signal2D(signal, layout): curSignal2D = np.zeros(layout.shape) slopemask = layout[:,:layout.shape[1]//2] curSignal2D[:,:layout.shape[1]//2][slopemask] = signal[:signal.size//2] curSignal2D[:,layout.shape[1]//2:][slopemask] = signal[signal.size//2:] return curSignal2D def dtype_to_float(dtype): """ Convert a NumPy dtype to a unique float. Parameters: - dtype: NumPy dtype object Returns: - float: Unique float representing the dtype """ for i, d in enumerate(NP_DATA_TYPES): if dtype == d: return i return -1 def float_to_dtype(dtype_float): """ Convert a unique float back to the original NumPy dtype. Parameters: - dtype_float: Unique float representing the dtype Returns: - np.dtype: NumPy dtype object """ return np.dtype(NP_DATA_TYPES[int(dtype_float)]) def bind_socket(host, start_port, max_attempts=5): """Bind a TCP socket, retrying across a short range of ports. This is primarily used by hard-RTC launcher/listener code so child hardware processes can recover from a busy preferred port without embedding their own retry logic. """ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # Allow reuse of socket addresses for attempt in range(max_attempts): try: # Attempt to bind the socket sock.bind((host, start_port + attempt)) logger.info("Bound socket to %s:%s", host, start_port + attempt) return sock except OSError as e: logger.warning("Failed to bind to %s:%s: %s", host, start_port + attempt, e) if e.errno == socket.errno.EADDRINUSE: logger.info("Address already in use. Trying next port.") else: logger.error("Unexpected socket bind failure. Stopping attempts.") break else: # After all attempts, if no binding was successful, raise an exception raise RuntimeError("Failed to bind socket after multiple attempts") return -1 def decrease_nice(): # Unsupported by MacOS or Windows if sys.platform != 'darwin' and sys.platform != 'win32': try: p = psutil.Process(os.getpid()) p.nice(-20) # Unix uses a numeric value (lower means higher priority) except Exception: logging.log(level=logging.WARNING, msg="Unable to adjust nice level.\ Give your user sudo privledges without passowrd to use this feature.") return # Set CPU affinity and priority for a thread def set_affinity_and_priority(thread_id, cpu_cores): set_affinity(cpu_cores) decrease_nice() logger.info("Thread %s: priority set to REALTIME", thread_id) def read_yaml_file(file_path): """Load a YAML file and return the parsed Python object.""" with open(file_path, 'r') as file: conf = yaml.safe_load(file) return conf def read_input_with_timeout(timeout): # Set the list of file descriptors to watch for input (stdin) inputs = [sys.stdin] # Use select to wait for input or timeout readable, _, _ = select.select(inputs, [], [], timeout) if readable: user_input = sys.stdin.readline().strip() return user_input else: return None def is_numeric(s): try: float(s) return True except ValueError: return False