Source code for gsmm.csm.visualisation

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import logging
from typing import Optional
from .config import *
import pickle

# Configure logging
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')

[docs]def load_data(filepath: str) -> Optional[pd.DataFrame]: """ Load data from a pickle file into a pandas DataFrame. Args: filepath (str): Path to the pickle file containing the data. Returns: Optional[pd.DataFrame]: Loaded pandas DataFrame if successful, otherwise None. Notes: This function attempts to load a pandas DataFrame from the specified pickle file. If the file is not found, a FileNotFoundError is caught and logged, returning None. Any other loading errors are also caught, logged, and None is returned. """ try: df = pd.read_pickle(filepath) logging.info(f"Loaded data from {filepath}") return df except FileNotFoundError: logging.error(f"File not found: {filepath}") return None except Exception as e: logging.error(f"Error loading data from {filepath}: {e}") return None
[docs]def plot_flux_distribution_clustermap(df_fluxes: pd.DataFrame, save_path: str, show_plot: bool = False) -> None: """ Generate a clustermap for flux distribution across different models and reactions. Parameters: - df_fluxes (pd.DataFrame): DataFrame with 'Model', 'Reaction', and 'Flux' columns. - save_path (str): Path to save the clustermap. - show_plot (bool): Whether the plot should be plotted along (Default: False) Returns: None """ try: logging.info("Generating flux distribution clustermap...") # Check if the required columns are present if not all(col in df_fluxes.columns for col in ['Model', 'Reaction', 'Flux']): raise ValueError("Required columns ('Model', 'Reaction', 'Flux') not found in the dataframe.") # Pivot the dataframe for clustermap pivot_table = df_fluxes.pivot_table(index='Reaction', columns='Model', values='Flux') # Drop rows/columns with all NaN values which might cause issues in clustermap pivot_table.dropna(axis=0, how='all', inplace=True) pivot_table.dropna(axis=1, how='all', inplace=True) # Fill NaNs with zeros or appropriate value if required pivot_table.fillna(0, inplace=True) # Plotting clustermap sns.clustermap(pivot_table, cmap="coolwarm", center=0, figsize=(14, 10), method='average', metric='euclidean', standard_scale=1) plt.title('Clustermap of Flux Distribution across Models and All Reactions') plt.tight_layout() # Save the plot plt.savefig(save_path) logging.info(f"Flux distribution clustermap saved as {save_path}") if show_plot: plt.show() except Exception as e: logging.error(f"Error generating flux distribution clustermap: {str(e)}")
[docs]def plot_sink_fluxes_heatmap(df: pd.DataFrame, output_file: str, show_plot: bool = False) -> None: """ Plot a heatmap for sink fluxes across different context models. Parameters: - df (pd.DataFrame): DataFrame containing sink flux data with columns 'Metabolite', 'Flux', 'Context_Model'. - output_file (str): Path to save the sink fluxes heatmap. - show_plot (bool): Whether the plot should be plotted along (Default: False) Returns: None """ if df is None or df.empty: logging.warning("Data frame is empty or None. Skipping heatmap generation.") return logging.info("Generating sink fluxes heatmap...") try: pivot_table = df.pivot(index="Metabolite", columns="Context_Model", values="Flux") plt.figure(figsize=(12, 10)) sns.heatmap(pivot_table, annot=True, cmap="coolwarm", center=0) plt.title("Sink Fluxes of Reactions associated with Metabolites of interest across all Models") plt.tight_layout() plt.savefig(output_file) plt.close() logging.info(f"Sink fluxes heatmap saved as {output_file}") if show_plot: plt.show() except Exception as e: logging.error(f"Error generating sink fluxes heatmap: {e}")
[docs]def plot_flux_correlation_heatmap(df_fluxes: pd.DataFrame, save_path: str, show_plot: bool = False) -> None: """ Generate a heatmap of correlation coefficients for fluxes between different models for all reactions. Parameters: - df_fluxes (pd.DataFrame): DataFrame with 'Model', 'Reaction', and 'Flux' columns. - save_path (str): Path to save the heatmap. - show_plot (bool): Whether the plot should be plotted along (Default: False) Returns: None """ try: logging.info("Generating flux correlation heatmap...") # Pivot the dataframe for correlation analysis pivot_df = df_fluxes.pivot_table(index='Reaction', columns='Model', values='Flux') # Compute correlation matrix, but only for different models models = pivot_df.columns corr_matrix = pd.DataFrame(index=models, columns=models) for i in models: for j in models: if i != j: corr_matrix.loc[i, j] = pivot_df[i].corr(pivot_df[j]) # Plotting heatmap plt.figure(figsize=(10, 8)) sns.heatmap(corr_matrix.astype(float), annot=True, cmap="coolwarm", vmin=-1, vmax=1, center=0, square=True) plt.title('Correlation Heatmap of All Reaction Fluxes across Models') plt.tight_layout() # Save the plot plt.savefig(save_path) logging.info(f"Flux correlation heatmap saved as {save_path}") if show_plot: plt.show() except Exception as e: logging.error(f"Error generating flux correlation heatmap: {str(e)}")
[docs]def plot_sink_flux_correlation_heatmap(df_sink_fluxes: pd.DataFrame, save_path: str, show_plot: bool = False) -> None: """ Generate a heatmap of correlation coefficients for sink fluxes between different models. Parameters: - df_sink_fluxes (pd.DataFrame): DataFrame with 'Metabolite', 'Flux', 'Context_Model' columns. - save_path (str): Path to save the heatmap. - show_plot (bool): Whether the plot should be plotted along (Default: False) Returns: None """ try: logging.info("Generating sink flux correlation heatmap...") # Pivot the dataframe for correlation analysis pivot_df = df_sink_fluxes.pivot_table(index='Metabolite', columns='Context_Model', values='Flux') # Compute correlation matrix, but only for different context models models = pivot_df.columns corr_matrix = pd.DataFrame(index=models, columns=models) for i in models: for j in models: if i != j: corr_matrix.loc[i, j] = pivot_df[i].corr(pivot_df[j]) # Plotting heatmap plt.figure(figsize=(10, 8)) sns.heatmap(corr_matrix.astype(float), annot=True, cmap="coolwarm", vmin=-1, vmax=1, center=0, square=True) plt.title('Correlation Heatmap of Sink Fluxes across all Models') plt.tight_layout() # Save the plot plt.savefig(save_path) logging.info(f"Sink flux correlation heatmap saved as {save_path}") if show_plot: plt.show() except Exception as e: logging.error(f"Error generating sink flux correlation heatmap: {str(e)}")
[docs]def filter_and_save_results(filtered_results: dict, file_path: str) -> None: """ Filter and save valid DataFrames from a dictionary to a pickle file. Args: filtered_results (dict): Dictionary containing results to filter and save. Each value should be a pandas DataFrame with columns 'ids', 'growth', and 'status'. file_path (str): Path to the pickle file where filtered results will be saved. Returns: None Notes: This function iterates through the provided dictionary of results. It filters out DataFrames that do not contain the required columns ('ids', 'growth', 'status') and saves the valid DataFrames to a pickle file. If a DataFrame does not meet the criteria, a warning is logged and it is skipped. """ filtered_results_clean = {} for key, df in filtered_results.items(): if isinstance(df, pd.DataFrame) and set(['ids', 'growth', 'status']).issubset(df.columns): filtered_results_clean[key] = df else: logging.warning(f"Skipping '{key}' as it is not a valid DataFrame for saving.") with open(file_path, 'wb') as f: pickle.dump(filtered_results_clean, f)
[docs]def plot_fluxes(flux_filepath: Optional[str] = flux_filepath, sink_flux_filepath: Optional[str] = sink_flux_filepath, show_plot: bool = False) -> None: """ Plot flux distributions and sink fluxes using default or provided file paths. Parameters: - flux_filepath (str, optional): File path to flux data (default: 'flux_data.pkl'). - sink_flux_filepath (str, optional): File path to sink flux data (default: 'sink_flux_data.pkl'). - show_plot (bool): Whether the plot should be plotted along (Default: False) Returns: - None """ df_flux = load_data(flux_filepath) df_sink_fluxes = load_data(sink_flux_filepath) if df_flux is not None: plot_flux_distribution_clustermap(df_flux, 'flux_distribution_clustermap.png', show_plot) plot_flux_correlation_heatmap(df_flux, 'flux_correlation_heatmap.png', show_plot) if df_sink_fluxes is not None: plot_sink_fluxes_heatmap(df_sink_fluxes, 'sink_fluxes_heatmap.png', show_plot) plot_sink_flux_correlation_heatmap(df_sink_fluxes, 'sink_flux_correlation_heatmap.png', show_plot)