import typing
from sklearn.neighbors import NearestNeighbors
from keras import backend as K
import numpy as np
import tensorflow as tf
[docs]class MMD:
    """
    Maximum Mean Discrepancy (MMD) class for computing distribution similarity
    between real and generated samples using Gaussian kernels.
    """
[docs]    def __init__(self, real_cells: np.ndarray):
        """
        Initialize the MMD class with scale and weight parameters based on the median
        nearest neighbor distance among real cells.
        Parameters
        ----------
        real_cells : np.ndarray
            A NumPy array representing real cell data (cells x features).
        """
        n_neighbors = 25
        med = np.ones(20)
        for ii in range(1, 20):
            sample = real_cells[
                np.random.randint(real_cells.shape[0] - 1, size=real_cells.shape[0]), :
            ]
            nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(sample)
            distances, _ = nbrs.kneighbors(sample)
            # nearest neighbor is the point so we need to exclude it
            med[ii] = np.median(distances[:, 1:n_neighbors])
        med = np.median(med)
        scales = [med / 2, med, med * 2]
        weights = K.eval(K.shape(scales)[0])
        weights = K.variable(value=np.asarray(weights))
        self.scales = np.expand_dims(np.expand_dims(scales, -1), -1)
        self.weights = np.expand_dims(np.expand_dims(weights, -1), -1) 
[docs]    def squaredDistance(
        self,
        X: typing.Union[np.ndarray, "tf.Tensor"],
        Y: typing.Union[np.ndarray, "tf.Tensor"],
    ) -> "tf.Tensor":
        """
        Compute pairwise squared Euclidean distances between rows of X and Y.
        Parameters
        ----------
        X : np.ndarray or tf.Tensor
            Input array of shape (n, d).
        Y : np.ndarray or tf.Tensor
            Input array of shape (m, d).
        Returns
        -------
        tf.Tensor
            A tensor of shape (n, m) representing squared distances.
        """
        # X is nxd, Y is mxd, returns nxm matrix of all pairwise Euclidean distances
        # broadcasted subtraction, a square, and a sum.
        r = K.expand_dims(X, axis=1)
        return K.sum(K.square(r - Y), axis=-1) 
[docs]    def gaussian_kernel(
        self,
        a: typing.Union[np.ndarray, "tf.Tensor"],
        b: typing.Union[np.ndarray, "tf.Tensor"],
    ) -> "tf.Tensor":
        """
        Compute the multi-scale Gaussian kernel between two datasets.
        Parameters
        ----------
        a : np.ndarray or tf.Tensor
            Input array of shape (n, d).
        b : np.ndarray or tf.Tensor
            Input array of shape (m, d).
        Returns
        -------
        tf.Tensor
            A tensor of shape (n, m) representing the Gaussian kernel matrix.
        """
        numerator = np.expand_dims(self.squaredDistance(a, b), 0)
        return np.sum(self.weights * np.exp(-numerator / (np.power(self.scales, 2))), 0) 
[docs]    def compute(
        self,
        a: typing.Union[np.ndarray, "tf.Tensor"],
        b: typing.Union[np.ndarray, "tf.Tensor"],
    ) -> "tf.Tensor":
        """
        Compute the Maximum Mean Discrepancy (MMD) between two samples.
        Parameters
        ----------
        a : np.ndarray or tf.Tensor
            First sample of shape (n, d).
        b : np.ndarray or tf.Tensor
            Second sample of shape (m, d).
        Returns
        -------
        tf.Tensor
            The MMD score between the two distributions.
        """
        return (
            self.gaussian_kernel(a, a).mean()
            + self.gaussian_kernel(b, b).mean()
            - 2 * self.gaussian_kernel(a, b).mean()
        )