"""Core paxplot functions from: https://github.com/kravitsjacob/paxplot
"""
import copy
import functools
import warnings
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm
from matplotlib.figure import Figure
class PaxFigure(Figure):
_safe_inherited_functions = ["savefig", "set_size_inches", "draw", "show"]
def __init__(self, *args, data=[], **kwargs):
"""
Paxplot extension of Matplot Figure
"""
# Setup
super().__init__(*args, **kwargs)
self._show_unsafe_warning = True
# Paxplot attributes
self._pax_data = []
self._pax_data_scale = []
self._pax_lims = []
self._pax_ticks = []
self._pax_ticks_scale = []
self._pax_ticks_labels = []
self._pax_custom_lims = []
self._pax_custom_ticks = []
self._pax_colorbar = False
def _scale_vals(self, data, lower=None, upper=None):
"""
Scale `data` between lower and upper
Parameters
----------
data : array-like
Data to be scalled
lower : numeric, optional
Lower value for scaling, by default None
upper : numeric, optional
Upper value for scaling, by default None
Returns
-------
array-like
Scaled data
"""
# Convert to numpy
data = np.array(data)
if lower is None and upper is None:
lower = data.min()
upper = data.max()
# Scale data
data_scale = (data - lower) / (upper - lower)
return data_scale
def _get_color_gradient(self, val, lower, upper, colormap):
"""
Get color gradient values for the `val`
Parameters
----------
val : float
value to get color for scaling
lower : float
Lower value
upper : float
Upper value for scaling
colormap : str
Matplotlib colormap to use for coloring
Returns
-------
color: str
string color code
"""
color = mpl.colors.rgb2hex(
cm.get_cmap(colormap)(self._scale_vals(val, lower, upper))
)
return color
def _update_plot_lines(self, ax_idx):
"""
Update plotted lines based on scaled data (_pax_data_scale)
Parameters
----------
ax_idx : int
Axis index to update line data
"""
if ax_idx == 0: # First axis
for i, line in enumerate(self.axes[ax_idx].lines):
# Replace left y value
y_l_scaled = self._pax_data_scale[i, ax_idx]
line.set_ydata([y_l_scaled, line.get_ydata()[1]])
elif ax_idx == self._pax_data.shape[1] - 1: # Last axis
for i, line in enumerate(self.axes[ax_idx - 1].lines):
# Replace right y value
y_r_scaled = self._pax_data_scale[i, ax_idx]
line.set_ydata([line.get_ydata()[0], y_r_scaled])
else: # Middle axes
for i, line in enumerate(self.axes[ax_idx].lines):
# Replace left y value
y_l_scaled = self._pax_data_scale[i, ax_idx]
line.set_ydata([y_l_scaled, line.get_ydata()[1]])
for i, line in enumerate(self.axes[ax_idx - 1].lines):
# Replace right y value
y_r_scaled = self._pax_data_scale[i, ax_idx]
line.set_ydata([line.get_ydata()[0], y_r_scaled])
def _update_plot_ticks(self, ax_idx):
"""
Update ticks based on tick labels (_pax_ticks_labels) and scaled tick
location data (_pax_ticks_scale)
Parameters
----------
ax_idx : _type_
_description_
Raises
------
ValueError
_description_
"""
self.axes[ax_idx].set_yticks(ticks=self._pax_ticks_scale[ax_idx])
try:
self.axes[ax_idx].set_yticklabels(labels=self._pax_ticks_labels[ax_idx])
except ValueError:
raise ValueError("Length of `labels` must be same as length of `ticks`")
# Set bounds on axis (always between 0 and 1)
self.axes[ax_idx].set_ylim([0.0, 1.0])
def _default_format(self):
"""
Set the default format of a Paxplot Figure
"""
# Set attributes
def_vals = [[0, 1]] * len(self.axes)
def_bools = [False] * len(self.axes)
self._pax_lims = copy.deepcopy(def_vals)
self._pax_ticks = copy.deepcopy(def_vals)
self._pax_ticks_scale = copy.deepcopy(def_vals)
self._pax_ticks_labels = copy.deepcopy(def_vals)
self._pax_custom_lims = copy.deepcopy(def_bools)
self._pax_custom_ticks = copy.deepcopy(def_bools)
self._pax_invert = copy.deepcopy(def_bools)
# Remove space between plots
subplots_adjust_args = {"wspace": 0.0, "hspace": 0.0}
self.subplots_adjust(**subplots_adjust_args)
for ax in self.axes:
# Remove axes frame
ax.spines["top"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.spines["right"].set_visible(False)
# Set limits
ax.set_ylim([0, 1])
ax.set_xlim([0, 1])
# Set x ticks
ax.set_xticks([0], [" "])
ax.tick_params(axis="x", length=0.0, pad=10)
# Set y ticks
ax.set_yticks([0, 1])
# Adjust ticks on last axis
self.axes[-1].yaxis.tick_right()
def _default_lim(self, ax_idx):
"""
Set the default limits for an axis. Default limits are between the
minimum and maximum values.
Parameters
----------
ax_idx : int
Index of matplotlib axes
"""
# Set attibutes
self._pax_custom_lims[ax_idx] = False
# Column statistics
col = self._pax_data[:, ax_idx]
minimum = min(col)
maximum = max(col)
# Set limits
self._set_lim(ax_idx=ax_idx, bottom=minimum, top=maximum)
def _default_ticks(self, ax_idx):
"""
Set the default ticks for an axis. Default ticks are six labels
between the current limits of the axis.
Parameters
----------
ax_idx : int
Index of matplotlib axes
"""
# Set attibutes
self._pax_custom_ticks[ax_idx] = False
# Set limits
n_ticks = 6
precision = 2
bottom = self._pax_lims[ax_idx][0]
top = self._pax_lims[ax_idx][1]
ticks = np.linspace(bottom, top, num=n_ticks + 1)
labels = ticks.round(precision)
self._set_ticks(ax_idx=ax_idx, ticks=ticks, labels=labels)
def _convert_string_data(self, data: list):
"""
Convert string input data to numerical data
Parameters
----------
data : list
Data to be plotted from `plot`
Returns
-------
data : list
Converted `data`
"""
for col_i in range(len(data[0])):
# Extract column
column = [row[col_i] for row in data]
if type(column[0]) is str:
# Unique values
strings = list(dict.fromkeys(column)) # Preserves order
numbers = list(range(len(strings)))
numbers = self._scale_vals(
numbers,
)
strings.reverse()
# Translation of strings to numbers for tick position
translate_dict = dict(zip(strings, numbers))
column_translated = [translate_dict.get(item, item) for item in column]
for row_idx, row in enumerate(data):
row[col_i] = column_translated[row_idx]
# Set ticks
self.set_ticks(ax_idx=col_i, ticks=numbers, labels=strings)
return np.array(data)
def plot(self, data: list, line_kwargs={}):
"""
Plot the supplied data
Parameters
----------
data : array-like
Data to be plotted
line_kwargs: dict
Keyword arguments for lines corresponding to data
"""
# Initial Checking
if len(data[0]) < len(self.axes) and not self._pax_colorbar:
warnings.warn(
"Supplied data has fewer columns than figure. Figure created "
"with empty column(s)",
Warning,
)
elif len(data[0]) > len(self.axes):
raise ValueError(
"Supplied data has more columns than figure. Please recreate "
"paxfigure with appropriate n_axes"
)
# Convert input data to numpy
data_input = np.array(data)
# Check if conversion needed
if not np.issubdtype(data_input.dtype.type, np.number):
data_input = self._convert_string_data(data)
# Update data attributes
if len(self._pax_data) == 0:
self._pax_data = data_input
else:
self._pax_data = np.vstack([self._pax_data, data_input])
# Scale input data based on current limits
data_input_scale = data_input.copy().astype(np.single)
for col_idx, col in enumerate(data_input.T):
data_input_scale[:, col_idx] = self._scale_vals(
data=col,
lower=self._pax_lims[col_idx][0],
upper=self._pax_lims[col_idx][1],
)
# Update scaled data attributes
if len(self._pax_data_scale) == 0:
self._pax_data_scale = data_input_scale
else:
self._pax_data_scale = np.vstack([self._pax_data_scale, data_input_scale])
# Add scaled input data to plot
for ax_idx, ax in enumerate(self.axes[:-1]):
ax.plot(data_input_scale[:, ax_idx : ax_idx + 2].T, **line_kwargs)
# Limits
for ax_idx in range(self._pax_data.shape[1]):
if self._pax_custom_lims[ax_idx]: # Respect custom limits
self._set_lim(
ax_idx=ax_idx,
bottom=self._pax_lims[ax_idx][0],
top=self._pax_lims[ax_idx][1],
)
else: # Default limits of data
self._default_lim(ax_idx=ax_idx)
# Respect custom ticks
for ax_idx in range(self._pax_data.shape[1]):
if self._pax_custom_ticks[ax_idx]: # Respect custom ticks
self._set_ticks(
ax_idx=ax_idx,
ticks=self._pax_ticks[ax_idx],
labels=self._pax_ticks_labels[ax_idx],
)
def set_lim(self, ax_idx: int, bottom: float, top: float):
"""
Set custom limits on axis
Parameters
----------
ax_idx : int
Index of matplotlib axes
bottom : numeric
Lower limit
top : numeric
Upper limit
"""
# Set attibutes
self._pax_custom_lims[ax_idx] = True
# Set ticks
self._set_lim(ax_idx=ax_idx, bottom=bottom, top=top)
def _set_lim(self, ax_idx: int, bottom: float, top: float):
"""
Private function to set custom limits on axis
Parameters
----------
ax_idx : int
Index of matplotlib axes
bottom : numeric
Lower limit
top : numeric
Upper limit
"""
# Check bottom top values
try:
if bottom == top:
bottom = bottom - 1.0
top = top + 1.0
except TypeError:
raise TypeError(
f"Both `bottom` and `top` must be numeric values. Currently "
f"`bottom` is of type {type(bottom)} and `top` is of type"
f"{type(top)}"
)
# Checking if data is plotted
try:
self._pax_data[:, ax_idx]
except TypeError:
raise AttributeError(
"Paxplot does not support set_lim if no data has been" "plotted"
)
# Set attribute data
self._pax_lims[ax_idx] = [bottom, top]
# Scale data
col = self._pax_data[:, ax_idx]
self._pax_data_scale[:, ax_idx] = self._scale_vals(
col, lower=bottom, upper=top
).astype(np.single)
# Update plot of scaled data
self._update_plot_lines(ax_idx)
# Ticks
if self._pax_custom_ticks[ax_idx]: # Preserve custom ticks
self._set_ticks(
ax_idx=ax_idx,
ticks=self._pax_ticks[ax_idx],
labels=self._pax_ticks_labels[ax_idx],
)
else: # Default ticks
self._default_ticks(ax_idx=ax_idx)
def set_ticks(self, ax_idx: int, ticks: list, labels=None):
"""
Set the axis tick locations and optionally labels.
Parameters
----------
ax_idx : int
Index of matplotlib axes
ticks : list of floats
List of tick locations.
labels : list of str, optional
List of tick labels. If not set, the labels show the data value.
"""
# Set attibutes
self._pax_custom_ticks[ax_idx] = True
# Set ticks
self._set_ticks(ax_idx=ax_idx, ticks=ticks, labels=labels)
def _set_ticks(self, ax_idx: int, ticks: list, labels=None):
"""
Private function to set the axis tick locations and optionally labels.
Parameters
----------
ax_idx : int
Index of matplotlib axes
ticks : list of floats
List of tick locations.
labels : list of str, optional
List of tick labels. If not set, the labels show the data value.
"""
# Tick tests ('ask permission' mindset as nested try/except gets nasty)
try:
ticks + [1]
except TypeError:
raise TypeError(f"`ticks` must be array-like not type {type(ticks)}")
try:
min(ticks)
except TypeError:
raise TypeError(
"All entries in `ticks` must be numeric. To set string ticks,"
" use the `labels` argument"
)
# Retrieve matplotlib axes
try:
ax = self.axes[ax_idx] # noqa
except IndexError:
raise IndexError(
"You are trying to set the limits of axis with index "
f"{ax_idx}. However, axis index only goes up to "
f"{len(self.axes)-1}."
)
except TypeError:
raise TypeError(f"Type of `ax_idx` must be integer not {type(ax_idx)}")
# Set tick attibutes
self._pax_ticks[ax_idx] = ticks
# Scale tick based on current limits
lim_bottom = self._pax_lims[ax_idx][0]
lim_top = self._pax_lims[ax_idx][1]
self._pax_ticks_scale[ax_idx] = self._scale_vals(
ticks, lower=lim_bottom, upper=lim_top
)
# Tick labels
if labels is None:
labels = ticks.copy()
self._pax_ticks_labels[ax_idx] = labels
# Update ticks on plots
self._update_plot_ticks(ax_idx)
# Check if limits need updating
lim_min = min(self._pax_lims[ax_idx])
lim_max = max(self._pax_lims[ax_idx])
if ticks[0] < lim_min or ticks[-1] > lim_max:
bottom = min(np.append(ticks, lim_min))
top = max(np.append(ticks, lim_max))
self._set_lim(ax_idx=ax_idx, bottom=bottom, top=top)
def set_even_ticks(
self, ax_idx: int, n_ticks=6, minimum=None, maximum=None, precision=2
):
"""
Set evenly spaced axis ticks between minimum and maximum value. If
no minimum and maximum values are specified, the limits of the
underlying plotted data are assumed.
Parameters
----------
ax_idx : int
Index of matplotlib axes
n_ticks : int
Number of ticks
minimum : numeric
minimum value for ticks
maximum : numeric
maximum value for ticks
precision : int
number of decimal points for tick labels
"""
# Set custom tick attributes
self._pax_custom_ticks[ax_idx] = True
# Set automatic min and maximum
if minimum is None and maximum is None:
minimum = self._pax_data[:, ax_idx].min()
maximum = self._pax_data[:, ax_idx].max()
# Minimum/maximum check
if minimum > maximum:
raise ValueError("Value for `minimum` cannot be greater than `maximum`")
# Retrieve matplotlib axes
try:
self.axes[ax_idx]
except IndexError:
raise IndexError(
f"You are trying to set the limits of axis with index "
f"{ax_idx}. However, axis index only goes up to "
f"{len(self.axes)-1}."
)
except TypeError:
raise TypeError(f"Type of `ax_idx` must be integer not {type(ax_idx)}")
# Generate ticks
try:
ticks = np.linspace(minimum, maximum, num=n_ticks + 1)
except TypeError:
raise TypeError(f"Type of `n_ticks` must be integer not {type(n_ticks)}")
labels = ticks.round(precision)
# Set ticks
self._set_ticks(ax_idx=ax_idx, ticks=ticks, labels=labels)
def set_label(self, ax_idx: int, label: str):
"""Set the label for the axis
Parameters
----------
ax_idx : int
Index of matplotlib axes
label : str
The label text
"""
try:
ax = self.axes[ax_idx]
except IndexError:
raise IndexError(
f"You are trying to set the limits of axis with index "
f"{ax_idx}. However, axis index only goes up to "
f"{len(self.axes)-1}."
)
except TypeError:
raise TypeError(f"Type of `ax_idx` must be integer not {type(ax_idx)}")
ax.set_xticks(ticks=[0.0])
ax.set_xticklabels([label])
def set_labels(self, labels: list):
"""
Set labels for all axes. A wrapper for set_label
Parameters
----------
labels : list
Labels for each axis. Must be same length as number of axes.
"""
# Checking length
if len(self._pax_data[0]) != len(labels):
raise IndexError("Length of `labels` must equal number of axes")
# Set labels
for i, label in enumerate(labels):
self.set_label(i, label)
def invert_axis(self, ax_idx: int):
"""Invert axis
Parameters
----------
ax_idx : int
Index of matplotlib axes
"""
# Local vars
try:
ax = self.axes[ax_idx] # noqa
except IndexError:
raise IndexError(
f"You are trying to set the limits of axis with index "
f"{ax_idx}. However, axis index only goes up to "
f"{len(self.axes)-1}."
)
except TypeError:
raise TypeError(f"Type of `ax_idx` must be integer not {type(ax_idx)}")
# Checking if data is plotted
try:
self._pax_data[:, ax_idx]
except TypeError:
raise AttributeError(
"Paxplot does not support invert_axis if no data has been" "plotted"
)
# Set attribute
self._pax_invert[ax_idx] = True
self._set_lim(
ax_idx=ax_idx,
bottom=self._pax_lims[ax_idx][1],
top=self._pax_lims[ax_idx][0],
)
def add_legend(self, labels=[]):
"""Create a legend for a specified figure
Parameters
----------
labels : list
List of data labels
"""
# Check if too many labels supplied
if len(labels) > len(self.axes[0].lines):
warnings.warn(
"More labels supplied than data. Some labels are unused.", Warning
)
if len(labels) > 0:
try:
for ax in self.axes:
for i, line in enumerate(ax.lines):
line.set_label(labels[i])
except IndexError:
raise IndexError(
f"Incorrect number of labels specified. You have supplied "
f"{len(labels)} labels, but {len(ax.lines)} were expected"
)
# Create blank axis for legend
n_axes = len(self.axes)
width_ratios = self.axes[0].get_gridspec().get_width_ratios()
new_n_axes = n_axes + 1
new_width_ratios = width_ratios + [1.0]
gs = self.add_gridspec(1, new_n_axes, width_ratios=new_width_ratios)
ax_legend = self.add_subplot(gs[0, n_axes])
# Create legend
lines = self.axes[0].lines
labels = [i.get_label() for i in lines]
ax_legend.legend(lines, labels, loc="center right")
# Figure formatting
for i in range(n_axes):
self.axes[i].set_subplotspec(gs[0:1, i : i + 1])
ax_legend.set_axis_off()
def add_colorbar(self, ax_idx: int, cmap="viridis", colorbar_kwargs={}):
"""Add colorbar to paxfigure
Parameters
----------
ax : int
axes index
data : array-like
Data to be plotted
cmap : str
Matplotlib colormap to use for coloring
colorbar_kwargs : dict
Matplotlib colorbar keyword arguments
"""
# Attribute
self._pax_colorbar = True
# Local vars
n_lines = len(self.axes[0].lines)
n_axes = len(self.axes)
# Testing
try:
self.axes[ax_idx]
except IndexError:
raise IndexError(
f"You are trying to set the limits of axis with index "
f"{ax_idx}. However, axis index only goes up to "
f"{len(self.axes)-1}."
)
except TypeError:
raise TypeError(f"Type of `ax_idx` must be integer not {type(ax_idx)}")
# Change line colors
for i in range(n_lines):
# Get value
if ax_idx < len(self.axes) - 1:
scale_val = self.axes[ax_idx].lines[i].get_ydata()[0]
else:
scale_val = self.axes[ax_idx - 1].lines[i].get_ydata()[1]
# Get color
color = self._get_color_gradient(scale_val, 0, 1, cmap)
# Assign color to line
for j in self.axes[:-1]:
j.lines[i].set_color(color)
# Create blank axis for colorbar
width_ratios = self.axes[0].get_gridspec().get_width_ratios()
new_n_axes = n_axes + 1
new_width_ratios = width_ratios + [0.5]
gs = self.add_gridspec(1, new_n_axes, width_ratios=new_width_ratios)
ax_colorbar = self.add_subplot(gs[0, n_axes])
# Create colorbar
sm = plt.cm.ScalarMappable(
norm=plt.Normalize(
vmin=self._pax_lims[ax_idx][0], vmax=self._pax_lims[ax_idx][1]
),
cmap=cmap,
)
self.colorbar(sm, orientation="vertical", ax=ax_colorbar, **colorbar_kwargs)
# Figure formatting
for i in range(n_axes):
self.axes[i].set_subplotspec(gs[0:1, i : i + 1])
ax_colorbar.set_axis_off()
def add_unsafe_warning(func, fig):
"""
Generate warning if not supported by Paxplot
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
if fig._show_unsafe_warning:
warnings.warn(
f"The function you have called ({func.__name__}) is not "
"officially supported by Paxplot, but it may still work. "
"Report issues to "
"https://github.com/kravitsjacob/paxplot/issues",
Warning,
)
return func(*args, **kwargs)
return wrapper
def disable_unsafe_warnings(func, fig):
"""
Temporarily disables safety warnings for the duration of the function
execution.
This allows a known safe function needs to make safe calls to otherwise
unsafe functions without throwing a warning.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
original_flag_value = fig._show_unsafe_warning
fig._show_unsafe_warning = False
result = func(*args, **kwargs)
fig._show_unsafe_warning = original_flag_value
return result
return wrapper
[docs]def pax_parallel(n_axes: int):
"""
Wrapper for paxplot analagous to the matplotlib.pyplot.subplots function
Parameters
----------
n_axes : int
Number of axes to create
Returns
-------
fig : PaxFigure
Paxplot figure class
"""
# Check type of n_axes
try:
width_ratios = [1.0] * (n_axes - 1)
except TypeError:
raise TypeError(
f"n_axes should by of type int. You have supplied a type" f"{type(n_axes)}"
)
# Create figure
width_ratios.append(0.0) # Last axis small
fig, _ = plt.subplots(
1,
n_axes,
sharey=False,
gridspec_kw={"width_ratios": width_ratios},
FigureClass=PaxFigure,
)
fig._default_format()
pax_figure_functions = set(
filter(
lambda func_name: callable(getattr(PaxFigure, func_name)),
vars(PaxFigure).keys(),
)
)
unsafe_functions = set(
filter(
lambda func_name: (
func_name not in PaxFigure._safe_inherited_functions
and func_name not in pax_figure_functions
),
dir(Figure),
)
)
# Add unsafe function warnings
for func_name in dir(PaxFigure):
cond_1 = not func_name.startswith("__")
cond_2 = not func_name.startswith("_")
cond_3 = not func_name.startswith("get")
cond_4 = callable(getattr(PaxFigure, func_name))
if cond_1 and cond_2 and cond_3 and cond_4:
func = getattr(fig, func_name)
if func_name in unsafe_functions:
func = add_unsafe_warning(func, fig)
else:
func = disable_unsafe_warnings(func, fig)
setattr(fig, func_name, func)
return fig