Source code for swiftest.data

from __future__ import annotations
import xarray as xr
import numpy as np
from scipy.spatial.transform import Rotation as R
from .constants import *

[docs] class SwiftestDataArray(xr.DataArray): """ N-dimensional ``xarray.DataArray``-like array. Inherits from ``xarray.DataArray`` and has its own set of methods and attributes specific to the Swiftest project Parameters ---------------- *args: Arguments for the ``xarray.DataArray`` class **kwargs: Keyword arguments for the ``xarray.DataArray`` class Notes ----- See `xarray.DataArray <https://docs.xarray.dev/en/stable/generated/xarray.DataArray.html>`__ for further information about DataArrays. """ __slots__ = ()
[docs] def __init__(self, *args,**kwargs): super().__init__(*args, **kwargs)
@classmethod def _construct_direct(cls, *args, **kwargs): """Override to make the result a ``swiftest.SwiftestDataArray`` class.""" return cls(xr.DataArray._construct_direct(*args, **kwargs)) def to_dataset(self): """Converts a ``SwiftestDataArray`` into a ``SwiftestDataset`` with a single data variable.""" xrds = super().to_dataset() return SwiftestDataset(xrds)
[docs] def magnitude(self, name: str | None = None): """ Computes the magnitude of a vector quantity. Note: The DataArray must have the "space" dimension. Parameters ---------- name : str, optional Name of the new DataArray. By default, the string "_mag" is appended to the original name. Returns ------- mag : SwiftestDataArray DataArray containing the magnitude of the vector quantity """ dim = "space" ord = None if dim not in self.dims: raise ValueError(f"Dimension {dim} not found in DataArray") if name is None and isinstance(self.name, str): name = self.name + "_mag" da = xr.apply_ufunc( np.linalg.norm, self.where(~np.isnan(self)), input_core_dims=[[dim]], kwargs={"ord": ord, "axis": -1}, dask="allowed" ) da = da.rename(name) return SwiftestDataArray(da)
[docs] def rotate(self, rotation): """ Rotates a vector quantity using a rotation matrix. The DataArray must have the "space" dimension. Parameters ---------- rotation : (3) float array Rotation vector """ if "space" not in self.dims: raise ValueError("DataArray must have a 'space' dimension") # Define a function to apply the rotation, which will be used with apply_ufunc def apply_rotation(vector, rotation): if not rotation.single: # If 'rotation' is a stack of rotations, apply each rotation sequentially for single_rotation in rotation: vector = single_rotation.apply(vector) return vector else: # If 'rotation' represents a single rotation, apply it directly return rotation.apply(vector) da = xr.apply_ufunc( apply_rotation, self, kwargs={'rotation': rotation}, input_core_dims=[['space']], output_core_dims=[['space']], vectorize=True, dask='parallelized', output_dtypes=[self.dtype] ) return SwiftestDataArray(da)
[docs] class SwiftestDataset(xr.Dataset): """ A ``xarray.Dataset``-like, multi-dimensional, in memory, array database. Inherits from ``xarray.Dataset`` and has its own set of methods and attributes specific to the Swiftest project. Parameters ---------------- *args: Arguments for the ``xarray.Dataset`` class **kwargs: Keyword arguments for the ``xarray.Dataset`` class Notes ----- See `xarray.Dataset <https://docs.xarray.dev/en/stable/generated/xarray.Dataset.html>`__ for further information about Datasets. """ __slots__ = ()
[docs] def __init__( self, *args, **kwargs, ): super().__init__(*args, **kwargs)
def __getitem__(self, key): """Override to make sure the result is an instance of ``swiftest.SwiftestDataArray`` or ``swiftest.SwiftestDataset``.""" value = super().__getitem__(key) if isinstance(value, xr.DataArray): value = SwiftestDataArray(value) elif isinstance(value, xr.Dataset): value = SwiftestDataset(value) return value def _calculate_binary_op(self, *args, **kwargs): """Override to make the result a complete instance of ``swiftest.SwiftestDataset``.""" ds = super()._calculate_binary_op(*args, **kwargs) if not isinstance(ds, SwiftestDataset): ds = SwiftestDataset(ds) ds = SwiftestDataset(ds) return ds def _construct_dataarray(self, name) -> SwiftestDataArray: """Override to make the result an instance of ``swiftest.SwiftestDataArray``.""" xarr = super()._construct_dataarray(name) return SwiftestDataArray(xarr) @classmethod def _construct_direct(cls, *args, **kwargs): """Override to make the result an ``swiftest.SwiftestDataset`` class.""" return cls(xr.Dataset._construct_direct(*args, **kwargs)) def _replace(self, *args, **kwargs): """Override to make the result a complete instance of ``swiftest.SwiftestDataset``.""" ds = super()._replace(*args, **kwargs) if not isinstance(ds, SwiftestDataset): ds = SwiftestDataset(ds) return ds @classmethod def from_dataframe(cls, dataframe): """Override to make the result a ``swiftest.SwiftestDataset`` class.""" return cls( {col: ("index", dataframe[col].values) for col in dataframe.columns}, coords={"index": dataframe.index}, ) @classmethod def from_dict(cls, data, **kwargs): """Override to make the result a ``swiftest.SwiftestDataset`` class.""" return cls( {key: ("index", val) for key, val in data.items()}, coords={"index": range(len(next(iter(data.values()))))}, **kwargs, )
[docs] def rotate(self, rotvec=None, pole=None, skip_vars=['space','Ip']): """ Rotates the coordinate system such that the z-axis is aligned with an input pole. The new pole is defined by the input vector. This will change all variables in the Dataset that have the "space" dimension, except for those passed to the skip_vars parameter. Parameters ---------- ds : SwiftestDataset Dataset containing the vector quantity rotvec: (N,3) or (3,) float array Rotation vector pole : (3) float array New pole vector skip_vars : list of str, optional List of variable names to skip. The default is ['space','Ip']. Returns ------- ds : SwiftestDataset Dataset with the new pole vector applied to all variables with the "space" dimension Notes ----- You can pass either rotvec or pole, but not both. If both, or none, are passed, the function will raise an exception. """ if rotvec is not None and pole is not None: raise ValueError("You can only pass either rotvec or pole, but not both") if rotvec is None and pole is None: raise ValueError("You must pass either rotvec or pole") if 'space' not in self.dims: raise ValueError("Dataset must have a 'space' dimension") # Verify that the new pole is a 3-element array if pole is not None: if len(pole) != 3: raise ValueError("Pole vector must be a 3-element array") # Normalize the new pole vector to ensure it is a unit vector pole_mag = np.linalg.norm(pole) unit_pole = pole / pole_mag # Define the original and target vectors target_vector = np.array([0, 0, 1]) # Rotate so that the z-axis is aligned with the new pole original_vector = unit_pole.reshape(1, 3) # Use align_vectors to get the rotation that aligns the z-axis with Mars_rot rotvec, _ = R.align_vectors(target_vector, original_vector) elif rotvec is not None: rotvec = np.asarray(rotvec) if (rotvec.shape[-1]) != 3: raise ValueError("Rotation vector must be a 3-element array") rotvec = R.from_rotvec(rotvec) # Loop through each variable in the dataset and apply the rotation if 'space' dimension is present for var in self.variables: if 'space' in self[var].dims and var not in skip_vars: self[var] = self[var].rotate(rotvec) return self
[docs] def el2xv(self, GMcb: xr.DataArray | float | None = None) -> SwiftestDataset: """ Converts a Dataset's orbital elements to Cartesian state vectors. The DataArray must have the appropriate dimensions for orbital elements. Parameters ---------- GMcb : xr.DataArray or float Gravitational parameter of the central body Returns ------- SwiftestDataset Dataset containing the computed state vectors (position 'rh' and velocity 'vh'). """ from .core import el2xv # Assuming el2xv is implemented in the .core module if 'space' not in self.dims: raise ValueError("Dataset must have a 'space' dimension") required_vars = ['a', 'e', 'inc', 'capom', 'omega', 'capm'] for var in required_vars: if var not in self.variables: raise ValueError(f"Dataset must have '{var}' variables") # Identify the index dimension if 'id' in self.dims: index_dim = 'id' elif 'name' in self.dims: index_dim = 'name' else: raise ValueError("Dataset must have an 'id' or 'name' dimension") if GMcb is None: if 'Gmass' not in self: raise ValueError("Dataset must have a 'Gmass' variable for the central body") if 'particle_type' in self.variables: GMcb = self['Gmass'].where(self['particle_type'] == CB_TYPE_NAME, drop=True) else: GMcb = self['Gmass'].where(self['id'] == 0, drop=True) if GMcb.size != 1: raise ValueError("Dataset must have a single central body") if isinstance(GMcb, xr.DataArray): if 'id' in GMcb.dims: GMcb = GMcb.isel(id=0) elif 'name' in GMcb.dims: GMcb = GMcb.isel(name=0) if isinstance(GMcb, xr.DataArray): if 'id' in GMcb.dims: GMcb = GMcb.isel(id=0) elif 'name' in GMcb.dims: GMcb = GMcb.isel(name=0) else: GMcb = xr.DataArray(data = GMcb) for dim in self.a.dims: if dim not in GMcb.dims: GMcb = GMcb.expand_dims(dim={dim: self[dim]}) if 'Gmass' in self: mu = xr.where(self['Gmass'] > 0.0, GMcb + self['Gmass'], GMcb) else: mu = GMcb # Prepare the orbital elements for the function call a = self['a'].astype(np.float64) e = self['e'].astype(np.float64) inc = self['inc'].astype(np.float64) capom = self['capom'].astype(np.float64) omega = self['omega'].astype(np.float64) capm = self['capm'].astype(np.float64) mu = mu.astype(np.float64) # Use apply_ufunc to convert orbital elements back to state vectors rh, vh = xr.apply_ufunc( el2xv, # Function to apply mu, a, e, inc, capom, omega, capm, # Inputs input_core_dims=[[index_dim], [index_dim], [index_dim], [index_dim], [index_dim], [index_dim], [index_dim]], # Core dimensions for each input output_core_dims=[[index_dim, 'space'], [index_dim, 'space']], # Core dimensions for outputs (position and velocity vectors) vectorize=True, # Automatically vectorize over non-core dimensions dask="parallelized", # Enable parallelized computation for Dask arrays, if applicable output_dtypes=[np.float64, np.float64] # Expected data types for outputs ) # Create a new Dataset with the state vectors new_vars = {'rh': rh, 'vh': vh} dataset = xr.Dataset(new_vars) if "name" in dataset.variables: dataset = dataset.drop_vars("name") dsnew = xr.merge([dataset, self], compat="override") return SwiftestDataset(dsnew)
[docs] def xv2el(self, GMcb: xr.DataArray | float | None = None) -> SwiftestDataset: """ Converts A Dataset's Cartesian state vectors to orbital elements. The DataArray must have the "space" dimension. Parameters ---------- GMcb : xr.DataArray or float Gravitational parameter of the central body Returns ------- SwiftestDataset Dataset containing the computed orbital elements. """ from .core import xv2el if 'space' not in self.dims: raise ValueError("Dataset must have a 'space' dimension") if 'rh' not in self.variables or 'vh' not in self.variables: raise ValueError("Dataset must have 'rh' and 'vh' variables") if 'id' in self.dims: index_dim = 'id' elif 'name' in self.dims: index_dim = 'name' else: raise ValueError("Dataset must have an 'id' or 'name' dimension") if GMcb is None: if 'Gmass' not in self: raise ValueError("Dataset must have a 'Gmass' variable for the central body") if 'particle_type' in self.variables: GMcb = self['Gmass'].where(self['particle_type'] == CB_TYPE_NAME, drop=True) else: GMcb = self['Gmass'].where(self['id'] == 0, drop=True) if GMcb.isel(time=0).size != 1: raise ValueError("Dataset must have a single central body") if isinstance(GMcb, xr.DataArray): if 'id' in GMcb.dims: GMcb = GMcb.isel(id=0) elif 'name' in GMcb.dims: GMcb = GMcb.isel(name=0) else: GMcb = xr.DataArray(data = GMcb) for dim in self.rh.dims: if dim not in GMcb.dims and dim not in ['space']: GMcb = GMcb.expand_dims(dim={dim: self[dim]}) if 'Gmass' in self: mu = xr.where(self['Gmass'] > 0.0, GMcb + self['Gmass'], GMcb) else: mu = GMcb # Prepare the cartesian state vectorsfor the function call rh = self['rh'].astype(np.float64) vh = self['vh'].astype(np.float64) mu = mu.astype(np.float64) result = xr.apply_ufunc( xv2el, mu, rh, vh, input_core_dims=[[index_dim], [index_dim, 'space'], [index_dim, 'space']], output_core_dims=[[index_dim]] * 11, vectorize=True, dask="parallelized", output_dtypes=[np.float64] * 11 ) varnames = ['a', 'e', 'inc', 'capom', 'omega', 'capm', 'varpi', 'lam', 'f', 'cape', 'capf'] dataset = xr.Dataset({var: result[i] for i, var in enumerate(varnames)}) if "name" in dataset.variables: dataset = dataset.drop_vars("name") dsnew = xr.merge([dataset, self], compat="override") return SwiftestDataset(dsnew)