Source code for mu_mimo.processing.modulation

# mu-mimo/mu_mimo/processing/modulation.py

from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
import numpy as np
from ..types import ComplexArray, RealArray, IntArray, BitArray, ConstType



# HELPERS

@dataclass()
class Constellation:
    """
    A constellation.

    Parameters
    ----------
    type : ConstType
        The constellation type.
    size : int
        The constellation size (number of points in the constellation).
    points : ComplexArray, shape (size,)
        The constellation points.
    """

    type: ConstType
    size: int

    points : ComplexArray | None = None

    def __post_init__(self):

        # Validate the constellation size.
        is_power_of_2 = (self.size > 0) and ((self.size & (self.size - 1)) == 0)
        is_power_of_4 = is_power_of_2 and (int(np.log2(self.size)) % 2 == 0)

        if self.type in ("PAM", "PSK") and not is_power_of_2:
            raise ValueError("For PAM and PSK modulation, the constellation size must be a power of 2.")
        elif self.type == "QAM" and not is_power_of_4:
            raise ValueError("For QAM modulation, the constellation size must be a power of 4.")
        
        # Generate the constellation points for a PAM constellation.
        if self.type == "PAM":
            self.points = np.arange(-(self.size-1), (self.size-1) + 1, 2) * np.sqrt(3/(self.size**2-1))
        
        # Generate the constellation points for a PSK constellation.
        elif self.type == "PSK":
            self.points = np.exp(1j * 2*np.pi * np.arange(self.size) / self.size)
        
        # Generate the constellation points for a QAM constellation.
        elif self.type == "QAM":
            sqrtM_PAM = np.arange(-(np.sqrt(self.size)-1), (np.sqrt(self.size)-1) + 1, 2) * np.sqrt(3 / (2*(self.size-1)))
            real_grid, imaginary_grid = np.meshgrid(sqrtM_PAM, sqrtM_PAM)
            self.points = (real_grid + 1j*imaginary_grid).ravel()
        
        else:
            raise ValueError(f'The constellation type is invalid.\nChoose between "PAM", "PSK", or "QAM". Right now, type is {self.type}.')

    def __eq__(self, other: object) -> bool:
        
        if not isinstance(other, Constellation):
            return NotImplemented
        
        return (
            self.type == other.type and
            self.size == other.size and
            np.array_equal(self.points, other.points)
        )

class NumberRepresentation:
    """
    Number Representation Utility Class.

    This class provides helper functions for the conversion between different number representations.
    """

    @staticmethod
    def binary_to_decimal(b: BitArray) -> int:
        """
        Convert a binary number to its decimal representation.

        Parameters
        ----------
        b : BitArray
            The binary number.

        Returns
        -------
        d : int
            The corresponding decimal representation.
        """
        powers = 1 << np.arange(b.size - 1, -1, -1, dtype=int)
        d = int(b @ powers)
        return d

    @staticmethod
    def decimal_to_binary(d: int, length: int = None) -> BitArray:
        """
        Convert a decimal number to its binary representation.

        Parameters
        ----------
        d : int
            The decimal number.
        length : int, optional
            The length of the binary representation.\\
            If None, the length is the minimum required to represent the decimal number.

        Returns
        -------
        b : BitArray
            The corresponding binary representation.
        """
        if d < 0: 
            raise ValueError("The decimal input must be non-negative.")
        
        b = np.array(list(np.binary_repr(int(d), width=length)), dtype=int)
        return b

    @staticmethod
    def gray_to_binary(g: BitArray) -> BitArray:
        """
        Convert a Gray code number to its binary representation.

        Parameters
        ----------
        g : BitArray
            The Gray code number.

        Returns
        -------
        b : BitArray
            The corresponding binary representation.
        """
        b =  np.bitwise_xor.accumulate(g).astype(int)
        return b

    @staticmethod
    def binary_to_gray(b: BitArray) -> BitArray:
        """
        Convert a binary number to its Gray code representation.

        Parameters
        ----------
        b : BitArray
            The binary number.

        Returns
        -------
        g : BitArray
            The corresponding Gray code representation.
        """
        if b.size == 0: 
            return np.array([], dtype=int)
        
        g = np.empty_like(b, dtype=int)
        g[0] = b[0]
        if b.size > 1:
            g[1:] = np.bitwise_xor(b[:-1], b[1:])

        return g

    @staticmethod
    def gray_to_decimal(g: BitArray) -> int:
        """
        Convert a Gray code number to its decimal representation.

        Parameters
        ----------
        g : BitArray
            The Gray code number.

        Returns
        -------
        d : int
            The corresponding decimal representation.
        """
        b = NumberRepresentation.gray_to_binary(g)
        d = NumberRepresentation.binary_to_decimal(b)
        return d

    @staticmethod
    def decimal_to_gray(d: int, length: int = None) -> BitArray:
        """
        Convert a decimal number to its Gray code representation.

        Parameters
        ----------
        d : int
            The decimal number.
        length : int, optional
            The length of the Gray code representation.\\
            If None, the length is the minimum required to represent the decimal number.

        Returns
        -------
        g : BitArray
            The corresponding Gray code representation.
        """
        b = NumberRepresentation.decimal_to_binary(d, length=length)
        g = NumberRepresentation.binary_to_gray(b)
        return g


