#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Implements the BaseGraph, Graph, and InteractiveGraph classes.
"""
import warnings
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from uuid import uuid4
from scipy.spatial import cKDTree
from ._utils import (
_get_unique_nodes,
_get_angle,
_get_interior_angle_between,
_get_orthogonal_unit_vector,
_get_point_along_spline,
_get_tangent_at_point,
_get_text_object_dimensions,
_make_pretty,
_rank,
_get_n_points_on_a_circle,
_edge_list_to_adjacency_list,
)
from ._node_layout import (
get_fruchterman_reingold_layout,
get_random_layout,
get_sugiyama_layout,
get_radial_tree_layout,
get_circular_layout,
get_linear_layout,
get_bipartite_layout,
get_multipartite_layout,
get_shell_layout,
get_community_layout,
get_geometric_layout,
_reduce_node_overlap,
_remove_node_overlap,
)
from ._edge_layout import (
get_straight_edge_paths,
_shift_edge,
get_curved_edge_paths,
get_arced_edge_paths,
get_bundled_edge_paths,
get_selfloop_paths,
_get_selfloop_path,
)
from ._artists import NodeArtist, EdgeArtist
from ._parser import parse_graph, _parse_edge_list, _is_directed
BASE_SCALE = 1e-2
DEFAULT_COLOR = '#2c404c' # '#677e8c' # '#121f26' # '#23343f' # 'k',
[docs]
class BaseGraph(object):
"""The Graph base class.
Parameters
----------
edges : list
The edges of the graph, with each edge being represented by a (source node ID, target node ID) tuple.
nodes : list or None, default None
List of nodes. Required argument if any node in the graph is unconnected.
If None, `nodes` is initialised to the set of the flattened `edges`.
node_layout : str or dict, default 'spring'
If `node_layout` is a string, the node positions are computed using the indicated method:
- 'random' : place nodes in random positions;
- 'circular' : place nodes regularly spaced on a circle;
- 'spring' : place nodes using a force-directed layout (Fruchterman-Reingold algorithm);
- 'dot' : place nodes using the Sugiyama algorithm; the graph should be directed and acyclic;
- 'radial' : place nodes radially using the Sugiyama algorithm; the graph should be directed and acyclic;
- 'community' : place nodes such that nodes belonging to the same community are grouped together;
- 'bipartite' : place nodes regularly spaced on two parallel lines;
- 'multipartite' : place nodes regularly spaced on several parallel lines;
- 'shell' : place nodes regularly spaced on concentric circles;
- 'geometric' : place nodes according to the length of the edges between them.
If `node_layout` is a dict, keys are nodes and values are (x, y) positions.
node_layout_kwargs : dict or None, default None
Keyword arguments passed to node layout functions.
See the documentation of the following functions for a full description of available options:
- get_random_layout
- get_circular_layout
- get_fruchterman_reingold_layout
- get_sugiyama_layout
- get_radial_tree_layout
- get_community_layout
- get_bipartite_layout
- get_multipartite_layout
- get_shell_layout
- get_geometric_layout
node_shape : str or dict, default 'o'
Node shape.
If the type is str, all nodes have the same shape.
If the type is dict, maps each node to an individual string representing the shape.
The string specification is as for matplotlib.scatter marker, i.e. one of 'so^>v<dph8'.
node_size : float or dict, default 3.
Node size (radius).
If the type is float, all nodes will have the same size.
If the type is dict, maps each node to an individual size.
.. note:: Values are rescaled by BASE_SCALE (1e-2) to be compatible with layout routines in igraph and networkx.
node_edge_width : float or dict, default 0.5
Line width of node marker border.
If the type is float, all nodes have the same line width.
If the type is dict, maps each node to an individual line width.
.. note:: Values are rescaled by BASE_SCALE (1e-2) to be compatible with layout routines in igraph and networkx.
node_color : matplotlib color specification or dict, default 'w'
Node color.
If the type is a string or RGBA array, all nodes have the same color.
If the type is dict, maps each node to an individual color.
node_edge_color : matplotlib color specification or dict, default {DEFAULT_COLOR}
Node edge color.
If the type is a string or RGBA array, all nodes have the same edge color.
If the type is dict, maps each node to an individual edge color.
node_alpha : scalar or dict, default 1.
Node transparency.
If the type is a float, all nodes have the same transparency.
If the type is dict, maps each node to an individual transparency.
node_zorder : int or dict, default 2
Order in which to plot the nodes.
If the type is an int, all nodes have the same zorder.
If the type is dict, maps each node to an individual zorder.
node_labels : bool or dict, (default False)
If False, the nodes are unlabelled.
If True, the nodes are labelled with their node IDs.
If the node labels are to be distinct from the node IDs, supply a dictionary mapping nodes to node labels.
Only nodes in the dictionary are labelled.
node_label_offset: float or tuple, default (0., 0.)
A (dx, dy) tuple specifies the exact offset from the node position.
If a single scalar delta is specified, the value is interpreted as a distance,
and the label is placed delta away from the node position while trying to
reduce node/label, node/edge, and label/label overlaps.
node_label_fontdict : dict
Keyword arguments passed to matplotlib.text.Text.
For a full list of available arguments see the matplotlib documentation.
The following default values differ from the defaults for matplotlib.text.Text:
- size (adjusted to fit into node artists if offset is (0, 0))
- horizontalalignment (default here: 'center')
- verticalalignment (default here: 'center')
- clip_on (default here: False)
- zorder (default here: inf)
edge_width : float or dict, default 1.
Width of edges.
If the type is a float, all edges have the same width.
If the type is dict, maps each edge to an individual width.
.. note:: Value is rescaled by BASE_SCALE (1e-2) to be compatible with layout routines in igraph and networkx.
edge_color : matplotlib color specification or dict, default {DEFAULT_COLOR}
Edge color.
If the type is a string or RGBA array, all edges have the same color.
If the type is dict, maps each edge to an individual color.
edge_alpha : float or dict, default 1.
The edge transparency,
If the type is a float, all edges have the same transparency.
If the type is dict, maps each edge to an individual transparency.
edge_zorder : int or dict, default 1
Order in which to plot the edges.
If the type is an int, all nodes have the same zorder.
If the type is dict, maps each node to an individual zorder.
If None, the edges will be plotted in the order they appear in 'adjacency'.
Hint: graphs typically appear more visually pleasing if darker edges are plotted on top of lighter edges.
arrows : bool, default False
If True, draw edges with arrow heads.
edge_layout : str or dict (default 'straight')
If edge_layout is a string, determine the layout internally:
- 'straight' : draw edges as straight lines
- 'curved' : draw edges as curved splines; the spline control points are optimised to avoid other nodes and edges
- 'arc' : draw edges as arcs with a fixed curvature
- 'bundled' : draw edges as edge bundles
If edge_layout is a dict, the keys are edges and the values are edge paths
in the form iterables of (x, y) tuples, the edge segments.
edge_layout_kwargs : dict, default None
Keyword arguments passed to edge layout functions.
See the documentation of the following functions for a full description of available options:
- get_straight_edge_paths
- get_curved_edge_paths
- get_bundled_edge_paths
edge_labels : bool or dict, default False
If False, the edges are unlabelled.
If True, the edges are labelled with their edge IDs.
If the edge labels are to be distinct from the edge IDs, supply a dictionary mapping edges to edge labels.
Only edges in the dictionary are labelled.
edge_label_position : float, default 0.5
Relative position along the edge where the label is placed.
- head : 0.
- centre : 0.5
- tail : 1.
edge_label_rotate : bool, default True
If True, edge labels are rotated such that they have the same orientation as their edge.
If False, edge labels are not rotated; the angle of the text is parallel to the axis.
edge_label_fontdict : dict
Keyword arguments passed to matplotlib.text.Text.
For a full list of available arguments see the matplotlib documentation.
The following default values differ from the defaults for matplotlib.text.Text:
- horizontalalignment (default here: 'center'),
- verticalalignment (default here: 'center')
- clip_on (default here: False),
- bbox (default here: dict(boxstyle='round', ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0)),
- zorder (default here: inf),
- rotation (determined by edge_label_rotate argument)
origin : tuple, default (0., 0.)
The lower left hand corner of the bounding box specifying the extent of the canvas.
scale : tuple, default (1., 1.)
The width and height of the bounding box specifying the extent of the canvas.
prettify : bool, default True
If True, despine and remove ticks and tick labels.
Set figure background to white. Set axis aspect to equal.
ax : matplotlib.axis instance or None, default None
Axis to plot onto; if none specified, one will be instantiated with plt.gca().
Attributes
----------
node_artists : dict
Mapping of node IDs to matplotlib PathPatch artists.
edge_artists : dict
Mapping of edge IDs to matplotlib PathPatch artists.
node_label_artists : dict
Mapping of node IDs to matplotlib text objects (if applicable).
edge_label_artists : dict
Mapping of edge IDs to matplotlib text objects (if applicable).
node_positions : dict node : (x, y) tuple
Mapping of node IDs to node positions.
See also
--------
Graph, InteractiveGraph
"""
def __init__(self, edges,
nodes=None,
node_layout='spring',
node_layout_kwargs=None,
node_shape='o',
node_size=3.,
node_edge_width=0.5,
node_color='w',
node_edge_color=DEFAULT_COLOR,
node_alpha=1.0,
node_zorder=2,
node_labels=False,
node_label_offset=(0., 0.),
node_label_fontdict=None,
edge_width=1.,
edge_color=DEFAULT_COLOR,
edge_alpha=0.5,
edge_zorder=1,
arrows=False,
edge_layout='straight',
edge_layout_kwargs=None,
edge_labels=False,
edge_label_position=0.5,
edge_label_rotate=True,
edge_label_fontdict=None,
origin=(0., 0.),
scale=(1., 1.),
prettify=True,
ax=None,
*args, **kwargs
):
self.edges = _parse_edge_list(edges)
self.nodes = self._initialize_nodes(nodes)
# Convert all node and edge parameters to dictionaries.
node_shape = self._normalize_string_argument(node_shape, self.nodes, 'node_shape')
node_size = self._normalize_numeric_argument(node_size, self.nodes, 'node_size')
node_edge_width = self._normalize_numeric_argument(node_edge_width, self.nodes, 'node_edge_width')
node_color = self._normalize_color_argument(node_color, self.nodes, 'node_color')
node_edge_color = self._normalize_color_argument(node_edge_color, self.nodes, 'node_edge_color')
node_alpha = self._normalize_numeric_argument(node_alpha, self.nodes, 'node_alpha')
node_zorder = self._normalize_numeric_argument(node_zorder, self.nodes, 'node_zorder')
edge_width = self._normalize_numeric_argument(edge_width, self.edges, 'edge_width')
edge_color = self._normalize_color_argument(edge_color, self.edges, 'edge_color')
edge_alpha = self._normalize_numeric_argument(edge_alpha, self.edges, 'edge_alpha')
edge_zorder = self._normalize_numeric_argument(edge_zorder, self.edges, 'edge_zorder')
for node in self.nodes:
if (node_size[node] < node_edge_width[node]) & (node_color[node] != node_edge_color[node]):
msg = f"The border around the node {node} is broader than its radius."
msg += f" The node will mostly have the color of the border ({node_edge_color[node]}), even though a different face color was specified ({node_color[node]})."
msg += f" To address this issue, reduce the value given for `node_edge_width`."
warnings.warn(msg)
# Rescale.
node_size = self._rescale(node_size, BASE_SCALE)
node_edge_width = self._rescale(node_edge_width, BASE_SCALE)
edge_width = self._rescale(edge_width, BASE_SCALE)
self.node_size = node_size
# Initialise node and edge layouts.
self.origin = origin
self.scale = scale
self.node_positions = self._initialize_node_layout(
node_layout, node_layout_kwargs, origin, scale, node_size)
self.edge_paths, self.edge_layout, self.edge_layout_kwargs = self._initialize_edge_layout(
edge_layout, edge_layout_kwargs, origin, scale, edge_width)
# Draw plot elements
self.ax = self._initialize_axis(ax)
self.edge_artists = dict()
self.draw_edges(self.edge_paths, edge_width, edge_color, edge_alpha,
edge_zorder, arrows, node_size)
self.node_artists = dict()
self.draw_nodes(self.nodes, self.node_positions,
node_shape, node_size, node_edge_width,
node_color, node_edge_color, node_alpha, node_zorder)
# This function needs to be called before any font sizes are adjusted,
# as the axis dimensions affect the effective font size.
self._update_view()
if node_labels:
if isinstance(node_labels, bool):
node_labels = dict(zip(self.nodes, self.nodes))
self.node_label_fontdict = self._initialize_node_label_fontdict(
node_label_fontdict, node_labels, node_label_offset)
self.node_label_offset, self._recompute_node_label_offsets =\
self._initialize_node_label_offset(node_labels, node_label_offset)
if self._recompute_node_label_offsets:
self._update_node_label_offsets()
self.node_label_artists = dict()
self.draw_node_labels(node_labels, self.node_label_fontdict)
if edge_labels:
if isinstance(edge_labels, bool):
edge_labels = dict(zip(self.edges, self.edges))
self.edge_label_fontdict = self._initialize_edge_label_fontdict(edge_label_fontdict)
self.edge_label_position = edge_label_position
self.edge_label_rotate = edge_label_rotate
self.edge_label_artists = dict()
self.draw_edge_labels(edge_labels, self.edge_label_position,
self.edge_label_rotate, self.edge_label_fontdict)
if prettify:
_make_pretty(self.ax)
def _initialize_nodes(self, nodes):
nodes_in_edges = _get_unique_nodes(self.edges)
if nodes is None:
return nodes_in_edges
else:
if set(nodes).issuperset(nodes_in_edges):
return nodes
else:
msg = "There are some node IDs in the edgelist not present in `nodes`. "
msg += "`nodes` has to be the superset of `edges`."
msg += "\nThe following nodes are missing:"
missing = set(nodes_in_edges) - set(nodes)
for node in missing:
msg += f"\n\t{node}"
raise ValueError(msg)
def _normalize_numeric_argument(self, numeric_or_dict, dict_keys, variable_name):
if isinstance(numeric_or_dict, (int, float)):
return {key : numeric_or_dict for key in dict_keys}
elif isinstance(numeric_or_dict, dict):
self._check_completeness(numeric_or_dict, dict_keys, variable_name)
self._check_types(numeric_or_dict.values(), (int, float), variable_name)
return numeric_or_dict
else:
msg = f"The type of {variable_name} has to be either a int, float, or a dict."
msg += f"\nThe current type is {type(numeric_or_dict)}."
raise TypeError(msg)
def _check_completeness(self, given_set, desired_set, variable_name):
# ensure that iterables are sets
# TODO: check that iterables can safely be converted to sets (unlike dict keys)
given_set = set(given_set)
desired_set = set(desired_set)
complete = given_set.issuperset(desired_set)
if not complete:
missing = desired_set - given_set
msg = f"{variable_name} is incomplete. The following elements are missing:"
for item in missing:
if isinstance(item, str):
msg += f"\n\'{item}\'"
else:
msg += f"\n{item}"
raise ValueError(msg)
def _check_types(self, items, types, variable_name):
for item in items:
if not isinstance(item, types):
msg = f"Item {item} in {variable_name} is of the wrong type."
msg += f"\nExpected type: {types}"
msg += f"\nActual type: {type(item)}"
raise TypeError(msg)
def _normalize_string_argument(self, str_or_dict, dict_keys, variable_name):
if isinstance(str_or_dict, str):
return {key : str_or_dict for key in dict_keys}
elif isinstance(str_or_dict, dict):
self._check_completeness(set(str_or_dict), dict_keys, variable_name)
self._check_types(str_or_dict.values(), str, variable_name)
return str_or_dict
else:
msg = f"The type of {variable_name} has to be either a str or a dict."
msg += f"The current type is {type(str_or_dict)}."
raise TypeError(msg)
def _normalize_color_argument(self, color_or_dict, dict_keys, variable_name):
if mpl.colors.is_color_like(color_or_dict):
return {key : color_or_dict for key in dict_keys}
elif color_or_dict is None:
return {key : color_or_dict for key in dict_keys}
elif isinstance(color_or_dict, dict):
self._check_completeness(set(color_or_dict), dict_keys, variable_name)
# TODO: assert that each element is a valid color
return color_or_dict
else:
msg = f"The type of {variable_name} has to be either a valid matplotlib color specification or a dict."
raise TypeError(msg)
def _rescale(self, mydict, scalar):
return {key: value * scalar for (key, value) in mydict.items()}
def _initialize_node_layout(self, node_layout, node_layout_kwargs, origin, scale, node_size):
if node_layout_kwargs is None:
node_layout_kwargs = dict()
if isinstance(node_layout, str):
if (node_layout == 'spring') or (node_layout == 'dot') or (node_layout == 'radial'):
node_layout_kwargs.setdefault('node_size', node_size)
return self._get_node_positions(node_layout, node_layout_kwargs, origin, scale)
elif isinstance(node_layout, dict):
self._check_completeness(set(node_layout), set(self.nodes), 'node_layout')
return node_layout
def _get_node_positions(self, node_layout, node_layout_kwargs, origin, scale):
if len(self.nodes) == 1:
return {self.nodes[0]: np.array([origin[0] + 0.5 * scale[0], origin[1] + 0.5 * scale[1]])}
if node_layout == 'spring':
node_positions = get_fruchterman_reingold_layout(
self.edges, nodes=self.nodes, origin=origin, scale=scale, **node_layout_kwargs)
if len(node_positions) > 3: # Qhull fails for 2 or less nodes
node_positions = _remove_node_overlap(node_positions, node_size=self.node_size, origin=origin, scale=scale)
return node_positions
if node_layout == 'community':
node_positions = get_community_layout(
self.edges, nodes=self.nodes, origin=origin, scale=scale, **node_layout_kwargs)
if len(node_positions) > 3: # Qhull fails for 2 or less nodes
node_positions = _remove_node_overlap(node_positions, node_size=self.node_size, origin=origin, scale=scale)
return node_positions
elif node_layout == 'circular':
return get_circular_layout(
self.edges, nodes=self.nodes, origin=origin, scale=scale, **node_layout_kwargs)
elif node_layout == 'linear':
return get_linear_layout(
self.edges, nodes=self.nodes, origin=origin, scale=scale, **node_layout_kwargs)
elif node_layout == 'bipartite':
return get_bipartite_layout(
self.edges, nodes=self.nodes, origin=origin, scale=scale, **node_layout_kwargs)
elif node_layout == 'multipartite':
return get_multipartite_layout(
self.edges, origin=origin, scale=scale, **node_layout_kwargs)
elif node_layout == 'shell':
return get_shell_layout(
self.edges, origin=origin, scale=scale, **node_layout_kwargs)
elif node_layout == 'dot':
return get_sugiyama_layout(
self.edges, nodes=self.nodes, origin=origin, scale=scale, **node_layout_kwargs)
elif node_layout == 'radial':
return get_radial_tree_layout(
self.edges, nodes=self.nodes, origin=origin, scale=scale, **node_layout_kwargs)
elif node_layout == 'random':
return get_random_layout(
self.edges, nodes=self.nodes, origin=origin, scale=scale, **node_layout_kwargs)
elif node_layout == 'geometric':
return get_geometric_layout(
self.edges, nodes=self.nodes, origin=origin, scale=scale, **node_layout_kwargs)
else:
implemented = ['spring', 'community', 'circular', 'linear', 'bipartite', 'multipartite', 'shell', 'dot', 'radial', 'random', 'geometric']
msg = f"Node layout {node_layout} not implemented. Available layouts are:"
for method in implemented:
msg += f"\n\t{method}"
raise NotImplementedError(msg)
def _initialize_edge_layout(self, edge_layout, edge_layout_kwargs, origin, scale, edge_width):
if edge_layout_kwargs is None:
edge_layout_kwargs = dict()
if edge_layout == "straight":
edge_layout_kwargs.setdefault('edge_width', edge_width)
edge_layout_kwargs.setdefault('origin', origin)
edge_layout_kwargs.setdefault('scale', scale)
edge_layout_kwargs.setdefault('selfloop_radius', 0.05 * np.linalg.norm(scale))
edge_layout_kwargs.setdefault('selfloop_angle', None)
elif edge_layout == 'curved':
edge_layout_kwargs.setdefault('origin', origin)
edge_layout_kwargs.setdefault('scale', scale)
edge_layout_kwargs.setdefault('selfloop_radius', 0.05 * np.linalg.norm(scale))
# area = np.product(scale)
# k = np.sqrt(area / float(len(self.nodes))) # expected distance between nodes
# # As there are multiple control points per edge,
# # edge segments should be much shorter. k hence needs to be smaller.
# k *= 0.1
# edge_layout_kwargs.setdefault('k', k)
edge_layout_kwargs.setdefault('k', 0.1)
elif edge_layout == 'arc':
edge_layout_kwargs.setdefault('rad', 1.)
edge_layout_kwargs.setdefault('origin', origin)
edge_layout_kwargs.setdefault('scale', scale)
edge_layout_kwargs.setdefault('selfloop_radius', 0.05 * np.linalg.norm(scale))
edge_layout_kwargs.setdefault('selfloop_angle', np.pi/2)
elif edge_layout == 'bundled':
edge_layout_kwargs.setdefault('k', 500)
edge_layout_kwargs.setdefault('total_cycles', 6)
if isinstance(edge_layout, str):
edge_paths = self._get_edge_paths(self.edges, self.node_positions,
edge_layout, edge_layout_kwargs)
elif isinstance(edge_layout, dict):
self._check_completeness(edge_layout, self.edges, 'edge_layout')
edge_paths = edge_layout
# determine a sensible edge_layout in case node positions change
path_lengths = np.array([len(path) for path in edge_paths.values()])
if np.any(path_lengths) > 2:
edge_layout = 'curved'
else:
edge_layout = 'straight'
else:
raise TypeError("Variable `edge_layout` either a string or a dict mapping edges to edge paths.")
return edge_paths, edge_layout, edge_layout_kwargs
def _initialize_axis(self, ax):
if ax is None:
return plt.gca()
elif isinstance(ax, mpl.axes.Axes):
return ax
else:
raise TypeError(f"Variable 'ax' either None or a matplotlib axis instance. However, type(ax) is {type(ax)}.")
def draw_nodes(self, nodes, node_positions, node_shape, node_size,
node_edge_width, node_color, node_edge_color, node_alpha,
node_zorder):
"""Draw or update node artists.
Parameters
----------
nodes : list
List of nodes IDs.
node_positions : dict
Mapping of nodes to (x, y) positions.
node_shape : dict
Mapping of nodes to shapes.
Specification is as for matplotlib.scatter marker, i.e. one of 'so^>v<dph8'.
node_size : dict
Mapping of nodes to sizes.
node_edge_width : dict
Mapping of nodes to marker edge widths.
node_color : dict
Mapping of nodes to valid matplotlib color specifications.
node_edge_color : dict
Mapping of nodes to valid matplotlib color specifications.
node_alpha : dict
Mapping of nodes to node transparencies.
node_zorder : dict
Mapping of nodes to z-orders.
Returns
-------
node_artists: dict
Updates mapping of nodes to corresponding node artists.
"""
for node in nodes:
node_artist = NodeArtist(shape=node_shape[node],
xy=node_positions[node],
radius=node_size[node],
facecolor=node_color[node],
edgecolor=node_edge_color[node],
linewidth=node_edge_width[node],
alpha=node_alpha[node],
zorder=node_zorder[node])
self.ax.add_patch(node_artist)
if node in self.node_artists:
self.node_artists[node].remove()
self.node_artists[node] = node_artist
def _update_node_artists(self, nodes):
for node in nodes:
self.node_artists[node].xy = self.node_positions[node]
def _get_edge_paths(self, edges, node_positions, edge_layout, edge_layout_kwargs):
"""Compute the edge routing.
Parameters
----------
edges : list
The edges of the graph, with each edge being represented by a (source node ID, target node ID) tuple.
node_positions : dict
Mapping of nodes to (x, y) positions
edge_layout : 'straight', 'curved' or 'bundled' (default 'straight')
If 'straight', draw edges as straight lines.
If 'curved', draw edges as curved splines. The spline control points are optimised to avoid other nodes and edges.
If 'bundled', draw edges as edge bundles.
edge_layout_kwargs : dict
Keyword arguments passed to edge layout functions.
See the documentation of the following functions for a full list of available options:
- get_straight_edge_paths
- get_curved_edge_paths
- get_bundled_edge_paths
Returns
-------
edge_paths : dict
Mapping of edges to arrays of (x, y) tuples, the edge path coordinates.
"""
if edge_layout == 'straight':
edge_paths = get_straight_edge_paths(edges, node_positions,
edge_layout_kwargs['edge_width'])
selfloop_paths = get_selfloop_paths(edges, node_positions,
edge_layout_kwargs['selfloop_radius'],
edge_layout_kwargs['origin'],
edge_layout_kwargs['scale'],
edge_layout_kwargs['selfloop_angle'])
edge_paths.update(selfloop_paths)
elif edge_layout == 'curved':
edge_paths = get_curved_edge_paths(edges, node_positions, node_size=self.node_size, **edge_layout_kwargs)
elif edge_layout == 'arc':
edge_paths = get_arced_edge_paths(edges, node_positions,
rad=edge_layout_kwargs['rad'],
origin=edge_layout_kwargs['origin'],
scale=edge_layout_kwargs['scale'])
selfloop_paths = get_selfloop_paths(edges, node_positions,
edge_layout_kwargs['selfloop_radius'],
edge_layout_kwargs['origin'],
edge_layout_kwargs['scale'],
edge_layout_kwargs['selfloop_angle'])
edge_paths.update(selfloop_paths)
elif edge_layout == 'bundled':
edge_paths = get_bundled_edge_paths(edges, node_positions, **edge_layout_kwargs)
else:
raise NotImplementedError(f"Variable edge_layout one of 'straight', 'curved', 'arc' or 'bundled', not {edge_layout}")
return edge_paths
def draw_edges(self, edge_path, edge_width, edge_color, edge_alpha,
edge_zorder, arrows, node_size):
"""Draw or update edge artists.
Parameters
----------
edge_path : dict
Mapping of edges to arrays of (x, y) tuples, the edge path coordinates.
edge_width : dict
Mapping of edges to floats, the edge widths.
edge_color : dict
Mapping of edges to valid matplotlib color specifications, the edge colors.
edge_alpha : dict
Mapping of edges to floats, the edge transparencies.
edge_zorder : dict
Mapping of edges to ints, the edge z-order values.
arrows : bool
If True, draw edges with arrow heads.
node_size : dict
Mapping of nodes to node sizes. Required to offset edges from nodes.
Returns
-------
self.edge_artists: dict
Updates mapping of edges to corresponding edge artists.
"""
for edge in edge_path:
curved = False if (len(edge_path[edge]) == 2) else True
source, target = edge
if ((target, source) in edge_path) and (source != target): # i.e. bidirectional edges excluding self-loops
if np.allclose(edge_path[(source, target)], edge_path[(target, source)][::-1]): # i.e. same path
shape = 'right' # i.e. plot half arrow / thin line shifted to the right
else:
shape = 'full'
else:
shape = 'full'
if arrows:
head_length = 2 * edge_width[edge]
head_width = 3 * edge_width[edge]
else:
head_length = 0
head_width = 0
edge_artist = EdgeArtist(
midline = edge_path[edge],
width = edge_width[edge],
facecolor = edge_color[edge],
alpha = edge_alpha[edge],
head_length = head_length,
head_width = head_width,
edgecolor = 'none',
linewidth = 0.,
offset = node_size[target],
shape = shape,
curved = curved,
zorder = edge_zorder[edge],
)
self.ax.add_patch(edge_artist)
if edge in self.edge_artists:
self.edge_artists[edge].remove()
self.edge_artists[edge] = edge_artist
def _update_edge_artists(self, edge_paths=None):
if edge_paths is None:
edge_paths = self.edge_paths
for edge, path in edge_paths.items():
self.edge_artists[edge].update_midline(path)
self.ax.draw_artist(self.edge_artists[edge])
def _update_edges(self, edges):
edge_paths = dict()
if self.edge_layout == 'straight':
edge_paths.update(self._update_straight_edge_paths([(source, target) for (source, target) in edges if source != target]))
edge_paths.update(self._update_selfloop_paths([(source, target) for (source, target) in edges if source == target]))
elif self.edge_layout == 'curved':
edge_paths.update(self._update_curved_edge_paths(edges))
elif self.edge_layout == 'bundled':
edge_paths.update(self._update_bundled_edge_paths(edges))
elif self.edge_layout == 'arc':
edge_paths.update(self._update_arced_edge_paths([(source, target) for (source, target) in edges if source != target]))
edge_paths.update(self._update_selfloop_paths([(source, target) for (source, target) in edges if source == target]))
self.edge_paths.update(edge_paths)
self._update_edge_artists(edge_paths)
def _update_straight_edge_paths(self, edges):
# remove self-loops
edges = [(source, target) for source, target in edges if source != target]
edge_paths = dict()
for (source, target) in edges:
x0, y0 = self.node_positions[source]
x1, y1 = self.node_positions[target]
# # shift edge right if bi-directional
# if (target, source) in edges:
# x0, y0, x1, y1 = _shift_edge(x0, y0, x1, y1, delta=-0.1*self.edge_artists[(source, target)].width)
edge_paths[(source, target)] = np.c_[[x0, x1], [y0, y1]]
return edge_paths
def _update_selfloop_paths(self, edges):
# restrict to self-loops
edges = [(source, target) for source, target in edges if source == target]
edge_paths = dict()
for (source, target) in edges:
edge_paths[(source, target)] = _get_selfloop_path(
source,
node_positions = self.node_positions,
selfloop_radius = self.edge_layout_kwargs['selfloop_radius'],
origin = self.edge_layout_kwargs['origin'],
scale = self.edge_layout_kwargs['scale'],
angle = self.edge_layout_kwargs['selfloop_angle']
)
return edge_paths
def _update_curved_edge_paths(self, stale_edges):
"""Compute a new layout for curved edges keeping all other edges constant."""
fixed_positions = dict()
constant_edges = [edge for edge in self.edges if edge not in stale_edges]
for edge in constant_edges:
edge_artist = self.edge_artists[edge]
if edge_artist.curved:
for position in edge_artist.midline[1:-1]:
fixed_positions[uuid4()] = position
else:
# Densely sample points along the straight edge such that updated
# edges avoid the whole edge, not just the end points.
edge_origin = edge_artist.midline[0]
delta = edge_artist.midline[-1] - edge_artist.midline[0]
for ii in range(100):
# y = mx + b
m = (ii + 1) / (100 + 1)
fixed_positions[uuid4()] = m * delta + edge_origin
fixed_positions.update(self.node_positions)
return get_curved_edge_paths(stale_edges, fixed_positions, node_size=self.node_size, **self.edge_layout_kwargs)
def _update_bundled_edge_paths(self, edges):
# edge_paths = get_bundled_edge_paths(edges, self.node_positions, **self.edge_layout_kwargs)
return get_bundled_edge_paths(self.edges, self.node_positions, **self.edge_layout_kwargs)
def _update_arced_edge_paths(self, edges):
return get_arced_edge_paths(edges, self.node_positions, rad=self.edge_layout_kwargs['rad'])
def _initialize_node_label_offset(self, node_labels, node_label_offset):
if isinstance(node_label_offset, (int, float)):
node_label_offset = {node : node_label_offset * self._get_vector_pointing_outwards(self.node_positions[node]) for node in node_labels}
recompute = True
return node_label_offset, recompute
elif isinstance(node_label_offset, (tuple, list, np.ndarray)):
if len(node_label_offset) == 2:
node_label_offset = {node : node_label_offset for node in node_labels}
recompute = False
return node_label_offset, recompute
else:
msg = "If the variable `node_label_offset` is an iterable, it should have length 2."
msg+= f"Current length: {len(node_label_offset)}."
raise ValueError(msg)
else:
msg = "The variable `node_label_offset` has to be either a float, an int, a tuple, a list, or a numpy ndarray."
msg += f"\nCurrent type: {type(node_label_offset)}."
raise TypeError(msg)
def _get_centroid(self):
return np.mean([position for position in self.node_positions.values()], axis=0)
def _get_vector_pointing_outwards(self, xy):
centroid = self._get_centroid()
delta = xy - centroid
distance = np.linalg.norm(delta)
unit_vector = delta / distance
return unit_vector
def _initialize_node_label_fontdict(self, node_label_fontdict, node_labels, node_label_offset):
if node_label_fontdict is None:
node_label_fontdict = dict()
node_label_fontdict.setdefault('horizontalalignment', 'center')
node_label_fontdict.setdefault('verticalalignment', 'center')
node_label_fontdict.setdefault('clip_on', False)
node_label_fontdict.setdefault('zorder', np.inf)
if np.all(np.isclose(node_label_offset, (0, 0))):
# Labels are centered on node artists.
# Set fontsize such that labels fit the diameter of the node artists.
size = self._get_font_size(node_labels, node_label_fontdict) * 0.75 # conservative fudge factor
if ('size' not in node_label_fontdict) and ('fontsize' not in node_label_fontdict):
node_label_fontdict.setdefault('size', size)
return node_label_fontdict
def _get_font_size(self, node_labels, node_label_fontdict):
"""Determine the maximum font size such that all labels fit inside their node artist."""
# TODO:
# -----
# - potentially rescale font sizes individually on a per node basis
rescale_factor = np.inf
for node, label in node_labels.items():
artist = self.node_artists[node]
diameter = 2 * (artist.radius - artist._lw_data/artist.linewidth_correction)
width, height = _get_text_object_dimensions(self.ax, label, **node_label_fontdict)
rescale_factor = min(rescale_factor, diameter/np.sqrt(width**2 + height**2))
if 'size' in node_label_fontdict:
size = rescale_factor * node_label_fontdict['size']
elif 'fontsize' in node_label_fontdict:
size = rescale_factor * node_label_fontdict['fontsize']
else:
size = rescale_factor * plt.rcParams['font.size']
return size
def draw_node_labels(self, node_labels, node_label_fontdict):
"""Draw or update node labels.
Parameters
----------
node_labels : dict
Mapping of nodes to strings, the node labels.
Only nodes in the dictionary are labelled.
node_label_offset: tuple, default (0., 0.)
The (x, y) offset from node centre of label position.
node_label_fontdict : dict
Keyword arguments passed to matplotlib.text.Text.
For a full list of available arguments see the matplotlib documentation.
The following default values differ from the defaults for matplotlib.text.Text:
- size (adjusted to fit into node artists if offset is (0, 0))
- horizontalalignment (default here: 'center')
- verticalalignment (default here: 'center')
- clip_on (default here: False)
Returns
-------
self.node_label_artists: dict
Updates mapping of nodes to text objects, the node label artists.
"""
for node, label in node_labels.items():
x, y = self.node_positions[node]
dx, dy = self.node_label_offset[node]
artist = self.ax.text(x+dx, y+dy, label, **node_label_fontdict)
if node in self.node_label_artists:
self.node_label_artists[node].remove()
self.node_label_artists[node] = artist
def _update_node_label_positions(self):
if self._recompute_node_label_offsets:
self._update_node_label_offsets()
for node, (dx, dy) in self.node_label_offset.items():
x, y = self.node_positions[node]
self.node_label_artists[node].set_position((x + dx, y + dy))
def _update_node_label_offsets(self, total_samples_per_edge=100):
fixed = []
for xy in self.node_positions.values():
fixed.append(xy)
for path in self.edge_paths.values():
fixed.extend([_get_point_along_spline(path, fraction) for fraction in np.arange(0, 1, 1./total_samples_per_edge)])
fixed = np.array(fixed)
offsets = np.array(list(self.node_label_offset.values()))
anchors = np.array([self.node_positions[node] for node in self.node_label_offset.keys()])
offsets = self._optimise_offsets(anchors, offsets, fixed)
for ii, node in enumerate(self.node_label_offset):
self.node_label_offset[node] = offsets[ii]
# # Variant no 1: use force directed layout to determine a suitable node label placements
# # pros : labels repel each other
# # cons : does not work very well; the optimum placement can still result in a collision
# def _optimise_offsets(self, anchors, offsets, fixed, total_iterations=5):
# # Compute the net repulsion exerted on each label by nodes, edges and other labels.
# # Place the label in the direction of net repulsion at the desired distance from the corresponding node (anchor).
# # TODO Test if gradually stepping in the direction of net repulsion improves results.
# for ii in range(total_iterations):
# repulsion = self._get_repulsion(anchors + offsets, fixed)
# directions = repulsion / np.linalg.norm(repulsion, axis=-1)[:, np.newaxis]
# offsets = np.linalg.norm(offsets, axis=-1)[:, np.newaxis] * directions
# return offsets
# def _get_repulsion(self, mobile, fixed, minimum_distance=0.01):
# combined = np.concatenate([mobile, fixed], axis=0)
# delta = mobile[np.newaxis, :, :] - combined[:, np.newaxis, :]
# distance = np.linalg.norm(delta, axis=-1)
# direction = delta / distance[..., None] # i.e. the unit vector
# # 1. We clip the distance as we want to reduce overlaps with
# # all nearby plot elements, not just the one that overlaps the
# # most.
# # 2. We only care about interactions with nearby objects, so
# # we heavily penalise repulsion from far away items by using a
# # exponent.
# magnitude = 1. / np.clip(distance, minimum_distance, np.inf)**6
# repulsion = direction * magnitude[..., None]
# for ii in range(repulsion.shape[-1]):
# np.fill_diagonal(repulsion[:, :, ii], 0)
# return np.sum(repulsion, axis=0)
# Variant no 2:
# pros : straightforward optimisation; works very well
# cons : labels can still collide with each other
def _optimise_offsets(self, anchors, offsets, fixed, total_queries_per_point=360):
tree = cKDTree(fixed)
output = np.zeros_like(offsets)
for ii, (anchor, offset) in enumerate(zip(anchors, offsets)):
x = _get_n_points_on_a_circle(anchor, np.linalg.norm(offset), total_queries_per_point)
# distances, _ = tree.query(x, 1) # can result in many ties; first element is arbitrarily chosen
# output[ii] = x[np.argmax(distances)]
distances, _ = tree.query(x, 2)
output[ii] = x[np.argmax(np.sum(distances, axis=1))]
return output - anchors
def _initialize_edge_label_fontdict(self, edge_label_fontdict):
if edge_label_fontdict is None:
edge_label_fontdict = dict()
edge_label_fontdict.setdefault('bbox', dict(boxstyle='round',
ec=(1.0, 1.0, 1.0),
fc=(1.0, 1.0, 1.0)))
edge_label_fontdict.setdefault('horizontalalignment', 'center')
edge_label_fontdict.setdefault('verticalalignment', 'center')
edge_label_fontdict.setdefault('clip_on', False)
edge_label_fontdict.setdefault('zorder', np.inf)
return edge_label_fontdict
def draw_edge_labels(self, edge_labels, edge_label_position,
edge_label_rotate, edge_label_fontdict):
"""Draw or update edge labels.
Parameters
----------
edge_labels : dict
Mapping of edges to strings, the edge labels.
Only edges in the dictionary are labelled.
edge_label_position : float
Relative position along the edge where the label is placed.
head : 0.
centre : 0.5
tail : 1.
edge_label_rotate : bool
If True, edge labels are rotated such that they have the same orientation as their corresponding edge.
If False, edge labels are not rotated; the angle of the text is parallel to the axis.
edge_label_fontdict : dict
Keyword arguments passed to matplotlib.text.Text.
Returns
-------
self.edge_label_artists: dict
Updates mapping of edges to text objects, the edge label artists.
"""
for edge, label in edge_labels.items():
edge_artist = self.edge_artists[edge]
if self._is_selfloop(edge) and (edge_artist.curved is False):
msg = "Plotting of edge labels for self-loops not supported for straight edges."
msg += "\nIgnoring edge with label: {}".format(label)
warnings.warn(msg)
continue
x, y = _get_point_along_spline(edge_artist.midline, edge_label_position)
if edge_label_rotate:
# get tangent in degrees
dx, dy = _get_tangent_at_point(edge_artist.midline, edge_label_position)
angle = _get_angle(dx, dy, radians=True)
# make label orientation "right-side-up"
if angle > 90:
angle -= 180
if angle < - 90:
angle += 180
else:
angle = None
edge_label_artist = self.ax.text(x, y, label,
rotation=angle,
**edge_label_fontdict)
if edge in self.edge_label_artists:
self.edge_label_artists[edge].remove()
self.edge_label_artists[edge] = edge_label_artist
def _is_selfloop(self, edge):
return True if edge[0] == edge[1] else False
def _update_edge_label_positions(self, edges):
labeled_edges = [edge for edge in edges if edge in self.edge_label_artists]
for (n1, n2) in labeled_edges:
edge_artist = self.edge_artists[(n1, n2)]
if edge_artist.curved:
x, y = _get_point_along_spline(edge_artist.midline, self.edge_label_position)
dx, dy = _get_tangent_at_point(edge_artist.midline, self.edge_label_position)
elif not edge_artist.curved and (n1 != n2):
(x1, y1) = self.node_positions[n1]
(x2, y2) = self.node_positions[n2]
if (n2, n1) in self.edges: # i.e. bidirectional edge
x1, y1, x2, y2 = _shift_edge(x1, y1, x2, y2, delta=1.5*self.edge_artists[(n1, n2)].width)
x, y = (x1 * self.edge_label_position + x2 * (1.0 - self.edge_label_position),
y1 * self.edge_label_position + y2 * (1.0 - self.edge_label_position))
dx, dy = x2 - x1, y2 - y1
else: # self-loop but edge is straight so we skip it
pass
self.edge_label_artists[(n1, n2)].set_position((x, y))
if self.edge_label_rotate:
angle = _get_angle(dx, dy, radians=True)
# make label orientation "right-side-up"
if angle > 90:
angle -= 180
if angle < -90:
angle += 180
# transform data coordinate angle to screen coordinate angle
trans_angle = self.ax.transData.transform_angles(np.array((angle,)), np.atleast_2d((x, y)))[0]
self.edge_label_artists[(n1, n2)].set_rotation(trans_angle)
def _update_view(self):
# Pad x and y limits as patches are not registered properly
# when matplotlib sets axis limits automatically.
# Hence we need to set them manually.
# max_radius = np.max([artist.radius for artist in self.node_artists.values()])
# maxx, maxy = np.max(list(self.node_positions.values()), axis=0)
# minx, miny = np.min(list(self.node_positions.values()), axis=0)
# w = maxx-minx
# h = maxy-miny
# padx, pady = 0.05*w + max_radius, 0.05*h + max_radius
# corners = (minx-padx, miny-pady), (maxx+padx, maxy+pady)
# self.ax.update_datalim(corners)
self.ax.autoscale_view()
self.ax.get_figure().canvas.draw()
[docs]
class Graph(BaseGraph):
"""Parses the given graph data object and initialises the BaseGraph object.
If the given graph includes edge weights, then these are mapped to colors using the `edge_cmap` parameter.
Parameters
----------
graph : various formats
Graph object to plot. Various input formats are supported.
In order of precedence:
- Edge list:
Iterable of (source, target) or (source, target, weight) tuples,
or equivalent (E, 2) or (E, 3) ndarray, where E is the number of edges.
- Adjacency matrix:
Full-rank (V, V) ndarray, where V is the number of nodes/vertices.
The absence of a connection is indicated by a zero.
.. note:: If V <= 3, any (2, 2) or (3, 3) matrices will be interpreted as edge lists.**
- networkx.Graph, igraph.Graph, or graph_tool.Graph object
node_layout : str or dict, default 'spring'
If `node_layout` is a string, the node positions are computed using the indicated method:
- 'random' : place nodes in random positions;
- 'circular' : place nodes regularly spaced on a circle;
- 'spring' : place nodes using a force-directed layout (Fruchterman-Reingold algorithm);
- 'dot' : place nodes using the Sugiyama algorithm; the graph should be directed and acyclic;
- 'radial' : place nodes radially using the Sugiyama algorithm; the graph should be directed and acyclic;
- 'community' : place nodes such that nodes belonging to the same community are grouped together;
- 'bipartite' : place nodes regularly spaced on two parallel lines;
- 'multipartite' : place nodes regularly spaced on several parallel lines;
- 'shell' : place nodes regularly spaced on concentric circles;
- 'geometric' : place nodes according to the length of the edges between them.
If `node_layout` is a dict, keys are nodes and values are (x, y) positions.
node_layout_kwargs : dict or None, default None
Keyword arguments passed to node layout functions.
See the documentation of the following functions for a full description of available options:
- get_random_layout
- get_circular_layout
- get_fruchterman_reingold_layout
- get_sugiyama_layout
- get_radial_tree_layout
- get_community_layout
- get_bipartite_layout
- get_multipartite_layout
- get_shell_layout
- get_geometric_layout
node_shape : str or dict, default 'o'
Node shape.
If the type is str, all nodes have the same shape.
If the type is dict, maps each node to an individual string representing the shape.
The string specification is as for matplotlib.scatter marker, i.e. one of 'so^>v<dph8'.
node_size : float or dict, default 3.
Node size (radius).
If the type is float, all nodes will have the same size.
If the type is dict, maps each node to an individual size.
.. note:: Values are rescaled by BASE_SCALE (1e-2) to be compatible with layout routines in igraph and networkx.
node_edge_width : float or dict, default 0.5
Line width of node marker border.
If the type is float, all nodes have the same line width.
If the type is dict, maps each node to an individual line width.
.. note: Values are rescaled by BASE_SCALE (1e-2) to be compatible with layout routines in igraph and networkx.
node_color : matplotlib color specification or dict, default 'w'
Node color.
If the type is a string or RGBA array, all nodes have the same color.
If the type is dict, maps each node to an individual color.
node_edge_color : matplotlib color specification or dict, default DEFAULT_COLOR
Node edge color.
If the type is a string or RGBA array, all nodes have the same edge color.
If the type is dict, maps each node to an individual edge color.
node_alpha : scalar or dict, default 1.
Node transparency.
If the type is a float, all nodes have the same transparency.
If the type is dict, maps each node to an individual transparency.
node_zorder : int or dict, default 2
Order in which to plot the nodes.
If the type is an int, all nodes have the same zorder.
If the type is dict, maps each node to an individual zorder.
node_labels : bool or dict, (default False)
If False, the nodes are unlabelled.
If True, the nodes are labelled with their node IDs.
If the node labels are to be distinct from the node IDs, supply a dictionary mapping nodes to node labels.
Only nodes in the dictionary are labelled.
node_label_offset: float or tuple, default (0., 0.)
A (dx, dy) tuple specifies the exact offset from the node position.
If a single scalar delta is specified, the value is interpreted as a distance,
and the label is placed delta away from the node position while trying to
reduce node/label, node/edge, and label/label overlaps.
node_label_fontdict : dict
Keyword arguments passed to matplotlib.text.Text.
For a full list of available arguments see the matplotlib documentation.
The following default values differ from the defaults for matplotlib.text.Text:
- size (adjusted to fit into node artists if offset is (0, 0))
- horizontalalignment (default here: 'center')
- verticalalignment (default here: 'center')
- clip_on (default here: False)
- zorder (default here: inf)
edge_width : float or dict, default 1.
Width of edges.
If the type is a float, all edges have the same width.
If the type is dict, maps each edge to an individual width.
.. note:: Value is rescaled by BASE_SCALE (1e-2) to be compatible with layout routines in igraph and networkx.
edge_cmap : matplotlib color map (default 'RdGy')
Color map used to map edge weights to edge colors. Should be diverging.
If edge weights are strictly positive, weights are mapped to the
left hand side of the color map with vmin=0 and vmax=np.max(weights).
If edge weights are positive and negative, then weights are mapped
to colors such that a weight of zero corresponds to the center of the
color map; the boundaries are set to +/- the maximum absolute weight.
If the graph is unweighted or the edge colors are specified explicitly,
this parameter is ignored.
edge_color : matplotlib color specification or dict, default DEFAULT_COLOR
Edge color. If provided explicitly, overrides `edge_cmap`.
If the type is a string or RGBA array, all edges have the same color.
If the type is dict, maps each edge to an individual color.
edge_alpha : float or dict, default 1.
The edge transparency,
If the type is a float, all edges have the same transparency.
If the type is dict, maps each edge to an individual transparency.
edge_zorder : int or dict, default 1
Order in which to plot the edges.
If the type is an int, all nodes have the same zorder.
If the type is dict, maps each node to an individual zorder.
If None, the edges will be plotted in the order they appear in 'adjacency'.
Hint: graphs typically appear more visually pleasing if darker edges are plotted on top of lighter edges.
arrows : bool, default False
If True, draw edges with arrow heads.
edge_layout : str or dict (default 'straight')
If edge_layout is a string, determine the layout internally:
- 'straight' : draw edges as straight lines
- 'curved' : draw edges as curved splines; the spline control points are optimised to avoid other nodes and edges
- 'bundled' : draw edges as edge bundles
If edge_layout is a dict, the keys are edges and the values are edge paths
in the form iterables of (x, y) tuples, the edge segments.
edge_layout_kwargs : dict, default None
Keyword arguments passed to edge layout functions.
See the documentation of the following functions for a full description of available options:
- get_straight_edge_paths
- get_curved_edge_paths
- get_bundled_edge_paths
edge_labels : bool or dict, default False
If False, the edges are unlabelled.
If True, the edges are labelled with their edge IDs.
If the edge labels are to be distinct from the edge IDs, supply a dictionary mapping edges to edge labels.
Only edges in the dictionary are labelled.
edge_label_position : float, default 0.5
Relative position along the edge where the label is placed.
- head : 0.
- centre : 0.5
- tail : 1.
edge_label_rotate : bool, default True
If True, edge labels are rotated such that they have the same orientation as their edge.
If False, edge labels are not rotated; the angle of the text is parallel to the axis.
edge_label_fontdict : dict
Keyword arguments passed to matplotlib.text.Text.
For a full list of available arguments see the matplotlib documentation.
The following default values differ from the defaults for matplotlib.text.Text:
- horizontalalignment (default here: 'center'),
- verticalalignment (default here: 'center')
- clip_on (default here: False),
- bbox (default here: dict(boxstyle='round', ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0)),
- zorder (default here: inf),
- rotation (determined by edge_label_rotate argument)
origin : tuple, default (0., 0.)
The lower left hand corner of the bounding box specifying the extent of the canvas.
scale : tuple, default (1., 1.)
The width and height of the bounding box specifying the extent of the canvas.
prettify : bool, default True
If True, despine and remove ticks and tick labels.
Set figure background to white. Set axis aspect to equal.
ax : matplotlib.axis instance or None, default None
Axis to plot onto; if none specified, one will be instantiated with plt.gca().
Attributes
----------
node_artists : dict
Mapping of node IDs to matplotlib PathPatch artists.
edge_artists : dict
Mapping of edge IDs to matplotlib PathPatch artists.
node_label_artists : dict
Mapping of node IDs to matplotlib text objects (if applicable).
edge_label_artists : dict
Mapping of edge IDs to matplotlib text objects (if applicable).
node_positions : dict node : (x, y) tuple
Mapping of node IDs to node positions.
See also
--------
BaseGraph, InteractiveGraph
"""
def __init__(self, graph, edge_cmap='RdGy', *args, **kwargs):
# Accept a variety of formats for 'graph' and convert to common denominator.
nodes, edges, edge_weight = parse_graph(graph)
kwargs.setdefault('nodes', nodes)
# Color and reorder edges for weighted graphs.
if edge_weight:
# If the graph is weighted, we want to visualise the weights using color.
# Edge width is another popular choice when visualising weighted networks,
# but if the variance in weights is large, this typically results in less
# visually pleasing results.
edge_color = _get_color(edge_weight, cmap=edge_cmap)
# Plotting darker edges over lighter edges typically results in visually
# more pleasing results. Here we hence specify the relative order in
# which edges are plotted according to the color of the edge.
edge_zorder = _get_zorder(edge_color)
node_zorder = np.max(list(edge_zorder.values())) + 1
kwargs.setdefault('edge_color', edge_color)
kwargs.setdefault('edge_zorder', edge_zorder)
kwargs.setdefault('node_zorder', node_zorder)
super().__init__(edges, *args, **kwargs)
def _get_color(mydict, cmap='RdGy', vmin=None, vmax=None):
"""Map positive and negative floats to a diverging colormap, such that
- the midpoint of the colormap corresponds to a value of 0., and
- values above and below the midpoint are mapped linearly and in equal measure
to increases in color intensity.
Parameters
----------
mydict: dict
Mapping of graph element (node, edge) to a float.
For example (source, target) : edge weight.
cmap: str, default 'RdGy'
Matplotlib colormap specification.
vmin, vmax: float or None, default None
Minimum and maximum float corresponding to the dynamic range of the colormap.
Returns
-------
newdict: dict
Mapping of graph elements to RGBA tuples.
"""
keys = mydict.keys()
values = np.array(list(mydict.values()), dtype=float)
# apply vmin, vmax
if vmin or vmax:
values = np.clip(values, vmin, vmax)
def abs(value):
try:
return np.abs(value)
except TypeError as e: # value is probably None
if isinstance(value, type(None)):
return 0
else:
raise e
# rescale values such that
# - the colormap midpoint is at zero-value, and
# - negative and positive values have comparable intensity values
values /= np.nanmax([np.nanmax(np.abs(values)), abs(vmax), abs(vmin)]) # [-1, 1]
values += 1. # [0, 2]
values /= 2. # [0, 1]
# convert value to color
mapper = mpl.cm.ScalarMappable(cmap=cmap)
mapper.set_clim(vmin=0., vmax=1.)
colors = mapper.to_rgba(values)
return {key: color for (key, color) in zip(keys, colors)}
def _get_zorder(color_dict):
"""Reorder plot elements such that darker items are plotted last and hence most prominent in the graph.
This assumes that the background is white.
"""
intensities = [rgba_to_grayscale(*v) for v in color_dict.values()]
zorder = _rank(intensities)
zorder = np.max(zorder) - zorder # reverse order as greater values correspond to lighter colors
return {key: index for key, index in zip(color_dict.keys(), zorder)}
def rgba_to_grayscale(r, g, b, a=1):
"""Convert RGBA values to grayscale.
Notes
-----
Adapted from: https://stackoverflow.com/a/689547/2912349
"""
return (0.299 * r + 0.587 * g + 0.114 * b) * a
class ClickableArtists(object):
"""Implements selection of matplotlib artists via the mouse left click (+/- ctrl or command key).
Notes:
------
Adapted from: https://stackoverflow.com/a/47312637/2912349
"""
def __init__(self, artists):
try:
self.fig, = set(list(artist.figure for artist in artists))
except ValueError:
raise Exception("All artists have to be on the same figure!")
try:
self.ax, = set(list(artist.axes for artist in artists))
except ValueError:
raise Exception("All artists have to be on the same axis!")
# self.fig.canvas.mpl_connect('button_press_event', self._on_press)
self.fig.canvas.mpl_connect('button_release_event', self._on_release)
self._clickable_artists = list(artists)
self._selected_artists = []
self._base_linewidth = dict([(artist, artist._lw_data) for artist in artists])
self._base_edgecolor = dict([(artist, artist.get_edgecolor()) for artist in artists])
if mpl.get_backend() == 'MacOSX':
msg = "You appear to be using the MacOSX backend."
msg += "\nModifier key presses are bugged on this backend. See https://github.com/matplotlib/matplotlib/issues/20486"
msg += "\nConsider using a different backend, e.g. TkAgg (import matplotlib; matplotlib.use('TkAgg'))."
msg += "\nNote that you must set the backend before importing any package depending on matplotlib (includes pyplot, networkx, netgraph)."
warnings.warn(msg)
# def _on_press(self, event):
def _on_release(self, event):
if event.inaxes == self.ax:
for artist in self._clickable_artists:
if artist.contains(event)[0]:
if event.key in ('control', 'super+??', 'ctrl+??'):
self._toggle_select_artist(artist)
else:
self._deselect_all_other_artists(artist)
self._toggle_select_artist(artist)
# NOTE: if two artists are overlapping, only the first one encountered is selected!
break
else:
if not event.key in ('control', 'super+??', 'ctrl+??'):
self._deselect_all_artists()
else:
print("Warning: clicked outside axis limits!")
def _toggle_select_artist(self, artist):
if artist in self._selected_artists:
self._deselect_artist(artist)
else:
self._select_artist(artist)
def _select_artist(self, artist):
if not (artist in self._selected_artists):
linewidth = artist._lw_data
artist.set_linewidth(max(1.5 * linewidth, 0.003))
artist.set_edgecolor('black')
self._selected_artists.append(artist)
self.fig.canvas.draw_idle()
def _deselect_artist(self, artist):
if artist in self._selected_artists: # should always be true?
artist.set_linewidth(self._base_linewidth[artist])
artist.set_edgecolor(self._base_edgecolor[artist])
self._selected_artists.remove(artist)
self.fig.canvas.draw_idle()
def _deselect_all_artists(self):
for artist in self._selected_artists[:]: # we make a copy of the list with [:], as we are modifying the list being iterated over
self._deselect_artist(artist)
def _deselect_all_other_artists(self, artist_to_keep):
for artist in self._selected_artists[:]:
if artist != artist_to_keep:
self._deselect_artist(artist)
class SelectableArtists(ClickableArtists):
"""Augments ClickableArtists with a rectangle selector.
Notes:
------
Adapted from: https://stackoverflow.com/a/47312637/2912349
"""
def __init__(self, artists):
super().__init__(artists)
self.fig.canvas.mpl_connect('button_press_event', self._on_press)
# self.fig.canvas.mpl_connect('button_release_event', self._on_release)
self.fig.canvas.mpl_connect('motion_notify_event', self._on_motion)
self._selectable_artists = list(artists)
self._currently_selecting = False
self._rect = plt.Rectangle((0, 0), 1, 1, linestyle="--", edgecolor="crimson", fill=False)
self.ax.add_patch(self._rect)
self._rect.set_visible(False)
self._x0 = 0
self._y0 = 0
self._x1 = 0
self._y1 = 0
def _on_press(self, event):
# super()._on_press(event)
if event.inaxes == self.ax:
# reset rectangle
self._x0 = event.xdata
self._y0 = event.ydata
self._x1 = event.xdata
self._y1 = event.ydata
for artist in self._clickable_artists:
if artist.contains(event)[0]:
break
else:
self._currently_selecting = True
def _on_release(self, event):
super()._on_release(event)
if self._currently_selecting:
# select artists inside window
for artist in self._selectable_artists:
if isinstance(artist, NodeArtist):
if self._is_inside_rect(*artist.xy):
if event.key in ('control', 'super+??', 'ctrl+??'): # if/else probably superfluouos
self._toggle_select_artist(artist) # as no artists will be selected
else: # if control is not held previously
self._select_artist(artist) #
elif isinstance(artist, EdgeArtist):
if np.all([self._is_inside_rect(x, y) for x, y in artist.midline]):
if event.key in ('control', 'super+??', 'ctrl+??'): # if/else probably superfluouos
self._toggle_select_artist(artist) # as no artists will be selected
else: # if control is not held previously
self._select_artist(artist) #
# stop window selection and draw new state
self._currently_selecting = False
self._rect.set_visible(False)
self.fig.canvas.draw_idle()
def _on_motion(self, event):
if event.inaxes == self.ax:
if self._currently_selecting:
self._x1 = event.xdata
self._y1 = event.ydata
# add rectangle for selection here
self._selector_on()
def _is_inside_rect(self, x, y):
xlim = np.sort([self._x0, self._x1])
ylim = np.sort([self._y0, self._y1])
if (xlim[0]<=x) and (x<xlim[1]) and (ylim[0]<=y) and (y<ylim[1]):
return True
else:
return False
def _selector_on(self):
self._rect.set_visible(True)
xlim = np.sort([self._x0, self._x1])
ylim = np.sort([self._y0, self._y1])
self._rect.set_xy((xlim[0], ylim[0]))
self._rect.set_width(xlim[1] - xlim[0])
self._rect.set_height(ylim[1] - ylim[0])
self.fig.canvas.draw_idle()
class DraggableArtists(SelectableArtists):
"""Augments SelectableArtists to support dragging of artists by holding the left mouse button.
Notes:
------
Adapted from: https://stackoverflow.com/a/47312637/2912349
"""
def __init__(self, artists):
super().__init__(artists)
self._draggable_artists = list(artists)
self._currently_clicking_on_artist = None
self._currently_dragging = False
self._offset = dict()
def _on_press(self, event):
super()._on_press(event)
if event.inaxes == self.ax:
for artist in self._draggable_artists:
if artist.contains(event)[0]:
self._currently_clicking_on_artist = artist
break
else:
print("Warning: clicked outside axis limits!")
def _on_motion(self, event):
super()._on_motion(event)
if event.inaxes == self.ax:
if self._currently_clicking_on_artist:
if self._currently_clicking_on_artist not in self._selected_artists:
if event.key not in ('control', 'super+??', 'ctrl+??'):
self._deselect_all_artists()
self._select_artist(self._currently_clicking_on_artist)
self._offset = {artist : artist.xy - np.array([event.xdata, event.ydata]) for artist in self._selected_artists if artist in self._draggable_artists}
self._currently_clicking_on_artist = None
self._currently_dragging = True
if self._currently_dragging:
self._move(event)
def _on_release(self, event):
if self._currently_dragging:
self._currently_dragging = False
else:
self._currently_clicking_on_artist = None
super()._on_release(event)
def _move(self, event):
cursor_position = np.array([event.xdata, event.ydata])
for artist in self._selected_artists:
artist.xy = cursor_position + self._offset[artist]
self.fig.canvas.draw_idle()
class DraggableGraph(Graph, DraggableArtists):
"""Augments `Graph` to support selection and dragging of node artists with the mouse."""
def __init__(self, *args, **kwargs):
Graph.__init__(self, *args, **kwargs)
DraggableArtists.__init__(self, self.node_artists.values())
self._draggable_artist_to_node = dict(zip(self.node_artists.values(), self.node_artists.keys()))
self._clickable_artists.extend(list(self.edge_artists.values()))
self._selectable_artists.extend(list(self.edge_artists.values()))
self._base_linewidth.update(dict([(artist, artist._lw_data) for artist in self.edge_artists.values()]))
self._base_edgecolor.update(dict([(artist, artist.get_edgecolor()) for artist in self.edge_artists.values()]))
# # trigger resize of labels when canvas size changes
# self.fig.canvas.mpl_connect('resize_event', self._on_resize)
def _move(self, event):
cursor_position = np.array([event.xdata, event.ydata])
nodes = self._get_stale_nodes()
self._update_node_positions(nodes, cursor_position)
self._update_node_artists(nodes)
if hasattr(self, 'node_label_artists'):
self._update_node_label_positions()
edges = self._get_stale_edges(nodes)
# In the interest of speed, we only compute the straight edge paths here.
# We will re-compute other edge layouts only on mouse button release,
# i.e. when the dragging motion has stopped.
edge_paths = dict()
edge_paths.update(self._update_straight_edge_paths([(source, target) for (source, target) in edges if source != target]))
edge_paths.update(self._update_selfloop_paths([(source, target) for (source, target) in edges if source == target]))
self.edge_paths.update(edge_paths)
self._update_edge_artists(edge_paths)
if hasattr(self, 'edge_label_artists'):
self._update_edge_label_positions(edges)
self.fig.canvas.draw_idle()
def _get_stale_nodes(self):
return [self._draggable_artist_to_node[artist] for artist in self._selected_artists if artist in self._draggable_artists]
def _update_node_positions(self, nodes, cursor_position):
for node in nodes:
self.node_positions[node] = cursor_position + self._offset[self.node_artists[node]]
def _get_stale_edges(self, nodes=None):
if nodes is None:
nodes = self._get_stale_nodes()
return [(source, target) for (source, target) in self.edges if (source in nodes) or (target in nodes)]
def _on_release(self, event):
if self._currently_dragging and not (self.edge_layout == 'straight'):
nodes = self._get_stale_nodes()
edges = self._get_stale_edges(nodes)
self._update_edges(edges)
if hasattr(self, 'edge_label_artists'): # move edge labels
self._update_edge_label_positions(edges)
super()._on_release(event)
# def _on_resize(self, event):
# if hasattr(self, 'node_labels'):
# self.draw_node_labels(self.node_labels)
# # print("As node label font size was not explicitly set, automatically adjusted node label font size to {:.2f}.".format(self.node_label_font_size))
class EmphasizeOnHover(object):
"""Emphasize matplotlib artists when hovering over them by desaturating all other artists."""
def __init__(self, artists):
self.emphasizeable_artists = artists
self._base_alpha = {artist : artist.get_alpha() for artist in self.emphasizeable_artists}
self.deemphasized_artists = []
try:
self.fig, = set(list(artist.figure for artist in artists))
except ValueError:
raise Exception("All artists have to be on the same figure!")
try:
self.ax, = set(list(artist.axes for artist in artists))
except ValueError:
raise Exception("All artists have to be on the same axis!")
self.fig.canvas.mpl_connect("motion_notify_event", self._on_motion)
def _on_motion(self, event):
if event.inaxes == self.ax:
# on artist
selected_artist = None
for artist in self.emphasizeable_artists:
if artist.contains(event)[0]: # returns two arguments for some reason
selected_artist = artist
break
if selected_artist:
for artist in self.emphasizeable_artists:
if artist is not selected_artist:
artist.set_alpha(self._base_alpha[artist]/5)
self.deemphasized_artists.append(artist)
self.fig.canvas.draw_idle()
# not on any artist
if (selected_artist is None) and self.deemphasized_artists:
for artist in self.deemphasized_artists:
artist.set_alpha(self._base_alpha[artist])
self.deemphasized_artists = []
self.fig.canvas.draw_idle()
class DraggableGraphWithGridMode(DraggableGraph):
"""
Implements a grid-mode, in which node positions are fixed to a grid.
To activate, press the letter 'g'.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.grid = False
self.grid_dx = 0.05 * self.scale[0]
self.grid_dy = 0.05 * self.scale[1]
self._grid_lines = []
self.fig.canvas.mpl_connect('key_press_event', self._on_key_toggle)
def _on_key_toggle(self, event):
if event.key == 'g':
if self.grid is False:
self.grid = True
self._draw_grid()
else:
self.grid = False
self._remove_grid()
self.fig.canvas.draw_idle()
def _draw_grid(self):
eps = 1e-13
for x in np.arange(self.origin[0], self.origin[0] + self.scale[0] + eps, self.grid_dx):
line = self.ax.axvline(x, color='k', alpha=0.1, linestyle='--')
self._grid_lines.append(line)
for y in np.arange(self.origin[1], self.origin[1] + self.scale[1] + eps, self.grid_dy):
line = self.ax.axhline(y, color='k', alpha=0.1, linestyle='--')
self._grid_lines.append(line)
def _remove_grid(self):
for line in self._grid_lines:
line.remove()
self._grid_lines = []
def _on_release(self, event):
if self._currently_dragging and self.grid:
nodes = self._get_stale_nodes()
for node in nodes:
self.node_positions[node] = self._get_nearest_grid_coordinate(*self.node_positions[node])
self._update_node_artists(nodes)
if hasattr(self, 'node_label_artists'):
self._update_node_label_positions()
edges = self._get_stale_edges(nodes)
self._update_edges(edges)
if hasattr(self, 'edge_label_artists'):
self._update_edge_label_positions(edges)
super()._on_release(event)
def _get_nearest_grid_coordinate(self, x, y):
x = np.round((x - self.origin[0]) / self.grid_dx) * self.grid_dx + self.origin[0]
y = np.round((y - self.origin[1]) / self.grid_dy) * self.grid_dy + self.origin[1]
return x, y
class EmphasizeOnHoverGraph(Graph, EmphasizeOnHover):
"""Combines `EmphasizeOnHover` with the `Graph` class such that nodes are emphasized when hovering over them with the mouse.
Parameters
----------
graph : various formats
Graph object to plot. Various input formats are supported.
In order of precedence:
- Edge list:
Iterable of (source, target) or (source, target, weight) tuples,
or equivalent (E, 2) or (E, 3) ndarray, where E is the number of edges.
- Adjacency matrix:
Full-rank (V, V) ndarray, where V is the number of nodes/vertices.
The absence of a connection is indicated by a zero.
.. note:: If V <= 3, any (2, 2) or (3, 3) matrices will be interpreted as edge lists.**
- networkx.Graph, igraph.Graph, or graph_tool.Graph object
mouseover_highlight_mapping : dict or None, default None
Determines which nodes and/or edges are highlighted when hovering over any given node or edge.
The keys of the dictionary are node and/or edge IDs, while the values are iterables of node and/or edge IDs.
If the parameter is None, a default dictionary is constructed, which maps
- edges to themselves as well as their source and target nodes, and
- nodes to themselves as well as their immediate neighbours and any edges between them.
*args, **kwargs
Parameters passed through to `Graph`. See its documentation for a full list of available arguments.
Attributes
----------
node_artists : dict
Mapping of node IDs to matplotlib PathPatch artists.
edge_artists : dict
Mapping of edge IDs to matplotlib PathPatch artists.
node_label_artists : dict
Mapping of node IDs to matplotlib text objects (if applicable).
edge_label_artists : dict
Mapping of edge IDs to matplotlib text objects (if applicable).
node_positions : dict node : (x, y) tuple
Mapping of node IDs to node positions.
See also
--------
Graph
"""
def __init__(self, graph, mouseover_highlight_mapping=None, *args, **kwargs):
Graph.__init__(self, graph, *args, **kwargs)
artists = list(self.node_artists.values()) + list(self.edge_artists.values())
keys = list(self.node_artists.keys()) + list(self.edge_artists.keys())
self.artist_to_key = dict(zip(artists, keys))
EmphasizeOnHover.__init__(self, artists)
if mouseover_highlight_mapping is None: # construct default mapping
self.mouseover_highlight_mapping = self._get_default_mouseover_highlight_mapping()
else: # this includes empty mappings!
self._check_mouseover_highlight_mapping(mouseover_highlight_mapping)
self.mouseover_highlight_mapping = mouseover_highlight_mapping
def _get_default_mouseover_highlight_mapping(self):
mapping = dict()
# mapping for edges: source node, target node and the edge itself
for (source, target) in self.edges:
mapping[(source, target)] = [(source, target), source, target]
# mapping for nodes: the node itself, its neighbours, and any edges between them
adjacency_list = _edge_list_to_adjacency_list(self.edges, directed=False)
for node, neighbours in adjacency_list.items():
mapping[node] = [node]
for neighbour in neighbours:
mapping[node].append(neighbour)
if (node, neighbour) in self.edge_artists:
mapping[node].append((node, neighbour))
if (neighbour, node) in self.edge_artists:
mapping[node].append((neighbour, node))
return mapping
def _check_mouseover_highlight_mapping(self, mapping):
if not isinstance(mapping, dict):
raise TypeError(f"Parameter `mouseover_highlight_mapping` is a dictionary, not {type(mapping)}.")
invalid_keys = []
for key in mapping:
if key in self.node_artists:
pass
elif key in self.edge_artists:
pass
else:
invalid_keys.append(key)
if invalid_keys:
msg = "Parameter `mouseover_highlight_mapping` contains invalid keys:"
for key in invalid_keys:
msg += f"\n\t- {key}"
raise ValueError(msg)
invalid_values = []
for values in mapping.values():
for value in values:
if value in self.node_artists:
pass
elif value in self.edge_artists:
pass
else:
invalid_values.append(value)
if invalid_values:
msg = "Parameter `mouseover_highlight_mapping` contains invalid values:"
for value in set(invalid_values):
msg += f"\n\t- {value}"
raise ValueError(msg)
def _on_motion(self, event):
if event.inaxes == self.ax:
# determine if the cursor is on an artist
selected_artist = None
for artist in self.emphasizeable_artists:
if artist.contains(event)[0]: # returns bool, {} for some reason
selected_artist = artist
break
if selected_artist:
key = self.artist_to_key[artist]
if key in self.mouseover_highlight_mapping:
emphasized_artists = []
for value in self.mouseover_highlight_mapping[key]:
if value in self.node_artists:
emphasized_artists.append(self.node_artists[value])
elif value in self.edge_artists:
emphasized_artists.append(self.edge_artists[value])
for artist in self.emphasizeable_artists:
if artist not in emphasized_artists:
artist.set_alpha(self._base_alpha[artist]/5)
self.deemphasized_artists.append(artist)
self.fig.canvas.draw_idle()
# not on any artist
if (selected_artist is None) and self.deemphasized_artists:
for artist in self.deemphasized_artists:
try:
artist.set_alpha(self._base_alpha[artist])
except KeyError:
# This mitigates issue #66.
pass
self.deemphasized_artists = []
self.fig.canvas.draw_idle()
class AnnotateOnClick(object):
"""Show or hide annotations when clicking on matplotlib artists."""
def __init__(self, artist_to_annotation, annotation_fontdict=None):
self.artist_to_annotation = artist_to_annotation
self.annotated_artists = set()
self.artist_to_text_object = dict()
self.annotation_fontdict = dict(
backgroundcolor = 'white',
zorder = np.inf,
clip_on = False
)
if annotation_fontdict:
self.annotation_fontdict.update(annotation_fontdict)
self.fig.canvas.mpl_connect("button_release_event", self._on_release)
def _on_release(self, event):
if event.inaxes == self.ax:
# clicked on already annotated artist
for artist in self.annotated_artists:
if artist.contains(event)[0]:
self._remove_annotation(artist)
self.fig.canvas.draw()
return
# clicked on un-annotated artist
for artist in self.artist_to_annotation:
if artist.contains(event)[0]:
placement = self._get_annotation_placement(artist)
self._add_annotation(artist, *placement)
self.fig.canvas.draw()
return
# # clicked outside of any artist
# for artist in list(self.annotated_artists): # list to force copy
# self._remove_annotation(artist)
# self.fig.canvas.draw()
def _get_annotation_placement(self, artist):
vector = self._get_vector_pointing_outwards(artist.xy)
x, y = artist.xy + 2 * artist.radius * vector
horizontalalignment, verticalalignment = self._get_text_alignment(vector)
return x, y, horizontalalignment, verticalalignment
def _get_centroid(self):
return np.mean([artist.xy for artist in self.artist_to_annotation], axis=0)
def _get_vector_pointing_outwards(self, xy):
centroid = self._get_centroid()
delta = xy - centroid
distance = np.linalg.norm(delta)
unit_vector = delta / distance
return unit_vector
def _get_text_alignment(self, vector):
dx, dy = vector
angle = _get_angle(dx, dy, radians=True) % 360
if (45 <= angle < 135):
horizontalalignment = 'center'
verticalalignment = 'bottom'
elif (135 <= angle < 225):
horizontalalignment = 'right'
verticalalignment = 'center'
elif (225 <= angle < 315):
horizontalalignment = 'center'
verticalalignment = 'top'
else:
horizontalalignment = 'left'
verticalalignment = 'center'
return horizontalalignment, verticalalignment
def _add_annotation(self, artist, x, y, horizontalalignment, verticalalignment):
params = self.annotation_fontdict.copy()
params.setdefault('horizontalalignment', horizontalalignment)
params.setdefault('verticalalignment', verticalalignment)
if isinstance(self.artist_to_annotation[artist], str):
self.artist_to_text_object[artist] = self.ax.text(
x, y, self.artist_to_annotation[artist], **params)
elif isinstance(self.artist_to_annotation[artist], dict):
params.update(self.artist_to_annotation[artist].copy())
self.artist_to_text_object[artist] = self.ax.text(
x, y, **params
)
self.annotated_artists.add(artist)
def _remove_annotation(self, artist):
text_object = self.artist_to_text_object[artist]
text_object.remove()
del self.artist_to_text_object[artist]
self.annotated_artists.discard(artist)
class AnnotateOnClickGraph(Graph, AnnotateOnClick):
"""Combines `AnnotateOnClick` with the `Graph` class such that nodes or edges can have toggleable annotations."""
def __init__(self, *args, **kwargs):
Graph.__init__(self, *args, **kwargs)
artist_to_annotation = dict()
if 'annotations' in kwargs:
for key, annotation in kwargs['annotations'].items():
if key in self.nodes:
artist_to_annotation[self.node_artists[key]] = annotation
elif key in self.edges:
artist_to_annotation[self.edge_artists[key]] = annotation
else:
raise ValueError(f"There is no node or edge with the ID {key} for the annotation '{annotation}'.")
AnnotateOnClick.__init__(self, artist_to_annotation)
def _get_centroid(self):
return Graph._get_centroid(self)
def _get_annotation_placement(self, artist):
if isinstance(artist, NodeArtist):
return self._get_node_annotation_placement(artist)
elif isinstance(artist, EdgeArtist):
return self._get_edge_annotation_placement(artist)
else:
raise NotImplementedError
def _get_node_annotation_placement(self, artist):
return super()._get_annotation_placement(artist)
def _get_edge_annotation_placement(self, artist):
midpoint = _get_point_along_spline(artist.midline, 0.5)
tangent = _get_tangent_at_point(artist.midline, 0.5)
orthogonal_vector = _get_orthogonal_unit_vector(np.atleast_2d(tangent)).ravel()
vector_pointing_outwards = self._get_vector_pointing_outwards(midpoint)
if _get_interior_angle_between(orthogonal_vector, vector_pointing_outwards, radians=True) > 90:
orthogonal_vector *= -1
x, y = midpoint + 2 * artist.width * orthogonal_vector
horizontalalignment, verticalalignment = self._get_text_alignment(orthogonal_vector)
return x, y, horizontalalignment, verticalalignment
class TableOnClick(object):
"""Show or hide tabular information when clicking on matplotlib artists."""
def __init__(self, artist_to_table, table_kwargs=None):
self.artist_to_table = artist_to_table
self.table = None
self.table_fontsize = None
self.table_kwargs = dict(
# bbox = [1.1, 0.1, 0.5, 0.8],
# edges = 'horizontal',
)
if table_kwargs:
if 'fontsize' in table_kwargs:
self.table_fontsize = table_kwargs['fontsize']
self.table_kwargs.update(table_kwargs)
try:
self.fig, = set(list(artist.figure for artist in artist_to_table))
except ValueError:
raise Exception("All artists have to be on the same figure!")
try:
self.ax, = set(list(artist.axes for artist in artist_to_table))
except ValueError:
raise Exception("All artists have to be on the same axis!")
self.fig.canvas.mpl_connect("button_release_event", self._on_release)
def _on_release(self, event):
if event.inaxes == self.ax:
for artist in self.artist_to_table:
if artist.contains(event)[0]:
if self.table:
self._remove_table()
self._add_table(artist)
self.fig.canvas.draw()
break
else:
if self.table:
self._remove_table()
self.fig.canvas.draw()
def _add_table(self, artist):
df = self.artist_to_table[artist]
self.table = self.ax.table(
cellText = df.values.tolist(),
rowLabels = df.index.values,
colLabels = df.columns.values,
**self.table_kwargs,
)
if self.table_fontsize:
self.table.auto_set_font_size(False)
self.table.set_fontsize(self.table_fontsize)
def _remove_table(self):
self.table.remove()
self.table = None
class TableOnClickGraph(Graph, TableOnClick):
"""Combines `TableOnClick` with the `Graph` class such that nodes or edges can have toggleable tabular annotations."""
def __init__(self, *args, **kwargs):
Graph.__init__(self, *args, **kwargs)
artist_to_table = dict()
if 'tables' in kwargs:
for key, table in kwargs['tables'].items():
if key in self.nodes:
artist_to_table[self.node_artists[key]] = table
elif key in self.edges:
artist_to_table[self.edge_artists[key]] = table
else:
raise ValueError(f"There is no node or edge with the ID {key} for the table '{table}'.")
if 'table_kwargs' in kwargs:
TableOnClick.__init__(self, artist_to_table, kwargs['table_kwargs'])
else:
TableOnClick.__init__(self, artist_to_table)
[docs]
class InteractiveGraph(DraggableGraphWithGridMode, EmphasizeOnHoverGraph, AnnotateOnClickGraph, TableOnClickGraph):
"""Extends the `Graph` class to support node placement with the mouse, emphasis of graph elements when hovering over them, and toggleable annotations.
- Nodes can be selected and dragged around with the mouse.
- Nodes and edges are emphasized when hovering over them.
- Supports additional annotations that can be toggled on and off by clicking on the corresponding node or edge.
- These annotations can also be tables.
Parameters
----------
graph : various formats
Graph object to plot. Various input formats are supported.
In order of precedence:
- Edge list:
Iterable of (source, target) or (source, target, weight) tuples,
or equivalent (E, 2) or (E, 3) ndarray, where E is the number of edges.
- Adjacency matrix:
Full-rank (V, V) ndarray, where V is the number of nodes/vertices.
The absence of a connection is indicated by a zero.
.. note:: If V <= 3, any (2, 2) or (3, 3) matrices will be interpreted as edge lists.
- networkx.Graph, igraph.Graph, or graph_tool.Graph object
node_layout : str or dict, default 'spring'
If `node_layout` is a string, the node positions are computed using the indicated method:
- 'random' : place nodes in random positions;
- 'circular' : place nodes regularly spaced on a circle;
- 'spring' : place nodes using a force-directed layout (Fruchterman-Reingold algorithm);
- 'dot' : place nodes using the Sugiyama algorithm; the graph should be directed and acyclic;
- 'radial' : place nodes radially using the Sugiyama algorithm; the graph should be directed and acyclic;
- 'community' : place nodes such that nodes belonging to the same community are grouped together;
- 'bipartite' : place nodes regularly spaced on two parallel lines;
- 'multipartite' : place nodes regularly spaced on several parallel lines;
- 'shell' : place nodes regularly spaced on concentric circles;
- 'geometric' : place nodes according to the length of the edges between them.
If `node_layout` is a dict, keys are nodes and values are (x, y) positions.
node_layout_kwargs : dict or None, default None
Keyword arguments passed to node layout functions.
See the documentation of the following functions for a full description of available options:
- get_random_layout
- get_circular_layout
- get_fruchterman_reingold_layout
- get_sugiyama_layout
- get_radial_tree_layout
- get_community_layout
- get_bipartite_layout
- get_multipartite_layout
- get_shell_layout
- get_geometric_layout
node_shape : str or dict, default 'o'
Node shape.
If the type is str, all nodes have the same shape.
If the type is dict, maps each node to an individual string representing the shape.
The string specification is as for matplotlib.scatter marker, i.e. one of 'so^>v<dph8'.
node_size : float or dict, default 3.
Node size (radius).
If the type is float, all nodes will have the same size.
If the type is dict, maps each node to an individual size.
.. note:: Values are rescaled by BASE_SCALE (1e-2) to be compatible with layout routines in igraph and networkx.
node_edge_width : float or dict, default 0.5ayout Line width of node marker border.
If the type is float, all nodes have the same line width.
If the type is dict, maps each node to an individual line width.
..note:: Values are rescaled by BASE_SCALE (1e-2) to be compatible with layout routines in igraph and networkx.
node_color : matplotlib color specification or dict, default 'w'
Node color.
If the type is a string or RGBA array, all nodes have the same color.
If the type is dict, maps each node to an individual color.
node_edge_color : matplotlib color specification or dict, default DEFAULT_COLOR
Node edge color.
If the type is a string or RGBA array, all nodes have the same edge color.
If the type is dict, maps each node to an individual edge color.
node_alpha : scalar or dict, default 1.
Node transparency.
If the type is a float, all nodes have the same transparency.
If the type is dict, maps each node to an individual transparency.
node_zorder : int or dict, default 2
Order in which to plot the nodes.
If the type is an int, all nodes have the same zorder.
If the type is dict, maps each node to an individual zorder.
node_labels : bool or dict, (default False)
If False, the nodes are unlabelled.
If True, the nodes are labelled with their node IDs.
If the node labels are to be distinct from the node IDs, supply a dictionary mapping nodes to node labels.
Only nodes in the dictionary are labelled.
node_label_offset: float or tuple, default (0., 0.)
A (dx, dy) tuple specifies the exact offset from the node position.
If a single scalar delta is specified, the value is interpreted as a distance,
and the label is placed delta away from the node position while trying to
reduce node/label, node/edge, and label/label overlaps.
node_label_fontdict : dict
Keyword arguments passed to matplotlib.text.Text.
For a full list of available arguments see the matplotlib documentation.
The following default values differ from the defaults for matplotlib.text.Text:
- size (adjusted to fit into node artists if offset is (0, 0))
- horizontalalignment (default here: 'center')
- verticalalignment (default here: 'center')
- clip_on (default here: False)
- zorder (default here: inf)
edge_width : float or dict, default 1.
Width of edges.
If the type is a float, all edges have the same width.
If the type is dict, maps each edge to an individual width.
.. note:: Value is rescaled by BASE_SCALE (1e-2) to be compatible with layout routines in igraph and networkx.
edge_cmap : matplotlib color map (default 'RdGy')
Color map used to map edge weights to edge colors. Should be diverging.
If edge weights are strictly positive, weights are mapped to the
left hand side of the color map with vmin=0 and vmax=np.max(weights).
If edge weights are positive and negative, then weights are mapped
to colors such that a weight of zero corresponds to the center of the
color map; the boundaries are set to +/- the maximum absolute weight.
If the graph is unweighted or the edge colors are specified explicitly,
this parameter is ignored.
edge_color : matplotlib color specification or dict, default DEFAULT_COLOR
Edge color. If provided explicitly, overrides `edge_cmap`.
If the type is a string or RGBA array, all edges have the same color.
If the type is dict, maps each edge to an individual color.
edge_alpha : float or dict, default 1.
The edge transparency,
If the type is a float, all edges have the same transparency.
If the type is dict, maps each edge to an individual transparency.
edge_zorder : int or dict, default 1
Order in which to plot the edges.
If the type is an int, all nodes have the same zorder.
If the type is dict, maps each node to an individual zorder.
If None, the edges will be plotted in the order they appear in 'adjacency'.
Hint: graphs typically appear more visually pleasing if darker edges are plotted on top of lighter edges.
arrows : bool, default False
If True, draw edges with arrow heads.
edge_layout : str or dict (default 'straight')
If edge_layout is a string, determine the layout internally:
- 'straight' : draw edges as straight lines
- 'curved' : draw edges as curved splines; the spline control points are optimised to avoid other nodes and edges
- 'bundled' : draw edges as edge bundles
If edge_layout is a dict, the keys are edges and the values are edge paths
in the form iterables of (x, y) tuples, the edge segments.
edge_layout_kwargs : dict, default None
Keyword arguments passed to edge layout functions.
See the documentation of the following functions for a full description of available options:
- get_straight_edge_paths
- get_curved_edge_paths
- get_bundled_edge_paths
edge_labels : bool or dict, default False
If False, the edges are unlabelled.
If True, the edges are labelled with their edge IDs.
If the edge labels are to be distinct from the edge IDs, supply a dictionary mapping edges to edge labels.
Only edges in the dictionary are labelled.
edge_label_position : float, default 0.5
Relative position along the edge where the label is placed.
- head : 0.
- centre : 0.5
- tail : 1.
edge_label_rotate : bool, default True
If True, edge labels are rotated such that they have the same orientation as their edge.
If False, edge labels are not rotated; the angle of the text is parallel to the axis.
edge_label_fontdict : dict
Keyword arguments passed to matplotlib.text.Text.
For a full list of available arguments see the matplotlib documentation.
The following default values differ from the defaults for matplotlib.text.Text:
- horizontalalignment (default here: 'center'),
- verticalalignment (default here: 'center')
- clip_on (default here: False),
- bbox (default here: dict(boxstyle='round', ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0)),
- zorder (default here: inf),
- rotation (determined by edge_label_rotate argument)
annotations : dict
Mapping of nodes or edges to strings or dictionaries, the annotations.
The visibility of the annotations can be toggled on or off by clicking on the corresponding node or edge.
.. line-block::
annotations = {
0 : 'Normal node',
1 : {s : 'Less important node', fontsize : 2},
2 : {s : 'Very important node', fontcolor : 'red'},
(0, 1) : 'Normal edge',
(1, 2) : {s : 'Less important edge', fontsize : 2},
(2, 0) : {s : 'Very important edge', fontcolor : 'red'},
}
annotation_fontdict : dict
Keyword arguments passed to matplotlib.text.Text if only the annotation string is given.
For a full list of available arguments see the matplotlib documentation.
The following default values differ from the defaults for matplotlib.text.Text:
- horizontalalignment (depends on node position or edge orientation),
- verticalalignment (depends on node position or edge orientation),
- clip_on (default here: False),
- backgroundcolor (default here: 'white'),
- zorder (default here: inf),
tables : dict node/edge : pandas dataframe
Mapping of nodes and/or edges to pandas dataframes.
The visibility of the tables that can toggled on or off by clicking on the corresponding node or edge.
table_kwargs : dict
Keyword arguments passed to matplotlib.pyplot.table.
origin : tuple, default (0., 0.)
The lower left hand corner of the bounding box specifying the extent of the canvas.
scale : tuple, default (1., 1.)
The width and height of the bounding box specifying the extent of the canvas.
prettify : bool, default True
If True, despine and remove ticks and tick labels.
Set figure background to white. Set axis aspect to equal.
ax : matplotlib.axis instance or None, default None
Axis to plot onto; if none specified, one will be instantiated with plt.gca().
Attributes
----------
node_artists : dict
Mapping of node IDs to matplotlib PathPatch artists.
edge_artists : dict
Mapping of edge IDs to matplotlib PathPatch artists.
node_label_artists : dict
Mapping of node IDs to matplotlib text objects (if applicable).
edge_label_artists : dict
Mapping of edge IDs to matplotlib text objects (if applicable).
node_positions : dict node : (x, y) tuple
Mapping of node IDs to node positions.
See also
--------
Graph
Notes
-----
You must retain a reference to the plot instance!
Otherwise, the plot instance will be garbage collected after the initial draw
and you won't be able to move the plot elements around.
Examples
--------
>>> import matplotlib.pyplot as plt
>>> from netgraph import InteractiveGraph
>>> plt.ion()
>>> plot_instance = InteractiveGraph(my_graph_obj)
>>> plt.show()
"""
def __init__(self, *args, **kwargs):
DraggableGraphWithGridMode.__init__(self, *args, **kwargs)
artists = list(self.node_artists.values()) + list(self.edge_artists.values())
keys = list(self.node_artists.keys()) + list(self.edge_artists.keys())
self.artist_to_key = dict(zip(artists, keys))
EmphasizeOnHover.__init__(self, artists)
self.mouseover_highlight_mapping = self._get_default_mouseover_highlight_mapping()
artist_to_annotation = dict()
if 'annotations' in kwargs:
for key, annotation in kwargs['annotations'].items():
# Test membership of edges first, as edge keys may
# result in a ValueError when testing membership of nodes.
if key in self.edges:
artist_to_annotation[self.edge_artists[key]] = annotation
elif key in self.nodes:
artist_to_annotation[self.node_artists[key]] = annotation
else:
raise ValueError(f"There is no node or edge with the ID {key} for the annotation '{annotation}'.")
if 'annotation_fontdict' in kwargs:
AnnotateOnClick.__init__(self, artist_to_annotation, kwargs['annotation_fontdict'])
else:
AnnotateOnClick.__init__(self, artist_to_annotation)
if 'tables' in kwargs:
artist_to_table = dict()
for key, table in kwargs['tables'].items():
if key in self.nodes:
artist_to_table[self.node_artists[key]] = table
elif key in self.edges:
artist_to_table[self.edge_artists[key]] = table
else:
raise ValueError(f"There is no node or edge with the ID {key} for the table '{table}'.")
if 'table_kwargs' in kwargs:
TableOnClick.__init__(self, artist_to_table, kwargs['table_kwargs'])
else:
TableOnClick.__init__(self, artist_to_table)
def _on_motion(self, event):
DraggableGraphWithGridMode._on_motion(self, event)
EmphasizeOnHoverGraph._on_motion(self, event)
def _on_release(self, event):
if self._currently_dragging is False:
DraggableGraphWithGridMode._on_release(self, event)
if self.artist_to_annotation:
AnnotateOnClickGraph._on_release(self, event)
if hasattr(self, 'artist_to_table'):
TableOnClickGraph._on_release(self, event)
else:
DraggableGraphWithGridMode._on_release(self, event)
if self.artist_to_annotation:
self._redraw_annotations(event)
def _redraw_annotations(self, event):
if event.inaxes == self.ax:
for artist in self.annotated_artists:
self._remove_annotation(artist)
placement = self._get_annotation_placement(artist)
self._add_annotation(artist, *placement)
self.fig.canvas.draw()