import matplotlib.pyplot as plt
import pandas as pd
import xarray as xr
import numpy as np
from pandas import DataFrame
from pandas.io.formats.style import Styler
# ------------------------------------------------------------------------------------
# Views of the ds or nc file
# ------------------------------------------------------------------------------------
[docs]
def show_contents(
data: str | xr.Dataset,
content_type: str = "variables",
) -> Styler | pd.DataFrame:
"""Wrapper function to show contents of an xarray Dataset or a netCDF file.
Parameters
----------
data : str or xr.Dataset
The input data, either a file path to a netCDF file or an xarray Dataset.
content_type : str, optional
The type of content to show, either 'variables' (or 'vars') or 'attributes' (or 'attrs').
Default is 'variables'.
Returns
-------
pandas.io.formats.style.Styler or pandas.DataFrame
A styled DataFrame with details about the variables or attributes.
Raises
------
TypeError
If the input data is not a file path (str) or an xarray Dataset.
ValueError
If the content_type is not 'variables' (or 'vars') or 'attributes' (or 'attrs').
"""
if content_type in ["variables", "vars"]:
if isinstance(data, (str, xr.Dataset)):
return show_variables(data)
else:
raise TypeError("Input data must be a file path (str) or an xarray Dataset")
elif content_type in ["attributes", "attrs"]:
if isinstance(data, (str, xr.Dataset)):
return show_attributes(data)
else:
raise TypeError(
"Attributes can only be shown for netCDF files (str) or xarray Datasets",
)
else:
raise ValueError(
"content_type must be either 'variables' (or 'vars') or 'attributes' (or 'attrs')",
)
[docs]
def show_variables(data: str | xr.Dataset) -> Styler:
"""Processes an xarray Dataset or a netCDF file, extracts variable information,
and returns a styled DataFrame with details about the variables.
Parameters
----------
data : str or xr.Dataset
The input data, either a file path to a netCDF file or an xarray Dataset.
Returns
-------
pd.io.formats.style.Styler
A styled DataFrame containing the following columns:
- dims: The dimension of the variable (or "string" if it is a string type).
- name: The name of the variable.
- units: The units of the variable (if available).
- comment: Any additional comments about the variable (if available).
- standard_name: The standard name of the variable (if available).
- dtype: The data type of the variable.
Raises
------
TypeError
If the input data is not a file path (str) or an xarray Dataset.
"""
from netCDF4 import Dataset
from pandas import DataFrame
if isinstance(data, str):
print(f"information is based on file: {data}")
dataset = Dataset(data)
variables = dataset.variables
elif isinstance(data, xr.Dataset):
print("information is based on xarray Dataset")
variables = data.variables
else:
raise TypeError("Input data must be a file path (str) or an xarray Dataset")
info = {}
for i, key in enumerate(variables):
var = variables[key]
if isinstance(data, str):
dims = var.dimensions[0] if len(var.dimensions) == 1 else "string"
units = "" if not hasattr(var, "units") else var.units
comment = "" if not hasattr(var, "comment") else var.comment
else:
dims = var.dims[0] if len(var.dims) == 1 else "string"
units = var.attrs.get("units", "")
comment = var.attrs.get("comment", "")
info[i] = {
"name": key,
"dims": dims,
"units": units,
"comment": comment,
"standard_name": var.attrs.get("standard_name", ""),
"dtype": str(var.dtype) if isinstance(data, str) else str(var.data.dtype),
}
vars = DataFrame(info).T
dim = vars.dims
dim[dim.str.startswith("str")] = "string"
vars["dims"] = dim
vars = (
vars.sort_values(["dims", "name"])
.reset_index(drop=True)
.loc[:, ["dims", "name", "units", "comment", "standard_name", "dtype"]]
.set_index("name")
.style
)
return vars
[docs]
def show_attributes(data: str | xr.Dataset) -> pd.DataFrame:
"""Processes an xarray Dataset or a netCDF file, extracts attribute information,
and returns a DataFrame with details about the attributes.
Parameters
----------
data : str or xr.Dataset
The input data, either a file path to a netCDF file or an xarray Dataset.
Returns
-------
pandas.DataFrame
A DataFrame containing the following columns:
- Attribute: The name of the attribute.
- Value: The value of the attribute.
- DType: The data type of the attribute.
Raises
------
TypeError
If the input data is not a file path (str) or an xarray Dataset.
"""
from netCDF4 import Dataset
from pandas import DataFrame
if isinstance(data, str):
print(f"information is based on file: {data}")
rootgrp = Dataset(data, "r", format="NETCDF4")
attributes = rootgrp.ncattrs()
get_attr = lambda key: getattr(rootgrp, key)
elif isinstance(data, xr.Dataset):
print("information is based on xarray Dataset")
attributes = data.attrs.keys()
get_attr = lambda key: data.attrs[key]
else:
raise TypeError("Input data must be a file path (str) or an xarray Dataset")
info = {}
for i, key in enumerate(attributes):
dtype = type(get_attr(key)).__name__
info[i] = {"Attribute": key, "Value": get_attr(key), "DType": dtype}
attrs = DataFrame(info).T
return attrs
[docs]
def show_variables_by_dimension(
data: str | xr.Dataset,
dimension_name: str = "trajectory",
) -> Styler:
"""Extracts variable information from an xarray Dataset or a netCDF file and returns a styled DataFrame
with details about the variables filtered by a specific dimension.
Parameters
----------
data : str or xr.Dataset
The input data, either a file path to a netCDF file or an xarray Dataset.
dimension_name : str, optional
The name of the dimension to filter variables by, by default "trajectory".
Returns
-------
pandas.io.formats.style.Styler
A styled DataFrame containing the following columns:
- dims: The dimension of the variable (or "string" if it is a string type).
- name: The name of the variable.
- units: The units of the variable (if available).
- comment: Any additional comments about the variable (if available).
Raises
------
TypeError
If the input data is not a file path (str) or an xarray Dataset.
"""
if isinstance(data, str):
print(f"information is based on file: {data}")
dataset = xr.open_dataset(data)
variables = dataset.variables
elif isinstance(data, xr.Dataset):
print("information is based on xarray Dataset")
variables = data.variables
else:
raise TypeError("Input data must be a file path (str) or an xarray Dataset")
info = {}
for i, key in enumerate(variables):
var = variables[key]
if isinstance(data, str):
dims = var.dimensions[0] if len(var.dimensions) == 1 else "string"
units = "" if not hasattr(var, "units") else var.units
comment = "" if not hasattr(var, "comment") else var.comment
else:
dims = var.dims[0] if len(var.dims) == 1 else "string"
units = var.attrs.get("units", "")
comment = var.attrs.get("comment", "")
if dims == dimension_name:
info[i] = {
"name": key,
"dims": dims,
"units": units,
"comment": comment,
}
vars = DataFrame(info).T
dim = vars.dims
dim[dim.str.startswith("str")] = "string"
vars["dims"] = dim
vars = (
vars.sort_values(["dims", "name"])
.reset_index(drop=True)
.loc[:, ["dims", "name", "units", "comment"]]
.set_index("name")
.style
)
return vars
[docs]
def monthly_resample(da: xr.DataArray) -> xr.DataArray:
"""Resample to monthly mean if data is not already monthly."""
time_key = [c for c in da.coords if c.lower() == "time"]
if not time_key:
raise ValueError("No time coordinate found.")
time_key = time_key[0]
# Extract time values and check spacing
time_values = da[time_key].values
dt_days = np.nanmean(np.diff(time_values) / np.timedelta64(1, "D"))
if 20 <= dt_days <= 40:
return da # Already monthly
# Drop NaT timestamps
mask_valid_time = ~np.isnat(time_values)
da = da.isel({time_key: mask_valid_time})
# Drop duplicate timestamps (keep first)
_, unique_indices = np.unique(da[time_key].values, return_index=True)
da = da.isel({time_key: np.sort(unique_indices)})
# Ensure strictly increasing time
da = da.sortby(time_key)
# Now resample
return da.resample({time_key: "1MS"}).mean()
[docs]
def plot_amoc_timeseries(
data,
varnames=None,
labels=None,
colors=None,
title="AMOC Time Series",
ylabel=None,
time_limits=None,
ylim=None,
figsize=(10, 3),
resample_monthly=True,
plot_raw=True,
):
"""
Plot original and optionally monthly-averaged AMOC time series for one or more datasets.
Parameters
----------
data : list of xarray.Dataset or xarray.DataArray
List of datasets or DataArrays to plot.
varnames : list of str, optional
List of variable names to extract from each dataset. Not needed if DataArrays are passed.
labels : list of str, optional
Labels for the legend.
colors : list of str, optional
Colors for monthly-averaged plots.
title : str
Title of the plot.
ylabel : str, optional
Label for the y-axis. If None, inferred from attributes.
time_limits : tuple of str or pd.Timestamp, optional
X-axis time limits (start, end).
ylim : tuple of float, optional
Y-axis limits (min, max).
figsize : tuple
Size of the figure.
resample_monthly : bool
If True, monthly averages are computed and plotted.
plot_raw : bool
If True, raw data is plotted.
"""
if not isinstance(data, list):
data = [data]
if varnames is None:
varnames = [None] * len(data)
if labels is None:
labels = [f"Dataset {i+1}" for i in range(len(data))]
if colors is None:
colors = ["red", "darkblue", "green", "purple", "orange"]
fig, ax = plt.subplots(figsize=figsize)
for i, item in enumerate(data):
label = labels[i]
color = colors[i % len(colors)]
var = varnames[i]
# Extract DataArray
if isinstance(item, xr.Dataset):
da = item[var]
else:
da = item
# Get time coordinate (case sensitive)
for coord in da.coords:
if coord.lower() == "time":
time_key = coord
break
else:
raise ValueError("No time coordinate found in dataset.")
# Plot original
if plot_raw:
ax.plot(
da[time_key],
da,
color="grey",
alpha=0.5,
linewidth=0.5,
label=f"{label} (raw)" if label else "Original",
)
# Plot monthly average if requested
if resample_monthly:
da_monthly = monthly_resample(da)
ax.plot(
da_monthly[time_key],
da_monthly,
color=color,
linewidth=1.5,
label=f"{label} Monthly Avg",
)
# Attempt to extract ylabel from metadata if not provided
if ylabel is None and "standard_name" in da.attrs and "units" in da.attrs:
ylabel = f"{da.attrs['standard_name']} [{da.attrs['units']}]"
# Horizontal zero line
ax.axhline(0, color="black", linestyle="--", linewidth=0.5)
# Styling
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_title(title)
ax.set_xlabel("Time")
ax.set_ylabel(ylabel if ylabel else "Transport [Sv]")
ax.legend(loc="best")
ax.grid(True, linestyle="--", alpha=0.5)
# Limits
if time_limits:
ax.set_xlim(pd.Timestamp(time_limits[0]), pd.Timestamp(time_limits[1]))
if ylim:
ax.set_ylim(ylim)
plt.tight_layout()
return fig, ax
[docs]
def plot_monthly_anomalies(**kwargs) -> tuple[plt.Figure, list[plt.Axes]]:
"""
Plot the monthly anomalies for various datasets.
Pass keyword arguments in the form: `label_name_data`, `label_name_label`.
For example:
osnap_data = standardOSNAP[0]["MOC_all"], osnap_label = "OSNAP"
...
"""
color_cycle = [
"blue",
"red",
"green",
"purple",
"orange",
"darkblue",
"darkred",
"darkgreen",
]
# Extract and sort data/labels by name to ensure consistent ordering
names = ["dso", "osnap", "fortyone", "rapid", "fw2015", "move", "samba"]
datasets = [monthly_resample(kwargs[f"{name}_data"]) for name in names]
labels = [kwargs[f"{name}_label"] for name in names]
fig, axes = plt.subplots(len(datasets), 1, figsize=(10, 16), sharex=True)
for i, (data, label, color) in enumerate(zip(datasets, labels, color_cycle)):
time = data["TIME"]
axes[i].plot(time, data, color=color, label=label)
axes[i].axhline(0, color="black", linestyle="--", linewidth=0.5)
axes[i].set_title(label)
axes[i].set_ylabel("Transport [Sv]")
axes[i].legend()
axes[i].grid(True, linestyle="--", alpha=0.5)
# Dynamic ylim
ymin = float(data.min()) - 1
ymax = float(data.max()) + 1
axes[i].set_ylim([ymin, ymax])
# Style choices
axes[i].spines["top"].set_visible(False)
axes[i].spines["right"].set_visible(False)
axes[i].set_xlim([pd.Timestamp("2000-01-01"), pd.Timestamp("2023-12-31")])
axes[i].set_clip_on(False)
axes[-1].set_xlabel("Time")
plt.tight_layout()
return fig, axes