"""AMOCatlas analysis tools for data processing, filtering, and calculations."""
import re
import numpy as np
import pandas as pd
import xarray as xr
from scipy.signal.windows import tukey
from amocatlas import logger
from amocatlas.logger import log_info, log_debug
log = logger.log
[docs]
def generate_reverse_conversions(
forward_conversions: dict[str, dict[str, float]],
) -> dict[str, dict[str, float]]:
"""Create a unit conversion dictionary with both forward and reverse conversions.
Parameters
----------
forward_conversions : dict of {str: dict of {str: float}}
Mapping of source units to target units and conversion factors.
Example: {"m": {"cm": 100, "km": 0.001}}
Returns
-------
dict of {str: dict of {str: float}}
Complete mapping of units including reverse conversions.
Example: {"cm": {"m": 0.01}, "km": {"m": 1000}}
Notes
-----
If a conversion factor is zero, a warning is printed, and the reverse conversion is skipped.
"""
complete_conversions: dict[str, dict[str, float]] = {}
for from_unit, targets in forward_conversions.items():
complete_conversions.setdefault(from_unit, {})
for to_unit, factor in targets.items():
complete_conversions[from_unit][to_unit] = factor
complete_conversions.setdefault(to_unit, {})
if factor == 0:
print(
f"Warning: zero factor in conversion from {from_unit} to {to_unit}",
)
continue
complete_conversions[to_unit][from_unit] = 1 / factor
return complete_conversions
# Various conversions from the key to units_name with the multiplicative conversion factor
base_unit_conversion = {
"cm/s": {"m/s": 0.01},
"cm s-1": {"m s-1": 0.01},
"S/m": {"mS/cm": 0.1},
"dbar": {"Pa": 10000, "kPa": 10},
"degrees_Celsius": {"degree_Celsius": 1},
"m": {"cm": 100, "km": 0.001},
"g m-3": {"kg m-3": 0.001},
"sverdrup": {"Sv": 1},
"PW": {"W": 1e15},
}
unit_conversion = generate_reverse_conversions(base_unit_conversion)
# Specify the preferred units, and it will convert if the conversion is available in unit_conversion
preferred_units = [
"m s-1",
"dbar",
"S m-1",
"sverdrup",
"degree_Celsius",
"kg m-3",
"m",
"degree_north",
"degree_east",
"watt",
]
# String formats for units. The key is the original, the value is the desired format
unit_str_format = {
"m/s": "m s-1",
"cm/s": "cm s-1",
"S/m": "S m-1",
"meters": "m",
"degrees_Celsius": "degree_Celsius",
"g/m^3": "g m-3",
}
[docs]
def convert_units_var(
var_values: np.ndarray | float,
current_unit: str,
new_unit: str,
unit_conversion: dict[str, dict[str, float]] = unit_conversion,
) -> np.ndarray | float:
"""Converts variable values from one unit to another using a predefined conversion factor.
Parameters
----------
var_values : numpy.ndarray or float
The values to be converted.
current_unit : str
The current unit of the variable values.
new_unit : str
The target unit to which the variable values should be converted.
unit_conversion : dict of {str: dict of {str: float}}, optional
A dictionary containing conversion factors between units. The default is `unit_conversion`.
Returns
-------
numpy.ndarray or float
The converted variable values. If no conversion factor is found, the original values are returned.
Raises
------
KeyError
If the conversion factor for the specified units is not found in the `unit_conversion` dictionary.
Notes
-----
If the conversion factor for the specified units is not available, a message is printed, and the original
values are returned without any conversion.
"""
try:
conversion_factor = unit_conversion[current_unit][new_unit]
return var_values * conversion_factor
except KeyError:
print(f"No conversion information found for {current_unit} to {new_unit}")
return var_values
[docs]
def find_best_dtype(var_name: str, da: xr.DataArray) -> np.dtype:
"""Determines the most suitable data type for a given variable.
Parameters
----------
var_name : str
The name of the variable.
da : xarray.DataArray
The data array containing the variable's values.
Returns
-------
numpy.dtype
The optimal data type for the variable based on its name and values.
"""
input_dtype = da.dtype.type
if "latitude" in var_name.lower() or "longitude" in var_name.lower():
return np.double
if var_name[-2:].lower() == "qc":
return np.int8
if "time" in var_name.lower():
return input_dtype
if var_name[-3:] == "raw" or "int" in str(input_dtype):
if np.nanmax(da.values) < 2**16 / 2:
return np.int16
elif np.nanmax(da.values) < 2**32 / 2:
return np.int32
if input_dtype == np.float64:
return np.float32
return input_dtype
[docs]
def set_fill_value(new_dtype: np.dtype) -> int:
"""Calculate the fill value for a given data type.
Parameters
----------
new_dtype : numpy.dtype
The data type for which the fill value is to be calculated.
Returns
-------
int
The calculated fill value based on the bit-width of the data type.
"""
fill_val: int = 2 ** (int(re.findall(r"\d+", str(new_dtype))[0]) - 1) - 1
return fill_val
[docs]
def set_best_dtype(ds: xr.Dataset) -> xr.Dataset:
"""Adjust the data types of variables in a dataset to optimize memory usage.
Parameters
----------
ds : xarray.Dataset
The input dataset whose variables' data types will be adjusted.
Returns
-------
xarray.Dataset
The dataset with updated data types for its variables, potentially saving memory.
Notes
-----
- The function determines the best data type for each variable using `find_best_dtype`.
- Attributes like `valid_min` and `valid_max` are updated to match the new data type.
- If the new data type is integer-based, NaN values are replaced with a fill value.
- Logs the percentage of memory saved after the data type adjustments.
"""
bytes_in: int = ds.nbytes
for var_name in list(ds):
da: xr.DataArray = ds[var_name]
input_dtype: np.dtype = da.dtype.type
new_dtype: np.dtype = find_best_dtype(var_name, da)
for att in ["valid_min", "valid_max"]:
if att in da.attrs.keys():
da.attrs[att] = np.array(da.attrs[att]).astype(new_dtype)
if new_dtype == input_dtype:
continue
log_debug(f"{var_name} input dtype {input_dtype} change to {new_dtype}")
da_new: xr.DataArray = da.astype(new_dtype)
ds = ds.drop_vars(var_name)
if "int" in str(new_dtype):
fill_val: int = set_fill_value(new_dtype)
da_new[np.isnan(da)] = fill_val
da_new.encoding["_FillValue"] = fill_val
ds[var_name] = da_new
bytes_out: int = ds.nbytes
log_info(
f"Space saved by dtype downgrade: {int(100 * (bytes_in - bytes_out) / bytes_in)} %",
)
return ds
# ------------------------------------------------------------------------------------
# Time series filtering and binning functions
# ------------------------------------------------------------------------------------
[docs]
def to_decimal_year(dates: pd.Series) -> pd.Series:
"""Convert datetime series to decimal years, handling NaN values safely.
Parameters
----------
dates : pandas.Series or pandas.DatetimeIndex
Series or Index of datetime objects to convert.
Returns
-------
pandas.Series
Series of decimal years with NaN preserved for invalid dates.
Examples
--------
>>> import pandas as pd
>>> dates = pd.Series(['2020-01-01', '2020-07-01', '2021-01-01'])
>>> dates = pd.to_datetime(dates)
>>> decimal_years = to_decimal_year(dates)
"""
# Convert to Series if DatetimeIndex
if isinstance(dates, pd.DatetimeIndex):
dates = pd.Series(dates)
# Drop NaN values and handle them separately
valid_dates = dates.dropna()
if len(valid_dates) == 0:
return pd.Series([np.nan] * len(dates), index=dates.index)
year = valid_dates.dt.year
start = pd.to_datetime(year.astype(str) + "-01-01")
end = pd.to_datetime((year + 1).astype(str) + "-01-01")
decimal_years = year + (valid_dates - start) / (end - start)
# Create full series with NaN for invalid dates
result = pd.Series([np.nan] * len(dates), index=dates.index)
result.loc[valid_dates.index] = decimal_years
return result
[docs]
def extract_time_and_time_num(ds: xr.Dataset, time_var: str = "TIME") -> pd.DataFrame:
"""Extract time coordinates from xarray Dataset and convert to pandas DataFrame.
Parameters
----------
ds : xarray.Dataset
Dataset containing time coordinate.
time_var : str, default "TIME"
Name of the time variable in the dataset.
Returns
-------
pandas.DataFrame
DataFrame with 'time' (datetime) and 'time_num' (decimal year) columns.
"""
time = pd.to_datetime(ds[time_var].values)
df = pd.DataFrame({"time": time})
df["time_num"] = to_decimal_year(df["time"])
return df
[docs]
def bin_average_5day(
df: pd.DataFrame, time_column: str = "time", value_column: str = "moc"
) -> pd.DataFrame:
"""Bin-average a time series into 5-day means.
Parameters
----------
df : pandas.DataFrame
Input DataFrame with time and value columns.
time_column : str, default "time"
Name of the datetime column.
value_column : str, default "moc"
Name of the data column to average.
Returns
-------
pandas.DataFrame
DataFrame with 5-day averaged time and values.
"""
df = df.copy()
df[time_column] = pd.to_datetime(df[time_column])
# Bin by 5-day frequency and take the mean of each bin
df_binned = (
df.set_index(time_column).resample("5D")[value_column].mean().reset_index()
)
# Drop NaNs and return
df_binned = df_binned.dropna().rename(
columns={value_column: value_column, time_column: "time"}
)
return df_binned
[docs]
def bin_average_monthly(df: pd.DataFrame, time_column: str = "time") -> pd.DataFrame:
"""Bin-average a time series into monthly means.
Parameters
----------
df : pandas.DataFrame
Input DataFrame with time column.
time_column : str, default "time"
Name of the datetime column.
Returns
-------
pandas.DataFrame
DataFrame with monthly averaged data.
"""
df = df.copy()
df[time_column] = pd.to_datetime(df[time_column])
df_binned = df.set_index(time_column).resample("ME").mean().reset_index()
df_binned = df_binned.dropna().rename(columns={time_column: "time"})
return df_binned
[docs]
def check_and_bin(df: pd.DataFrame, time_column: str = "time") -> pd.DataFrame:
"""Check temporal resolution and bin to monthly if needed.
Parameters
----------
df : pandas.DataFrame
Input DataFrame with time column.
time_column : str, default "time"
Name of the datetime column.
Returns
-------
pandas.DataFrame
Original DataFrame if already monthly, or monthly-binned version.
"""
# Calculate median time difference in days
time_diffs = df[time_column].sort_values().diff().dt.total_seconds().dropna() / (
3600 * 24
)
median_diff = time_diffs.median()
if median_diff < 15:
return bin_average_monthly(df, time_column)
else:
return df
[docs]
def apply_tukey_filter(
df: pd.DataFrame,
column: str,
window_months: int = 6,
samples_per_day: float = 0.2,
alpha: float = 0.5,
add_back_mean: bool = False,
output_column: str | None = None,
) -> pd.DataFrame:
"""Apply a Tukey filter using NumPy convolution (safely handles NaN values).
This function uses pandas DataFrame input to leverage NumPy's convolution
capabilities with Tukey windows, which provides more flexibility than
xarray's built-in rolling operations for this specific filtering approach.
Parameters
----------
df : pandas.DataFrame
Input DataFrame containing the column to filter.
column : str
Name of the column to apply the filter to.
window_months : int, default 6
Filter window size in months.
samples_per_day : float, default 0.2
Expected number of samples per day in the data.
alpha : float, default 0.5
Tukey window parameter (0=rectangular, 1=Hann).
add_back_mean : bool, default False
Whether to remove and add back the overall mean.
output_column : str, optional
Name for the filtered output column. If None, uses "{column}_filtered".
Returns
-------
pandas.DataFrame
Copy of input DataFrame with filtered column added.
Notes
-----
Uses pandas DataFrame rather than xarray Dataset because pandas provides
better access to convolution operations with custom window functions.
"""
df = df.copy()
data = df[column].astype(float).values
# Replace NaNs with nanmean for stable filtering
nan_mask = np.isnan(data)
safe_data = np.where(nan_mask, np.nanmean(data), data)
if add_back_mean:
overall_mean = np.nanmean(safe_data)
safe_data = safe_data - overall_mean
else:
overall_mean = 0.0
samples_per_month = int(round(30.44 * samples_per_day))
window_len = window_months * samples_per_month
if window_len % 2 == 0:
window_len += 1
half_width = window_len // 2
# Build normalized Tukey window
win = tukey(window_len, alpha)
win /= win.sum()
# Apply convolution
filtered = np.convolve(safe_data, win, mode="same")
filtered += overall_mean
# Restore original NaNs and edge mask
filtered[nan_mask] = np.nan
filtered[:half_width] = np.nan
filtered[-half_width:] = np.nan
if output_column is None:
output_column = f"{column}_filtered"
df[output_column] = filtered
return df
[docs]
def handle_samba_gaps(df: pd.DataFrame, time_column: str = "time") -> pd.DataFrame:
"""Handle temporal gaps in SAMBA MOC data to prevent plotting artifacts.
SAMBA data has significant gaps (e.g., 2011-2014) that cause plotting functions
to draw connecting lines across missing periods. This function creates a regular
monthly grid and masks interpolation to only occur within existing data periods,
preventing spurious connections across large gaps.
Parameters
----------
df : pandas.DataFrame
Input DataFrame with time and MOC columns.
time_column : str, default "time"
Name of the datetime column.
Returns
-------
pandas.DataFrame
DataFrame with regular monthly grid and gap-aware data masking.
Notes
-----
PyGMT and other plotting functions connect all valid (non-NaN) data points
regardless of temporal gaps. This function prevents artifacts by:
1. Creating a regular monthly time grid
2. Preserving NaN values where no original data existed
3. Only interpolating within continuous data segments
"""
df_input = df.copy()
# Create a regular monthly time grid covering the range
monthly_time = pd.date_range(
start=df_input[time_column].min(), end=df_input[time_column].max(), freq="ME"
)
# Reindex to the monthly grid
df_monthly = df_input.set_index(time_column).reindex(monthly_time)
# For MOC column, preserve gaps by masking interpolation
if "moc" in df_monthly.columns:
# Track where original data existed
mask = df_monthly["moc"].notna()
# Create gap-aware MOC column
df_monthly["moc_interp"] = np.nan
df_monthly.loc[mask, "moc_interp"] = df_monthly.loc[mask, "moc"]
# Mark locations with original data
df_monthly["had_data"] = mask
# Replace original moc with gap-aware version
df_monthly["moc"] = df_monthly["moc_interp"].where(mask)
df_monthly = df_monthly.drop(columns=["moc_interp"])
# Reset index for further use
df_monthly = df_monthly.reset_index().rename(columns={"index": time_column})
# Recalculate time_num if it exists
if "time_num" in df_monthly.columns:
df_monthly["time_num"] = to_decimal_year(df_monthly[time_column])
return df_monthly