"""Ground projection axes module."""
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from matplotlib.axes._base import _process_plot_format as mpl_fmt
from matplotlib.axis import XAxis, YAxis
from matplotlib.cm import ScalarMappable
from matplotlib.collections import LineCollection, PatchCollection
from matplotlib.colors import Normalize
from matplotlib.legend_handler import HandlerPolyCollection
from matplotlib.patches import PathPatch
from matplotlib.ticker import FixedLocator, NullLocator
from ..ticks import (
UnitFormatter, deg_ticks, hr_ticks, km_pix_ticks,
km_s_ticks, km_ticks, lat_ticks, lon_e_ticks, lon_west_ticks
)
PROPS = {
'alt': ('Altitude', km_ticks),
'dist': ('Distance', km_ticks),
'target_size': ('Target angular size', deg_ticks),
'local_time': ('Local time', hr_ticks),
'inc': ('Incidence angle', deg_ticks),
'emi': ('Emission angle', deg_ticks),
'phase': ('Phase angle', deg_ticks),
'solar_zenith_angle': ('Solar zenith angle', deg_ticks),
'solar_longitude': ('Seasonal solar longitude', deg_ticks),
'true_anomaly': ('True anomaly angle', deg_ticks),
'groundtrack_velocity': ('Groundtrack velocity', km_s_ticks),
'pixel_scale': ('Pixel scale', km_pix_ticks),
}
def get_values(traj, attr):
"""Get trajectory attribute values."""
if hasattr(traj, attr):
return getattr(traj, attr)
# Check if the provided key is a valid matplotlib string (and return an empty array)
try:
mpl_fmt(attr)
except ValueError:
raise ValueError(
f'The second argument `{attr}` must be a '
f'`{traj.__class__.__name__}` property '
'or a valid matplotlib format string (eg. `ro`).'
) from None
return []
[docs]class ProjAxes(Axes):
"""An abstract base class for geographic projections."""
def __init__(self, *args, proj='equi', bg=None, bg_extent=False, target='', **kwargs):
self.proj = proj
self.bg = bg
self.bg_extent = bg_extent
self.target = target
self._cbar = None
super().__init__(*args, **kwargs)
def _init_axis(self):
self.xaxis = XAxis(self)
self.yaxis = YAxis(self)
self._update_transScale()
[docs] def clear(self):
"""Clear axes."""
Axes.clear(self)
self.set_aspect(1)
self.xaxis.set_minor_locator(NullLocator())
self.yaxis.set_minor_locator(NullLocator())
self.set_longitude_grid(30)
self.set_latitude_grid(30)
self.set_background()
self.grid(lw=.5, color='k')
Axes.set_xlim(self, *self.proj.extent[:2])
Axes.set_ylim(self, *self.proj.extent[2:])
def _check_target(self, obj):
"""Check object target name."""
if hasattr(obj, 'target'):
proj_target = str(getattr(self, 'target')).upper()
obj_target = str(getattr(obj, 'target')).upper()
if proj_target != obj_target:
raise ProjectionMapTargetError(
f'Target mismatch: {proj_target} map with {obj_target} data.')
[docs] def plot(self, *args, scalex=True, scaley=True, data=None, **kwargs):
"""Generic plot function with map projection.
Warning
-------
If explicit X and Y values are provided, they will considered as
East Longitude and Latitude angles (in degrees).
See Also
--------
matplotlib.pyplot.plot
"""
if hasattr(args[0], 'lonlat'):
traj = args[0]
self._check_target(traj)
if len(args) > 1 and isinstance(args[1], str):
attr = args[1].lower().replace(' ', '_')
if any(values := get_values(traj, attr)):
x, y, data = self.proj.xy_plot(*traj.lonlat, values=values)
label, fmt = PROPS.get(attr, (None, None))
kwargs = {'label': label, 'fmt': fmt, **kwargs}
return self.plot_colorline(x, y, data, **kwargs)
x, y = self.proj.xy_plot(*traj.lonlat)
args = args[1:]
elif len(args) >= 2 and isinstance(args[0], (int, float)) \
and isinstance(args[1], (int, float)):
x, y = self.proj.xy_plot([args[0]], [args[1]])
args = args[2:]
elif len(args[0]) == 2 and isinstance(args[0], (tuple, list)) \
and np.ndim(args[0]) == 2:
x, y = self.proj.xy_plot(*args[0])
args = args[1:]
elif len(args) > 2 and '.' not in args[2] and 'o' not in args[2]:
x, y = self.proj.xy_plot(*args[:2])
args = args[2:]
else:
x, y = self.proj.xy(*args[:2])
args = args[2:]
return super().plot(x, y, *args,
scalex=scalex, scaley=scaley, data=data, **kwargs)
[docs] def scatter(self, *args, **kwargs):
"""Scatter plot with map projection.
See Also
--------
matplotlib.pyplot.scatter
"""
if hasattr(args[0], 'lonlat'):
traj = args[0]
self._check_target(traj)
if len(args) > 1 and isinstance(args[1], str):
attr = args[1].lower().replace(' ', '_')
if any(values := get_values(traj, attr)):
vmin = np.nanmin(values)
vmax = np.nanmax(values)
kwargs = {
# defaults kwargs
'cmap': 'turbo_r',
'vmin': vmin,
'vmax': vmax,
# user kwargs
**kwargs,
# override default and user kwargs
'c': values,
}
if kwargs.pop('cbar', None):
cmin = vmin < kwargs['vmin']
cmax = vmax > kwargs['vmax']
extend = 'both' if cmin and cmax else 'min' \
if cmin else 'max' if cmax else 'neither'
self.colorbar(kwargs['vmin'], kwargs['vmax'],
label=attr, extend=extend, cmap=kwargs['cmap'])
return self.scatter(*traj.lonlat, **kwargs)
return super().scatter(*args, **kwargs)
[docs] def plot_colorline(self, x, y, data, cmap=None, vmin=None, vmax=None, norm=None,
label=None, fmt=None, orientation='horizontal', cbar=True,
**kwargs): # pylint: disable=too-many-locals
"""Plot a colored line with a colorbar.
Parameters
----------
x: numpy.ndarray
Projected x-coordinates.
y: numpy.ndarray
Projected y-coordinates.
data: numpy.ndarray
Value to use to color the line.
cmap: str, optional
Matplotlib colormap name (default: `turbo_r`)
vmin: int or float
Color scaling min value. If ``None`` is provided (default)
the data are scaled to the lowest (not-NaN) value.
vmax: int or float
Color scaling max value. If ``None`` is provided (default)
the data are scaled to the lowest (not-NaN) value.
norm: matplotlib.colors.Normalize
Normalization colors normalizer. By default
the values will be normalized between :py:attr:`vmin`
and :py:attr:`vmax`.
label: str, optional
Colorbar label.
fmt: str, optional
Colorbar ticks formatter.
orientation: str, optional
Colorbar orientation (default: `horizontal`).
**kwargs:
Keyword attributes for :py:class:`LineCollection`.
Note
----
If the range provided (with :py:attr:`vmin` and :py:attr:`vmax`)
is smaller than the range of the data, the colorbar will
be extended with arrows.
"""
points = np.transpose([x, y]).reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)
data = np.array(data)
values = .5 * (data[1:] + data[:-1])
if vmin is None:
vmin = np.nanmin(data)
if vmax is None:
vmax = np.nanmax(data)
if cmap is None:
cmap = 'turbo_r'
if norm is None:
norm = plt.Normalize(vmin, vmax)
lc = LineCollection(segments, cmap=cmap, norm=norm, **kwargs)
lc.set_array(values)
lines = super(Axes, self).add_collection(lc) # pylint: disable=bad-super-call
if not cbar:
return lines
# Colorbar extend is based on the data range
cmin = np.nanmin(data) < vmin
cmax = np.nanmax(data) > vmax
extend = 'both' if cmin and cmax else 'min' \
if cmin else 'max' if cmax else 'neither'
cbar_kwargs = {
'cmap': cmap,
'orientation': orientation,
'extend': extend,
'format': fmt,
'label': label,
}
return self.colorbar(vmin, vmax, **cbar_kwargs)
[docs] def text(self, x, y, s, fontdict=None, clip_on=True, **kwargs):
"""Add text to the axes.
Note
----
Set clip on to `True` by default.
"""
return super().text(*self.proj.xy(x, y), s, fontdict=fontdict,
clip_on=clip_on, **kwargs)
[docs] def add_path(self, path, *args, **kwargs):
"""Draw path."""
self.add_patch(PathPatch(path, *args, **kwargs))
[docs] def add_patch(self, p):
"""Draw patch."""
self._check_target(p)
super().add_patch(self.proj.xy_patch(p))
[docs] def add_collection(self, collection, autolim=True):
"""Draw patches collection."""
self._check_target(collection)
super().add_collection(self.proj.xy_collection(collection), autolim=autolim)
[docs] def legend(self, *args, **kwargs):
"""Add HandlerPolyCollection to `handler_map` for PatchCollection."""
if 'handler_map' not in kwargs:
kwargs['handler_map'] = {}
if PatchCollection not in kwargs['handler_map']:
kwargs['handler_map'][PatchCollection] = HandlerPolyCollection()
return super().legend(*args, **kwargs)
[docs] def colorbar(self, vmin, vmax, cmap='turbo_r',
orientation='horizontal',
shrink=.6,
aspect=40,
pad=0.075,
**kwargs):
"""Add a standalone colorbar on the axis.
Parameters
----------
vmin: int or float, optional
Color scaling min value. If ``None`` is provided (default)
the data are scaled to the lowest (not-NaN) value.
vmax: int or float, optional
Color scaling max value. If ``None`` is provided (default)
the data are scaled to the lowest (not-NaN) value.
cmap: str, optional
Matplotlib colormap name (default: `turbo_r`)
orientation: str, optional
Colorbar orientation (default: `horizontal`).
label: str, optional
Colorbar label (shortcuts are available).
**kwargs:
Keyword attributes for :py:class:`Colorbar`.
Returns
-------
matplotlib.colorbar.Colorbar
Output colorbar.
"""
norm = Normalize(vmin, vmax)
# Shortcut to format the ticks of known units
if 'label' in kwargs and kwargs['label'] in PROPS:
kwargs.update(zip(
('label', 'format'), PROPS[kwargs['label']]
))
self._cbar = self.figure.colorbar(
ScalarMappable(norm=norm, cmap=cmap),
ax=self,
orientation=orientation,
shrink=shrink,
aspect=aspect,
pad=pad,
**kwargs
)
return self._cbar
[docs] def twin_colorbar(self, label=None, format=None, # pylint: disable=redefined-builtin
offset=.05, ticks=None):
"""Twin colorbar with a secondary axis.
Parameters
----------
label: str, optional
Twin colorbar label (no shortcut).
format: matplotlib.ticker.Formatter, optional
Optional ticks formatter.
offset: float, optional
Colorbar offset (default: `0.05`).
ticks: list, optional
Custom list of ticks (default: ``None``).
"""
if self._cbar is None:
raise ValueError('No parent colorbar found.')
pos = self._cbar.ax.get_position()
self._cbar.ax.set_aspect('auto') # change default `equal` to `auto`
# Relocate ticks for UnitFormatter
if ticks and isinstance(format, UnitFormatter):
ticks = format @ ticks
# Shift the colorbar to avoid to overlap the figure
if self._cbar.orientation == 'horizontal':
pos.y0 -= offset
pos.y1 -= offset
ax = self._cbar.ax.secondary_xaxis('top')
ax.set_xlim(self._cbar.ax.get_xlim())
set_label = ax.set_xlabel
set_ticks = ax.set_xticks
set_formatter = ax.xaxis.set_major_formatter
else:
pos.x0 += offset
pos.x1 += offset
ax = self._cbar.ax.secondary_yaxis('left')
ax.set_ylim(self._cbar.ax.get_ylim())
set_label = ax.set_ylabel
set_ticks = ax.set_yticks
set_formatter = ax.yaxis.set_major_formatter
# Change label/ticks and formatter
if label:
set_label(label)
if ticks:
set_ticks(ticks)
if format:
set_formatter(format)
# Move original colorbar position
self._cbar.ax.set_position(pos)
# Remove the frame borders
for border in ax.spines:
ax.spines[border].set_visible(False)
return ax
[docs] def set_longitude_grid(self, degrees):
"""Set the number of degrees between each longitude grid."""
grid = np.linspace(0, 360, int(360 / degrees) + 1).astype(int)
self.xaxis.set_major_locator(FixedLocator(grid))
self.xaxis.set_major_formatter(lon_e_ticks)
[docs] def set_latitude_grid(self, degrees):
"""Set the number of degrees between each longitude grid."""
grid = np.linspace(-90, 90, int(180 / degrees) + 1).astype(int)
self.yaxis.set_major_locator(FixedLocator(grid))
self.yaxis.set_major_formatter(lat_ticks)
[docs] def set_lon_ticks(self, key, secondary=False):
"""Toggle longitude ticks (East/West and top/bottom).
Parameters
----------
key: str
Longitude ticks format. Possible values:
``'east'`` | ``'0 360'`` or ``'west'`` | ``'360 0'``
secondary: bool, optional
Display the ticks on top secondary axis (default: False)
Warning
-------
The values provided in the plot are always in East longitude.
Here, only the axis ticks are changed (the data are not re-projected).
"""
if secondary:
xaxis = self.secondary_xaxis('top').xaxis
xaxis.set_ticks(self.xaxis.get_ticklocs())
else:
xaxis = self.xaxis
if key.lower() in ['east', '0 360']:
xaxis.set_major_formatter(lon_e_ticks)
elif key.lower() in ['west', '360 0']:
xaxis.set_major_formatter(lon_west_ticks)
else:
raise KeyError(
f'Only `east`/`west` (or `0 360`/`360 0`) are accepted. Provided: `{key}`'
)
[docs] def set_lat_ticks(self, secondary=False):
"""Toggle latitude secondary ticks.
Parameters
----------
secondary: bool, optional
Display the ticks on right secondary axis (default: False)
Warning
-------
The values provided in the plot are always in East longitude.
Here, only the axis ticks are changed (the data are not re-projected).
"""
if secondary:
yaxis = self.secondary_yaxis('right').yaxis
yaxis.set_ticks(self.yaxis.get_ticklocs())
yaxis.set_major_formatter(lat_ticks)
[docs] def set_view(self, *args, margin=5):
"""Center view on object coordinates.
Parameters
----------
*args: [float, float, float, float] or object
East longitudes and latitudes to center the view on.
It can be either:
- lon_e_min, lon_e_max, lat_min, lat_max
- [lon_e_min, lon_e_max], [lat_min, lat_max]
- [lon_e_min, lon_e_max, lat_min, lat_max]
- an object with `lonlat` property
- an object with `lons_e` and `lats` properties
margin: int or float
Margin percentage fraction of the object to add to the sides.
Default: 5%.
Raises
------
ValueError
If the provided coordinates are invalid.
Note
----
The limits are clipped on the side of the projection extent.
"""
if len(args) == 1:
if isinstance(args[0], (list, tuple, np.ndarray)):
return self.set_view(*args[0], margin=margin)
# Check target name (if present)
self._check_target(args[0])
if hasattr(args[0], 'lonlat'):
return self.set_view(*args[0].lonlat, margin=margin)
if hasattr(args[0], 'lons_e') and hasattr(args[0], 'lats'):
return self.set_view(args[0].lons_e, args[0].lats, margin=margin)
if len(args) == 2:
lons_e, lats = args
elif len(args) == 4:
lons_e, lats = args[:2], args[2:]
else:
raise ValueError(
f'Invalid view: {args}. It should be either:\n'
'- lon_e_min, lon_e_max, lat_min, lat_max\n'
'- [lon_e_min, lon_e_max], [lat_min, lat_max]\n'
'- [lon_e_min, lon_e_max, lat_min, lat_max]\n'
'- an object with `lonlat` property\n'
'- an object with `lons_e` and `lats` properties'
)
# Project the data on the map
x, y = self.proj.xy(lons_e, lats)
xmin, xmax, ymin, ymax = np.nanmin(x), np.nanmax(x), np.nanmin(y), np.nanmax(y)
# Adjust the margin
margin *= max(xmax - xmin, ymax - ymin) / 100
# Get projection extent to clip the limits
x0, x1, y0, y1 = self.proj.extent
return self.set_xlim(max(x0, xmin - margin), min(xmax + margin, x1)), \
self.set_ylim(max(y0, ymin - margin), min(ymax + margin, y1))
[docs] def set_xlim(self, left=None, right=None, emit=True, auto=False, **kwargs):
"""Rescale the x-map coordinate limits.
Note
----
If both sides are provided the ticks grid is readjusted.
"""
if left is not None and right is not None:
width = right - left
if width <= 5:
grid = 1
elif width <= 15:
grid = 2
elif width <= 45:
grid = 5
elif width <= 90:
grid = 10
elif width <= 180:
grid = 15
else:
grid = 30
self.set_longitude_grid(grid)
return super().set_xlim(left=left, right=right, emit=emit, auto=auto, **kwargs)
[docs] def set_ylim(self, bottom=None, top=None, emit=True, auto=False, **kwargs):
"""Rescale the y-map coordinate limits.
Note
----
If both sides are provided the ticks grid is readjusted.
"""
if bottom is not None and top is not None:
height = top - bottom
if height <= 5:
grid = 1
elif height <= 15:
grid = 2
elif height <= 45:
grid = 5
elif height <= 90:
grid = 10
else:
grid = 30
self.set_latitude_grid(grid)
return super().set_ylim(bottom=bottom, top=top, emit=emit, auto=auto, **kwargs)
[docs] def set_background(self):
"""Set image basemap background."""
if self.bg:
im = plt.imread(self.bg)
self.imshow(im, extent=self.bg_extent, cmap='gray')
class ProjectionMapTargetError(Exception):
"""Mismatch between the projection map target and the data."""