Source code for riboraptor.plotting

"""Plotting methods."""
from __future__ import (absolute_import, division, print_function,
                        unicode_literals)

from collections import Counter
from itertools import cycle
from itertools import islice
import os
import pickle
import sys

import matplotlib
matplotlib.use('Agg')

import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
from matplotlib.ticker import FormatStrFormatter
from matplotlib.ticker import LinearLocator
from matplotlib.ticker import MaxNLocator
from matplotlib.ticker import AutoMinorLocator

import numpy as np
import pandas as pd
import seaborn as sns
import six

from .helpers import identify_peaks
from .helpers import load_pickle
from .helpers import millify
from .helpers import round_to_nearest
from .helpers import set_xrotation

__FRAME_COLORS__ = ['#1b9e77', '#d95f02', '#7570b3']
__FRAME_COLORS__ = ['#fc8d62', '#66c2a5', '#8da0cb']

DPI = 300


[docs]def setup_plot(): """Setup plotting defaults""" plt.rcParams['savefig.dpi'] = 120 plt.rcParams['figure.dpi'] = 120 plt.rcParams['figure.autolayout'] = False plt.rcParams['figure.figsize'] = 12, 8 plt.rcParams['axes.labelsize'] = 18 plt.rcParams['axes.titlesize'] = 20 plt.rcParams['font.size'] = 10 plt.rcParams['lines.linewidth'] = 2.0 plt.rcParams['lines.markersize'] = 8 plt.rcParams['legend.fontsize'] = 14 sns.set_style('white') sns.set_context('paper', font_scale=2)
[docs]def setup_axis(ax, axis='x', majorticks=5, minorticks=1, xrotation=45, yrotation=0): """Setup axes defaults Parameters ---------- ax : matplotlib.Axes axis : str Setup 'x' or 'y' axis majorticks : int Length of interval between two major ticks minorticks : int Length of interval between two major ticks xrotation : int Rotate x axis labels by xrotation degrees yrotation : int Rotate x axis labels by xrotation degrees """ major_locator = MultipleLocator(majorticks) major_formatter = FormatStrFormatter('%d') minor_locator = MultipleLocator(minorticks) if axis == 'x': ax.xaxis.set_major_locator(major_locator) ax.xaxis.set_major_formatter(major_formatter) ax.xaxis.set_minor_locator(minor_locator) elif axis == 'y': ax.yaxis.set_major_locator(major_locator) ax.yaxis.set_major_formatter(major_formatter) ax.yaxis.set_minor_locator(minor_locator) elif axis == 'both': setup_axis(ax, 'x', majorticks, minorticks, xrotation, yrotation) setup_axis(ax, 'y', majorticks, minorticks, xrotation, yrotation) ax.yaxis.set_major_locator(MaxNLocator(integer=True)) #ax.yaxis.set_minor_locator(AutoMinorLocator())#integer=True)) ax.tick_params(which='major', width=2, length=10) ax.tick_params(which='minor', width=1, length=6) ax.tick_params(axis='x', labelrotation=xrotation) ax.tick_params(axis='y', labelrotation=yrotation)
#ax.yaxis.set_major_locator(LinearLocator(10)) #ax.yaxis.set_minor_locator(LinearLocator(10)) #set_xrotation(ax, xrotation)
[docs]def plot_read_length_dist(read_lengths, ax=None, millify_labels=True, input_is_stream=False, title=None, saveto=None, ascii=False, **kwargs): """Plot read length distribution. Parameters ---------- read_lengths : array_like Array of read lengths ax : matplotlib.Axes Axis object millify_labels : bool True if labels should be formatted to read millions/trillions etc input_is_stream : bool True if input is sent through stdin saveto : str Path to save output file to (<filename>.png/<filename>.pdf) """ if input_is_stream: counter = {} for line in read_lengths: splitted = list([int(x) for x in line.strip().split('\t')]) counter[splitted[0]] = splitted[1] read_lengths = Counter(counter) elif isinstance(read_lengths, six.string_types): if '.pickle' in str(read_lengths): # Try opening as a pickle first read_lengths = load_pickle(read_lengths) elif isinstance(read_lengths, pd.Series): pass else: # Some random encoding error try: read_lengths = pd.read_table(read_lengths) read_lengths = pd.Series( read_lengths['count'].tolist(), index=read_lengths.read_length.tolist()) except KeyError: pass fig = None if ax is None: fig, ax = plt.subplots() else: fig = ax.get_figure() if 'majorticks' not in kwargs: kwargs['majorticks'] = 5 if 'minorticks' not in kwargs: kwargs['minorticks'] = 1 if 'xrotation' not in kwargs: kwargs['xrotation'] = 0 if isinstance(read_lengths, Counter) or isinstance(read_lengths, pd.Series): read_lengths = pd.Series(read_lengths) read_lengths_counts = read_lengths.values else: read_lengths = pd.Series(read_lengths) read_lengths_counts = read_lengths.value_counts().sort_index() ax.set_ylim( min(read_lengths_counts), round_to_nearest(max(read_lengths_counts), 5) + 0.5) ax.set_xlim( min(read_lengths.index) - 0.5, round_to_nearest(max(read_lengths.index), 10) + 0.5) ax.bar(read_lengths.index, read_lengths_counts) setup_axis(ax, **kwargs) reads_total = millify(read_lengths_counts.sum()) if title: ax.set_title('{}\n Total reads = {}'.format(title, reads_total)) else: ax.set_title('Total reads = {}'.format(reads_total)) if millify_labels: ax.set_yticklabels(list([millify(x) for x in ax.get_yticks()])) sns.despine(trim=True, offset=20) if saveto: fig.tight_layout() if '.dat' in saveto: fig.savefig(saveto, format='png', dpi=DPI) else: fig.savefig(saveto, dpi=DPI) if ascii: import gnuplotlib as gp sys.stdout.write(os.linesep) gp.plot((read_lengths.index, read_lengths.values, { 'with': 'boxes' }), terminal='dumb 160, 40', unset='grid') sys.stdout.write(os.linesep) return ax, fig
[docs]def plot_framewise_counts(counts, frames_to_plot='all', ax=None, title=None, millify_labels=False, position_range=None, saveto=None, ascii=False, input_is_stream=False, **kwargs): """Plot framewise distribution of reads. Parameters ---------- counts : Series A series with position as index and value as counts frames_to_plot : str or range A comma seaprated list of frames to highlight or a range ax : matplotlib.Axes Default none saveto : str Path to save output file to (<filename>.png/<filename>.pdf) """ # setup_plot() if input_is_stream: counts_counter = {} for line in counts: splitted = list([int(x) for x in line.strip().split('\t')]) counts_counter[splitted[0]] = splitted[1] counts = Counter(counts_counter) elif isinstance(counts, six.string_types): try: # Try opening as a pickle first counts = load_pickle(counts) except KeyError: pass if isinstance(counts, Counter): counts = pd.Series(counts) # TODO if isinstance(frames_to_plot, six.string_types) and frames_to_plot != 'all': frames_to_plot = list( [int(x) for x in frames_to_plot.rstrip().split(',')]) if isinstance(position_range, six.string_types): splitted = list([int(x) for x in position_range.strip().split(':')]) position_range = list(range(splitted[0], splitted[1] + 1)) if position_range: counts = counts[list(position_range)] fig = None if ax is None: fig, ax = plt.subplots() else: fig = ax.get_figure() if 'majorticks' not in kwargs: kwargs['majorticks'] = 10 if 'minorticks' not in kwargs: kwargs['minorticks'] = 5 if 'xrotation' not in kwargs: kwargs['xrotation'] = 90 setup_axis(ax, **kwargs) ax.set_ylabel('Number of reads') #ax.set_xlim( # min(counts.index) - 0.6, # round_to_nearest(max(counts.index), 10) + 0.6) barlist = ax.bar(counts.index, counts.values) barplot_colors = list( islice(cycle(__FRAME_COLORS__), None, len(counts.index))) for index, cbar in enumerate(barlist): cbar.set_color(barplot_colors[index]) ax.legend((barlist[0], barlist[1], barlist[2]), ('Frame 1', 'Frame 2', 'Frame 3'), bbox_to_anchor=(0., 1.02, 1., .102), loc=3, ncol=3, mode='expand', borderaxespad=0.) if title: ax.set_title(title) if millify_labels: ax.set_yticklabels(list([millify(x) for x in ax.get_yticks()])) if ascii: sys.stdout.write(os.linesep) import gnuplotlib as gp gp.plot( np.array(counts.index.tolist()), np.array(counts.values.tolist()), _with='boxes', # 'points pointtype 0', terminal='dumb 200,40', unset='grid') sys.stdout.write(os.linesep) set_xrotation(ax, kwargs['xrotation']) fig.tight_layout() if saveto: fig.tight_layout() fig.savefig(saveto, dpi=DPI) return ax
[docs]def plot_read_counts(counts, ax=None, marker=None, color='royalblue', title=None, label=None, millify_labels=False, identify_peak=True, saveto=None, position_range=None, ascii=False, input_is_stream=False, ylabel='Normalized RPF density', **kwargs): """Plot RPF density aro und start/stop codons. Parameters ---------- counts : Series/Counter A series with coordinates as index and counts as values ax : matplotlib.Axes Axis to create object on marker : string 'o'/'x' color : string Line color label : string Label (useful only if plotting multiple objects on same axes) millify_labels : bool True if labels should be formatted to read millions/trillions etc saveto : str Path to save output file to (<filename>.png/<filename>.pdf) """ # setup_plot() if input_is_stream: counts_counter = {} for line in counts: splitted = list([int(x) for x in line.strip().split('\t')]) counts_counter[splitted[0]] = splitted[1] counts = Counter(counts_counter) elif isinstance(counts, six.string_types): try: # Try opening as a pickle first counts = load_pickle(counts) except IndexError: counts_pd = pd.read_table(counts) counts = pd.Series( counts_pd['count'].tolist(), index=counts_pd['position'].tolist()) except KeyError: pass if not isinstance(counts, pd.Series): counts = pd.Series(counts) if isinstance(position_range, six.string_types): splitted = list([int(x) for x in position_range.strip().split(':')]) position_range = np.arange(splitted[0], splitted[1] + 1) if position_range is not None: counts = counts[position_range] fig = None if ax is None: fig, ax = plt.subplots() else: fig = ax.get_figure() if 'majorticks' not in kwargs: kwargs['majorticks'] = 10 if 'minorticks' not in kwargs: kwargs['minorticks'] = 5 if 'xrotation' not in kwargs: kwargs['xrotation'] = 0 if 'yrotation' not in kwargs: kwargs['yrotation'] = 0 if not marker: ax.plot( counts.index, counts.values, color=color, linewidth=1, markersize=1.5, label=label) else: ax.plot( counts.index, counts.values, color=color, marker='o', linewidth=1, markersize=1.5, label=label) # ax.set_xlim(round_to_nearest(ax.get_xlim()[0], 50) - 0.6, # round_to_nearest(ax.get_xlim()[1], 50) + 0.6) peak = None if identify_peak: peak = identify_peaks(counts) ax.axvline(x=peak, color='r', linestyle='dashed') ax.text( peak + 0.5, ax.get_ylim()[1] * 0.9, '{}'.format(peak), color='r') if millify_labels: ax.set_yticklabels(list([millify(x) for x in ax.get_yticks()])) setup_axis(ax, **kwargs) ax.set_xlim( round_to_nearest(min(counts.index), 10) - 1, round_to_nearest(max(counts.index), 10) + 1) if ylabel: ax.set_ylabel(ylabel) if title: ax.set_title(title) sns.despine(trim=True, offset=10) if saveto: fig.tight_layout() fig.savefig(saveto, dpi=DPI) if ascii: sys.stdout.write(os.linesep) import gnuplotlib as gp gp.plot( np.array(counts.index.tolist()), np.array(counts.values.tolist()), _with='lines', # 'points pointtype 0', terminal='dumb 200,40', unset='grid') sys.stdout.write(os.linesep) return ax, fig, peak
[docs]def plot_featurewise_barplot(utr5_counts, cds_counts, utr3_counts, ax=None, saveto=None, **kwargs): """Plot barplots for 5'UTR/CDS/3'UTR counts. Parameters ---------- utr5_counts : int or dict Total number of reads in 5'UTR region or alternatively a dictionary/series with genes as key and 5'UTR counts as values cds_counts : int or dict Total number of reads in CDs region or alternatively a dictionary/series with genes as key and CDS counts as values utr3_counts : int or dict Total number of reads in 3'UTR region or alternatively a dictionary/series with genes as key and 3'UTR counts as values saveto : str Path to save output file to (<filename>.png/<filename>.pdf) """ fig = None if ax is None: fig, ax = plt.subplots() else: fig = ax.get_figure() barlist = ax.bar([0, 1, 2], [utr5_counts, cds_counts, utr3_counts]) barlist[0].set_color('#1b9e77') barlist[1].set_color('#d95f02') barlist[2].set_color('#7570b3') ax.set_xticks([0, 1, 2]) ax.set_xticklabels(["5'UTR", "CDS", "3'UTR"]) max_counts = np.max(np.hstack([utr5_counts, cds_counts, utr3_counts])) setup_axis( ax=ax, axis='y', majorticks=max_counts // 10, minorticks=max_counts // 20) ax.set_ylabel('# RPFs') sns.despine(trim=True, offset=10) if saveto: fig.tight_layout() fig.savefig(saveto, dpi=DPI) return ax, fig
[docs]def create_wavelet(data, ax): import pycwt as wavelet t = data.index N = len(data.index) p = np.polyfit(data.index, data, 1) data_notrend = data - np.polyval(p, data.index) std = data_notrend.std() # Standard deviation var = std**2 # Variance data_normalized = data_notrend / std # Normalized dataset mother = wavelet.Morlet(6) dt = 1 s0 = 2 * dt # Starting scale, in this case 2 * 0.25 years = 6 months dj = 1 / 12 # Twelve sub-octaves per octaves J = 7 / dj # Seven powers of two with dj sub-octaves alpha, _, _ = wavelet.ar1(data) # Lag-1 autocorrelation for red noise wave, scales, freqs, coi, fft, fftfreqs = wavelet.cwt( data_normalized, dt=dt, dj=dj, s0=s0, J=J, wavelet=mother) iwave = wavelet.icwt(wave, scales, dt, dj, mother) * std power = (np.abs(wave))**2 fft_power = np.abs(fft)**2 period = 1 / freqs power /= scales[:, None] signif, fft_theor = wavelet.significance( 1.0, dt, scales, 0, alpha, significance_level=0.95, wavelet=mother) sig95 = np.ones([1, N]) * signif[:, None] sig95 = power / sig95 glbl_power = power.mean(axis=1) dof = N - scales # Correction for padding at edges glbl_signif, tmp = wavelet.significance( var, dt, scales, 1, alpha, significance_level=0.95, dof=dof, wavelet=mother) levels = [0.0625, 0.125, 0.25, 0.5, 1, 2, 4, 8, 16] ax.contourf( t, np.log2(period), np.log2(power), np.log2(levels), extend='both', cmap=plt.cm.viridis) extent = [t.min(), t.max(), 0, max(period)] ax.contour( t, np.log2(period), sig95, [-99, 1], colors='k', linewidths=2, extent=extent) ax.fill( np.concatenate([t, t[-1:] + dt, t[-1:] + dt, t[:1] - dt, t[:1] - dt]), np.concatenate([ np.log2(coi), [1e-9], np.log2(period[-1:]), np.log2(period[-1:]), [1e-9] ]), 'k', alpha=0.3, hatch='x') ax.set_title('Wavelet Power Spectrum') ax.set_ylabel('Frequency') Yticks = 2**np.arange(0, np.ceil(np.log2(period.max()))) ax.set_yticks(np.log2(Yticks)) ax.set_yticklabels(np.round(1 / Yticks, 3)) return (iwave, period, power, sig95, coi)
[docs]def plot_periodicity_df(df, saveto, cbar=False, figsize=(8, 8)): """Plot periodicty values across fragment lengths as a matrix. Parameters ---------------- df: string Path to dataframe containing fragment length specific periodicities saveto: string Path to output plot file cbar: bool Whether to plot cbar or not """ fig, ax = plt.subplots(figsize=figsize) df = pd.read_table(df, index_col=0) sns.heatmap(df.T, cmap='Blues', square=True, annot=True, cbar=cbar, ax=ax) fig.tight_layout() fig.savefig(saveto)