# MAPPING & DEMAPPING

[docs] class Mapper(ABC): """ Mapper Abstract Base Class. """
[docs] @staticmethod @abstractmethod def apply(b: list[BitArray], ibr: IntArray, c_types: list[ConstType], Ns: IntArray) -> ComplexArray: """ Apply the mapper operation to the bitstreams. For each active data stream of each user terminal, the mapper converts the bitstream into a data symbol stream according to the information bit rate (number of bits per symbol or thus the constellation size in bits). Parameters ---------- b : list[BitArray], shape (Ns_total, ibr[s] * M) The compound bitstream vector. ibr : IntArray, shape (K*Nr,) The information bit rate for each data stream. c_types : list[ConstType], shape (K,) The constellation types for the data streams to each UT. Ns : IntArray, shape (K,) The number of active data streams for each UT. Returns ------- a : ComplexArray, shape (Ns_total, M) The compound data symbol vector. """ raise NotImplementedError
[docs] class NeutralMapper(Mapper): """ Neutral Mapper. Acts as a 'neutral element' for mapping.\\ It simply converts the bitstream into a data symbol stream by interpreting the bits as integers, without applying any modulation scheme. So in practice, the bitstreams pass through the mapper without any change. A requirement for this mapper is that the information bit rate equals one bit per symbol for each data stream (neutral bit allocation)! """
[docs] @staticmethod def apply(b: list[BitArray], ibr: IntArray, c_types: list[ConstType], Ns: IntArray) -> ComplexArray: if not np.all(ibr == 1): raise ValueError("The information bit rate must be equal to one bit per symbol for each data stream when using the NeutralMapper.") a = np.array(b, dtype=complex) return a
[docs] class GrayCodeMapper(Mapper): """ Gray Code Mapper. """
[docs] @staticmethod def apply(b: list[BitArray], ibr: IntArray, c_types: list[ConstType], Ns: IntArray) -> ComplexArray: # Initialization. K = len(c_types) ibr = ibr[ibr > 0] Ns_total = np.sum(Ns) M = len(b[np.argmax(ibr)]) // np.max(ibr) if len(ibr) > 0 else 0 c_types = [c_types[k] for k in range(K) for _ in range(Ns[k])] # Convert the binary bitstreams, interpret as Gray code numbers, to their corresponding decimal representations. d = np.empty((Ns_total, M), dtype=int) for a_s in range(Ns_total): for m in range(M): d[a_s, m] = NumberRepresentation.gray_to_decimal(b[a_s][m*ibr[a_s] : (m+1)*ibr[a_s]]) # Map the decimal numbers to their corresponding constellation points. a = np.empty((Ns_total, M), dtype=complex) for a_s in range(Ns_total): constellation_points = Constellation(type=c_types[a_s], size=2**ibr[a_s]).points a[a_s] = constellation_points[d[a_s]] return a
[docs] class Demapper(ABC): """ Demapper Abstract Base Class. """
[docs] @staticmethod @abstractmethod def apply(cpi_k_hat: IntArray, ibr_k: IntArray) -> list[BitArray]: """ Apply the demapper operation to the reconstructed symbol streams. For each data stream of this user terminal, the demapper converts the reconstructed data symbol stream into a bitstream, based to the used modulation scheme (constellation typpe and sizes) and the constellation point indices of the reconstructed data symbols. Parameters ---------- cpi_k_hat : IntArray, shape (Ns_k, M) The indices of the constellation points (decimal integers) corresponding to the reconstructed data symbols, for each data stream of this user terminal. ibr_k : IntArray, shape (Nr,) The information bit rate for each data stream of this user terminal. Returns ------- b_k_hat : BitArray, shape (Ns_k, ibr_k[s] * M) The list of reconstructed bitstreams of this user terminal. """ raise NotImplementedError
[docs] class NeutralDemapper(Demapper): """ Neutral Demapper. Acts as a 'neutral element' for demapping.\\ It simply converts the constellation point indices into a bitstream by interpreting the indices of the constellation points as bits, without applying any demodulation scheme. So in practice, the constellation point indices pass through the demapper without any change. """
[docs] @staticmethod def apply(cpi_k_hat: IntArray, ibr_k: IntArray) -> list[BitArray]: if not np.all(ibr_k[ibr_k > 0] == 1): raise ValueError("The information bit rate must be equal to one bit per symbol for each data stream when using the NeutralDemapper.") if not np.all(np.isin(cpi_k_hat, [0, 1])): raise ValueError("The reconstructed data symbol stream must consist of symbols that are either 0 or 1 when using the NeutralDemapper.") b_k_hat = [cpi_k_hat[s] for s in range(np.sum(ibr_k > 0))] return b_k_hat
[docs] class GrayCodeDemapper(Demapper): """ Gray Code Demapper. """
[docs] @staticmethod def apply(cpi_k_hat: IntArray, ibr_k: IntArray) -> list[BitArray]: # Determine the number of data streams and the number of symbol vectors. M = cpi_k_hat.shape[1] ibr_k = ibr_k[ibr_k > 0] Ns_k = np.sum(ibr_k > 0) # Convert the decimal index numbers to their Gray code representations. b_k_hat = [np.empty(ibr_k[a_s]*M) for a_s in range(Ns_k)] for a_s in range(Ns_k): for m in range(M): b_k_hat[a_s][m*ibr_k[a_s] : (m+1)*ibr_k[a_s]] = NumberRepresentation.decimal_to_gray(cpi_k_hat[a_s, m], length=ibr_k[a_s]) return b_k_hat
# EQUALIZATION
[docs] class Equalizer(): """ Equalization Class. """
[docs] @staticmethod def apply(z_k: ComplexArray, C_eq_k: ComplexArray, ibr_k: IntArray) -> ComplexArray: """ Apply the equalization operation to the combined signal. Each data stream of this user terminal is multiplied by its corresponding equalization coefficient to rescale the received symbols before the decoding process. Parameters ---------- z_k : ComplexArray, shape (Ns_k, M) The combined signal for this UT. C_eq_k : ComplexArray, shape (Nr,) The equalization coefficients for each data stream of this UT. ibr_k : IntArray, shape (Nr,) The information bit rate for each data stream of this UT. Returns ------- u_k : ComplexArray, shape (Ns_k, M) The decision variable streams for this UT. """ u_k = z_k / C_eq_k[ibr_k > 0][:, np.newaxis] return u_k
# DETECTION
[docs] class Detector(ABC): """ Detector Abstract Base Class. """
[docs] @staticmethod @abstractmethod def apply(u_k: ComplexArray, ibr_k: IntArray, c_type_k: ConstType) -> IntArray: """ Apply the detector operation to the decision variable streams. For each data stream of this user terminal, the detector converts the decision variable stream into a stream of constellation point indices (decimal integers), based on the used modulation scheme. The indices correspond to the constellation points that are most likely transmitted by the base station. The desicion of which constellation points that are most likely depends on the type of the detector and might be suboptimal for certain detector types. Parameters ---------- u_k : ComplexArray, shape (Ns_k, M) The decision variable streams of this user terminal. ibr_k : IntArray, shape (Nr,) The information bit rate for each data stream of this user terminal. c_type_k : ConstType The constellation type for the data streams to this user terminal. Returns ------- cpi_k_hat : IntArray, shape (Ns_k, M) The indices of the constellation points (decimal integers) corresponding to the reconstructed data symbols, for each data stream of this user terminal. """ raise NotImplementedError
[docs] class NeutralDetector(Detector): """ Neutral Detector. Acts as a 'neutral element' for detection.\\ It simply converts the decision variable stream into a stream of constellation point indices by interpreting the decision variables as integers, without applying any detection scheme. So in practice, the decision variable streams pass through the detector without any change. """
[docs] @staticmethod def apply(u_k: ComplexArray, ibr_k: IntArray, c_type_k: ConstType) -> IntArray: if not np.all(ibr_k[ibr_k > 0] == 1): raise ValueError("The information bit rate must be equal to one bit per symbol for each data stream when using the NeutralDetector.") cpi_k_hat = np.array(u_k.real, dtype=int) return cpi_k_hat
[docs] class MDDetector(Detector): """ Minimum Distance (MD) Detector. """
[docs] @staticmethod def apply(u_k: ComplexArray, ibr_k: IntArray, c_type_k: ConstType) -> IntArray: # Determine the number of symbol vectors. M = u_k.shape[1] ibr_k = ibr_k[ibr_k > 0] Ns_k = np.sum(ibr_k > 0) # Decide the constellation points that are most likely transmitted by finding the constellation points that are closest to the decision variables. Then, retrieve the corresponding constellation point indices (decimal integers) of the decided constellation points. cpi_k_hat = np.empty((Ns_k, M), dtype=int) for a_s in range(Ns_k): constellation_points = Constellation(type=c_type_k, size=2**ibr_k[a_s]).points cpi_k_hat[a_s] = np.argmin( np.abs(np.tile(constellation_points, (M, 1)) - u_k[a_s][:, np.newaxis]), axis=1) return cpi_k_hat