Source code for waveline.utils
"""Utility functions and classes."""
from __future__ import annotations
import logging
from collections import defaultdict, deque
from dataclasses import dataclass
import numpy as np
from waveline.datatypes import AERecord, TRRecord
logger = logging.getLogger(__name__)
[docs]
def decibel_to_volts(decibel: float | np.ndarray) -> float | np.ndarray:
"""
Convert from dB(AE) to volts.
Args:
decibel: Input in decibel, scalar or array
Returns:
Input value(s) in volts
"""
return 1e-6 * np.power(10, np.asarray(decibel) / 20)
[docs]
def volts_to_decibel(volts: float | np.ndarray) -> float | np.ndarray:
"""
Convert from volts to dB(AE).
Args:
volts: Inpult in volts, scalar or array
Returns:
Input value(s) in dB(AE)
"""
return 20 * np.log10(np.asarray(volts) * 1e6)
[docs]
@dataclass
class HitRecord:
"""Merged hit record combining AERecord and TRRecord."""
ae: AERecord
tr: TRRecord | None
[docs]
class QueueFullError(Exception):
"""Exception raised when a queue is full."""
[docs]
class HitMerger:
"""
Merge AE and TR records into HitRecords based on the transient recorder index (trai).
Since AE and TR records arrive in order for each channel, AE records are stored in
channel-specific queues and merged with corresponding TR records as they become available.
"""
[docs]
@dataclass
class ChannelState:
queue: deque[AERecord]
last_trai: int = 0
[docs]
def __init__(self, max_queue_size: int | None = None):
"""
Initialize the HitMerger with an optional maximum queue size.
Args:
max_queue_size: Maximum queue size for each channel. If `None`, queues are unbounded.
"""
self._channel_state: dict[int, HitMerger.ChannelState] = defaultdict(
lambda: HitMerger.ChannelState(deque(maxlen=max_queue_size), 0)
)
def __enter__(self):
return self
def __exit__(self, *args, **kwargs):
self.clear()
[docs]
def clear(self):
"""
Clear all buffered AE records.
"""
self._channel_state.clear()
[docs]
def process(self, record: AERecord | TRRecord) -> HitRecord | None:
"""
Process a single AE or TR record.
Returns:
HitRecord if a merge occurred, otherwise None.
Raises:
QueueFullError: If the queue for a channel is full when processing an AERecord.
"""
if isinstance(record, AERecord):
return self._handle_ae_record(record)
if isinstance(record, TRRecord):
return self._handle_tr_record(record)
return None
def _handle_ae_record(self, ae_record: AERecord) -> HitRecord | None:
if ae_record.trai == 0:
return HitRecord(ae=ae_record, tr=None)
state = self._channel_state[ae_record.channel]
if state.queue.maxlen is not None and len(state.queue) >= state.queue.maxlen:
raise QueueFullError()
assert ae_record.trai > state.last_trai, "TRAI must be strictly increasing per channel"
state.queue.append(ae_record)
state.last_trai = ae_record.trai
return None
def _handle_tr_record(self, tr_record: TRRecord) -> HitRecord | None:
state = self._channel_state[tr_record.channel]
logger.debug("AE queue size for channel %d: %s", tr_record.channel, len(state.queue))
while state.queue and state.queue[0].trai < tr_record.trai:
ae_record = state.queue.popleft()
logger.warning("Missing TR for TRAI %d, discard AE", ae_record.trai)
if not state.queue or state.queue[0].trai > tr_record.trai:
logger.warning("Missing AE for TRAI %d, discard TR", tr_record.trai)
return None
ae_record = state.queue.popleft()
assert ae_record.trai == tr_record.trai
assert ae_record.channel == tr_record.channel
return HitRecord(ae=ae_record, tr=tr_record)