Source code for amocatlas.tools

"""AMOCatlas analysis tools for data processing, filtering, and calculations."""

import re

import numpy as np
import pandas as pd
import xarray as xr
from scipy.signal.windows import tukey

from amocatlas import logger
from amocatlas.logger import log_info, log_debug

log = logger.log


[docs] def generate_reverse_conversions( forward_conversions: dict[str, dict[str, float]], ) -> dict[str, dict[str, float]]: """Create a unit conversion dictionary with both forward and reverse conversions. Parameters ---------- forward_conversions : dict of {str: dict of {str: float}} Mapping of source units to target units and conversion factors. Example: {"m": {"cm": 100, "km": 0.001}} Returns ------- dict of {str: dict of {str: float}} Complete mapping of units including reverse conversions. Example: {"cm": {"m": 0.01}, "km": {"m": 1000}} Notes ----- If a conversion factor is zero, a warning is printed, and the reverse conversion is skipped. """ complete_conversions: dict[str, dict[str, float]] = {} for from_unit, targets in forward_conversions.items(): complete_conversions.setdefault(from_unit, {}) for to_unit, factor in targets.items(): complete_conversions[from_unit][to_unit] = factor complete_conversions.setdefault(to_unit, {}) if factor == 0: print( f"Warning: zero factor in conversion from {from_unit} to {to_unit}", ) continue complete_conversions[to_unit][from_unit] = 1 / factor return complete_conversions
# Various conversions from the key to units_name with the multiplicative conversion factor base_unit_conversion = { "cm/s": {"m/s": 0.01}, "cm s-1": {"m s-1": 0.01}, "S/m": {"mS/cm": 0.1}, "dbar": {"Pa": 10000, "kPa": 10}, "degrees_Celsius": {"degree_Celsius": 1}, "m": {"cm": 100, "km": 0.001}, "g m-3": {"kg m-3": 0.001}, "sverdrup": {"Sv": 1}, "PW": {"W": 1e15}, } unit_conversion = generate_reverse_conversions(base_unit_conversion) # Specify the preferred units, and it will convert if the conversion is available in unit_conversion preferred_units = [ "m s-1", "dbar", "S m-1", "sverdrup", "degree_Celsius", "kg m-3", "m", "degree_north", "degree_east", "watt", ] # String formats for units. The key is the original, the value is the desired format unit_str_format = { "m/s": "m s-1", "cm/s": "cm s-1", "S/m": "S m-1", "meters": "m", "degrees_Celsius": "degree_Celsius", "g/m^3": "g m-3", }
[docs] def reformat_units_var( ds: xr.Dataset, var_name: str, unit_format: dict[str, str] = unit_str_format, ) -> str: """Reformat the units of a variable in the dataset based on a provided mapping. Parameters ---------- ds : xarray.Dataset The input dataset containing variables with units to be reformatted. var_name : str The name of the variable whose units need to be reformatted. unit_format : dict of {str: str}, optional A dictionary mapping old unit strings to new formatted unit strings. Defaults to `unit_str_format`. Returns ------- str The reformatted unit string. If the old unit is not found in `unit_format`, the original unit string is returned. """ old_unit = ds[var_name].attrs["units"] new_unit = unit_format.get(old_unit, old_unit) return new_unit
[docs] def convert_units_var( var_values: np.ndarray | float, current_unit: str, new_unit: str, unit_conversion: dict[str, dict[str, float]] = unit_conversion, ) -> np.ndarray | float: """Converts variable values from one unit to another using a predefined conversion factor. Parameters ---------- var_values : numpy.ndarray or float The values to be converted. current_unit : str The current unit of the variable values. new_unit : str The target unit to which the variable values should be converted. unit_conversion : dict of {str: dict of {str: float}}, optional A dictionary containing conversion factors between units. The default is `unit_conversion`. Returns ------- numpy.ndarray or float The converted variable values. If no conversion factor is found, the original values are returned. Raises ------ KeyError If the conversion factor for the specified units is not found in the `unit_conversion` dictionary. Notes ----- If the conversion factor for the specified units is not available, a message is printed, and the original values are returned without any conversion. """ try: conversion_factor = unit_conversion[current_unit][new_unit] return var_values * conversion_factor except KeyError: print(f"No conversion information found for {current_unit} to {new_unit}") return var_values
[docs] def find_best_dtype(var_name: str, da: xr.DataArray) -> np.dtype: """Determines the most suitable data type for a given variable. Parameters ---------- var_name : str The name of the variable. da : xarray.DataArray The data array containing the variable's values. Returns ------- numpy.dtype The optimal data type for the variable based on its name and values. """ input_dtype = da.dtype.type if "latitude" in var_name.lower() or "longitude" in var_name.lower(): return np.double if var_name[-2:].lower() == "qc": return np.int8 if "time" in var_name.lower(): return input_dtype if var_name[-3:] == "raw" or "int" in str(input_dtype): if np.nanmax(da.values) < 2**16 / 2: return np.int16 elif np.nanmax(da.values) < 2**32 / 2: return np.int32 if input_dtype == np.float64: return np.float32 return input_dtype
[docs] def set_fill_value(new_dtype: np.dtype) -> int: """Calculate the fill value for a given data type. Parameters ---------- new_dtype : numpy.dtype The data type for which the fill value is to be calculated. Returns ------- int The calculated fill value based on the bit-width of the data type. """ fill_val: int = 2 ** (int(re.findall(r"\d+", str(new_dtype))[0]) - 1) - 1 return fill_val
[docs] def set_best_dtype(ds: xr.Dataset) -> xr.Dataset: """Adjust the data types of variables in a dataset to optimize memory usage. Parameters ---------- ds : xarray.Dataset The input dataset whose variables' data types will be adjusted. Returns ------- xarray.Dataset The dataset with updated data types for its variables, potentially saving memory. Notes ----- - The function determines the best data type for each variable using `find_best_dtype`. - Attributes like `valid_min` and `valid_max` are updated to match the new data type. - If the new data type is integer-based, NaN values are replaced with a fill value. - Logs the percentage of memory saved after the data type adjustments. """ bytes_in: int = ds.nbytes for var_name in list(ds): da: xr.DataArray = ds[var_name] input_dtype: np.dtype = da.dtype.type new_dtype: np.dtype = find_best_dtype(var_name, da) for att in ["valid_min", "valid_max"]: if att in da.attrs.keys(): da.attrs[att] = np.array(da.attrs[att]).astype(new_dtype) if new_dtype == input_dtype: continue log_debug(f"{var_name} input dtype {input_dtype} change to {new_dtype}") da_new: xr.DataArray = da.astype(new_dtype) ds = ds.drop_vars(var_name) if "int" in str(new_dtype): fill_val: int = set_fill_value(new_dtype) da_new[np.isnan(da)] = fill_val da_new.encoding["_FillValue"] = fill_val ds[var_name] = da_new bytes_out: int = ds.nbytes log_info( f"Space saved by dtype downgrade: {int(100 * (bytes_in - bytes_out) / bytes_in)} %", ) return ds
# ------------------------------------------------------------------------------------ # Time series filtering and binning functions # ------------------------------------------------------------------------------------
[docs] def to_decimal_year(dates: pd.Series) -> pd.Series: """Convert datetime series to decimal years, handling NaN values safely. Parameters ---------- dates : pandas.Series or pandas.DatetimeIndex Series or Index of datetime objects to convert. Returns ------- pandas.Series Series of decimal years with NaN preserved for invalid dates. Examples -------- >>> import pandas as pd >>> dates = pd.Series(['2020-01-01', '2020-07-01', '2021-01-01']) >>> dates = pd.to_datetime(dates) >>> decimal_years = to_decimal_year(dates) """ # Convert to Series if DatetimeIndex if isinstance(dates, pd.DatetimeIndex): dates = pd.Series(dates) # Drop NaN values and handle them separately valid_dates = dates.dropna() if len(valid_dates) == 0: return pd.Series([np.nan] * len(dates), index=dates.index) year = valid_dates.dt.year start = pd.to_datetime(year.astype(str) + "-01-01") end = pd.to_datetime((year + 1).astype(str) + "-01-01") decimal_years = year + (valid_dates - start) / (end - start) # Create full series with NaN for invalid dates result = pd.Series([np.nan] * len(dates), index=dates.index) result.loc[valid_dates.index] = decimal_years return result
[docs] def extract_time_and_time_num(ds: xr.Dataset, time_var: str = "TIME") -> pd.DataFrame: """Extract time coordinates from xarray Dataset and convert to pandas DataFrame. Parameters ---------- ds : xarray.Dataset Dataset containing time coordinate. time_var : str, default "TIME" Name of the time variable in the dataset. Returns ------- pandas.DataFrame DataFrame with 'time' (datetime) and 'time_num' (decimal year) columns. """ time = pd.to_datetime(ds[time_var].values) df = pd.DataFrame({"time": time}) df["time_num"] = to_decimal_year(df["time"]) return df
[docs] def bin_average_5day( df: pd.DataFrame, time_column: str = "time", value_column: str = "moc" ) -> pd.DataFrame: """Bin-average a time series into 5-day means. Parameters ---------- df : pandas.DataFrame Input DataFrame with time and value columns. time_column : str, default "time" Name of the datetime column. value_column : str, default "moc" Name of the data column to average. Returns ------- pandas.DataFrame DataFrame with 5-day averaged time and values. """ df = df.copy() df[time_column] = pd.to_datetime(df[time_column]) # Bin by 5-day frequency and take the mean of each bin df_binned = ( df.set_index(time_column).resample("5D")[value_column].mean().reset_index() ) # Drop NaNs and return df_binned = df_binned.dropna().rename( columns={value_column: value_column, time_column: "time"} ) return df_binned
[docs] def bin_average_monthly(df: pd.DataFrame, time_column: str = "time") -> pd.DataFrame: """Bin-average a time series into monthly means. Parameters ---------- df : pandas.DataFrame Input DataFrame with time column. time_column : str, default "time" Name of the datetime column. Returns ------- pandas.DataFrame DataFrame with monthly averaged data. """ df = df.copy() df[time_column] = pd.to_datetime(df[time_column]) df_binned = df.set_index(time_column).resample("ME").mean().reset_index() df_binned = df_binned.dropna().rename(columns={time_column: "time"}) return df_binned
[docs] def check_and_bin(df: pd.DataFrame, time_column: str = "time") -> pd.DataFrame: """Check temporal resolution and bin to monthly if needed. Parameters ---------- df : pandas.DataFrame Input DataFrame with time column. time_column : str, default "time" Name of the datetime column. Returns ------- pandas.DataFrame Original DataFrame if already monthly, or monthly-binned version. """ # Calculate median time difference in days time_diffs = df[time_column].sort_values().diff().dt.total_seconds().dropna() / ( 3600 * 24 ) median_diff = time_diffs.median() if median_diff < 15: return bin_average_monthly(df, time_column) else: return df
[docs] def apply_tukey_filter( df: pd.DataFrame, column: str, window_months: int = 6, samples_per_day: float = 0.2, alpha: float = 0.5, add_back_mean: bool = False, output_column: str | None = None, ) -> pd.DataFrame: """Apply a Tukey filter using NumPy convolution (safely handles NaN values). This function uses pandas DataFrame input to leverage NumPy's convolution capabilities with Tukey windows, which provides more flexibility than xarray's built-in rolling operations for this specific filtering approach. Parameters ---------- df : pandas.DataFrame Input DataFrame containing the column to filter. column : str Name of the column to apply the filter to. window_months : int, default 6 Filter window size in months. samples_per_day : float, default 0.2 Expected number of samples per day in the data. alpha : float, default 0.5 Tukey window parameter (0=rectangular, 1=Hann). add_back_mean : bool, default False Whether to remove and add back the overall mean. output_column : str, optional Name for the filtered output column. If None, uses "{column}_filtered". Returns ------- pandas.DataFrame Copy of input DataFrame with filtered column added. Notes ----- Uses pandas DataFrame rather than xarray Dataset because pandas provides better access to convolution operations with custom window functions. """ df = df.copy() data = df[column].astype(float).values # Replace NaNs with nanmean for stable filtering nan_mask = np.isnan(data) safe_data = np.where(nan_mask, np.nanmean(data), data) if add_back_mean: overall_mean = np.nanmean(safe_data) safe_data = safe_data - overall_mean else: overall_mean = 0.0 samples_per_month = int(round(30.44 * samples_per_day)) window_len = window_months * samples_per_month if window_len % 2 == 0: window_len += 1 half_width = window_len // 2 # Build normalized Tukey window win = tukey(window_len, alpha) win /= win.sum() # Apply convolution filtered = np.convolve(safe_data, win, mode="same") filtered += overall_mean # Restore original NaNs and edge mask filtered[nan_mask] = np.nan filtered[:half_width] = np.nan filtered[-half_width:] = np.nan if output_column is None: output_column = f"{column}_filtered" df[output_column] = filtered return df
[docs] def handle_samba_gaps(df: pd.DataFrame, time_column: str = "time") -> pd.DataFrame: """Handle temporal gaps in SAMBA MOC data to prevent plotting artifacts. SAMBA data has significant gaps (e.g., 2011-2014) that cause plotting functions to draw connecting lines across missing periods. This function creates a regular monthly grid and masks interpolation to only occur within existing data periods, preventing spurious connections across large gaps. Parameters ---------- df : pandas.DataFrame Input DataFrame with time and MOC columns. time_column : str, default "time" Name of the datetime column. Returns ------- pandas.DataFrame DataFrame with regular monthly grid and gap-aware data masking. Notes ----- PyGMT and other plotting functions connect all valid (non-NaN) data points regardless of temporal gaps. This function prevents artifacts by: 1. Creating a regular monthly time grid 2. Preserving NaN values where no original data existed 3. Only interpolating within continuous data segments """ df_input = df.copy() # Create a regular monthly time grid covering the range monthly_time = pd.date_range( start=df_input[time_column].min(), end=df_input[time_column].max(), freq="ME" ) # Reindex to the monthly grid df_monthly = df_input.set_index(time_column).reindex(monthly_time) # For MOC column, preserve gaps by masking interpolation if "moc" in df_monthly.columns: # Track where original data existed mask = df_monthly["moc"].notna() # Create gap-aware MOC column df_monthly["moc_interp"] = np.nan df_monthly.loc[mask, "moc_interp"] = df_monthly.loc[mask, "moc"] # Mark locations with original data df_monthly["had_data"] = mask # Replace original moc with gap-aware version df_monthly["moc"] = df_monthly["moc_interp"].where(mask) df_monthly = df_monthly.drop(columns=["moc_interp"]) # Reset index for further use df_monthly = df_monthly.reset_index().rename(columns={"index": time_column}) # Recalculate time_num if it exists if "time_num" in df_monthly.columns: df_monthly["time_num"] = to_decimal_year(df_monthly[time_column]) return df_monthly