Source code for mr_toolkit.reweighting.analysis

"""
Code for reweighting MSMs and obtaining reweighted estimates of steady-state.
"""
import deeptime.markov.tools.analysis
import numpy as np
from scipy.special import rel_entr
import pandas as pd
import mr_toolkit.trajectory_analysis.traj_analysis as ta
import logging
import pyemma
from .splicing import get_receiving_distribution, splice_trajectories, iterative_trajectory_splicing

log = logging.getLogger()

try:
    from msm_we.fpt import MarkovFPT
except ImportError:
    log.warning("msm_we not found, fpt_distribution calculations will be unavailable")


[docs]def compute_reweighted_stationary( discrete_trajectories, N, lag, n_clusters, last_frac=1.0, min_weight=1e-12, n_reweighting_iters=100, ): """ Estimates a stationary distribution from a discrete trajectory using reweighted MSMs. Parameters ---------- discrete_trajectories: array-like 2-D array or list of lists with discretized trajectories N: int Fragment length for reweighting lag: int Lagtime used in reweighting MSMs n_clusters: int Number of total states in the reweighted models (or in the discretization) last_frac: float Fraction of the trajectories to use. I.e., last_frac=0.25 only uses the last 1/4 of the trajectories. min_weight: float Minimum bound on weights during reweighting iteration n_reweighting_iters: int Maximum number of reweighting iterations Returns ------- (Set of state indices, Stationary distributions at each reweighting iteration, Total number of iterations before convergence, estimated transition matrices at each reweighting iteration) """ reweighted_stationaries = np.empty(shape=(n_reweighting_iters, n_clusters)) reweighted_matrices = np.empty( shape=( n_reweighting_iters, n_clusters, n_clusters, ) ) try: ( reweighted_distributions, _reweighted_matrices, weighted_count_matrices, last_iter, ) = ta.optimized_resliced_reweighted( discrete_trajectories, n_reweighting_iters, N, lagtime=lag, n_states=n_clusters, last_frac=last_frac, return_matrices=True, min_weight=min_weight, convergence=1e-10, # Convergence threshold in units of kT ) except (AssertionError, np.linalg.LinAlgError) as e: # This may trip if something goes awry in solving the matrices log.error(e) for i in range(n_reweighting_iters): reweighted_stationaries[i, :] = np.nan reweighted_matrices[i, :, :] = np.nan last_iter = 0 else: for i, _distribution in enumerate(reweighted_distributions): reweighted_stationaries[i, :] = _distribution reweighted_matrices[i] = _reweighted_matrices[i] states = np.arange(n_clusters) return states, reweighted_stationaries, last_iter, reweighted_matrices
[docs]def get_set_kls(distributions): """ Get KL divergences between multiple sets of distributions. I.e., a (4x10) input corresponds to 4x 10-element distributions. This would return an upper triangular 4x4 matrix with the unique pairwise KL-divergences. Parameters ---------- distributions Returns ------- """ n_sets = len(distributions) kls = np.full( shape=(n_sets, n_sets), fill_value=np.nan, ) for x, y in np.array(np.triu_indices(n_sets, 1)).T: setA_distribution = distributions[x] setB_distribution = distributions[y] kl_sum = get_kl(setA_distribution, setB_distribution, return_nan=True) kls[x, y] = kl_sum mean = np.nanmean(kls) return mean
[docs]def get_kl(test_dist, ref_dist, return_nan=False): """ Obtain the KL divergence between two distributions. :param test_dist: The distribution to test :param ref_dist: The reference distribution :param return_nan: If the KL divergence is invalid for some reason, return NaN if true or -1 otherwise. :return: The KL divergence of the two distributions. If invalid, -1 or NaN depending on the value of return_nan. """ elem_kl = rel_entr(test_dist, ref_dist, ) kl_sum = np.nansum(elem_kl[elem_kl < np.inf]) if kl_sum == 0: return [-1, np.nan][return_nan] return kl_sum
class AnalysisRun: """ WARNING: You should almost certainly not use this! Convenience class to handle computing various observables from a trajectory set and storing relevant hyperparameters. """ def __init__( self, _run, _reference, _trajectory_sets, dt=1, lag=None, # mfpt_method=metaparameters['mfpt_method'], metaparameters={}, ): self.run = _run self.reference = _reference self.trajectory_sets = _trajectory_sets self.dt = dt self.lag = lag self.current_direction = None methods = [ "histogram", "pyemma_rev", "pyemma_irrev", "naive", "reweighted", "resliced", ] multi_index = pd.MultiIndex.from_product( [range(metaparameters["n_trajectory_sets"]), methods], names=["trajectory set", "method"], ) self.n_stratified_clusters = len(_reference) self.equil_df = ( pd.DataFrame( index=multi_index, columns=range(self.n_stratified_clusters), dtype=np.float64, ) .apply(pd.to_numeric) .sort_index() ) # TODO: Generalize to more than 2 directions self.directions = ["unfolding", "folding"] multi_index = pd.MultiIndex.from_product( [range(metaparameters["n_trajectory_sets"]), self.directions, methods], names=["trajectory set", "direction", "method"], ) self.ness_df = ( pd.DataFrame( index=multi_index, columns=range(self.n_stratified_clusters), dtype=np.float64, ) .apply(pd.to_numeric) .sort_index() ) self.current_traj_set = -1 self.active_df = None self.spliced_trajs = None self.current_direction = None self.metaparameters = metaparameters self.mfpt_method = self.metaparameters["mfpt_method"] def compute_avg_kl(self, method): # Compute average set-set KL kls = np.full( shape=( self.metaparameters["n_trajectory_sets"], self.metaparameters["n_trajectory_sets"], ), fill_value=np.nan, ) for x, y in np.array( np.triu_indices(self.metaparameters["n_trajectory_sets"], 1) ).T: setA_converged_iter = self.equil_df.loc[(x, method)].values.shape[0] setB_converged_iter = self.equil_df.loc[(y, method)].values.shape[0] setA_reweighted_last = self.equil_df.loc[(x, method)].values.astype(float) setB_reweighted_last = self.equil_df.loc[(y, method)].values.astype(float) kl_sum = get_kl(setA_reweighted_last, setB_reweighted_last, return_nan=True) kls[x, y] = kl_sum mean_kl = np.nanmean(kls) std_kl = 2 * np.nanstd(kls) / np.sqrt(self.metaparameters["n_trajectory_sets"]) return mean_kl, std_kl def compute_stationary(self, method, **kwargs): # TODO: Optionally disable logging / data storage here, so I can reuse this in compute_mfpt set_idx = self.current_traj_set assert self.active_df is not None, "No dataframe is active for storing results" stationary = np.zeros(shape=self.n_stratified_clusters) if method == "reweighted": ( states, stationaries, last_iter, reweighted_matrices, ) = self.compute_reweighted_stationary( self.trajectory_sets[set_idx], **kwargs ) assert np.isclose( np.sum(stationaries[:last_iter], axis=1), 1.0 ).all(), "Stationary distributions not normalized!" # Compute KL to reference resliced_kl = get_kl(stationaries[0], self.reference) reweighted_kl = get_kl(stationaries[last_iter], self.reference) self.run.log_metric(f"set{set_idx}_resliced_kl", resliced_kl) self.run.log_metric(f"set{set_idx}_reweighted_kl", reweighted_kl) self.equil_df.loc[(set_idx, "resliced"), states] = stationaries[0] self.equil_df.loc[(set_idx, "reweighted"), states] = stationaries[last_iter] return stationaries[last_iter], reweighted_kl if method == "naive": _states, _stationary, tmatrix = self.compute_stationary_naive( self.trajectory_sets[set_idx] ) elif method == "pyemma_rev": kwargs.pop('N') _states, _stationary, tmatrix = self.compute_pyemma_stationary( self.trajectory_sets[set_idx], reversible=True, **kwargs ) elif method == "pyemma_irrev": kwargs.pop('N') _states, _stationary, tmatrix = self.compute_pyemma_stationary( self.trajectory_sets[set_idx], reversible=False, **kwargs ) elif method == "histogram": _states, counts = np.unique( self.trajectory_sets[set_idx], return_counts=True ) _stationary = counts / sum(counts) else: raise NotImplementedError("Invalid method specified") stationary[_states] = _stationary assert np.isclose( np.sum(stationary), 1.0 ), "Stationary distribution not normalized!" # Compute KL to reference kl = get_kl(stationary, self.reference) self.run.log_metric(f"set{set_idx}_{method}_kl", kl) self.equil_df.loc[(set_idx, method), :] = stationary return stationary, kl @staticmethod def compute_stationary_naive(discrete_trajectories): tmatrix, state_map, cmatrix, weights = ta.build_msm( discrete_trajectories, reslicing=False, normalize_initial=False ) evals, evecs = np.linalg.eig(tmatrix.T) max_eig_index = np.argmin(1 - evals) stationary = np.real(evecs[:, max_eig_index]) / np.real( sum(evecs[:, max_eig_index]) ) states = list(state_map.keys()) return states, stationary, tmatrix @staticmethod def compute_pyemma_stationary(discrete_trajectories, lag, reversible=False): # TODO: Support different lagtimes pyemma_msm = pyemma.msm.estimate_markov_model( [x for x in discrete_trajectories], lag=lag, reversible=reversible ) pyemma_stationary = pyemma_msm.stationary_distribution pyemma_states = pyemma_msm.active_set return pyemma_states, pyemma_stationary, pyemma_msm.transition_matrix def compute_reweighted_stationary( self, discrete_trajectories, N, lag, last_frac=None, min_weight=None, n_reweighting_iters=None, store_matrices=False, ): # TODO: This is obsolete now, remove this and just wrap a call to the static version if last_frac is None: last_frac = self.metaparameters.get("last_frac") if min_weight is None: min_weight = self.metaparameters.get("min_weight") if n_reweighting_iters is None: n_reweighting_iters = self.metaparameters.get("n_reweighting_iters") reweighted_stationaries = np.empty(shape=(n_reweighting_iters, self.n_stratified_clusters)) reweighted_matrices = np.empty( shape=( n_reweighting_iters, self.n_stratified_clusters, self.n_stratified_clusters, ) ) try: ( reweighted_distributions, _reweighted_matrices, weighted_count_matrices, last_iter, ) = ta.optimized_resliced_reweighted( discrete_trajectories, n_reweighting_iters, N, lagtime=lag, n_states=self.n_stratified_clusters, last_frac=last_frac, return_matrices=True, min_weight=min_weight, convergence=1e-10, # Convergence threshold in units of kT ) except (AssertionError, np.linalg.LinAlgError) as e: # This may trip if something goes awry in solving the matrices log.error(e) for i in range(n_reweighting_iters): reweighted_stationaries[i, :] = np.nan reweighted_matrices[i, :, :] = np.nan last_iter = 0 else: for i, _distribution in enumerate(reweighted_distributions): reweighted_stationaries[i, :] = _distribution reweighted_matrices[i] = _reweighted_matrices[i] if store_matrices: index = pd.MultiIndex.from_product( [ range(last_iter + 1), range(self.n_stratified_clusters), range(self.n_stratified_clusters), ], names=["Iteration", "From", "To"], ) reweighted_matrix_df = pd.DataFrame( np.array(reweighted_matrices).flatten(), index=index ) reweighted_matrix_df.to_pickle( f"../results/{self.run.id}_set{self.current_traj_set}_reweighted_matrix_df.pkl" ) index = pd.MultiIndex.from_product( [ range(self.n_stratified_clusters), range(self.n_stratified_clusters), range(self.n_stratified_clusters), ], names=["FragStart", "From", "To"], ) count_matrix_df = pd.DataFrame( np.array(weighted_count_matrices).flatten(), index=index ) count_matrix_df.to_pickle( f"../results/{self.run.id}_set{self.current_traj_set}_weighted_count_matrix_df.pkl" ) states = np.arange(self.n_stratified_clusters) return states, reweighted_stationaries, last_iter, reweighted_matrices def iterative_trajectory_splicing(self, source_states, sink_states, splice_msm_lag=None, msm_reversible=False, target_steps_to_keep=1, convergence=1e-9, max_iterations=100): # This just wraps the external call for backwards-compatibility with some existing analysis scripts. if splice_msm_lag is None: print(f"No lag provided for splice MSM -- using set value of {self.lag}") splice_msm_lag = self.lag spliced_trajectories = iterative_trajectory_splicing( source_states=source_states, sink_states=sink_states, splice_msm_lag=splice_msm_lag, msm_reversible=msm_reversible, target_steps_to_keep=target_steps_to_keep, convergence=convergence, max_iterations=max_iterations, n_clusters=self.n_stratified_clusters ) self.spliced_trajs = spliced_trajectories def splice_trajectories( self, source_states, sink_states, msm_lag=1, msm_reversible=False, target_steps_to_keep=1, trajs_to_splice=None, pbar_visible=True ): # This just wraps the external call for backwards-compatibility with some existing analysis scripts. set_idx = self.current_traj_set # Build an MSM to approximate the equilibrium distribution over the boundary states # TODO: Do we want to just use the PyEmma MSM? Or is there a better choice? if trajs_to_splice is None: trajs_to_splice = self.trajectory_sets[set_idx] spliced_trajs = splice_trajectories( trajs_to_splice=trajs_to_splice, source_states=source_states, sink_states=sink_states, msm_lag=msm_lag, msm_reversible=msm_reversible, target_steps_to_keep=target_steps_to_keep, pbar_visible=pbar_visible ) self.spliced_trajs = spliced_trajs def compute_mfpt(self, method, source_states, target_states, **kwargs): set_idx = self.current_traj_set assert self.active_df is not None, "No dataframe is active for storing results" assert ( self.current_direction is not None ), "No direction is specified for NESS/MFPTs" direction_index = self.directions.index(self.current_direction) if 'lag' in kwargs.keys(): kwargs.pop('lag') print("Warning: Lag specification in arguments to compute_mfpt is unnecessary and unused") # Get the stationary distribution and transition matrix if method == "reweighted": ( states, stationaries, last_iter, tmatrices, ) = self.compute_reweighted_stationary(self.spliced_trajs, lag=self.lag, **kwargs) if last_iter == 0: log.warning("Bad stationary distribution -- not attempting an MFPT") return None, None, None assert np.isclose( np.sum(stationaries[:last_iter], axis=1), 1.0 ).all(), "Stationary distributions not normalized!" resliced_stationary = np.zeros(shape=self.n_stratified_clusters) reweighted_stationary = np.zeros(shape=self.n_stratified_clusters) resliced_tmatrix = np.zeros( shape=(self.n_stratified_clusters, self.n_stratified_clusters) ) reweighted_tmatrix = np.zeros( shape=(self.n_stratified_clusters, self.n_stratified_clusters) ) _resliced_stationary, _resliced_tmatrix = stationaries[0], tmatrices[0] resliced_stationary[states] = _resliced_stationary resliced_tmatrix[np.ix_(states, states)] = _resliced_tmatrix # resliced_mfpt = cg.get_hill_mfpt(resliced_stationary, resliced_tmatrix, target_states) * self.dt # resliced_mfpt = pyemma.msm.markov_model(resliced_tmatrix).mfpt(A=source_states, B=target_states) resliced_mfpt = self.mfpt( self.mfpt_method, resliced_tmatrix, source_states, target_states, lag=self.lag, dt=self.dt, stationary=resliced_stationary, ) self.run.log_metric( f"set{set_idx}_resliced_mfpt_{self.current_direction}", resliced_mfpt ) self.ness_df.loc[ (set_idx, self.current_direction, "resliced"), states ] = resliced_stationary _reweighted_stationary, _reweighted_tmatrix = ( stationaries[last_iter], tmatrices[last_iter], ) reweighted_stationary[states] = _reweighted_stationary reweighted_tmatrix[np.ix_(states, states)] = _reweighted_tmatrix # reweighted_mfpt = cg.get_hill_mfpt(reweighted_stationary, reweighted_tmatrix, target_states) * self.dt # reweighted_mfpt = pyemma.msm.markov_model(reweighted_tmatrix).mfpt(A=source_states, B=target_states) reweighted_mfpt = self.mfpt( self.mfpt_method, reweighted_tmatrix, source_states, target_states, lag=self.lag, dt=self.dt, stationary=reweighted_stationary, ) if not np.isnan(reweighted_mfpt): self.run.log_metric( f"set{set_idx}_reweighted_mfpt_{self.current_direction}", reweighted_mfpt, ) self.ness_df.loc[ (set_idx, self.current_direction, "reweighted") ] = reweighted_stationary return ( (resliced_mfpt, reweighted_mfpt), (resliced_tmatrix, reweighted_tmatrix), (resliced_stationary, reweighted_stationary), ) if method == "naive": print("Warning: Naive estimator assumes a lag of 1!") _states, _stationary, _tmatrix = self.compute_stationary_naive( self.spliced_trajs ) elif method == "pyemma_rev": if 'N' in kwargs.keys(): kwargs.pop('N') _states, _stationary, _tmatrix = self.compute_pyemma_stationary( self.spliced_trajs, reversible=True, lag=self.lag # TODO: Manage lag correctly ) elif method == "pyemma_irrev": if 'N' in kwargs.keys(): kwargs.pop('N') _states, _stationary, _tmatrix = self.compute_pyemma_stationary( self.spliced_trajs, reversible=False, lag=self.lag # TODO: Manage lag correctly ) else: raise NotImplementedError("Invalid method specified") stationary = np.zeros(shape=self.n_stratified_clusters) tmatrix = np.zeros( shape=(self.n_stratified_clusters, self.n_stratified_clusters) ) stationary[_states] = _stationary assert np.isclose( np.sum(stationary), 1.0 ), "Stationary distribution not normalized!" tmatrix[np.ix_(_states, _states)] = _tmatrix mfpt = self.mfpt( self.mfpt_method, tmatrix, source_states, target_states, lag=self.lag, dt=self.dt, stationary=stationary, ) if not np.isnan(mfpt): self.run.log_metric( f"set{set_idx}_{method}_mfpt_{self.current_direction}", mfpt ) self.ness_df.loc[(set_idx, self.current_direction, method), :] = stationary return mfpt, tmatrix, stationary @staticmethod def mfpt( method, tmatrix, source_states, sink_states, dt, lag, stationary=None, clean_stationary=True, ): valid_methods = ["hill", "pyemma", "first_step", "fpt_distribution"] assert ( method in valid_methods ), f"Invalid method -- choose one of {valid_methods}" _stationary = stationary.copy() _tmatrix = tmatrix.copy() # * Get MFPT via Hill relation # We shouldn't need to do this as a distribution, unless we don't actually have the stationary distribution # -- the target flux should be constant in steady state. if method == "hill": # mfpt = cg.get_hill_mfpt(stationary, tmatrix, sink_states) if clean_stationary: _stationary[sink_states] = 0 _stationary = _stationary / sum(_stationary) flux = 0 for state in sink_states: flux += np.dot(_stationary, _tmatrix[:, state]).sum() mfpt = 1.0 / flux # * Get MFPT via PyEmma # To do this, we need a "proper" transition matrix -- i.e., not including any fully zeroed out rows elif method == "pyemma" or method == "first_step": # state_map = {} # _j = 0 # for _i, _sum in enumerate(_tmatrix.sum(axis=1)): # if np.isclose(_sum, 1): # state_map[int(_i)] = int(_j) # _j += 1 # # try: # remapped_source = np.array([ # state_map[x] for x in source_states if x in state_map.keys() # ]).astype(int) # remapped_target = np.array([ # state_map[x] for x in sink_states if x in state_map.keys() # ]).astype(int) # except KeyError as e: # raise e # # valid_states = list(state_map.keys()) # clean_tmatrix = _tmatrix[valid_states][:, valid_states] # These methods fail by default for transition matrices with zero rows, because that'll fail deeptime's # is_transition_matrix check. # However, by specifying a stationary distribution, it won't try to recalculate the stationary distribution, # which is where that check is. # So even an "invalid" (by their definition) transition matrix will work. receiving_distribution = np.zeros_like(_stationary) receiving_distribution[source_states] = get_receiving_distribution(_tmatrix, _stationary, source_states) mfpt = deeptime.markov.tools.analysis.mfpt(_tmatrix, origin=source_states, target=sink_states, mu=receiving_distribution) elif method == "fpt_distribution": assert _stationary is not None receiving_distribution = get_receiving_distribution(_tmatrix, _stationary, source_states) initial_probs = receiving_distribution initial_states = source_states # initial_probs = _stationary[source_states] / sum(_stationary[source_states]) fpt_probs, _, _, times = MarkovFPT.adaptive_fpt_distribution( Tmatrix=_tmatrix, initial_states=initial_states, initial_state_probs=initial_probs, target_states=sink_states, verbose=False, fine_increment=1.1, increment=1.2, ) mfpt = np.average(times, weights=fpt_probs) # Do NOT need to adjust for lag time here! mfpt = mfpt * dt # \* lag return mfpt