domino.core

   1import xarray as xr
   2import numpy as np
   3import pandas as pd
   4import datetime as dt
   5import itertools as it
   6from scipy.ndimage import label
   7import dateutil.relativedelta as durel
   8
   9from domino import agg
  10from domino.categorical_analysis import get_transmat, synthetic_states_from_transmat
  11from domino.util import holm_bonferroni_correction, split_to_contiguous, is_time_type, make_all_dims_coords, drop_scalar_coords, squeeze_da,offset_time_dim
  12from domino.filtering import ds_large_regions, convolve_pad_ds
  13from domino.deseasonaliser import Agg_Deseasonaliser
  14
  15class LaggedAnalyser(object):
  16    """Computes lagged composites of variables with respect to a categorical categorical event series, with support for bootstrap resampling to provide a non-parametric assessment of composite significance, and for deseasonalisation of variables.
  17    
  18    **Arguments:**
  19        
  20    *event*
  21            
  22    An xarray.DataArray with one dimension taking on categorical values, each defining a class of event (or non-event).
  23            
  24    **Optional arguments**
  25        
  26    *variables, name, is_categorical*
  27        
  28    Arguments for adding variables to the LaggedAnalyser. Identical behaviour to calling *LaggedAnalyser.add_variables* directly.
  29    """
  30    
  31    def __init__(self,event,variables=None,name=None,is_categorical=None):
  32        """Initialise a new LaggedAnalyser object."""
  33        
  34        #: event is a dataarray
  35        self.event=xr.DataArray(event)#: This is a docstring?
  36        """@private"""
  37        
  38        #variables are stored in a dataset, and can be added later,
  39        #or passed as a DataArray, a Dataset or as a dict of DataArrays
  40        self.variables=xr.Dataset(coords=event.coords)
  41        """@private"""
  42
  43        if variables is not None:
  44            self.add_variable(variables,name,is_categorical,False)
  45            
  46        #Time lagged versions of the dataset self.variables will be stored here, with a key
  47        #equal to the lag applied. Designed to be accessed by the self.lagged_variables function
  48        self._lagged_variables={}
  49        self.lagged_means=None
  50        """@private"""
  51
  52        #variables that are a linear combination of other variables are more efficiently
  53        #computed after compositing using the self.add_derived_composite method
  54        self._derived_variables={}
  55        self._deseasonalisers={}
  56        
  57        self.composite_mask=None
  58        """@private"""
  59
  60        self.boot_indices=None
  61        """@private"""
  62
  63        return
  64    
  65    def __repr__(self):
  66        l1='A LaggedAnalyser object\n'
  67        l2='event:\n\n'
  68        da_string=self.event.__str__().split('\n')[0]
  69        l3='\n\nvariables:\n\n'
  70        ds_string=self.variables.__str__().split('\n')
  71        ds_string=ds_string[0]+' '+ds_string[1]
  72        ds_string2='\n'+self.variables.data_vars.__str__()
  73        if list(self._lagged_variables.keys())!=[]:
  74            lag_string=f'\n Lagged variables at time intervals:\n {list(self._lagged_variables.keys())}'
  75        else:
  76            lag_string=""
  77        return "".join([l1,l2,da_string,l3,ds_string,ds_string2,lag_string])
  78    
  79    def __str__(self):
  80        return self.__repr__()
  81
  82    def add_variable(self,variables,name=None,is_categorical=None,overwrite=False,join_type='outer'):
  83        """Adds an additional variable to LaggedAnalyser.variables.
  84        
  85        **Arguments**
  86        
  87        *variables* 
  88        
  89        An xarray.DataArray, xarray.Dataset or dictionary of xarray.DataArrays, containing data to be composited with respect to *event*. One of the coordinates of *variables* should have the same name as the coordinate of *events*. Stored internally as an xarray.Dataset. If a dictionary is passed, the DataArrays are joined according to the method *join_type* which defaults to 'outer'.
  90            
  91        **Optional Arguments**
  92        
  93        *name* 
  94        
  95        A string. If *variables* is a single xarray.DataArray then *name* will be used as the name of the array in the LaggedAnalyser.variables DataArray. Otherwise ignored.
  96        
  97        *is_categorical* 
  98        
  99        An integer, if *variables* is an xarray.DataArray, or else a dictionary of integers with keys corresponding to DataArrays in the xarray.Dataset/dictionary. 0 indicates that the variable is continuous, and 1 indicates that it is categorical. Note that continuous and categorical variables are by default composited differently (see LaggedAnalyser.compute_composites). Default assumption is all DataArrays are continuous, unless a DataAarray contains an 'is_categorical' key in its DataArray.attrs, in which case this value is used.
 100            
 101        *overwrite*
 102        
 103        A boolean. If False then attempts to assign a variable who's name is already in *LaggedAnalyser.variables* will raise a ValueError
 104        
 105        *join_type*
 106        
 107        A string setting the rules for how differences in the coordinate indices of different variables are handled:
 108        “outer”: use the union of object indexes
 109        “inner”: use the intersection of object indexes
 110
 111        “left”: use indexes from the pre-existing *LaggedAnalyser.variables* with each dimension
 112
 113        “right”: use indexes from the new *variables* with each dimension
 114
 115        “exact”: instead of aligning, raise ValueError when indexes to be aligned are not equal
 116
 117        “override”: if indexes are of same size, rewrite indexes to be those of the pre-existing *LaggedAnalyser.variables*. Indexes for the same dimension must have the same size in all objects.
 118        """
 119        if isinstance(variables,dict):
 120            
 121            if is_categorical is None:
 122                is_categorical={v:None for v in variables}
 123                
 124            [self._add_variable(da,v,is_categorical[v],overwrite,join_type) for v,da in variables.items()]
 125            
 126        elif isinstance(variables,xr.Dataset):
 127            self.add_variable({v:variables[v] for v in variables.data_vars},None,is_categorical,overwrite,join_type)
 128            
 129        else:
 130            
 131            self._add_variable(variables,name,is_categorical,overwrite,join_type)            
 132        return
 133    
 134    def _more_mergable(self,ds):
 135        
 136        return drop_scalar_coords(make_all_dims_coords(ds))
 137    
 138    def _add_variable(self,da,name,is_categorical,overwrite,join_type):
 139        
 140        if name is None:
 141            name=da.name
 142        if (name in self.variables)&(not overwrite):
 143            raise(KeyError(f'Key "{name}" is already in variables.'))
 144        
 145        try:
 146            self.variables=self.variables.merge(squeeze_da(da).to_dataset(name=name),join=join_type)
 147        except:
 148            #Trying to make the merge work:
 149            self.variables=self._more_mergable(self.variables).merge(self._more_mergable(squeeze_da(da).to_dataset(name=name)),join=join_type)
 150
 151        if (is_categorical is None) and (not 'is_categorical' in da.attrs):
 152            self.variables[name].attrs['is_categorical']=0
 153        elif is_categorical is not None:
 154            self.variables[name].attrs['is_categorical']=is_categorical
 155
 156    def lagged_variables(self,t):
 157        """A convenience function that retrieves variables at lag *t* from the *LaggedAnalyser*"""
 158        if t in self._lagged_variables:
 159            return self._lagged_variables[t]
 160        elif t==0:
 161            return self.variables
 162        else:
 163            raise(KeyError(f'Lag {t} is not in self._lagged_variables.'))
 164
 165    def _lag_variables(self,offset,offset_unit='days',offset_dim='time',mode='any',overwrite=False):
 166        
 167        if offset==0:
 168            return
 169        if (offset in self._lagged_variables)&(not overwrite):
 170            raise(KeyError(f'Key "{offset}" is already in lagged_variables.'))
 171            
 172        #We are really paranoid about mixing up our lags. So we implement this safety check
 173        self._check_offset_is_valid(offset,mode)
 174        
 175        #REPLACED PREVIOUS IMPLEMENTATION WITH EQUIVALENT UTIL IMPORT.
 176        self._lagged_variables[offset]=offset_time_dim(self.variables,-offset,offset_unit=offset_unit,offset_dim=offset_dim)
 177
 178        return
 179    
 180    #For coords not in a time format
 181    def _ilag_variables(self,offset,*args,overwrite=False):
 182        raise(NotImplementedError('Only lagging along timelike dimensions is currently supported.'))
 183        
 184    def lag_variables(self,offsets,offset_unit='days',offset_dim='time',mode='any',overwrite=False):
 185        """Produces time lags of *LaggedAnalyser.variables*, which can be used to produce lagged composites.
 186        
 187        **Arguments**
 188        
 189        *offsets*
 190        
 191        An iterable of integers which represent time lags at which to lag *LaggedAnalyser.variables* in the units specified by *offset_unit*. Positive offsets denote variables *preceding* the event.
 192            
 193        **Optional arguments**
 194        
 195        *offset_unit*
 196        
 197        A string, defining the units of *offsets*. Valid options are weeks, days, hours, minutes, seconds, milliseconds, and microseconds.
 198            
 199        *offset_dim*
 200        
 201        A string, defining the coordinate of *LaggedAnalyser.variables* along which offsets are to be calculated.
 202            
 203        *mode*
 204        
 205        One of 'any', 'past', or 'future'. If 'past' or 'future' is used then only positive or negative lags are valid, respectively.
 206            
 207        *overwrite*
 208        
 209        A boolean. If False, then attempts to produce a lag which already exist will raise a ValueError.
 210        
 211        """
 212        time_type=int(is_time_type(self.variables[offset_dim][0].values))
 213        self.offset_unit=offset_unit
 214        lag_funcs=[self._ilag_variables,self._lag_variables]
 215        offsets=np.atleast_1d(offsets)
 216        for o in offsets:
 217            lag_funcs[time_type](int(o),offset_unit,offset_dim,mode,overwrite)
 218        
 219    def _check_offset_is_valid(self,offset,mode):
 220        
 221        valid_modes=['any','past','future']
 222        if not mode in valid_modes:
 223            raise(ValueError(f'mode must be one of {valid_modes}'))
 224        if offset>0 and mode == 'past':
 225            raise(ValueError(f'Positive offset {offset} given, but mode is "{mode}"'))
 226        if offset<0 and mode == 'future':
 227            raise(ValueError(f'Negative offset {offset} given, but mode is "{mode}"'))
 228        return
 229    
 230    """
 231        COMPOSITE COMPUTATION FUNCTIONS 
 232        Composite computation is split over 4 function layers:
 233        compute_composites(): calls
 234            _compute_aggregate_over_lags(): calls
 235                _composite_from_ix(): splits data into cat vars
 236                and con vars and then calls
 237                    _aggregate_from_ix(): applies an operation to
 238                    subsets of ds, where the ix takes unique values
 239                then merges them together.
 240             then loops over lags and merges those.
 241        And then substracts any anomalies and returns the data.
 242
 243
 244        Example usage of the aggregate funcs:
 245        i.e. self._aggregate_from_ix(ds,ix,'time',self._mean_ds)
 246        self._aggregate_from_ix(ds,ix,'time',self._std_ds)
 247        self._aggregate_from_ix(ds,ix,'time',self._cat_occ_ds,s=reg_ds)
 248    """
 249    
 250    def _aggregate_from_ix(self,ds,ix,dim,agg_func,*agg_args):
 251        return xr.concat([agg_func(ds.isel({dim:ix==i}),dim,*agg_args) for i in np.unique(ix)],'index_val')
 252    
 253    
 254    #Splits variables into cat and con and then combines the two different kinds of composites.
 255    #Used with a synthetic 'ix' for bootstrapping by self._compute_bootstraps.
 256    def _composite_from_ix(self,ix,ds,dim,con_func,cat_func,lag=0):
 257                
 258        ix=ix.values #passed in as a da
 259        cat_vars=[v for v in ds if ds[v].attrs['is_categorical']]
 260        con_vars=[v for v in ds if not v in cat_vars]
 261        cat_ds=ds[cat_vars]
 262        con_ds=ds[con_vars]
 263        cat_vals=cat_ds.map(np.unique)
 264
 265        if (con_vars!=[]) and (con_func is not None):
 266            if (cat_vars!=[]) and (cat_func is not None):
 267                con_comp=self._aggregate_from_ix(con_ds,ix,dim,con_func)
 268                cat_comp=self._aggregate_from_ix(cat_ds,ix,dim,cat_func,cat_vals)
 269                comp=con_comp.merge(cat_comp)
 270            else:
 271                comp=self._aggregate_from_ix(con_ds,ix,dim,con_func)
 272        else:
 273                comp=self._aggregate_from_ix(cat_ds,ix,dim,cat_func,cat_vals)
 274        comp.attrs=ds.attrs
 275        return comp.assign_coords({'lag':[lag]})    
 276    
 277    #loops over all lags, calling _composite_from_ix, and assembles composites into a single dataset
 278    def _compute_aggregate_over_lags(self,da,dim,lag_vals,con_func,cat_func):
 279            
 280        if lag_vals=='all':
 281            lag_vals=list(self._lagged_variables)
 282                    
 283        composite=self._composite_from_ix(*xr.align(da,self.variables),dim,con_func,cat_func)
 284              
 285        if lag_vals is not None:
 286            lag_composites=[]
 287            for t in lag_vals:
 288                lag_composites.append(self._composite_from_ix(*xr.align(da,self.lagged_variables(t)),dim,con_func,cat_func,lag=t))
 289            composite=xr.concat([composite,*lag_composites],'lag').sortby('lag')
 290            
 291        return composite
 292
 293    #The top level wrapper for compositing
 294    def compute_composites(self,dim='time',lag_vals='all',as_anomaly=False,con_func=agg.mean_ds,cat_func=agg.cat_occ_ds,inplace=True):
 295        
 296        """
 297        Partitions *LaggedAnalyser.variables*, and any time-lagged equivalents, into subsets depending on the value of *LaggedAnalyser.event*, and then computes a bulk summary metric for each.
 298
 299        **Optional arguments**
 300        
 301        *dim*
 302        
 303        A string, the coordinate along which to compute composites.
 304            
 305        *lag_vals*
 306        
 307        Either 'All', or a list of integers, denoting the time lags for which composites should be computed.
 308            
 309        *as_anomaly*
 310        
 311        A Boolean, defining whether composites should be given as absolute values or differences from the unpartitioned value.
 312            
 313        *con_func*
 314        
 315        The summary metric to use for continuous variables. Defaults to a standard mean average. If None, then continuous variables will be ignored
 316            
 317        *cat_func*
 318        
 319        The summary metric to use for categorical variables. Defaults to the occurrence probability of each categorical value. If None, then categorical variables will be ignored
 320            
 321        *inplace*
 322    
 323        A boolean, defining whether the composite should be stored in *LaggedAnalyser.composites*
 324        
 325        **returns**
 326        
 327        An xarray.Dataset like  *LaggedAnalyser.variables* but summarised according to *con_func* and *cat_func*, and with an additional coordinate *index_val*, which indexes over the values taken by *LaggedAnalyser.event*.
 328            
 329        """
 330        composite=self._compute_aggregate_over_lags(self.event,dim,lag_vals,con_func,cat_func)
 331        lagged_means=self.aggregate_variables(dim,lag_vals,con_func,cat_func)
 332
 333        if as_anomaly:
 334            composite=composite-lagged_means
 335            
 336        composite=make_all_dims_coords(composite)
 337        for v in list(composite.data_vars):
 338            composite[v].attrs=self.variables[v].attrs
 339        if inplace:
 340            self.composites=composite
 341            self.composite_func=(con_func,cat_func)
 342            self.composites_are_anomaly=as_anomaly
 343            self.lagged_means=lagged_means
 344        return composite
 345
 346    #Aggregates variables over all time points where event is defined, regardless of its value
 347    def aggregate_variables(self,dim='time',lag_vals='all',con_func=agg.mean_ds,cat_func=agg.cat_occ_ds):
 348        
 349        """Calculates a summary metric from *LaggedAnalyser.variables* at all points where *LaggedAnalyser.event* is defined, regardless of its value.
 350        
 351        **Optional arguments**
 352        
 353        *dim*
 354        
 355        A string, the name of the shared coordinate between *LaggedAnalyser.variables* and *LaggedAnalyser.event*.
 356        
 357        *lag_vals*
 358        
 359        'all' or a iterable of integers, specifying for which lag values to compute the summary metric.
 360        
 361        *con_func*
 362        
 363        The summary metric to use for continuous variables. Defaults to a standard mean average. If None, then continuous variables will be ignored
 364            
 365        *cat_func*
 366        
 367        The summary metric to use for categorical variables. Defaults to the occurrence probability of each categorical value. If None, then continuous variables will be ignored
 368
 369        **returns**
 370        
 371        An xarray.Dataset like  *LaggedAnalyser.variables* but summarised according to *con_func* and *cat_func*.
 372
 373"""
 374        fake_event=self.event.copy(data=np.zeros_like(self.event))
 375        return self._compute_aggregate_over_lags(fake_event,dim,lag_vals,con_func,cat_func).isel(index_val=0)
 376
 377    def add_derived_composite(self,name,func,composite_vars,as_anomaly=False):
 378        """Applies *func* to one or multiple composites to calculate composites of derived quantities, and additionally, stores *func* to allow derived bootstrap composites to be calculated. For linear quantities, where Ex[f(x)]==f(Ex[x]), then this can minimise redundant memory use.
 379        
 380        **Arguments**
 381        
 382        *name*
 383        
 384        A string, providing the name of the new variable to add.
 385            
 386        *func*
 387        
 388         A callable which must take 1 or more xarray.DataArrays as inputs
 389            
 390        *composite_vars*
 391        
 392        An iterable of strings, of the same length as the number of arguments taken by *func*. Each string must be the name of a variable in *LaggedAnalyser.variables* which will be passed into *func* in order.
 393        
 394        **Optional arguments**
 395        
 396        *as_anomaly*
 397        
 398        A boolean. Whether anomaly composites or full composites should be passed in to func.
 399        """
 400        
 401        if np.ndim(as_anomaly)==1:
 402            raise(NotImplementedError('variable-specific anomalies not yet implemented'))
 403
 404        self._derived_variables[name]=(func,composite_vars,as_anomaly)
 405        self.composites[name]=self._compute_derived_da(self.composites,func,composite_vars,as_anomaly)
 406        
 407        if self.lagged_means is not None:
 408            self.lagged_means[name]=self._compute_derived_da(self.lagged_means,func,composite_vars,as_anomaly)
 409            
 410        return
 411
 412    ### Compute bootstraps ###
 413    
 414    #Top level func
 415    def compute_bootstraps(self,bootnum,dim='time',con_func=agg.mean_ds,cat_func=agg.cat_occ_ds,lag=0,synth_mode='markov',data_vars=None,reuse_ixs=False):
 416        
 417        """Computes composites from synthetic event indices, which can be used to assess whether composites are insignificant.
 418        
 419        **Arguments**
 420        
 421        *bootnum*
 422        
 423        An integer, the number of bootstrapped composites to compute
 424            
 425        **Optional arguments**
 426        
 427        *dim*
 428        
 429        A string, the name of the shared coordinate between *LaggedAnalyser.variables* and *LaggedAnalyser.event*.
 430            
 431        *con_func*
 432        
 433        The summary metric to use for continuous variables. Defaults to a standard mean average. If None, then continuous variables will be ignored
 434            
 435        *cat_func*
 436        
 437        The summary metric to use for categorical variables. Defaults to the occurrence probability of each categorical value. If None, then continuous variables will be ignored
 438
 439        *lag*
 440        
 441        An integer, specifying which lagged variables to use for the bootstraps. i.e. bootstraps for lag=90 will be from a completely different season than those for lag=0.
 442            
 443        *synth_mode*
 444        
 445        A string, specifying how synthetic event indices are to be computed. Valid options are:
 446            
 447        "random": 
 448        
 449        categorical values are randomly chosen with the same probability of occurrence as those found in *LaggedAnalyser.event*, but with no autocorrelation.
 450
 451        "markov": 
 452        
 453        A first order Markov chain is fitted to *LaggedAnalyser.event*, producing some autocorrelation and state dependence in the synthetic series. Generally a better approximation than "random" and so should normally be used.
 454
 455        "shuffle": 
 456        
 457        The values are randomly reordered. This means that each value will occur exactly the same amount of times as in the original index, and so is ideal for particularly rare events or short series.
 458            
 459        *data_vars*
 460        
 461        An iterable of strings, specifying for which variables bootstraps should be computed.
 462                
 463        **returns**
 464        
 465        An xarray.Dataset like *LaggedAnalyser.variables* but summarised according to *con_func* and *cat_func*, and with a new coordinate 'bootnum' of length *bootnum*.
 466
 467        """
 468        if data_vars==None:
 469            data_vars=list(self.variables.data_vars)
 470
 471        boots=self._add_derived_boots(self._compute_bootstraps(bootnum,dim,con_func,cat_func,lag,synth_mode,data_vars,reuse_ixs))
 472        if self.composites_are_anomaly:
 473            boots=boots-self.lagged_means.sel(lag=lag)
 474        return make_all_dims_coords(boots)
 475    
 476    
 477    def _compute_derived_da(self,ds,func,varnames,as_anomaly):
 478        if as_anomaly:
 479            input_vars=[ds[v]-self.lagged_means[v] for v in varnames]
 480        else:
 481            input_vars=[ds[v] for v in varnames]
 482        return make_all_dims_coords(func(*input_vars))
 483    
 484    
 485    def _add_derived_boots(self,boots):
 486        for var in self._derived_variables:
 487            func,input_vars,as_anomaly=self._derived_variables[var]
 488            boots[var]=self._compute_derived_da(boots,func,input_vars,as_anomaly)
 489        return boots
 490
 491    def _compute_bootstraps(self,bootnum,dim,con_func,cat_func,lag,synth_mode,data_vars,reuse_ixs):
 492
 493        da,ds=xr.align(self.event,self.lagged_variables(lag))
 494        ds=ds[data_vars]
 495        
 496        if (self.boot_indices is None)|(not reuse_ixs):
 497            
 498            ix_vals,ix_probs,L=self._get_bootparams(da)
 499            ixs=self._get_synth_indices(da.values,bootnum,synth_mode,da,dim,ix_vals,ix_probs,L)
 500            self.boot_indices=ixs
 501        else:
 502            ixs=self.boot_indices
 503            print('Reusing stored boot_indices, ignoring new boot parameters.')
 504        
 505        boots=[make_all_dims_coords(\
 506                self._composite_from_ix(ix,ds,dim,con_func,cat_func,lag)\
 507             ) for ix in ixs]
 508        return xr.concat(boots,'boot_num')
 509    
 510    #Gets some necessary variables
 511    def _get_bootparams(self,da):
 512        ix_vals,ix_probs=np.unique(da.values,return_counts=True)
 513        return ix_vals,ix_probs/len(da),len(da)
 514    
 515    #compute indices
 516    def _get_synth_indices(self,index,bootnum,mode,da,dim,ix_vals,ix_probs,L):
 517        
 518        ixs=[]
 519        if mode=='markov':
 520            xs=split_to_contiguous(da[dim].values,x_arr=da)
 521            T=get_transmat(xs)
 522            for n in range(bootnum):
 523                ixs.append(synthetic_states_from_transmat(T,L-1))
 524                
 525        elif mode=='random':
 526            for n in range(bootnum):
 527                ixs.append(np.random.choice(ix_vals,size=L,p=list(ix_probs)))
 528                
 529        elif mode=='shuffle':
 530            for n in range(bootnum):
 531                ixv=index.copy()
 532                np.random.shuffle(ixv)
 533                
 534                ixs.append(ixv)
 535        else:
 536            raise(ValueError(f'synth_mode={synth_mode} is not valid.'))
 537            
 538        return [xr.DataArray(ix) for ix in ixs]
 539        
 540    ### apply significance test ###
 541    
 542    def get_significance(self,bootstraps,comp,p,data_vars=None,hb_correction=False):
 543        
 544        """Computes whether a composite is significant with respect to a given distribution of bootstrapped composites. 
 545        
 546        **Arguments**
 547        
 548        *bootstraps*
 549
 550        An xarray.Dataset with a coordinate 'bootnum', such as produced by *LaggedAnalyser.compute_bootstraps*
 551
 552        *comp*
 553
 554        An xarray Dataset of the same shape as *bootstraps* but without a 'bootnum' coordinate. Missing or additional variables are allowed, and are simply ignored.
 555        *p*
 556
 557        A float, specifying the p-value of the 2-sided significance test (values in the range 0 to 1). 
 558            
 559        **Optional arguments**
 560
 561        *data_vars*
 562            
 563        An iterable of strings, specifying for which variables significance should be computed.
 564            
 565        *hb_correction*
 566        
 567        A Boolean, specifying whether a Holm-Bonferroni correction should be applied to *p*, in order to reduce the family-wide error rate. Note that this correction is currently only applied to each variable in *comp* independently, and so will have no impact on scalar variables.
 568        
 569        **returns**
 570        
 571        An xarray.Dataset like *comp* but with boolean data, specifying whether each feature of each variable passed the significance test.
 572        """
 573        if data_vars==None:
 574            data_vars=list(bootstraps.data_vars)
 575
 576        bootnum=len(bootstraps.boot_num)
 577        comp=comp[data_vars]
 578        bootstraps=bootstraps[data_vars]
 579        frac=(comp<bootstraps).sum('boot_num')/bootnum
 580        pval_ds=1-2*np.abs(frac-0.5)
 581        if hb_correction:
 582            for var in pval_ds:
 583                corrected_pval=holm_bonferroni_correction(pval_ds[var].values.reshape(-1),p)\
 584                            .reshape(pval_ds[var].shape)
 585                pval_ds[var].data=corrected_pval
 586        else:
 587            pval_ds=pval_ds<p
 588            
 589        self.composite_sigs=pval_ds.assign_coords(lag=comp.lag)
 590        return self.composite_sigs
 591    
 592    def bootstrap_significance(self,bootnum,p,dim='time',synth_mode='markov',reuse_lag0_boots=False,data_vars=None,hb_correction=False):
 593        
 594        """A wrapper around *compute_bootstraps* and *get_significance*, that calculates bootstraps and applies a significance test to a number of time lagged composites simulataneously.
 595        
 596    **Arguments**
 597
 598    *bootnum*
 599
 600    An integer, the number of bootstrapped composites to compute
 601
 602    *p*
 603
 604    A float, specifying the p-value of the 2-sided significance test (values in the range 0 to 1). 
 605
 606    **Optional arguments**
 607
 608    *dim*
 609
 610    A string, the name of the shared coordinate between *LaggedAnalyser.variables* and *LaggedAnalyser.event*.
 611
 612    *synth_mode*
 613
 614    A string, specifying how synthetic event indices are to be computed. Valid options are:
 615    "random": categorical values are randomly chosen with the same probability of occurrence as those found in *LaggedAnalyser.event*, but with no autocorrelation.
 616    'markov': A first order Markov chain is fitted to *LaggedAnalyser.event*, producing some autocorrelation and state dependence in the synthetic series. Generally a better approximation than "random" and so should normally be used.
 617
 618    *reuse_lag0_boots*
 619        A Boolean. If True, bootstraps are only computed for lag=0, and then used as a null distribution to assess all lagged composites. For variables which are approximately stationary across the lag timescale, then this is a good approximation and can increase performance. However if used incorrectly, it may lead to 'significant composites' which simply reflect the seasonal cycle. if False, separate bootstraps are computed for all time lags.
 620
 621    *data_vars*
 622        An iterable of strings, specifying for which variables significance should be computed.
 623
 624    *hb_correction*
 625        A Boolean, specifying whether a Holm-Bonferroni correction should be applied to *p*, in order to reduce the family-wide error rate. Note that this correction is currently only applied to each variable in *comp* independently, and so will have no impact on scalar variables.
 626        
 627    **returns**
 628
 629    An xarray.Dataset like *LaggedAnalyser.variables* but with the *dim* dimension summarised according to *con_func* and *cat_func*, an additional *lag* coordinate, and with boolean data specifying whether each feature of each variable passed the significance test.
 630
 631        """
 632        lag_vals=list(self._lagged_variables)
 633        
 634        con_func,cat_func=self.composite_func
 635        
 636        boots=self.compute_bootstraps(bootnum,dim,con_func,cat_func,0,synth_mode,data_vars)
 637        
 638        #reuse_lag0_boots=True can substantially reduce run time!
 639        if not reuse_lag0_boots:
 640                    boots=xr.concat([boots,*[self.compute_bootstraps(bootnum,dim,con_func,cat_func,t,synth_mode,data_vars)\
 641                        for t in lag_vals]],'lag').sortby('lag')
 642                
 643        sig_composite=self.get_significance(boots,self.composites,p,data_vars,hb_correction=hb_correction)
 644        
 645        self.composite_sigs=sig_composite
 646        return self.composite_sigs
 647    
 648    
 649    def deseasonalise_variables(self,variable_list=None,dim='time',agg='dayofyear',smooth=1,coeffs=None):
 650        """Computes a seasonal cycle for each variable in *LaggedAnalyser.variables* and subtracts it inplace, turning *LaggedAnalyser.variables* into deseasonalised anomalies. The seasonal cycle is computed via temporal aggregation of each variable over a given period - by default the calendar day of the year. This cycle can then be smoothed with an n-point rolling average.
 651
 652                **Optional arguments**
 653
 654                *variable_list*
 655                
 656                A list of variables to deseasonalise. Defaults to all variables in the *LaggedAnalyser.variables*
 657
 658                *dim*
 659                
 660                A string, the name of the shared coordinate between *LaggedAnalyser.variables* and *LaggedAnalyser.event*, along which the seasonal cycle is computed. Currently, only timelike coordinates are supported.
 661                
 662                *agg*
 663                
 664                A string specifying the datetime-like field to aggregate over. Useful and supported values are 'season', 'month', 'weekofyear', and 'dayofyear'
 665                    
 666                *smooth*
 667                
 668                An integer, specifying the size of the n-timestep centred rolling mean applied to the aggregated seasonal cycle. By default *smooth*=1 results in no smoothing.
 669
 670                *coeffs*
 671                
 672                A Dataset containing a precomputed seasonal cycle, which, if *LaggedAnalyser.variables* has coordinates (*dim*,[X,Y,...,Z]), has coords (*agg*,[X,Y,...,Z]), and has the same data variables as *LaggedAnalyser.variables*. If *coeffs* is provided, no seasonal cycle is fitted to *LaggedAnalyser.variables*, *coeffs* is used instead.
 673
 674        """        
 675
 676        if variable_list is None:
 677            variable_list=list(self.variables)
 678        for var in variable_list:
 679            da=self.variables[var]
 680            dsnlsr=Agg_Deseasonaliser()
 681            if coeffs is None:
 682                dsnlsr.fit_cycle(da,dim=dim,agg=agg)
 683            else:
 684                dsnslr.cycle_coeffs=coeffs[var]
 685
 686            cycle=dsnlsr.evaluate_cycle(data=da[dim],smooth=smooth)
 687            self.variables[var]=da.copy(data=da.data-cycle.data)
 688            dsnlsr.data=None #Prevents excess memory storage
 689            self._deseasonalisers[var]=dsnlsr
 690        return   
 691    
 692    def get_seasonal_cycle_coeffs(self):
 693        """ Retrieve seasonal cycle coeffs computed with *LaggedAnalyser.deseasonalise_variables*, suitable for passing into *coeffs* in other *LaggedAnalyser.deseasonalise_variables* function calls as a precomputed cycle.
 694        
 695        **Returns**
 696        An xarray.Dataset, as specified in  the *LaggedAnalyser.deseasonalise_variables* *coeff* optional keyword.
 697        """
 698        coeffs=xr.Dataset({v:dsnlsr.cycle_coeffs for v,dsnlsr in self._deseasonalisers.items()})
 699        return coeffs
 700
 701    #If deseasonalise_variables has been called, then this func can be used to compute the
 702    #seasonal mean state corresponding to a given composite. This mean state+ the composite
 703    # produced by self.compute_composites gives the full field composite pattern.
 704    def get_composite_seasonal_mean(self):
 705        """
 706        If *LaggedAnalyser.deseasonalise_variables* has been called, then this function returns the seasonal mean state corresponding to a given composite, given by a sum of the seasonal cycle weighted by the time-varying occurrence of each categorical value in *LaggedAnalyser.events*. This mean state + the deseasonalised anomaly composite
 707    produced by *LaggedAnalyser.compute_composites* then retrieves the full composite pattern.
 708    
 709    **Returns**
 710        An xarray.Dataset containing the composite seasonal mean values.
 711        """
 712        variable_list=list(self._deseasonalisers)
 713        ts={e:self.event[self.event==e].time for e in np.unique(self.event)}
 714        lags=np.unique([0,*list(self._lagged_variables)])
 715        
 716        mean_states={}
 717        for var in variable_list:
 718            dsnlsr=self._deseasonalisers[var]
 719            agg=dsnlsr.agg
 720            mean_states[var]=xr.concat([\
 721                                 xr.concat([\
 722                                    self._lag_average_cycle(dsnlsr,agg,l,t,i)\
 723                                for l in lags],'lag')\
 724                            for i,t in ts.items()],'index_val')
 725            
 726        return xr.Dataset(mean_states)
 727        
 728    def _lag_average_cycle(self,dsnlsr,agg,l,t,i):
 729        
 730        dt=durel.relativedelta(**{self.offset_unit:int(l)})
 731        tvals=pd.to_datetime([pd.to_datetime(tt)+dt for tt in t.values])
 732        cycle_eval=dsnlsr.cycle_coeffs.sel({agg:getattr(tvals,agg)})
 733        cycle_mean=cycle_eval.mean(agg).assign_coords({'lag':l,'index_val':i})
 734        return cycle_mean
 735    
 736class PatternFilter(object):
 737    """Provides filtering methods to refine n-dimensional boolean masks, and apply them to an underlying dataset.
 738    
 739        **Optional arguments:**
 740        
 741        *mask_ds*
 742        
 743        An xarray boolean Dataset of arbitrary dimensions which provides the initial mask dataset. If *mask_ds*=None  and *analyser*=None, then *mask_ds* will be initialised as a Dataset of the same dimensions and data_vars as *val_ds*, with all values = 1 (i.e. initially unmasked). 
 744        
 745        *val_ds*
 746        
 747        An xarray Dataset with the same dimensions as *mask_ds* if provided, otherwise arbitrary, consisting of an underlying dataset to which the mask is applied. If *val_ds*=None and *analyser*=None, then *PatternFilter.apply_value_mask* will raise an Error
 748            
 749        *analyser*
 750        
 751        An instance of a  core.LaggedAnalyser class for which both composites and significance masks have been computed, used to infer the *val_ds* and *mask_ds* arguments respectively. This overrides any values passed explicitly to  *mask_ds* and *val_ds*.
 752            
 753    """
 754    def __init__(self,mask_ds=None,val_ds=None,analyser=None):
 755        """Initialise a new PatternFilter object"""
 756        self.mask_ds=mask_ds
 757        """@private"""
 758        self.val_ds=val_ds
 759        """@private"""
 760
 761        if analyser is not None:
 762            self._parse_analyser(analyser)
 763            
 764        else:
 765            if mask_ds is None:
 766                self.mask_ds=self._mask_ds_like_val_ds()
 767                
 768    def __repr__(self):
 769        return 'A PatternFilter object'
 770        
 771    def __str__(self):
 772            return self.__repr__
 773        
 774    def _parse_analyser(self,analyser):
 775        self.mask_ds=analyser.composite_sigs
 776        self.val_ds=analyser.composites
 777        
 778    def _mask_ds_like_val_ds(self):
 779        if self.val_ds is None:
 780            raise(ValueError('At least one of "mask_ds", "val_ds" and "analyser" must be provided.'))
 781        
 782        x=self.val_ds
 783        y=x.where(x!=0).fillna(1) #replace nans and 0s with 1
 784        y=(y/y).astype(int) #make everything 1 via division and assert integer type.
 785        self.mask_ds=y
 786        return
 787    
 788    def update_mask(self,new_mask,mode):
 789        """ Update *PatternFilter.mask_ds* with a new mask, either taking their union or intersection, or replacing the current mask with new_mask.
 790        
 791        **Arguments**
 792        
 793        *new_mask*
 794
 795        An xarray.Dataset with the same coords and variables as *PatternFilter.mask_ds*.
 796
 797        *mode*
 798
 799        A string, one of 'replace','intersection' or 'union', defining how *new_mask* should be used to update the mask.
 800        """
 801        new_mask=new_mask.astype(int)
 802        if mode=='replace':
 803            self.mask_ds=new_mask
 804        elif mode=='intersection':
 805            self.mask_ds=self.mask_ds*new_mask
 806        elif mode == 'union':
 807            self.mask_ds=self.mask_ds|new_mask
 808        else:
 809            raise(ValueError(f'Invalid mode, {mode}'))
 810        return
 811                  
 812    def apply_value_mask(self,truth_function,*args,mode='intersection'):
 813        """ Apply a filter to *PatternFilter.mask_ds* based on a user-specified truth function which is applied to *PatternFilter.val_ds. 
 814        
 815        **Examples**
 816        
 817            #Mask values beneath a threshold:
 818            def larger_than_thresh(ds,thresh):
 819                return ds>thresh
 820            patternfilter.apply_value_mask(is_positive,thresh)
 821
 822            #Mask values where absolute value is less than a reference field:
 823            def amp_greater_than_reference(ds,ref_ds):
 824                return np.abs(ds)>ref_ds
 825            pattern_filter.apply_value_mask(amp_greater_than_reference,ref_ds)
 826
 827        **Arguments**
 828
 829        *truth_function*
 830        
 831        A function with inputs (val_ds,*args) that returns a boolean dataset with the same coords and data variables as *PatternFilter.val_ds*.
 832
 833        **Optional arguments**
 834        
 835        *mode*
 836            
 837        A string, one of 'replace','intersection' or 'union', defining how the value filter should be used to update the *PatternFilter.mask_ds*.
 838        """        
 839        if self.val_ds is None:
 840            raise(ValueError('val_ds must be provided to apply value mask.'))
 841        value_mask=truth_function(self.val_ds,*args)
 842        self.update_mask(value_mask,mode)
 843        return
 844    
 845    def apply_area_mask(self,n,dims=None,mode='intersection',area_type='gridpoint'):
 846        """ Apply a filter to *PatternFilter.mask_ds* that identifies connected groups of True values within a subspace of the Dataset's dimensions specified by *dims*, and masks out groups which are beneath a threshold size *n*. This is done through the application of *scipy.ndimage.label* using the default structuring element (https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.label.html). 
 847    
 848        When *area_type*='gridpoint', *n* specifies the number of connected datapoints within each connected region. For the special case where *dims* consists of a latitude- and longitude-like coordinate, area_type='spherical' applies a cosine-latitude weighting, such that *n* can be interpreted as a measure of area, where a datapoint with lat=0 would have area 1. 
 849        
 850        **Examples**
 851        
 852            #Keep groups of True values consisting of an area >=30 square equatorial gridpoints
 853            patternfilter.apply_area_mask(30,dims=('lat','lon'),area_type='spherical')
 854            
 855            #Keep groups of True values that are consistent for at least 3 neighbouring time lags
 856            patternfilter.apply_area_mask(3,dims=('time'))
 857            
 858            #Keep groups of true values consisting of >=10 longitudinal values, or >=30 values in longitude and altitude if the variables have an altitude coord:
 859            patternfilter.apply_area_mask(10,dims=('longitude'))
 860            patternfilter.apply_area_mask(30,dims=('longitude,altitude'),mode='union')
 861
 862        **Arguments**
 863
 864        *n*
 865            
 866        A scalar indicating the minimum size of an unmasked group, in terms of number of gridpoints (for *area_type*=gridpoint) or the weighted area (for *area_type*=spherical), beneath which the group will be masked.
 867
 868        **Optional arguments**
 869        
 870        *dims*
 871            
 872        An iterable of strings specifying coords in *PatternFilter.mask_ds* which define the subspace in which groups of connected True values are identified. Other dims will be iterated over. DataArrays within *PatternFilter.mask_ds* that do not contain all the *dims* will be ignored. If *dims*=None, all dims in each DataArray will be used.
 873            
 874        *mode*
 875
 876        A string, one of 'replace','intersection' or 'union', defining how the area filter should be used to update the *PatternFilter.mask_ds*.
 877            
 878        *area_type*
 879
 880        A string, one of 'gridpoint' or 'spherical' as specified above. 'spherical' is currently only supported for len-2 *dims* kwargs, with the first assumed to be latitude-like. 
 881            
 882        """        
 883        if area_type=='gridpoint':
 884            area_based=False
 885        elif area_type=='spherical':
 886            area_based=True
 887        else:
 888            raise(ValueError(f"Unknown area_type {area_type}. Valid options are 'gridpoint' and 'spherical'"))
 889        area_mask=ds_large_regions(self.mask_ds,n,dims=dims,area_based=area_based)
 890        self.update_mask(area_mask,mode)
 891        return
 892    
 893    
 894    def apply_convolution(self,n,dims,mode='replace'):
 895        """ Apply a square n-point convolution filter to *PatternFilter.mask_ds* in one or two dimensions specified by *dims*, iterated over remaining dimensions. This has the effect of extending the unmasked regions and smoothing the mask overall.
 896        
 897        **Arguments**
 898        
 899        *n*
 900            
 901        A positive integer specifying the size of the convolution filter. *n*=1 leaves the mask unchanged. Even *n* are asymmetric and shifted right. 
 902
 903        *dims*
 904
 905        A length 1 or 2 iterable of strings specifying the dims in which the convolution is applied. Other dims will be iterated over. DataArrays within *PatternFilter.mask_ds* that do not contain all the *dims* will be ignored. 
 906
 907        *mode*
 908
 909        A string, one of 'replace','intersection' or 'union', defining how the area filter should be used to update the *PatternFilter.mask_ds*.
 910        """
 911        
 912        if not len(dims) in [1,2]:
 913            raise(ValueError('Only 1 and 2D dims currently supported'))
 914            
 915        convolution=convolve_pad_ds(self.mask_ds,n,dims=dims)
 916        self.update_mask(convolution,mode)
 917        return
 918    
 919    def get_mask(self):
 920        """" Retrieve the mask with all filters applied.
 921        **Returns**
 922        An xarray.Dataset of boolean values.
 923        """
 924        return self.mask_ds
 925    
 926    def filter(self,ds=None,drop_empty=True,fill_val=np.nan):
 927        """ Apply the current mask to *ds* or to *PatternFilter.val_ds* (if *ds* is None), replacing masked gridpoints with *fill_val*.
 928        **Optional arguments**
 929        
 930        *ds*
 931        
 932        An xarray.Dataset to apply the mask to. Should have the same coords and data_vars as *PatternFilter.mask_ds*. If None, the mask is applied to *PatternFilter.val_ds*.
 933        
 934        *drop_empty*
 935        
 936        A boolean value. If True, then completely masked variables are dropped from the returned masked Dataset.
 937        
 938        *fill_val*
 939        
 940        A scalar that defaults to np.nan. The value with which masked gridpoints in the Dataset are replaced.
 941        
 942        **Returns**
 943        
 944        A Dataset with masked values replaced by *fill_val*.
 945        """
 946        if ds is None:
 947            ds=self.val_ds.copy(deep=True)
 948            
 949        ds=ds.where(self.mask_ds)
 950        if drop_empty:
 951            drop_vars=((~np.isnan(ds)).sum()==0).to_array('vars')
 952            ds=ds.drop_vars(drop_vars[drop_vars].vars.values)
 953        return ds.fillna(fill_val)
 954    
 955def _DEFAULT_RENAME_FUNC(v,d):
 956    
 957    for k,x in d.items():
 958        v=v+f'_{k}{x}'
 959    return v
 960    
 961def _Dataset_to_dict(ds):
 962    return {v:d['data'] for v,d in ds.to_dict()['data_vars'].items()}
 963
 964class IndexGenerator(object):
 965    
 966    """ Computes dot-products between a Dataset of patterns and a Dataset of variables, reducing them to standardised scalar indices.
 967    """
 968    def __init__(self):
 969        self._means=[]
 970        self._stds=[]
 971        self._rename_function=_DEFAULT_RENAME_FUNC
 972        
 973    def __repr__(self):
 974        return 'An IndexGenerator object'
 975        
 976    def __str__(self):
 977            return self.__repr__
 978    
 979    
 980    def centre(self,x,dim='time',ref=None):
 981        """@private"""
 982
 983        if ref is None:
 984            ref=x.mean(dim=dim)
 985        return x-ref
 986    
 987    def normalise(self,x,dim='time',ref=None):
 988        """@private"""
 989
 990        if ref is None:
 991            ref=x.std(dim=dim)
 992        return x/ref
 993    
 994    def standardise(self,x,dim='time',mean_ref=None,std_ref=None):
 995        """@private"""
 996        centred_x=self.centre(x,dim,mean_ref)
 997        standardised_x=self.normalise(centred_x,dim,std_ref)
 998        return standardised_x
 999        
1000    def collapse_index(self,ix,dims):
1001        """@private"""
1002        lat_coords=['lat','latitude','grid_latitude']
1003        if not np.any(np.isin(lat_coords,dims)):
1004            return ix.sum(dims)
1005        
1006        else:
1007            #assumes only one lat coord: seems safe.
1008            lat_dim=lat_coords[np.where(np.isin(lat_coords,dims))[0][0]]
1009            weights=np.cos(np.deg2rad(ix[lat_dim]))
1010            return ix.weighted(weights).sum(dims)
1011            
1012    def generate(self,pattern_ds,series_ds,dim='time',slices=None,ix_means=None,ix_stds=None,drop_blank=False,in_place=True,strict_metadata=False):
1013        """Compute standardised indices from an xarray.Dataset of patterns and an xarray.Dataset of arbitrary dimension variables.
1014        
1015        **Arguments**
1016        
1017        *pattern_ds*
1018        
1019        An xarray.Dataset of patterns to project onto with arbitrary dimensions.
1020        
1021        *series_ds*
1022        
1023        An xarray.Dataset of variables to project onto the patterns. Coordinates of *series_ds* once subsetted using *slices* must match the dimensions of *pattern_ds* + the extra coord *dim*.
1024        
1025        **Optional arguments**
1026        
1027        *dim*:
1028        
1029        A string specifying the remaining coord of the scalar indices. Defaults to 'time', which should be the choice for most use cases.
1030        
1031        *slices*
1032        
1033        A dictionary or iterable of dictionaries, each specifying a subset of *pattern_ds* to take before computing an index, with one index returned for each dictionary and for each variable. Subsetting is based on the *xr.Dataset.sel* method: e.g. *slices*=[dict(lag=0,index_val=1)] will produce 1 set of indices based on pattern_ds.sel(lag=0,index_val=1). If *slices*=None, no subsets are computed.
1034        
1035        *ix_means*
1036        
1037        If None, the mean of each index is calculated and subtracted, resulting in centred indices. Otherwise, *ix_means* should be a dictionary of index names and predefined mean values which are subtracted instead. Of most use for online computations, updating a precomputed index in a new dataset.
1038        
1039        *ix_stds*
1040        
1041        If None, the standard deviation of each index is calculated and is divided by, resulting in standardised indices. Otherwise, *ix_stds* should be a dictionary of index names and predefined std values which are divided by instead. Of most use for online computations, updating a precomputed index in a new dataset.
1042
1043        *drop_blank*
1044        
1045        A boolean. If True, drop indices where the corresponding pattern is entirely blank. If False, returns an all np.nan time series.
1046        *in_place*
1047        
1048        *strict_metadata*
1049        
1050        If False, indices will be merged into a common dataset regardless of metadata. If True, nonmatching metadata will raise a ValueError.
1051        
1052        **Returns
1053        
1054        An xarray.Dataset of indices with a single coordinate (*dim*).
1055        """
1056        #Parse inputs
1057        
1058        if slices is None:
1059            self.slices=[{}]
1060        elif type(slices) is dict:
1061            self.slices=[slices]
1062        else:
1063            self.slices=slices
1064            
1065        if ix_means is not None or ix_stds is not None:
1066            self.user_params=True
1067            self.means=ix_means
1068            self.stds=ix_stds
1069        else:
1070            self.user_params=False
1071            self.means={}
1072            self.stds={}
1073            
1074        self.indices=None
1075        
1076        #Compute indices
1077        indices=[self._generate_index(pattern_ds,series_ds,dim,sl)\
1078                for sl in self.slices]
1079        try:
1080            indices=xr.merge(indices)
1081        except Exception as e:
1082            if strict_metadata:
1083                print("Merge of indices failed. Consider 'strict_metadata=False'")
1084                raise e
1085            else:
1086                indices=xr.merge(indices,compat='override')
1087            
1088        #Optionally remove indices which are all nan    
1089        if drop_blank:
1090            drop=(~indices.isnull()).sum()==0
1091            drop=[k for k,d in drop.to_dict()['data_vars'].items() if d['data']]
1092            indices=indices.drop_vars(drop)
1093            _=[(self.means.pop(x),self.stds.pop(x)) for x in drop]
1094        if in_place:
1095            self.indices=indices
1096        return indices
1097    
1098    def _generate_index(self,pattern_ds,series_ds,dim,sl):
1099                
1100        pattern_ds,series_ds=xr.align(pattern_ds,series_ds)
1101        pattern_ds=pattern_ds.sel(sl)
1102        dims=list(pattern_ds.dims)
1103
1104        index=pattern_ds*series_ds
1105        #coslat weights lat coords
1106        index=self.collapse_index(index,dims)
1107        index=self._rename_index_vars(index,sl)
1108
1109        if self.user_params:
1110            mean=self.means
1111            std=self.stds
1112        else:
1113            mean=_Dataset_to_dict(index.mean(dim))
1114            std=_Dataset_to_dict(index.std(dim))
1115            for v in mean:
1116                self.means[v]=mean[v]
1117            for v in std:
1118                self.stds[v]=std[v]
1119                
1120        index=self.standardise(index,dim,mean_ref=mean,std_ref=std)
1121        index=self._add_index_attrs(index,sl,mean,std)
1122
1123        
1124        self.generated_index=index
1125        return index
1126    
1127    def _add_index_attrs(self,index,sl,mean,std):
1128        for v in index:
1129            ix=index[v]
1130            ix.attrs['mean']=np.array(mean[v])
1131            ix.attrs['std']=np.array(std[v])
1132            for k,i in sl.items():
1133                ix.attrs[k]=i
1134            index[v]=ix
1135        return index
1136    
1137    def _rename_index_vars(self,index,sl):
1138        func=self._rename_function
1139        return index.rename({v:func(v,sl) for v in index.data_vars})
1140    
1141    def get_standardisation_params(self,as_dict=False):
1142        
1143        """ Retrieve index means and stds for computed indices, for use as future inputs into index_means or index_stds in *IndexGenerator.Generate*
1144        """
1145        if as_dict:
1146            return self.means,self.stds
1147        else:
1148            params=[xr.Dataset(self.means),xr.Dataset(self.stds)]
1149            return xr.concat(params,'param').assign_coords({'param':['mean','std']})
class LaggedAnalyser:
 16class LaggedAnalyser(object):
 17    """Computes lagged composites of variables with respect to a categorical categorical event series, with support for bootstrap resampling to provide a non-parametric assessment of composite significance, and for deseasonalisation of variables.
 18    
 19    **Arguments:**
 20        
 21    *event*
 22            
 23    An xarray.DataArray with one dimension taking on categorical values, each defining a class of event (or non-event).
 24            
 25    **Optional arguments**
 26        
 27    *variables, name, is_categorical*
 28        
 29    Arguments for adding variables to the LaggedAnalyser. Identical behaviour to calling *LaggedAnalyser.add_variables* directly.
 30    """
 31    
 32    def __init__(self,event,variables=None,name=None,is_categorical=None):
 33        """Initialise a new LaggedAnalyser object."""
 34        
 35        #: event is a dataarray
 36        self.event=xr.DataArray(event)#: This is a docstring?
 37        """@private"""
 38        
 39        #variables are stored in a dataset, and can be added later,
 40        #or passed as a DataArray, a Dataset or as a dict of DataArrays
 41        self.variables=xr.Dataset(coords=event.coords)
 42        """@private"""
 43
 44        if variables is not None:
 45            self.add_variable(variables,name,is_categorical,False)
 46            
 47        #Time lagged versions of the dataset self.variables will be stored here, with a key
 48        #equal to the lag applied. Designed to be accessed by the self.lagged_variables function
 49        self._lagged_variables={}
 50        self.lagged_means=None
 51        """@private"""
 52
 53        #variables that are a linear combination of other variables are more efficiently
 54        #computed after compositing using the self.add_derived_composite method
 55        self._derived_variables={}
 56        self._deseasonalisers={}
 57        
 58        self.composite_mask=None
 59        """@private"""
 60
 61        self.boot_indices=None
 62        """@private"""
 63
 64        return
 65    
 66    def __repr__(self):
 67        l1='A LaggedAnalyser object\n'
 68        l2='event:\n\n'
 69        da_string=self.event.__str__().split('\n')[0]
 70        l3='\n\nvariables:\n\n'
 71        ds_string=self.variables.__str__().split('\n')
 72        ds_string=ds_string[0]+' '+ds_string[1]
 73        ds_string2='\n'+self.variables.data_vars.__str__()
 74        if list(self._lagged_variables.keys())!=[]:
 75            lag_string=f'\n Lagged variables at time intervals:\n {list(self._lagged_variables.keys())}'
 76        else:
 77            lag_string=""
 78        return "".join([l1,l2,da_string,l3,ds_string,ds_string2,lag_string])
 79    
 80    def __str__(self):
 81        return self.__repr__()
 82
 83    def add_variable(self,variables,name=None,is_categorical=None,overwrite=False,join_type='outer'):
 84        """Adds an additional variable to LaggedAnalyser.variables.
 85        
 86        **Arguments**
 87        
 88        *variables* 
 89        
 90        An xarray.DataArray, xarray.Dataset or dictionary of xarray.DataArrays, containing data to be composited with respect to *event*. One of the coordinates of *variables* should have the same name as the coordinate of *events*. Stored internally as an xarray.Dataset. If a dictionary is passed, the DataArrays are joined according to the method *join_type* which defaults to 'outer'.
 91            
 92        **Optional Arguments**
 93        
 94        *name* 
 95        
 96        A string. If *variables* is a single xarray.DataArray then *name* will be used as the name of the array in the LaggedAnalyser.variables DataArray. Otherwise ignored.
 97        
 98        *is_categorical* 
 99        
100        An integer, if *variables* is an xarray.DataArray, or else a dictionary of integers with keys corresponding to DataArrays in the xarray.Dataset/dictionary. 0 indicates that the variable is continuous, and 1 indicates that it is categorical. Note that continuous and categorical variables are by default composited differently (see LaggedAnalyser.compute_composites). Default assumption is all DataArrays are continuous, unless a DataAarray contains an 'is_categorical' key in its DataArray.attrs, in which case this value is used.
101            
102        *overwrite*
103        
104        A boolean. If False then attempts to assign a variable who's name is already in *LaggedAnalyser.variables* will raise a ValueError
105        
106        *join_type*
107        
108        A string setting the rules for how differences in the coordinate indices of different variables are handled:
109        “outer”: use the union of object indexes
110        “inner”: use the intersection of object indexes
111
112        “left”: use indexes from the pre-existing *LaggedAnalyser.variables* with each dimension
113
114        “right”: use indexes from the new *variables* with each dimension
115
116        “exact”: instead of aligning, raise ValueError when indexes to be aligned are not equal
117
118        “override”: if indexes are of same size, rewrite indexes to be those of the pre-existing *LaggedAnalyser.variables*. Indexes for the same dimension must have the same size in all objects.
119        """
120        if isinstance(variables,dict):
121            
122            if is_categorical is None:
123                is_categorical={v:None for v in variables}
124                
125            [self._add_variable(da,v,is_categorical[v],overwrite,join_type) for v,da in variables.items()]
126            
127        elif isinstance(variables,xr.Dataset):
128            self.add_variable({v:variables[v] for v in variables.data_vars},None,is_categorical,overwrite,join_type)
129            
130        else:
131            
132            self._add_variable(variables,name,is_categorical,overwrite,join_type)            
133        return
134    
135    def _more_mergable(self,ds):
136        
137        return drop_scalar_coords(make_all_dims_coords(ds))
138    
139    def _add_variable(self,da,name,is_categorical,overwrite,join_type):
140        
141        if name is None:
142            name=da.name
143        if (name in self.variables)&(not overwrite):
144            raise(KeyError(f'Key "{name}" is already in variables.'))
145        
146        try:
147            self.variables=self.variables.merge(squeeze_da(da).to_dataset(name=name),join=join_type)
148        except:
149            #Trying to make the merge work:
150            self.variables=self._more_mergable(self.variables).merge(self._more_mergable(squeeze_da(da).to_dataset(name=name)),join=join_type)
151
152        if (is_categorical is None) and (not 'is_categorical' in da.attrs):
153            self.variables[name].attrs['is_categorical']=0
154        elif is_categorical is not None:
155            self.variables[name].attrs['is_categorical']=is_categorical
156
157    def lagged_variables(self,t):
158        """A convenience function that retrieves variables at lag *t* from the *LaggedAnalyser*"""
159        if t in self._lagged_variables:
160            return self._lagged_variables[t]
161        elif t==0:
162            return self.variables
163        else:
164            raise(KeyError(f'Lag {t} is not in self._lagged_variables.'))
165
166    def _lag_variables(self,offset,offset_unit='days',offset_dim='time',mode='any',overwrite=False):
167        
168        if offset==0:
169            return
170        if (offset in self._lagged_variables)&(not overwrite):
171            raise(KeyError(f'Key "{offset}" is already in lagged_variables.'))
172            
173        #We are really paranoid about mixing up our lags. So we implement this safety check
174        self._check_offset_is_valid(offset,mode)
175        
176        #REPLACED PREVIOUS IMPLEMENTATION WITH EQUIVALENT UTIL IMPORT.
177        self._lagged_variables[offset]=offset_time_dim(self.variables,-offset,offset_unit=offset_unit,offset_dim=offset_dim)
178
179        return
180    
181    #For coords not in a time format
182    def _ilag_variables(self,offset,*args,overwrite=False):
183        raise(NotImplementedError('Only lagging along timelike dimensions is currently supported.'))
184        
185    def lag_variables(self,offsets,offset_unit='days',offset_dim='time',mode='any',overwrite=False):
186        """Produces time lags of *LaggedAnalyser.variables*, which can be used to produce lagged composites.
187        
188        **Arguments**
189        
190        *offsets*
191        
192        An iterable of integers which represent time lags at which to lag *LaggedAnalyser.variables* in the units specified by *offset_unit*. Positive offsets denote variables *preceding* the event.
193            
194        **Optional arguments**
195        
196        *offset_unit*
197        
198        A string, defining the units of *offsets*. Valid options are weeks, days, hours, minutes, seconds, milliseconds, and microseconds.
199            
200        *offset_dim*
201        
202        A string, defining the coordinate of *LaggedAnalyser.variables* along which offsets are to be calculated.
203            
204        *mode*
205        
206        One of 'any', 'past', or 'future'. If 'past' or 'future' is used then only positive or negative lags are valid, respectively.
207            
208        *overwrite*
209        
210        A boolean. If False, then attempts to produce a lag which already exist will raise a ValueError.
211        
212        """
213        time_type=int(is_time_type(self.variables[offset_dim][0].values))
214        self.offset_unit=offset_unit
215        lag_funcs=[self._ilag_variables,self._lag_variables]
216        offsets=np.atleast_1d(offsets)
217        for o in offsets:
218            lag_funcs[time_type](int(o),offset_unit,offset_dim,mode,overwrite)
219        
220    def _check_offset_is_valid(self,offset,mode):
221        
222        valid_modes=['any','past','future']
223        if not mode in valid_modes:
224            raise(ValueError(f'mode must be one of {valid_modes}'))
225        if offset>0 and mode == 'past':
226            raise(ValueError(f'Positive offset {offset} given, but mode is "{mode}"'))
227        if offset<0 and mode == 'future':
228            raise(ValueError(f'Negative offset {offset} given, but mode is "{mode}"'))
229        return
230    
231    """
232        COMPOSITE COMPUTATION FUNCTIONS 
233        Composite computation is split over 4 function layers:
234        compute_composites(): calls
235            _compute_aggregate_over_lags(): calls
236                _composite_from_ix(): splits data into cat vars
237                and con vars and then calls
238                    _aggregate_from_ix(): applies an operation to
239                    subsets of ds, where the ix takes unique values
240                then merges them together.
241             then loops over lags and merges those.
242        And then substracts any anomalies and returns the data.
243
244
245        Example usage of the aggregate funcs:
246        i.e. self._aggregate_from_ix(ds,ix,'time',self._mean_ds)
247        self._aggregate_from_ix(ds,ix,'time',self._std_ds)
248        self._aggregate_from_ix(ds,ix,'time',self._cat_occ_ds,s=reg_ds)
249    """
250    
251    def _aggregate_from_ix(self,ds,ix,dim,agg_func,*agg_args):
252        return xr.concat([agg_func(ds.isel({dim:ix==i}),dim,*agg_args) for i in np.unique(ix)],'index_val')
253    
254    
255    #Splits variables into cat and con and then combines the two different kinds of composites.
256    #Used with a synthetic 'ix' for bootstrapping by self._compute_bootstraps.
257    def _composite_from_ix(self,ix,ds,dim,con_func,cat_func,lag=0):
258                
259        ix=ix.values #passed in as a da
260        cat_vars=[v for v in ds if ds[v].attrs['is_categorical']]
261        con_vars=[v for v in ds if not v in cat_vars]
262        cat_ds=ds[cat_vars]
263        con_ds=ds[con_vars]
264        cat_vals=cat_ds.map(np.unique)
265
266        if (con_vars!=[]) and (con_func is not None):
267            if (cat_vars!=[]) and (cat_func is not None):
268                con_comp=self._aggregate_from_ix(con_ds,ix,dim,con_func)
269                cat_comp=self._aggregate_from_ix(cat_ds,ix,dim,cat_func,cat_vals)
270                comp=con_comp.merge(cat_comp)
271            else:
272                comp=self._aggregate_from_ix(con_ds,ix,dim,con_func)
273        else:
274                comp=self._aggregate_from_ix(cat_ds,ix,dim,cat_func,cat_vals)
275        comp.attrs=ds.attrs
276        return comp.assign_coords({'lag':[lag]})    
277    
278    #loops over all lags, calling _composite_from_ix, and assembles composites into a single dataset
279    def _compute_aggregate_over_lags(self,da,dim,lag_vals,con_func,cat_func):
280            
281        if lag_vals=='all':
282            lag_vals=list(self._lagged_variables)
283                    
284        composite=self._composite_from_ix(*xr.align(da,self.variables),dim,con_func,cat_func)
285              
286        if lag_vals is not None:
287            lag_composites=[]
288            for t in lag_vals:
289                lag_composites.append(self._composite_from_ix(*xr.align(da,self.lagged_variables(t)),dim,con_func,cat_func,lag=t))
290            composite=xr.concat([composite,*lag_composites],'lag').sortby('lag')
291            
292        return composite
293
294    #The top level wrapper for compositing
295    def compute_composites(self,dim='time',lag_vals='all',as_anomaly=False,con_func=agg.mean_ds,cat_func=agg.cat_occ_ds,inplace=True):
296        
297        """
298        Partitions *LaggedAnalyser.variables*, and any time-lagged equivalents, into subsets depending on the value of *LaggedAnalyser.event*, and then computes a bulk summary metric for each.
299
300        **Optional arguments**
301        
302        *dim*
303        
304        A string, the coordinate along which to compute composites.
305            
306        *lag_vals*
307        
308        Either 'All', or a list of integers, denoting the time lags for which composites should be computed.
309            
310        *as_anomaly*
311        
312        A Boolean, defining whether composites should be given as absolute values or differences from the unpartitioned value.
313            
314        *con_func*
315        
316        The summary metric to use for continuous variables. Defaults to a standard mean average. If None, then continuous variables will be ignored
317            
318        *cat_func*
319        
320        The summary metric to use for categorical variables. Defaults to the occurrence probability of each categorical value. If None, then categorical variables will be ignored
321            
322        *inplace*
323    
324        A boolean, defining whether the composite should be stored in *LaggedAnalyser.composites*
325        
326        **returns**
327        
328        An xarray.Dataset like  *LaggedAnalyser.variables* but summarised according to *con_func* and *cat_func*, and with an additional coordinate *index_val*, which indexes over the values taken by *LaggedAnalyser.event*.
329            
330        """
331        composite=self._compute_aggregate_over_lags(self.event,dim,lag_vals,con_func,cat_func)
332        lagged_means=self.aggregate_variables(dim,lag_vals,con_func,cat_func)
333
334        if as_anomaly:
335            composite=composite-lagged_means
336            
337        composite=make_all_dims_coords(composite)
338        for v in list(composite.data_vars):
339            composite[v].attrs=self.variables[v].attrs
340        if inplace:
341            self.composites=composite
342            self.composite_func=(con_func,cat_func)
343            self.composites_are_anomaly=as_anomaly
344            self.lagged_means=lagged_means
345        return composite
346
347    #Aggregates variables over all time points where event is defined, regardless of its value
348    def aggregate_variables(self,dim='time',lag_vals='all',con_func=agg.mean_ds,cat_func=agg.cat_occ_ds):
349        
350        """Calculates a summary metric from *LaggedAnalyser.variables* at all points where *LaggedAnalyser.event* is defined, regardless of its value.
351        
352        **Optional arguments**
353        
354        *dim*
355        
356        A string, the name of the shared coordinate between *LaggedAnalyser.variables* and *LaggedAnalyser.event*.
357        
358        *lag_vals*
359        
360        'all' or a iterable of integers, specifying for which lag values to compute the summary metric.
361        
362        *con_func*
363        
364        The summary metric to use for continuous variables. Defaults to a standard mean average. If None, then continuous variables will be ignored
365            
366        *cat_func*
367        
368        The summary metric to use for categorical variables. Defaults to the occurrence probability of each categorical value. If None, then continuous variables will be ignored
369
370        **returns**
371        
372        An xarray.Dataset like  *LaggedAnalyser.variables* but summarised according to *con_func* and *cat_func*.
373
374"""
375        fake_event=self.event.copy(data=np.zeros_like(self.event))
376        return self._compute_aggregate_over_lags(fake_event,dim,lag_vals,con_func,cat_func).isel(index_val=0)
377
378    def add_derived_composite(self,name,func,composite_vars,as_anomaly=False):
379        """Applies *func* to one or multiple composites to calculate composites of derived quantities, and additionally, stores *func* to allow derived bootstrap composites to be calculated. For linear quantities, where Ex[f(x)]==f(Ex[x]), then this can minimise redundant memory use.
380        
381        **Arguments**
382        
383        *name*
384        
385        A string, providing the name of the new variable to add.
386            
387        *func*
388        
389         A callable which must take 1 or more xarray.DataArrays as inputs
390            
391        *composite_vars*
392        
393        An iterable of strings, of the same length as the number of arguments taken by *func*. Each string must be the name of a variable in *LaggedAnalyser.variables* which will be passed into *func* in order.
394        
395        **Optional arguments**
396        
397        *as_anomaly*
398        
399        A boolean. Whether anomaly composites or full composites should be passed in to func.
400        """
401        
402        if np.ndim(as_anomaly)==1:
403            raise(NotImplementedError('variable-specific anomalies not yet implemented'))
404
405        self._derived_variables[name]=(func,composite_vars,as_anomaly)
406        self.composites[name]=self._compute_derived_da(self.composites,func,composite_vars,as_anomaly)
407        
408        if self.lagged_means is not None:
409            self.lagged_means[name]=self._compute_derived_da(self.lagged_means,func,composite_vars,as_anomaly)
410            
411        return
412
413    ### Compute bootstraps ###
414    
415    #Top level func
416    def compute_bootstraps(self,bootnum,dim='time',con_func=agg.mean_ds,cat_func=agg.cat_occ_ds,lag=0,synth_mode='markov',data_vars=None,reuse_ixs=False):
417        
418        """Computes composites from synthetic event indices, which can be used to assess whether composites are insignificant.
419        
420        **Arguments**
421        
422        *bootnum*
423        
424        An integer, the number of bootstrapped composites to compute
425            
426        **Optional arguments**
427        
428        *dim*
429        
430        A string, the name of the shared coordinate between *LaggedAnalyser.variables* and *LaggedAnalyser.event*.
431            
432        *con_func*
433        
434        The summary metric to use for continuous variables. Defaults to a standard mean average. If None, then continuous variables will be ignored
435            
436        *cat_func*
437        
438        The summary metric to use for categorical variables. Defaults to the occurrence probability of each categorical value. If None, then continuous variables will be ignored
439
440        *lag*
441        
442        An integer, specifying which lagged variables to use for the bootstraps. i.e. bootstraps for lag=90 will be from a completely different season than those for lag=0.
443            
444        *synth_mode*
445        
446        A string, specifying how synthetic event indices are to be computed. Valid options are:
447            
448        "random": 
449        
450        categorical values are randomly chosen with the same probability of occurrence as those found in *LaggedAnalyser.event*, but with no autocorrelation.
451
452        "markov": 
453        
454        A first order Markov chain is fitted to *LaggedAnalyser.event*, producing some autocorrelation and state dependence in the synthetic series. Generally a better approximation than "random" and so should normally be used.
455
456        "shuffle": 
457        
458        The values are randomly reordered. This means that each value will occur exactly the same amount of times as in the original index, and so is ideal for particularly rare events or short series.
459            
460        *data_vars*
461        
462        An iterable of strings, specifying for which variables bootstraps should be computed.
463                
464        **returns**
465        
466        An xarray.Dataset like *LaggedAnalyser.variables* but summarised according to *con_func* and *cat_func*, and with a new coordinate 'bootnum' of length *bootnum*.
467
468        """
469        if data_vars==None:
470            data_vars=list(self.variables.data_vars)
471
472        boots=self._add_derived_boots(self._compute_bootstraps(bootnum,dim,con_func,cat_func,lag,synth_mode,data_vars,reuse_ixs))
473        if self.composites_are_anomaly:
474            boots=boots-self.lagged_means.sel(lag=lag)
475        return make_all_dims_coords(boots)
476    
477    
478    def _compute_derived_da(self,ds,func,varnames,as_anomaly):
479        if as_anomaly:
480            input_vars=[ds[v]-self.lagged_means[v] for v in varnames]
481        else:
482            input_vars=[ds[v] for v in varnames]
483        return make_all_dims_coords(func(*input_vars))
484    
485    
486    def _add_derived_boots(self,boots):
487        for var in self._derived_variables:
488            func,input_vars,as_anomaly=self._derived_variables[var]
489            boots[var]=self._compute_derived_da(boots,func,input_vars,as_anomaly)
490        return boots
491
492    def _compute_bootstraps(self,bootnum,dim,con_func,cat_func,lag,synth_mode,data_vars,reuse_ixs):
493
494        da,ds=xr.align(self.event,self.lagged_variables(lag))
495        ds=ds[data_vars]
496        
497        if (self.boot_indices is None)|(not reuse_ixs):
498            
499            ix_vals,ix_probs,L=self._get_bootparams(da)
500            ixs=self._get_synth_indices(da.values,bootnum,synth_mode,da,dim,ix_vals,ix_probs,L)
501            self.boot_indices=ixs
502        else:
503            ixs=self.boot_indices
504            print('Reusing stored boot_indices, ignoring new boot parameters.')
505        
506        boots=[make_all_dims_coords(\
507                self._composite_from_ix(ix,ds,dim,con_func,cat_func,lag)\
508             ) for ix in ixs]
509        return xr.concat(boots,'boot_num')
510    
511    #Gets some necessary variables
512    def _get_bootparams(self,da):
513        ix_vals,ix_probs=np.unique(da.values,return_counts=True)
514        return ix_vals,ix_probs/len(da),len(da)
515    
516    #compute indices
517    def _get_synth_indices(self,index,bootnum,mode,da,dim,ix_vals,ix_probs,L):
518        
519        ixs=[]
520        if mode=='markov':
521            xs=split_to_contiguous(da[dim].values,x_arr=da)
522            T=get_transmat(xs)
523            for n in range(bootnum):
524                ixs.append(synthetic_states_from_transmat(T,L-1))
525                
526        elif mode=='random':
527            for n in range(bootnum):
528                ixs.append(np.random.choice(ix_vals,size=L,p=list(ix_probs)))
529                
530        elif mode=='shuffle':
531            for n in range(bootnum):
532                ixv=index.copy()
533                np.random.shuffle(ixv)
534                
535                ixs.append(ixv)
536        else:
537            raise(ValueError(f'synth_mode={synth_mode} is not valid.'))
538            
539        return [xr.DataArray(ix) for ix in ixs]
540        
541    ### apply significance test ###
542    
543    def get_significance(self,bootstraps,comp,p,data_vars=None,hb_correction=False):
544        
545        """Computes whether a composite is significant with respect to a given distribution of bootstrapped composites. 
546        
547        **Arguments**
548        
549        *bootstraps*
550
551        An xarray.Dataset with a coordinate 'bootnum', such as produced by *LaggedAnalyser.compute_bootstraps*
552
553        *comp*
554
555        An xarray Dataset of the same shape as *bootstraps* but without a 'bootnum' coordinate. Missing or additional variables are allowed, and are simply ignored.
556        *p*
557
558        A float, specifying the p-value of the 2-sided significance test (values in the range 0 to 1). 
559            
560        **Optional arguments**
561
562        *data_vars*
563            
564        An iterable of strings, specifying for which variables significance should be computed.
565            
566        *hb_correction*
567        
568        A Boolean, specifying whether a Holm-Bonferroni correction should be applied to *p*, in order to reduce the family-wide error rate. Note that this correction is currently only applied to each variable in *comp* independently, and so will have no impact on scalar variables.
569        
570        **returns**
571        
572        An xarray.Dataset like *comp* but with boolean data, specifying whether each feature of each variable passed the significance test.
573        """
574        if data_vars==None:
575            data_vars=list(bootstraps.data_vars)
576
577        bootnum=len(bootstraps.boot_num)
578        comp=comp[data_vars]
579        bootstraps=bootstraps[data_vars]
580        frac=(comp<bootstraps).sum('boot_num')/bootnum
581        pval_ds=1-2*np.abs(frac-0.5)
582        if hb_correction:
583            for var in pval_ds:
584                corrected_pval=holm_bonferroni_correction(pval_ds[var].values.reshape(-1),p)\
585                            .reshape(pval_ds[var].shape)
586                pval_ds[var].data=corrected_pval
587        else:
588            pval_ds=pval_ds<p
589            
590        self.composite_sigs=pval_ds.assign_coords(lag=comp.lag)
591        return self.composite_sigs
592    
593    def bootstrap_significance(self,bootnum,p,dim='time',synth_mode='markov',reuse_lag0_boots=False,data_vars=None,hb_correction=False):
594        
595        """A wrapper around *compute_bootstraps* and *get_significance*, that calculates bootstraps and applies a significance test to a number of time lagged composites simulataneously.
596        
597    **Arguments**
598
599    *bootnum*
600
601    An integer, the number of bootstrapped composites to compute
602
603    *p*
604
605    A float, specifying the p-value of the 2-sided significance test (values in the range 0 to 1). 
606
607    **Optional arguments**
608
609    *dim*
610
611    A string, the name of the shared coordinate between *LaggedAnalyser.variables* and *LaggedAnalyser.event*.
612
613    *synth_mode*
614
615    A string, specifying how synthetic event indices are to be computed. Valid options are:
616    "random": categorical values are randomly chosen with the same probability of occurrence as those found in *LaggedAnalyser.event*, but with no autocorrelation.
617    'markov': A first order Markov chain is fitted to *LaggedAnalyser.event*, producing some autocorrelation and state dependence in the synthetic series. Generally a better approximation than "random" and so should normally be used.
618
619    *reuse_lag0_boots*
620        A Boolean. If True, bootstraps are only computed for lag=0, and then used as a null distribution to assess all lagged composites. For variables which are approximately stationary across the lag timescale, then this is a good approximation and can increase performance. However if used incorrectly, it may lead to 'significant composites' which simply reflect the seasonal cycle. if False, separate bootstraps are computed for all time lags.
621
622    *data_vars*
623        An iterable of strings, specifying for which variables significance should be computed.
624
625    *hb_correction*
626        A Boolean, specifying whether a Holm-Bonferroni correction should be applied to *p*, in order to reduce the family-wide error rate. Note that this correction is currently only applied to each variable in *comp* independently, and so will have no impact on scalar variables.
627        
628    **returns**
629
630    An xarray.Dataset like *LaggedAnalyser.variables* but with the *dim* dimension summarised according to *con_func* and *cat_func*, an additional *lag* coordinate, and with boolean data specifying whether each feature of each variable passed the significance test.
631
632        """
633        lag_vals=list(self._lagged_variables)
634        
635        con_func,cat_func=self.composite_func
636        
637        boots=self.compute_bootstraps(bootnum,dim,con_func,cat_func,0,synth_mode,data_vars)
638        
639        #reuse_lag0_boots=True can substantially reduce run time!
640        if not reuse_lag0_boots:
641                    boots=xr.concat([boots,*[self.compute_bootstraps(bootnum,dim,con_func,cat_func,t,synth_mode,data_vars)\
642                        for t in lag_vals]],'lag').sortby('lag')
643                
644        sig_composite=self.get_significance(boots,self.composites,p,data_vars,hb_correction=hb_correction)
645        
646        self.composite_sigs=sig_composite
647        return self.composite_sigs
648    
649    
650    def deseasonalise_variables(self,variable_list=None,dim='time',agg='dayofyear',smooth=1,coeffs=None):
651        """Computes a seasonal cycle for each variable in *LaggedAnalyser.variables* and subtracts it inplace, turning *LaggedAnalyser.variables* into deseasonalised anomalies. The seasonal cycle is computed via temporal aggregation of each variable over a given period - by default the calendar day of the year. This cycle can then be smoothed with an n-point rolling average.
652
653                **Optional arguments**
654
655                *variable_list*
656                
657                A list of variables to deseasonalise. Defaults to all variables in the *LaggedAnalyser.variables*
658
659                *dim*
660                
661                A string, the name of the shared coordinate between *LaggedAnalyser.variables* and *LaggedAnalyser.event*, along which the seasonal cycle is computed. Currently, only timelike coordinates are supported.
662                
663                *agg*
664                
665                A string specifying the datetime-like field to aggregate over. Useful and supported values are 'season', 'month', 'weekofyear', and 'dayofyear'
666                    
667                *smooth*
668                
669                An integer, specifying the size of the n-timestep centred rolling mean applied to the aggregated seasonal cycle. By default *smooth*=1 results in no smoothing.
670
671                *coeffs*
672                
673                A Dataset containing a precomputed seasonal cycle, which, if *LaggedAnalyser.variables* has coordinates (*dim*,[X,Y,...,Z]), has coords (*agg*,[X,Y,...,Z]), and has the same data variables as *LaggedAnalyser.variables*. If *coeffs* is provided, no seasonal cycle is fitted to *LaggedAnalyser.variables*, *coeffs* is used instead.
674
675        """        
676
677        if variable_list is None:
678            variable_list=list(self.variables)
679        for var in variable_list:
680            da=self.variables[var]
681            dsnlsr=Agg_Deseasonaliser()
682            if coeffs is None:
683                dsnlsr.fit_cycle(da,dim=dim,agg=agg)
684            else:
685                dsnslr.cycle_coeffs=coeffs[var]
686
687            cycle=dsnlsr.evaluate_cycle(data=da[dim],smooth=smooth)
688            self.variables[var]=da.copy(data=da.data-cycle.data)
689            dsnlsr.data=None #Prevents excess memory storage
690            self._deseasonalisers[var]=dsnlsr
691        return   
692    
693    def get_seasonal_cycle_coeffs(self):
694        """ Retrieve seasonal cycle coeffs computed with *LaggedAnalyser.deseasonalise_variables*, suitable for passing into *coeffs* in other *LaggedAnalyser.deseasonalise_variables* function calls as a precomputed cycle.
695        
696        **Returns**
697        An xarray.Dataset, as specified in  the *LaggedAnalyser.deseasonalise_variables* *coeff* optional keyword.
698        """
699        coeffs=xr.Dataset({v:dsnlsr.cycle_coeffs for v,dsnlsr in self._deseasonalisers.items()})
700        return coeffs
701
702    #If deseasonalise_variables has been called, then this func can be used to compute the
703    #seasonal mean state corresponding to a given composite. This mean state+ the composite
704    # produced by self.compute_composites gives the full field composite pattern.
705    def get_composite_seasonal_mean(self):
706        """
707        If *LaggedAnalyser.deseasonalise_variables* has been called, then this function returns the seasonal mean state corresponding to a given composite, given by a sum of the seasonal cycle weighted by the time-varying occurrence of each categorical value in *LaggedAnalyser.events*. This mean state + the deseasonalised anomaly composite
708    produced by *LaggedAnalyser.compute_composites* then retrieves the full composite pattern.
709    
710    **Returns**
711        An xarray.Dataset containing the composite seasonal mean values.
712        """
713        variable_list=list(self._deseasonalisers)
714        ts={e:self.event[self.event==e].time for e in np.unique(self.event)}
715        lags=np.unique([0,*list(self._lagged_variables)])
716        
717        mean_states={}
718        for var in variable_list:
719            dsnlsr=self._deseasonalisers[var]
720            agg=dsnlsr.agg
721            mean_states[var]=xr.concat([\
722                                 xr.concat([\
723                                    self._lag_average_cycle(dsnlsr,agg,l,t,i)\
724                                for l in lags],'lag')\
725                            for i,t in ts.items()],'index_val')
726            
727        return xr.Dataset(mean_states)
728        
729    def _lag_average_cycle(self,dsnlsr,agg,l,t,i):
730        
731        dt=durel.relativedelta(**{self.offset_unit:int(l)})
732        tvals=pd.to_datetime([pd.to_datetime(tt)+dt for tt in t.values])
733        cycle_eval=dsnlsr.cycle_coeffs.sel({agg:getattr(tvals,agg)})
734        cycle_mean=cycle_eval.mean(agg).assign_coords({'lag':l,'index_val':i})
735        return cycle_mean

Computes lagged composites of variables with respect to a categorical categorical event series, with support for bootstrap resampling to provide a non-parametric assessment of composite significance, and for deseasonalisation of variables.

Arguments:

event

An xarray.DataArray with one dimension taking on categorical values, each defining a class of event (or non-event).

Optional arguments

variables, name, is_categorical

Arguments for adding variables to the LaggedAnalyser. Identical behaviour to calling LaggedAnalyser.add_variables directly.

LaggedAnalyser(event, variables=None, name=None, is_categorical=None)
32    def __init__(self,event,variables=None,name=None,is_categorical=None):
33        """Initialise a new LaggedAnalyser object."""
34        
35        #: event is a dataarray
36        self.event=xr.DataArray(event)#: This is a docstring?
37        """@private"""
38        
39        #variables are stored in a dataset, and can be added later,
40        #or passed as a DataArray, a Dataset or as a dict of DataArrays
41        self.variables=xr.Dataset(coords=event.coords)
42        """@private"""
43
44        if variables is not None:
45            self.add_variable(variables,name,is_categorical,False)
46            
47        #Time lagged versions of the dataset self.variables will be stored here, with a key
48        #equal to the lag applied. Designed to be accessed by the self.lagged_variables function
49        self._lagged_variables={}
50        self.lagged_means=None
51        """@private"""
52
53        #variables that are a linear combination of other variables are more efficiently
54        #computed after compositing using the self.add_derived_composite method
55        self._derived_variables={}
56        self._deseasonalisers={}
57        
58        self.composite_mask=None
59        """@private"""
60
61        self.boot_indices=None
62        """@private"""
63
64        return

Initialise a new LaggedAnalyser object.

def add_variable( self, variables, name=None, is_categorical=None, overwrite=False, join_type='outer'):
 83    def add_variable(self,variables,name=None,is_categorical=None,overwrite=False,join_type='outer'):
 84        """Adds an additional variable to LaggedAnalyser.variables.
 85        
 86        **Arguments**
 87        
 88        *variables* 
 89        
 90        An xarray.DataArray, xarray.Dataset or dictionary of xarray.DataArrays, containing data to be composited with respect to *event*. One of the coordinates of *variables* should have the same name as the coordinate of *events*. Stored internally as an xarray.Dataset. If a dictionary is passed, the DataArrays are joined according to the method *join_type* which defaults to 'outer'.
 91            
 92        **Optional Arguments**
 93        
 94        *name* 
 95        
 96        A string. If *variables* is a single xarray.DataArray then *name* will be used as the name of the array in the LaggedAnalyser.variables DataArray. Otherwise ignored.
 97        
 98        *is_categorical* 
 99        
100        An integer, if *variables* is an xarray.DataArray, or else a dictionary of integers with keys corresponding to DataArrays in the xarray.Dataset/dictionary. 0 indicates that the variable is continuous, and 1 indicates that it is categorical. Note that continuous and categorical variables are by default composited differently (see LaggedAnalyser.compute_composites). Default assumption is all DataArrays are continuous, unless a DataAarray contains an 'is_categorical' key in its DataArray.attrs, in which case this value is used.
101            
102        *overwrite*
103        
104        A boolean. If False then attempts to assign a variable who's name is already in *LaggedAnalyser.variables* will raise a ValueError
105        
106        *join_type*
107        
108        A string setting the rules for how differences in the coordinate indices of different variables are handled:
109        “outer”: use the union of object indexes
110        “inner”: use the intersection of object indexes
111
112        “left”: use indexes from the pre-existing *LaggedAnalyser.variables* with each dimension
113
114        “right”: use indexes from the new *variables* with each dimension
115
116        “exact”: instead of aligning, raise ValueError when indexes to be aligned are not equal
117
118        “override”: if indexes are of same size, rewrite indexes to be those of the pre-existing *LaggedAnalyser.variables*. Indexes for the same dimension must have the same size in all objects.
119        """
120        if isinstance(variables,dict):
121            
122            if is_categorical is None:
123                is_categorical={v:None for v in variables}
124                
125            [self._add_variable(da,v,is_categorical[v],overwrite,join_type) for v,da in variables.items()]
126            
127        elif isinstance(variables,xr.Dataset):
128            self.add_variable({v:variables[v] for v in variables.data_vars},None,is_categorical,overwrite,join_type)
129            
130        else:
131            
132            self._add_variable(variables,name,is_categorical,overwrite,join_type)            
133        return

Adds an additional variable to LaggedAnalyser.variables.

Arguments

variables

An xarray.DataArray, xarray.Dataset or dictionary of xarray.DataArrays, containing data to be composited with respect to event. One of the coordinates of variables should have the same name as the coordinate of events. Stored internally as an xarray.Dataset. If a dictionary is passed, the DataArrays are joined according to the method join_type which defaults to 'outer'.

Optional Arguments

name

A string. If variables is a single xarray.DataArray then name will be used as the name of the array in the LaggedAnalyser.variables DataArray. Otherwise ignored.

is_categorical

An integer, if variables is an xarray.DataArray, or else a dictionary of integers with keys corresponding to DataArrays in the xarray.Dataset/dictionary. 0 indicates that the variable is continuous, and 1 indicates that it is categorical. Note that continuous and categorical variables are by default composited differently (see LaggedAnalyser.compute_composites). Default assumption is all DataArrays are continuous, unless a DataAarray contains an 'is_categorical' key in its DataArray.attrs, in which case this value is used.

overwrite

A boolean. If False then attempts to assign a variable who's name is already in LaggedAnalyser.variables will raise a ValueError

join_type

A string setting the rules for how differences in the coordinate indices of different variables are handled: “outer”: use the union of object indexes “inner”: use the intersection of object indexes

“left”: use indexes from the pre-existing LaggedAnalyser.variables with each dimension

“right”: use indexes from the new variables with each dimension

“exact”: instead of aligning, raise ValueError when indexes to be aligned are not equal

“override”: if indexes are of same size, rewrite indexes to be those of the pre-existing LaggedAnalyser.variables. Indexes for the same dimension must have the same size in all objects.

def lagged_variables(self, t):
157    def lagged_variables(self,t):
158        """A convenience function that retrieves variables at lag *t* from the *LaggedAnalyser*"""
159        if t in self._lagged_variables:
160            return self._lagged_variables[t]
161        elif t==0:
162            return self.variables
163        else:
164            raise(KeyError(f'Lag {t} is not in self._lagged_variables.'))

A convenience function that retrieves variables at lag t from the LaggedAnalyser

def lag_variables( self, offsets, offset_unit='days', offset_dim='time', mode='any', overwrite=False):
185    def lag_variables(self,offsets,offset_unit='days',offset_dim='time',mode='any',overwrite=False):
186        """Produces time lags of *LaggedAnalyser.variables*, which can be used to produce lagged composites.
187        
188        **Arguments**
189        
190        *offsets*
191        
192        An iterable of integers which represent time lags at which to lag *LaggedAnalyser.variables* in the units specified by *offset_unit*. Positive offsets denote variables *preceding* the event.
193            
194        **Optional arguments**
195        
196        *offset_unit*
197        
198        A string, defining the units of *offsets*. Valid options are weeks, days, hours, minutes, seconds, milliseconds, and microseconds.
199            
200        *offset_dim*
201        
202        A string, defining the coordinate of *LaggedAnalyser.variables* along which offsets are to be calculated.
203            
204        *mode*
205        
206        One of 'any', 'past', or 'future'. If 'past' or 'future' is used then only positive or negative lags are valid, respectively.
207            
208        *overwrite*
209        
210        A boolean. If False, then attempts to produce a lag which already exist will raise a ValueError.
211        
212        """
213        time_type=int(is_time_type(self.variables[offset_dim][0].values))
214        self.offset_unit=offset_unit
215        lag_funcs=[self._ilag_variables,self._lag_variables]
216        offsets=np.atleast_1d(offsets)
217        for o in offsets:
218            lag_funcs[time_type](int(o),offset_unit,offset_dim,mode,overwrite)

Produces time lags of LaggedAnalyser.variables, which can be used to produce lagged composites.

Arguments

offsets

An iterable of integers which represent time lags at which to lag LaggedAnalyser.variables in the units specified by offset_unit. Positive offsets denote variables preceding the event.

Optional arguments

offset_unit

A string, defining the units of offsets. Valid options are weeks, days, hours, minutes, seconds, milliseconds, and microseconds.

offset_dim

A string, defining the coordinate of LaggedAnalyser.variables along which offsets are to be calculated.

mode

One of 'any', 'past', or 'future'. If 'past' or 'future' is used then only positive or negative lags are valid, respectively.

overwrite

A boolean. If False, then attempts to produce a lag which already exist will raise a ValueError.

def compute_composites( self, dim='time', lag_vals='all', as_anomaly=False, con_func=<function mean_ds>, cat_func=<function cat_occ_ds>, inplace=True):
295    def compute_composites(self,dim='time',lag_vals='all',as_anomaly=False,con_func=agg.mean_ds,cat_func=agg.cat_occ_ds,inplace=True):
296        
297        """
298        Partitions *LaggedAnalyser.variables*, and any time-lagged equivalents, into subsets depending on the value of *LaggedAnalyser.event*, and then computes a bulk summary metric for each.
299
300        **Optional arguments**
301        
302        *dim*
303        
304        A string, the coordinate along which to compute composites.
305            
306        *lag_vals*
307        
308        Either 'All', or a list of integers, denoting the time lags for which composites should be computed.
309            
310        *as_anomaly*
311        
312        A Boolean, defining whether composites should be given as absolute values or differences from the unpartitioned value.
313            
314        *con_func*
315        
316        The summary metric to use for continuous variables. Defaults to a standard mean average. If None, then continuous variables will be ignored
317            
318        *cat_func*
319        
320        The summary metric to use for categorical variables. Defaults to the occurrence probability of each categorical value. If None, then categorical variables will be ignored
321            
322        *inplace*
323    
324        A boolean, defining whether the composite should be stored in *LaggedAnalyser.composites*
325        
326        **returns**
327        
328        An xarray.Dataset like  *LaggedAnalyser.variables* but summarised according to *con_func* and *cat_func*, and with an additional coordinate *index_val*, which indexes over the values taken by *LaggedAnalyser.event*.
329            
330        """
331        composite=self._compute_aggregate_over_lags(self.event,dim,lag_vals,con_func,cat_func)
332        lagged_means=self.aggregate_variables(dim,lag_vals,con_func,cat_func)
333
334        if as_anomaly:
335            composite=composite-lagged_means
336            
337        composite=make_all_dims_coords(composite)
338        for v in list(composite.data_vars):
339            composite[v].attrs=self.variables[v].attrs
340        if inplace:
341            self.composites=composite
342            self.composite_func=(con_func,cat_func)
343            self.composites_are_anomaly=as_anomaly
344            self.lagged_means=lagged_means
345        return composite

Partitions LaggedAnalyser.variables, and any time-lagged equivalents, into subsets depending on the value of LaggedAnalyser.event, and then computes a bulk summary metric for each.

Optional arguments

dim

A string, the coordinate along which to compute composites.

lag_vals

Either 'All', or a list of integers, denoting the time lags for which composites should be computed.

as_anomaly

A Boolean, defining whether composites should be given as absolute values or differences from the unpartitioned value.

con_func

The summary metric to use for continuous variables. Defaults to a standard mean average. If None, then continuous variables will be ignored

cat_func

The summary metric to use for categorical variables. Defaults to the occurrence probability of each categorical value. If None, then categorical variables will be ignored

inplace

A boolean, defining whether the composite should be stored in LaggedAnalyser.composites

returns

An xarray.Dataset like LaggedAnalyser.variables but summarised according to con_func and cat_func, and with an additional coordinate index_val, which indexes over the values taken by LaggedAnalyser.event.

def aggregate_variables( self, dim='time', lag_vals='all', con_func=<function mean_ds>, cat_func=<function cat_occ_ds>):
348    def aggregate_variables(self,dim='time',lag_vals='all',con_func=agg.mean_ds,cat_func=agg.cat_occ_ds):
349        
350        """Calculates a summary metric from *LaggedAnalyser.variables* at all points where *LaggedAnalyser.event* is defined, regardless of its value.
351        
352        **Optional arguments**
353        
354        *dim*
355        
356        A string, the name of the shared coordinate between *LaggedAnalyser.variables* and *LaggedAnalyser.event*.
357        
358        *lag_vals*
359        
360        'all' or a iterable of integers, specifying for which lag values to compute the summary metric.
361        
362        *con_func*
363        
364        The summary metric to use for continuous variables. Defaults to a standard mean average. If None, then continuous variables will be ignored
365            
366        *cat_func*
367        
368        The summary metric to use for categorical variables. Defaults to the occurrence probability of each categorical value. If None, then continuous variables will be ignored
369
370        **returns**
371        
372        An xarray.Dataset like  *LaggedAnalyser.variables* but summarised according to *con_func* and *cat_func*.
373
374"""
375        fake_event=self.event.copy(data=np.zeros_like(self.event))
376        return self._compute_aggregate_over_lags(fake_event,dim,lag_vals,con_func,cat_func).isel(index_val=0)

Calculates a summary metric from LaggedAnalyser.variables at all points where LaggedAnalyser.event is defined, regardless of its value.

Optional arguments

dim

A string, the name of the shared coordinate between LaggedAnalyser.variables and LaggedAnalyser.event.

lag_vals

'all' or a iterable of integers, specifying for which lag values to compute the summary metric.

con_func

The summary metric to use for continuous variables. Defaults to a standard mean average. If None, then continuous variables will be ignored

cat_func

The summary metric to use for categorical variables. Defaults to the occurrence probability of each categorical value. If None, then continuous variables will be ignored

returns

An xarray.Dataset like LaggedAnalyser.variables but summarised according to con_func and cat_func.

def add_derived_composite(self, name, func, composite_vars, as_anomaly=False):
378    def add_derived_composite(self,name,func,composite_vars,as_anomaly=False):
379        """Applies *func* to one or multiple composites to calculate composites of derived quantities, and additionally, stores *func* to allow derived bootstrap composites to be calculated. For linear quantities, where Ex[f(x)]==f(Ex[x]), then this can minimise redundant memory use.
380        
381        **Arguments**
382        
383        *name*
384        
385        A string, providing the name of the new variable to add.
386            
387        *func*
388        
389         A callable which must take 1 or more xarray.DataArrays as inputs
390            
391        *composite_vars*
392        
393        An iterable of strings, of the same length as the number of arguments taken by *func*. Each string must be the name of a variable in *LaggedAnalyser.variables* which will be passed into *func* in order.
394        
395        **Optional arguments**
396        
397        *as_anomaly*
398        
399        A boolean. Whether anomaly composites or full composites should be passed in to func.
400        """
401        
402        if np.ndim(as_anomaly)==1:
403            raise(NotImplementedError('variable-specific anomalies not yet implemented'))
404
405        self._derived_variables[name]=(func,composite_vars,as_anomaly)
406        self.composites[name]=self._compute_derived_da(self.composites,func,composite_vars,as_anomaly)
407        
408        if self.lagged_means is not None:
409            self.lagged_means[name]=self._compute_derived_da(self.lagged_means,func,composite_vars,as_anomaly)
410            
411        return

Applies func to one or multiple composites to calculate composites of derived quantities, and additionally, stores func to allow derived bootstrap composites to be calculated. For linear quantities, where Ex[f(x)]==f(Ex[x]), then this can minimise redundant memory use.

Arguments

name

A string, providing the name of the new variable to add.

func

A callable which must take 1 or more xarray.DataArrays as inputs

composite_vars

An iterable of strings, of the same length as the number of arguments taken by func. Each string must be the name of a variable in LaggedAnalyser.variables which will be passed into func in order.

Optional arguments

as_anomaly

A boolean. Whether anomaly composites or full composites should be passed in to func.

def compute_bootstraps( self, bootnum, dim='time', con_func=<function mean_ds>, cat_func=<function cat_occ_ds>, lag=0, synth_mode='markov', data_vars=None, reuse_ixs=False):
416    def compute_bootstraps(self,bootnum,dim='time',con_func=agg.mean_ds,cat_func=agg.cat_occ_ds,lag=0,synth_mode='markov',data_vars=None,reuse_ixs=False):
417        
418        """Computes composites from synthetic event indices, which can be used to assess whether composites are insignificant.
419        
420        **Arguments**
421        
422        *bootnum*
423        
424        An integer, the number of bootstrapped composites to compute
425            
426        **Optional arguments**
427        
428        *dim*
429        
430        A string, the name of the shared coordinate between *LaggedAnalyser.variables* and *LaggedAnalyser.event*.
431            
432        *con_func*
433        
434        The summary metric to use for continuous variables. Defaults to a standard mean average. If None, then continuous variables will be ignored
435            
436        *cat_func*
437        
438        The summary metric to use for categorical variables. Defaults to the occurrence probability of each categorical value. If None, then continuous variables will be ignored
439
440        *lag*
441        
442        An integer, specifying which lagged variables to use for the bootstraps. i.e. bootstraps for lag=90 will be from a completely different season than those for lag=0.
443            
444        *synth_mode*
445        
446        A string, specifying how synthetic event indices are to be computed. Valid options are:
447            
448        "random": 
449        
450        categorical values are randomly chosen with the same probability of occurrence as those found in *LaggedAnalyser.event*, but with no autocorrelation.
451
452        "markov": 
453        
454        A first order Markov chain is fitted to *LaggedAnalyser.event*, producing some autocorrelation and state dependence in the synthetic series. Generally a better approximation than "random" and so should normally be used.
455
456        "shuffle": 
457        
458        The values are randomly reordered. This means that each value will occur exactly the same amount of times as in the original index, and so is ideal for particularly rare events or short series.
459            
460        *data_vars*
461        
462        An iterable of strings, specifying for which variables bootstraps should be computed.
463                
464        **returns**
465        
466        An xarray.Dataset like *LaggedAnalyser.variables* but summarised according to *con_func* and *cat_func*, and with a new coordinate 'bootnum' of length *bootnum*.
467
468        """
469        if data_vars==None:
470            data_vars=list(self.variables.data_vars)
471
472        boots=self._add_derived_boots(self._compute_bootstraps(bootnum,dim,con_func,cat_func,lag,synth_mode,data_vars,reuse_ixs))
473        if self.composites_are_anomaly:
474            boots=boots-self.lagged_means.sel(lag=lag)
475        return make_all_dims_coords(boots)

Computes composites from synthetic event indices, which can be used to assess whether composites are insignificant.

Arguments

bootnum

An integer, the number of bootstrapped composites to compute

Optional arguments

dim

A string, the name of the shared coordinate between LaggedAnalyser.variables and LaggedAnalyser.event.

con_func

The summary metric to use for continuous variables. Defaults to a standard mean average. If None, then continuous variables will be ignored

cat_func

The summary metric to use for categorical variables. Defaults to the occurrence probability of each categorical value. If None, then continuous variables will be ignored

lag

An integer, specifying which lagged variables to use for the bootstraps. i.e. bootstraps for lag=90 will be from a completely different season than those for lag=0.

synth_mode

A string, specifying how synthetic event indices are to be computed. Valid options are:

"random":

categorical values are randomly chosen with the same probability of occurrence as those found in LaggedAnalyser.event, but with no autocorrelation.

"markov":

A first order Markov chain is fitted to LaggedAnalyser.event, producing some autocorrelation and state dependence in the synthetic series. Generally a better approximation than "random" and so should normally be used.

"shuffle":

The values are randomly reordered. This means that each value will occur exactly the same amount of times as in the original index, and so is ideal for particularly rare events or short series.

data_vars

An iterable of strings, specifying for which variables bootstraps should be computed.

returns

An xarray.Dataset like LaggedAnalyser.variables but summarised according to con_func and cat_func, and with a new coordinate 'bootnum' of length bootnum.

def get_significance(self, bootstraps, comp, p, data_vars=None, hb_correction=False):
543    def get_significance(self,bootstraps,comp,p,data_vars=None,hb_correction=False):
544        
545        """Computes whether a composite is significant with respect to a given distribution of bootstrapped composites. 
546        
547        **Arguments**
548        
549        *bootstraps*
550
551        An xarray.Dataset with a coordinate 'bootnum', such as produced by *LaggedAnalyser.compute_bootstraps*
552
553        *comp*
554
555        An xarray Dataset of the same shape as *bootstraps* but without a 'bootnum' coordinate. Missing or additional variables are allowed, and are simply ignored.
556        *p*
557
558        A float, specifying the p-value of the 2-sided significance test (values in the range 0 to 1). 
559            
560        **Optional arguments**
561
562        *data_vars*
563            
564        An iterable of strings, specifying for which variables significance should be computed.
565            
566        *hb_correction*
567        
568        A Boolean, specifying whether a Holm-Bonferroni correction should be applied to *p*, in order to reduce the family-wide error rate. Note that this correction is currently only applied to each variable in *comp* independently, and so will have no impact on scalar variables.
569        
570        **returns**
571        
572        An xarray.Dataset like *comp* but with boolean data, specifying whether each feature of each variable passed the significance test.
573        """
574        if data_vars==None:
575            data_vars=list(bootstraps.data_vars)
576
577        bootnum=len(bootstraps.boot_num)
578        comp=comp[data_vars]
579        bootstraps=bootstraps[data_vars]
580        frac=(comp<bootstraps).sum('boot_num')/bootnum
581        pval_ds=1-2*np.abs(frac-0.5)
582        if hb_correction:
583            for var in pval_ds:
584                corrected_pval=holm_bonferroni_correction(pval_ds[var].values.reshape(-1),p)\
585                            .reshape(pval_ds[var].shape)
586                pval_ds[var].data=corrected_pval
587        else:
588            pval_ds=pval_ds<p
589            
590        self.composite_sigs=pval_ds.assign_coords(lag=comp.lag)
591        return self.composite_sigs

Computes whether a composite is significant with respect to a given distribution of bootstrapped composites.

Arguments

bootstraps

An xarray.Dataset with a coordinate 'bootnum', such as produced by LaggedAnalyser.compute_bootstraps

comp

An xarray Dataset of the same shape as bootstraps but without a 'bootnum' coordinate. Missing or additional variables are allowed, and are simply ignored. p

A float, specifying the p-value of the 2-sided significance test (values in the range 0 to 1).

Optional arguments

data_vars

An iterable of strings, specifying for which variables significance should be computed.

hb_correction

A Boolean, specifying whether a Holm-Bonferroni correction should be applied to p, in order to reduce the family-wide error rate. Note that this correction is currently only applied to each variable in comp independently, and so will have no impact on scalar variables.

returns

An xarray.Dataset like comp but with boolean data, specifying whether each feature of each variable passed the significance test.

def bootstrap_significance( self, bootnum, p, dim='time', synth_mode='markov', reuse_lag0_boots=False, data_vars=None, hb_correction=False):
593    def bootstrap_significance(self,bootnum,p,dim='time',synth_mode='markov',reuse_lag0_boots=False,data_vars=None,hb_correction=False):
594        
595        """A wrapper around *compute_bootstraps* and *get_significance*, that calculates bootstraps and applies a significance test to a number of time lagged composites simulataneously.
596        
597    **Arguments**
598
599    *bootnum*
600
601    An integer, the number of bootstrapped composites to compute
602
603    *p*
604
605    A float, specifying the p-value of the 2-sided significance test (values in the range 0 to 1). 
606
607    **Optional arguments**
608
609    *dim*
610
611    A string, the name of the shared coordinate between *LaggedAnalyser.variables* and *LaggedAnalyser.event*.
612
613    *synth_mode*
614
615    A string, specifying how synthetic event indices are to be computed. Valid options are:
616    "random": categorical values are randomly chosen with the same probability of occurrence as those found in *LaggedAnalyser.event*, but with no autocorrelation.
617    'markov': A first order Markov chain is fitted to *LaggedAnalyser.event*, producing some autocorrelation and state dependence in the synthetic series. Generally a better approximation than "random" and so should normally be used.
618
619    *reuse_lag0_boots*
620        A Boolean. If True, bootstraps are only computed for lag=0, and then used as a null distribution to assess all lagged composites. For variables which are approximately stationary across the lag timescale, then this is a good approximation and can increase performance. However if used incorrectly, it may lead to 'significant composites' which simply reflect the seasonal cycle. if False, separate bootstraps are computed for all time lags.
621
622    *data_vars*
623        An iterable of strings, specifying for which variables significance should be computed.
624
625    *hb_correction*
626        A Boolean, specifying whether a Holm-Bonferroni correction should be applied to *p*, in order to reduce the family-wide error rate. Note that this correction is currently only applied to each variable in *comp* independently, and so will have no impact on scalar variables.
627        
628    **returns**
629
630    An xarray.Dataset like *LaggedAnalyser.variables* but with the *dim* dimension summarised according to *con_func* and *cat_func*, an additional *lag* coordinate, and with boolean data specifying whether each feature of each variable passed the significance test.
631
632        """
633        lag_vals=list(self._lagged_variables)
634        
635        con_func,cat_func=self.composite_func
636        
637        boots=self.compute_bootstraps(bootnum,dim,con_func,cat_func,0,synth_mode,data_vars)
638        
639        #reuse_lag0_boots=True can substantially reduce run time!
640        if not reuse_lag0_boots:
641                    boots=xr.concat([boots,*[self.compute_bootstraps(bootnum,dim,con_func,cat_func,t,synth_mode,data_vars)\
642                        for t in lag_vals]],'lag').sortby('lag')
643                
644        sig_composite=self.get_significance(boots,self.composites,p,data_vars,hb_correction=hb_correction)
645        
646        self.composite_sigs=sig_composite
647        return self.composite_sigs

A wrapper around compute_bootstraps and get_significance, that calculates bootstraps and applies a significance test to a number of time lagged composites simulataneously.

Arguments

bootnum

An integer, the number of bootstrapped composites to compute

p

A float, specifying the p-value of the 2-sided significance test (values in the range 0 to 1).

Optional arguments

dim

A string, the name of the shared coordinate between LaggedAnalyser.variables and LaggedAnalyser.event.

synth_mode

A string, specifying how synthetic event indices are to be computed. Valid options are: "random": categorical values are randomly chosen with the same probability of occurrence as those found in LaggedAnalyser.event, but with no autocorrelation. 'markov': A first order Markov chain is fitted to LaggedAnalyser.event, producing some autocorrelation and state dependence in the synthetic series. Generally a better approximation than "random" and so should normally be used.

reuse_lag0_boots A Boolean. If True, bootstraps are only computed for lag=0, and then used as a null distribution to assess all lagged composites. For variables which are approximately stationary across the lag timescale, then this is a good approximation and can increase performance. However if used incorrectly, it may lead to 'significant composites' which simply reflect the seasonal cycle. if False, separate bootstraps are computed for all time lags.

data_vars An iterable of strings, specifying for which variables significance should be computed.

hb_correction A Boolean, specifying whether a Holm-Bonferroni correction should be applied to p, in order to reduce the family-wide error rate. Note that this correction is currently only applied to each variable in comp independently, and so will have no impact on scalar variables.

returns

An xarray.Dataset like LaggedAnalyser.variables but with the dim dimension summarised according to con_func and cat_func, an additional lag coordinate, and with boolean data specifying whether each feature of each variable passed the significance test.

def deseasonalise_variables( self, variable_list=None, dim='time', agg='dayofyear', smooth=1, coeffs=None):
650    def deseasonalise_variables(self,variable_list=None,dim='time',agg='dayofyear',smooth=1,coeffs=None):
651        """Computes a seasonal cycle for each variable in *LaggedAnalyser.variables* and subtracts it inplace, turning *LaggedAnalyser.variables* into deseasonalised anomalies. The seasonal cycle is computed via temporal aggregation of each variable over a given period - by default the calendar day of the year. This cycle can then be smoothed with an n-point rolling average.
652
653                **Optional arguments**
654
655                *variable_list*
656                
657                A list of variables to deseasonalise. Defaults to all variables in the *LaggedAnalyser.variables*
658
659                *dim*
660                
661                A string, the name of the shared coordinate between *LaggedAnalyser.variables* and *LaggedAnalyser.event*, along which the seasonal cycle is computed. Currently, only timelike coordinates are supported.
662                
663                *agg*
664                
665                A string specifying the datetime-like field to aggregate over. Useful and supported values are 'season', 'month', 'weekofyear', and 'dayofyear'
666                    
667                *smooth*
668                
669                An integer, specifying the size of the n-timestep centred rolling mean applied to the aggregated seasonal cycle. By default *smooth*=1 results in no smoothing.
670
671                *coeffs*
672                
673                A Dataset containing a precomputed seasonal cycle, which, if *LaggedAnalyser.variables* has coordinates (*dim*,[X,Y,...,Z]), has coords (*agg*,[X,Y,...,Z]), and has the same data variables as *LaggedAnalyser.variables*. If *coeffs* is provided, no seasonal cycle is fitted to *LaggedAnalyser.variables*, *coeffs* is used instead.
674
675        """        
676
677        if variable_list is None:
678            variable_list=list(self.variables)
679        for var in variable_list:
680            da=self.variables[var]
681            dsnlsr=Agg_Deseasonaliser()
682            if coeffs is None:
683                dsnlsr.fit_cycle(da,dim=dim,agg=agg)
684            else:
685                dsnslr.cycle_coeffs=coeffs[var]
686
687            cycle=dsnlsr.evaluate_cycle(data=da[dim],smooth=smooth)
688            self.variables[var]=da.copy(data=da.data-cycle.data)
689            dsnlsr.data=None #Prevents excess memory storage
690            self._deseasonalisers[var]=dsnlsr
691        return   

Computes a seasonal cycle for each variable in LaggedAnalyser.variables and subtracts it inplace, turning LaggedAnalyser.variables into deseasonalised anomalies. The seasonal cycle is computed via temporal aggregation of each variable over a given period - by default the calendar day of the year. This cycle can then be smoothed with an n-point rolling average.

Optional arguments

variable_list

A list of variables to deseasonalise. Defaults to all variables in the LaggedAnalyser.variables

dim

A string, the name of the shared coordinate between LaggedAnalyser.variables and LaggedAnalyser.event, along which the seasonal cycle is computed. Currently, only timelike coordinates are supported.

agg

A string specifying the datetime-like field to aggregate over. Useful and supported values are 'season', 'month', 'weekofyear', and 'dayofyear'

smooth

An integer, specifying the size of the n-timestep centred rolling mean applied to the aggregated seasonal cycle. By default smooth=1 results in no smoothing.

coeffs

A Dataset containing a precomputed seasonal cycle, which, if LaggedAnalyser.variables has coordinates (dim,[X,Y,...,Z]), has coords (agg,[X,Y,...,Z]), and has the same data variables as LaggedAnalyser.variables. If coeffs is provided, no seasonal cycle is fitted to LaggedAnalyser.variables, coeffs is used instead.

def get_seasonal_cycle_coeffs(self):
693    def get_seasonal_cycle_coeffs(self):
694        """ Retrieve seasonal cycle coeffs computed with *LaggedAnalyser.deseasonalise_variables*, suitable for passing into *coeffs* in other *LaggedAnalyser.deseasonalise_variables* function calls as a precomputed cycle.
695        
696        **Returns**
697        An xarray.Dataset, as specified in  the *LaggedAnalyser.deseasonalise_variables* *coeff* optional keyword.
698        """
699        coeffs=xr.Dataset({v:dsnlsr.cycle_coeffs for v,dsnlsr in self._deseasonalisers.items()})
700        return coeffs

Retrieve seasonal cycle coeffs computed with LaggedAnalyser.deseasonalise_variables, suitable for passing into coeffs in other LaggedAnalyser.deseasonalise_variables function calls as a precomputed cycle.

Returns An xarray.Dataset, as specified in the LaggedAnalyser.deseasonalise_variables coeff optional keyword.

def get_composite_seasonal_mean(self):
705    def get_composite_seasonal_mean(self):
706        """
707        If *LaggedAnalyser.deseasonalise_variables* has been called, then this function returns the seasonal mean state corresponding to a given composite, given by a sum of the seasonal cycle weighted by the time-varying occurrence of each categorical value in *LaggedAnalyser.events*. This mean state + the deseasonalised anomaly composite
708    produced by *LaggedAnalyser.compute_composites* then retrieves the full composite pattern.
709    
710    **Returns**
711        An xarray.Dataset containing the composite seasonal mean values.
712        """
713        variable_list=list(self._deseasonalisers)
714        ts={e:self.event[self.event==e].time for e in np.unique(self.event)}
715        lags=np.unique([0,*list(self._lagged_variables)])
716        
717        mean_states={}
718        for var in variable_list:
719            dsnlsr=self._deseasonalisers[var]
720            agg=dsnlsr.agg
721            mean_states[var]=xr.concat([\
722                                 xr.concat([\
723                                    self._lag_average_cycle(dsnlsr,agg,l,t,i)\
724                                for l in lags],'lag')\
725                            for i,t in ts.items()],'index_val')
726            
727        return xr.Dataset(mean_states)

If LaggedAnalyser.deseasonalise_variables has been called, then this function returns the seasonal mean state corresponding to a given composite, given by a sum of the seasonal cycle weighted by the time-varying occurrence of each categorical value in LaggedAnalyser.events. This mean state + the deseasonalised anomaly composite produced by LaggedAnalyser.compute_composites then retrieves the full composite pattern.

Returns An xarray.Dataset containing the composite seasonal mean values.

class PatternFilter:
737class PatternFilter(object):
738    """Provides filtering methods to refine n-dimensional boolean masks, and apply them to an underlying dataset.
739    
740        **Optional arguments:**
741        
742        *mask_ds*
743        
744        An xarray boolean Dataset of arbitrary dimensions which provides the initial mask dataset. If *mask_ds*=None  and *analyser*=None, then *mask_ds* will be initialised as a Dataset of the same dimensions and data_vars as *val_ds*, with all values = 1 (i.e. initially unmasked). 
745        
746        *val_ds*
747        
748        An xarray Dataset with the same dimensions as *mask_ds* if provided, otherwise arbitrary, consisting of an underlying dataset to which the mask is applied. If *val_ds*=None and *analyser*=None, then *PatternFilter.apply_value_mask* will raise an Error
749            
750        *analyser*
751        
752        An instance of a  core.LaggedAnalyser class for which both composites and significance masks have been computed, used to infer the *val_ds* and *mask_ds* arguments respectively. This overrides any values passed explicitly to  *mask_ds* and *val_ds*.
753            
754    """
755    def __init__(self,mask_ds=None,val_ds=None,analyser=None):
756        """Initialise a new PatternFilter object"""
757        self.mask_ds=mask_ds
758        """@private"""
759        self.val_ds=val_ds
760        """@private"""
761
762        if analyser is not None:
763            self._parse_analyser(analyser)
764            
765        else:
766            if mask_ds is None:
767                self.mask_ds=self._mask_ds_like_val_ds()
768                
769    def __repr__(self):
770        return 'A PatternFilter object'
771        
772    def __str__(self):
773            return self.__repr__
774        
775    def _parse_analyser(self,analyser):
776        self.mask_ds=analyser.composite_sigs
777        self.val_ds=analyser.composites
778        
779    def _mask_ds_like_val_ds(self):
780        if self.val_ds is None:
781            raise(ValueError('At least one of "mask_ds", "val_ds" and "analyser" must be provided.'))
782        
783        x=self.val_ds
784        y=x.where(x!=0).fillna(1) #replace nans and 0s with 1
785        y=(y/y).astype(int) #make everything 1 via division and assert integer type.
786        self.mask_ds=y
787        return
788    
789    def update_mask(self,new_mask,mode):
790        """ Update *PatternFilter.mask_ds* with a new mask, either taking their union or intersection, or replacing the current mask with new_mask.
791        
792        **Arguments**
793        
794        *new_mask*
795
796        An xarray.Dataset with the same coords and variables as *PatternFilter.mask_ds*.
797
798        *mode*
799
800        A string, one of 'replace','intersection' or 'union', defining how *new_mask* should be used to update the mask.
801        """
802        new_mask=new_mask.astype(int)
803        if mode=='replace':
804            self.mask_ds=new_mask
805        elif mode=='intersection':
806            self.mask_ds=self.mask_ds*new_mask
807        elif mode == 'union':
808            self.mask_ds=self.mask_ds|new_mask
809        else:
810            raise(ValueError(f'Invalid mode, {mode}'))
811        return
812                  
813    def apply_value_mask(self,truth_function,*args,mode='intersection'):
814        """ Apply a filter to *PatternFilter.mask_ds* based on a user-specified truth function which is applied to *PatternFilter.val_ds. 
815        
816        **Examples**
817        
818            #Mask values beneath a threshold:
819            def larger_than_thresh(ds,thresh):
820                return ds>thresh
821            patternfilter.apply_value_mask(is_positive,thresh)
822
823            #Mask values where absolute value is less than a reference field:
824            def amp_greater_than_reference(ds,ref_ds):
825                return np.abs(ds)>ref_ds
826            pattern_filter.apply_value_mask(amp_greater_than_reference,ref_ds)
827
828        **Arguments**
829
830        *truth_function*
831        
832        A function with inputs (val_ds,*args) that returns a boolean dataset with the same coords and data variables as *PatternFilter.val_ds*.
833
834        **Optional arguments**
835        
836        *mode*
837            
838        A string, one of 'replace','intersection' or 'union', defining how the value filter should be used to update the *PatternFilter.mask_ds*.
839        """        
840        if self.val_ds is None:
841            raise(ValueError('val_ds must be provided to apply value mask.'))
842        value_mask=truth_function(self.val_ds,*args)
843        self.update_mask(value_mask,mode)
844        return
845    
846    def apply_area_mask(self,n,dims=None,mode='intersection',area_type='gridpoint'):
847        """ Apply a filter to *PatternFilter.mask_ds* that identifies connected groups of True values within a subspace of the Dataset's dimensions specified by *dims*, and masks out groups which are beneath a threshold size *n*. This is done through the application of *scipy.ndimage.label* using the default structuring element (https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.label.html). 
848    
849        When *area_type*='gridpoint', *n* specifies the number of connected datapoints within each connected region. For the special case where *dims* consists of a latitude- and longitude-like coordinate, area_type='spherical' applies a cosine-latitude weighting, such that *n* can be interpreted as a measure of area, where a datapoint with lat=0 would have area 1. 
850        
851        **Examples**
852        
853            #Keep groups of True values consisting of an area >=30 square equatorial gridpoints
854            patternfilter.apply_area_mask(30,dims=('lat','lon'),area_type='spherical')
855            
856            #Keep groups of True values that are consistent for at least 3 neighbouring time lags
857            patternfilter.apply_area_mask(3,dims=('time'))
858            
859            #Keep groups of true values consisting of >=10 longitudinal values, or >=30 values in longitude and altitude if the variables have an altitude coord:
860            patternfilter.apply_area_mask(10,dims=('longitude'))
861            patternfilter.apply_area_mask(30,dims=('longitude,altitude'),mode='union')
862
863        **Arguments**
864
865        *n*
866            
867        A scalar indicating the minimum size of an unmasked group, in terms of number of gridpoints (for *area_type*=gridpoint) or the weighted area (for *area_type*=spherical), beneath which the group will be masked.
868
869        **Optional arguments**
870        
871        *dims*
872            
873        An iterable of strings specifying coords in *PatternFilter.mask_ds* which define the subspace in which groups of connected True values are identified. Other dims will be iterated over. DataArrays within *PatternFilter.mask_ds* that do not contain all the *dims* will be ignored. If *dims*=None, all dims in each DataArray will be used.
874            
875        *mode*
876
877        A string, one of 'replace','intersection' or 'union', defining how the area filter should be used to update the *PatternFilter.mask_ds*.
878            
879        *area_type*
880
881        A string, one of 'gridpoint' or 'spherical' as specified above. 'spherical' is currently only supported for len-2 *dims* kwargs, with the first assumed to be latitude-like. 
882            
883        """        
884        if area_type=='gridpoint':
885            area_based=False
886        elif area_type=='spherical':
887            area_based=True
888        else:
889            raise(ValueError(f"Unknown area_type {area_type}. Valid options are 'gridpoint' and 'spherical'"))
890        area_mask=ds_large_regions(self.mask_ds,n,dims=dims,area_based=area_based)
891        self.update_mask(area_mask,mode)
892        return
893    
894    
895    def apply_convolution(self,n,dims,mode='replace'):
896        """ Apply a square n-point convolution filter to *PatternFilter.mask_ds* in one or two dimensions specified by *dims*, iterated over remaining dimensions. This has the effect of extending the unmasked regions and smoothing the mask overall.
897        
898        **Arguments**
899        
900        *n*
901            
902        A positive integer specifying the size of the convolution filter. *n*=1 leaves the mask unchanged. Even *n* are asymmetric and shifted right. 
903
904        *dims*
905
906        A length 1 or 2 iterable of strings specifying the dims in which the convolution is applied. Other dims will be iterated over. DataArrays within *PatternFilter.mask_ds* that do not contain all the *dims* will be ignored. 
907
908        *mode*
909
910        A string, one of 'replace','intersection' or 'union', defining how the area filter should be used to update the *PatternFilter.mask_ds*.
911        """
912        
913        if not len(dims) in [1,2]:
914            raise(ValueError('Only 1 and 2D dims currently supported'))
915            
916        convolution=convolve_pad_ds(self.mask_ds,n,dims=dims)
917        self.update_mask(convolution,mode)
918        return
919    
920    def get_mask(self):
921        """" Retrieve the mask with all filters applied.
922        **Returns**
923        An xarray.Dataset of boolean values.
924        """
925        return self.mask_ds
926    
927    def filter(self,ds=None,drop_empty=True,fill_val=np.nan):
928        """ Apply the current mask to *ds* or to *PatternFilter.val_ds* (if *ds* is None), replacing masked gridpoints with *fill_val*.
929        **Optional arguments**
930        
931        *ds*
932        
933        An xarray.Dataset to apply the mask to. Should have the same coords and data_vars as *PatternFilter.mask_ds*. If None, the mask is applied to *PatternFilter.val_ds*.
934        
935        *drop_empty*
936        
937        A boolean value. If True, then completely masked variables are dropped from the returned masked Dataset.
938        
939        *fill_val*
940        
941        A scalar that defaults to np.nan. The value with which masked gridpoints in the Dataset are replaced.
942        
943        **Returns**
944        
945        A Dataset with masked values replaced by *fill_val*.
946        """
947        if ds is None:
948            ds=self.val_ds.copy(deep=True)
949            
950        ds=ds.where(self.mask_ds)
951        if drop_empty:
952            drop_vars=((~np.isnan(ds)).sum()==0).to_array('vars')
953            ds=ds.drop_vars(drop_vars[drop_vars].vars.values)
954        return ds.fillna(fill_val)

Provides filtering methods to refine n-dimensional boolean masks, and apply them to an underlying dataset.

Optional arguments:

mask_ds

An xarray boolean Dataset of arbitrary dimensions which provides the initial mask dataset. If mask_ds=None and analyser=None, then mask_ds will be initialised as a Dataset of the same dimensions and data_vars as val_ds, with all values = 1 (i.e. initially unmasked).

val_ds

An xarray Dataset with the same dimensions as mask_ds if provided, otherwise arbitrary, consisting of an underlying dataset to which the mask is applied. If val_ds=None and analyser=None, then PatternFilter.apply_value_mask will raise an Error

analyser

An instance of a core.LaggedAnalyser class for which both composites and significance masks have been computed, used to infer the val_ds and mask_ds arguments respectively. This overrides any values passed explicitly to mask_ds and val_ds.

PatternFilter(mask_ds=None, val_ds=None, analyser=None)
755    def __init__(self,mask_ds=None,val_ds=None,analyser=None):
756        """Initialise a new PatternFilter object"""
757        self.mask_ds=mask_ds
758        """@private"""
759        self.val_ds=val_ds
760        """@private"""
761
762        if analyser is not None:
763            self._parse_analyser(analyser)
764            
765        else:
766            if mask_ds is None:
767                self.mask_ds=self._mask_ds_like_val_ds()

Initialise a new PatternFilter object

def update_mask(self, new_mask, mode):
789    def update_mask(self,new_mask,mode):
790        """ Update *PatternFilter.mask_ds* with a new mask, either taking their union or intersection, or replacing the current mask with new_mask.
791        
792        **Arguments**
793        
794        *new_mask*
795
796        An xarray.Dataset with the same coords and variables as *PatternFilter.mask_ds*.
797
798        *mode*
799
800        A string, one of 'replace','intersection' or 'union', defining how *new_mask* should be used to update the mask.
801        """
802        new_mask=new_mask.astype(int)
803        if mode=='replace':
804            self.mask_ds=new_mask
805        elif mode=='intersection':
806            self.mask_ds=self.mask_ds*new_mask
807        elif mode == 'union':
808            self.mask_ds=self.mask_ds|new_mask
809        else:
810            raise(ValueError(f'Invalid mode, {mode}'))
811        return

Update PatternFilter.mask_ds with a new mask, either taking their union or intersection, or replacing the current mask with new_mask.

Arguments

new_mask

An xarray.Dataset with the same coords and variables as PatternFilter.mask_ds.

mode

A string, one of 'replace','intersection' or 'union', defining how new_mask should be used to update the mask.

def apply_value_mask(self, truth_function, *args, mode='intersection'):
813    def apply_value_mask(self,truth_function,*args,mode='intersection'):
814        """ Apply a filter to *PatternFilter.mask_ds* based on a user-specified truth function which is applied to *PatternFilter.val_ds. 
815        
816        **Examples**
817        
818            #Mask values beneath a threshold:
819            def larger_than_thresh(ds,thresh):
820                return ds>thresh
821            patternfilter.apply_value_mask(is_positive,thresh)
822
823            #Mask values where absolute value is less than a reference field:
824            def amp_greater_than_reference(ds,ref_ds):
825                return np.abs(ds)>ref_ds
826            pattern_filter.apply_value_mask(amp_greater_than_reference,ref_ds)
827
828        **Arguments**
829
830        *truth_function*
831        
832        A function with inputs (val_ds,*args) that returns a boolean dataset with the same coords and data variables as *PatternFilter.val_ds*.
833
834        **Optional arguments**
835        
836        *mode*
837            
838        A string, one of 'replace','intersection' or 'union', defining how the value filter should be used to update the *PatternFilter.mask_ds*.
839        """        
840        if self.val_ds is None:
841            raise(ValueError('val_ds must be provided to apply value mask.'))
842        value_mask=truth_function(self.val_ds,*args)
843        self.update_mask(value_mask,mode)
844        return

Apply a filter to PatternFilter.mask_ds based on a user-specified truth function which is applied to *PatternFilter.val_ds.

Examples

#Mask values beneath a threshold:
def larger_than_thresh(ds,thresh):
    return ds>thresh
patternfilter.apply_value_mask(is_positive,thresh)

#Mask values where absolute value is less than a reference field:
def amp_greater_than_reference(ds,ref_ds):
    return np.abs(ds)>ref_ds
pattern_filter.apply_value_mask(amp_greater_than_reference,ref_ds)

Arguments

truth_function

A function with inputs (val_ds,args) that returns a boolean dataset with the same coords and data variables as *PatternFilter.val_ds.

Optional arguments

mode

A string, one of 'replace','intersection' or 'union', defining how the value filter should be used to update the PatternFilter.mask_ds.

def apply_area_mask(self, n, dims=None, mode='intersection', area_type='gridpoint'):
846    def apply_area_mask(self,n,dims=None,mode='intersection',area_type='gridpoint'):
847        """ Apply a filter to *PatternFilter.mask_ds* that identifies connected groups of True values within a subspace of the Dataset's dimensions specified by *dims*, and masks out groups which are beneath a threshold size *n*. This is done through the application of *scipy.ndimage.label* using the default structuring element (https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.label.html). 
848    
849        When *area_type*='gridpoint', *n* specifies the number of connected datapoints within each connected region. For the special case where *dims* consists of a latitude- and longitude-like coordinate, area_type='spherical' applies a cosine-latitude weighting, such that *n* can be interpreted as a measure of area, where a datapoint with lat=0 would have area 1. 
850        
851        **Examples**
852        
853            #Keep groups of True values consisting of an area >=30 square equatorial gridpoints
854            patternfilter.apply_area_mask(30,dims=('lat','lon'),area_type='spherical')
855            
856            #Keep groups of True values that are consistent for at least 3 neighbouring time lags
857            patternfilter.apply_area_mask(3,dims=('time'))
858            
859            #Keep groups of true values consisting of >=10 longitudinal values, or >=30 values in longitude and altitude if the variables have an altitude coord:
860            patternfilter.apply_area_mask(10,dims=('longitude'))
861            patternfilter.apply_area_mask(30,dims=('longitude,altitude'),mode='union')
862
863        **Arguments**
864
865        *n*
866            
867        A scalar indicating the minimum size of an unmasked group, in terms of number of gridpoints (for *area_type*=gridpoint) or the weighted area (for *area_type*=spherical), beneath which the group will be masked.
868
869        **Optional arguments**
870        
871        *dims*
872            
873        An iterable of strings specifying coords in *PatternFilter.mask_ds* which define the subspace in which groups of connected True values are identified. Other dims will be iterated over. DataArrays within *PatternFilter.mask_ds* that do not contain all the *dims* will be ignored. If *dims*=None, all dims in each DataArray will be used.
874            
875        *mode*
876
877        A string, one of 'replace','intersection' or 'union', defining how the area filter should be used to update the *PatternFilter.mask_ds*.
878            
879        *area_type*
880
881        A string, one of 'gridpoint' or 'spherical' as specified above. 'spherical' is currently only supported for len-2 *dims* kwargs, with the first assumed to be latitude-like. 
882            
883        """        
884        if area_type=='gridpoint':
885            area_based=False
886        elif area_type=='spherical':
887            area_based=True
888        else:
889            raise(ValueError(f"Unknown area_type {area_type}. Valid options are 'gridpoint' and 'spherical'"))
890        area_mask=ds_large_regions(self.mask_ds,n,dims=dims,area_based=area_based)
891        self.update_mask(area_mask,mode)
892        return

Apply a filter to PatternFilter.mask_ds that identifies connected groups of True values within a subspace of the Dataset's dimensions specified by dims, and masks out groups which are beneath a threshold size n. This is done through the application of scipy.ndimage.label using the default structuring element (https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.label.html).

When area_type='gridpoint', n specifies the number of connected datapoints within each connected region. For the special case where dims consists of a latitude- and longitude-like coordinate, area_type='spherical' applies a cosine-latitude weighting, such that n can be interpreted as a measure of area, where a datapoint with lat=0 would have area 1.

Examples

#Keep groups of True values consisting of an area >=30 square equatorial gridpoints
patternfilter.apply_area_mask(30,dims=('lat','lon'),area_type='spherical')

#Keep groups of True values that are consistent for at least 3 neighbouring time lags
patternfilter.apply_area_mask(3,dims=('time'))

#Keep groups of true values consisting of >=10 longitudinal values, or >=30 values in longitude and altitude if the variables have an altitude coord:
patternfilter.apply_area_mask(10,dims=('longitude'))
patternfilter.apply_area_mask(30,dims=('longitude,altitude'),mode='union')

Arguments

n

A scalar indicating the minimum size of an unmasked group, in terms of number of gridpoints (for area_type=gridpoint) or the weighted area (for area_type=spherical), beneath which the group will be masked.

Optional arguments

dims

An iterable of strings specifying coords in PatternFilter.mask_ds which define the subspace in which groups of connected True values are identified. Other dims will be iterated over. DataArrays within PatternFilter.mask_ds that do not contain all the dims will be ignored. If dims=None, all dims in each DataArray will be used.

mode

A string, one of 'replace','intersection' or 'union', defining how the area filter should be used to update the PatternFilter.mask_ds.

area_type

A string, one of 'gridpoint' or 'spherical' as specified above. 'spherical' is currently only supported for len-2 dims kwargs, with the first assumed to be latitude-like.

def apply_convolution(self, n, dims, mode='replace'):
895    def apply_convolution(self,n,dims,mode='replace'):
896        """ Apply a square n-point convolution filter to *PatternFilter.mask_ds* in one or two dimensions specified by *dims*, iterated over remaining dimensions. This has the effect of extending the unmasked regions and smoothing the mask overall.
897        
898        **Arguments**
899        
900        *n*
901            
902        A positive integer specifying the size of the convolution filter. *n*=1 leaves the mask unchanged. Even *n* are asymmetric and shifted right. 
903
904        *dims*
905
906        A length 1 or 2 iterable of strings specifying the dims in which the convolution is applied. Other dims will be iterated over. DataArrays within *PatternFilter.mask_ds* that do not contain all the *dims* will be ignored. 
907
908        *mode*
909
910        A string, one of 'replace','intersection' or 'union', defining how the area filter should be used to update the *PatternFilter.mask_ds*.
911        """
912        
913        if not len(dims) in [1,2]:
914            raise(ValueError('Only 1 and 2D dims currently supported'))
915            
916        convolution=convolve_pad_ds(self.mask_ds,n,dims=dims)
917        self.update_mask(convolution,mode)
918        return

Apply a square n-point convolution filter to PatternFilter.mask_ds in one or two dimensions specified by dims, iterated over remaining dimensions. This has the effect of extending the unmasked regions and smoothing the mask overall.

Arguments

n

A positive integer specifying the size of the convolution filter. n=1 leaves the mask unchanged. Even n are asymmetric and shifted right.

dims

A length 1 or 2 iterable of strings specifying the dims in which the convolution is applied. Other dims will be iterated over. DataArrays within PatternFilter.mask_ds that do not contain all the dims will be ignored.

mode

A string, one of 'replace','intersection' or 'union', defining how the area filter should be used to update the PatternFilter.mask_ds.

def get_mask(self):
920    def get_mask(self):
921        """" Retrieve the mask with all filters applied.
922        **Returns**
923        An xarray.Dataset of boolean values.
924        """
925        return self.mask_ds

" Retrieve the mask with all filters applied. Returns An xarray.Dataset of boolean values.

def filter(self, ds=None, drop_empty=True, fill_val=nan):
927    def filter(self,ds=None,drop_empty=True,fill_val=np.nan):
928        """ Apply the current mask to *ds* or to *PatternFilter.val_ds* (if *ds* is None), replacing masked gridpoints with *fill_val*.
929        **Optional arguments**
930        
931        *ds*
932        
933        An xarray.Dataset to apply the mask to. Should have the same coords and data_vars as *PatternFilter.mask_ds*. If None, the mask is applied to *PatternFilter.val_ds*.
934        
935        *drop_empty*
936        
937        A boolean value. If True, then completely masked variables are dropped from the returned masked Dataset.
938        
939        *fill_val*
940        
941        A scalar that defaults to np.nan. The value with which masked gridpoints in the Dataset are replaced.
942        
943        **Returns**
944        
945        A Dataset with masked values replaced by *fill_val*.
946        """
947        if ds is None:
948            ds=self.val_ds.copy(deep=True)
949            
950        ds=ds.where(self.mask_ds)
951        if drop_empty:
952            drop_vars=((~np.isnan(ds)).sum()==0).to_array('vars')
953            ds=ds.drop_vars(drop_vars[drop_vars].vars.values)
954        return ds.fillna(fill_val)

Apply the current mask to ds or to PatternFilter.val_ds (if ds is None), replacing masked gridpoints with fill_val. Optional arguments

ds

An xarray.Dataset to apply the mask to. Should have the same coords and data_vars as PatternFilter.mask_ds. If None, the mask is applied to PatternFilter.val_ds.

drop_empty

A boolean value. If True, then completely masked variables are dropped from the returned masked Dataset.

fill_val

A scalar that defaults to np.nan. The value with which masked gridpoints in the Dataset are replaced.

Returns

A Dataset with masked values replaced by fill_val.

class IndexGenerator:
 965class IndexGenerator(object):
 966    
 967    """ Computes dot-products between a Dataset of patterns and a Dataset of variables, reducing them to standardised scalar indices.
 968    """
 969    def __init__(self):
 970        self._means=[]
 971        self._stds=[]
 972        self._rename_function=_DEFAULT_RENAME_FUNC
 973        
 974    def __repr__(self):
 975        return 'An IndexGenerator object'
 976        
 977    def __str__(self):
 978            return self.__repr__
 979    
 980    
 981    def centre(self,x,dim='time',ref=None):
 982        """@private"""
 983
 984        if ref is None:
 985            ref=x.mean(dim=dim)
 986        return x-ref
 987    
 988    def normalise(self,x,dim='time',ref=None):
 989        """@private"""
 990
 991        if ref is None:
 992            ref=x.std(dim=dim)
 993        return x/ref
 994    
 995    def standardise(self,x,dim='time',mean_ref=None,std_ref=None):
 996        """@private"""
 997        centred_x=self.centre(x,dim,mean_ref)
 998        standardised_x=self.normalise(centred_x,dim,std_ref)
 999        return standardised_x
1000        
1001    def collapse_index(self,ix,dims):
1002        """@private"""
1003        lat_coords=['lat','latitude','grid_latitude']
1004        if not np.any(np.isin(lat_coords,dims)):
1005            return ix.sum(dims)
1006        
1007        else:
1008            #assumes only one lat coord: seems safe.
1009            lat_dim=lat_coords[np.where(np.isin(lat_coords,dims))[0][0]]
1010            weights=np.cos(np.deg2rad(ix[lat_dim]))
1011            return ix.weighted(weights).sum(dims)
1012            
1013    def generate(self,pattern_ds,series_ds,dim='time',slices=None,ix_means=None,ix_stds=None,drop_blank=False,in_place=True,strict_metadata=False):
1014        """Compute standardised indices from an xarray.Dataset of patterns and an xarray.Dataset of arbitrary dimension variables.
1015        
1016        **Arguments**
1017        
1018        *pattern_ds*
1019        
1020        An xarray.Dataset of patterns to project onto with arbitrary dimensions.
1021        
1022        *series_ds*
1023        
1024        An xarray.Dataset of variables to project onto the patterns. Coordinates of *series_ds* once subsetted using *slices* must match the dimensions of *pattern_ds* + the extra coord *dim*.
1025        
1026        **Optional arguments**
1027        
1028        *dim*:
1029        
1030        A string specifying the remaining coord of the scalar indices. Defaults to 'time', which should be the choice for most use cases.
1031        
1032        *slices*
1033        
1034        A dictionary or iterable of dictionaries, each specifying a subset of *pattern_ds* to take before computing an index, with one index returned for each dictionary and for each variable. Subsetting is based on the *xr.Dataset.sel* method: e.g. *slices*=[dict(lag=0,index_val=1)] will produce 1 set of indices based on pattern_ds.sel(lag=0,index_val=1). If *slices*=None, no subsets are computed.
1035        
1036        *ix_means*
1037        
1038        If None, the mean of each index is calculated and subtracted, resulting in centred indices. Otherwise, *ix_means* should be a dictionary of index names and predefined mean values which are subtracted instead. Of most use for online computations, updating a precomputed index in a new dataset.
1039        
1040        *ix_stds*
1041        
1042        If None, the standard deviation of each index is calculated and is divided by, resulting in standardised indices. Otherwise, *ix_stds* should be a dictionary of index names and predefined std values which are divided by instead. Of most use for online computations, updating a precomputed index in a new dataset.
1043
1044        *drop_blank*
1045        
1046        A boolean. If True, drop indices where the corresponding pattern is entirely blank. If False, returns an all np.nan time series.
1047        *in_place*
1048        
1049        *strict_metadata*
1050        
1051        If False, indices will be merged into a common dataset regardless of metadata. If True, nonmatching metadata will raise a ValueError.
1052        
1053        **Returns
1054        
1055        An xarray.Dataset of indices with a single coordinate (*dim*).
1056        """
1057        #Parse inputs
1058        
1059        if slices is None:
1060            self.slices=[{}]
1061        elif type(slices) is dict:
1062            self.slices=[slices]
1063        else:
1064            self.slices=slices
1065            
1066        if ix_means is not None or ix_stds is not None:
1067            self.user_params=True
1068            self.means=ix_means
1069            self.stds=ix_stds
1070        else:
1071            self.user_params=False
1072            self.means={}
1073            self.stds={}
1074            
1075        self.indices=None
1076        
1077        #Compute indices
1078        indices=[self._generate_index(pattern_ds,series_ds,dim,sl)\
1079                for sl in self.slices]
1080        try:
1081            indices=xr.merge(indices)
1082        except Exception as e:
1083            if strict_metadata:
1084                print("Merge of indices failed. Consider 'strict_metadata=False'")
1085                raise e
1086            else:
1087                indices=xr.merge(indices,compat='override')
1088            
1089        #Optionally remove indices which are all nan    
1090        if drop_blank:
1091            drop=(~indices.isnull()).sum()==0
1092            drop=[k for k,d in drop.to_dict()['data_vars'].items() if d['data']]
1093            indices=indices.drop_vars(drop)
1094            _=[(self.means.pop(x),self.stds.pop(x)) for x in drop]
1095        if in_place:
1096            self.indices=indices
1097        return indices
1098    
1099    def _generate_index(self,pattern_ds,series_ds,dim,sl):
1100                
1101        pattern_ds,series_ds=xr.align(pattern_ds,series_ds)
1102        pattern_ds=pattern_ds.sel(sl)
1103        dims=list(pattern_ds.dims)
1104
1105        index=pattern_ds*series_ds
1106        #coslat weights lat coords
1107        index=self.collapse_index(index,dims)
1108        index=self._rename_index_vars(index,sl)
1109
1110        if self.user_params:
1111            mean=self.means
1112            std=self.stds
1113        else:
1114            mean=_Dataset_to_dict(index.mean(dim))
1115            std=_Dataset_to_dict(index.std(dim))
1116            for v in mean:
1117                self.means[v]=mean[v]
1118            for v in std:
1119                self.stds[v]=std[v]
1120                
1121        index=self.standardise(index,dim,mean_ref=mean,std_ref=std)
1122        index=self._add_index_attrs(index,sl,mean,std)
1123
1124        
1125        self.generated_index=index
1126        return index
1127    
1128    def _add_index_attrs(self,index,sl,mean,std):
1129        for v in index:
1130            ix=index[v]
1131            ix.attrs['mean']=np.array(mean[v])
1132            ix.attrs['std']=np.array(std[v])
1133            for k,i in sl.items():
1134                ix.attrs[k]=i
1135            index[v]=ix
1136        return index
1137    
1138    def _rename_index_vars(self,index,sl):
1139        func=self._rename_function
1140        return index.rename({v:func(v,sl) for v in index.data_vars})
1141    
1142    def get_standardisation_params(self,as_dict=False):
1143        
1144        """ Retrieve index means and stds for computed indices, for use as future inputs into index_means or index_stds in *IndexGenerator.Generate*
1145        """
1146        if as_dict:
1147            return self.means,self.stds
1148        else:
1149            params=[xr.Dataset(self.means),xr.Dataset(self.stds)]
1150            return xr.concat(params,'param').assign_coords({'param':['mean','std']})

Computes dot-products between a Dataset of patterns and a Dataset of variables, reducing them to standardised scalar indices.

def generate( self, pattern_ds, series_ds, dim='time', slices=None, ix_means=None, ix_stds=None, drop_blank=False, in_place=True, strict_metadata=False):
1013    def generate(self,pattern_ds,series_ds,dim='time',slices=None,ix_means=None,ix_stds=None,drop_blank=False,in_place=True,strict_metadata=False):
1014        """Compute standardised indices from an xarray.Dataset of patterns and an xarray.Dataset of arbitrary dimension variables.
1015        
1016        **Arguments**
1017        
1018        *pattern_ds*
1019        
1020        An xarray.Dataset of patterns to project onto with arbitrary dimensions.
1021        
1022        *series_ds*
1023        
1024        An xarray.Dataset of variables to project onto the patterns. Coordinates of *series_ds* once subsetted using *slices* must match the dimensions of *pattern_ds* + the extra coord *dim*.
1025        
1026        **Optional arguments**
1027        
1028        *dim*:
1029        
1030        A string specifying the remaining coord of the scalar indices. Defaults to 'time', which should be the choice for most use cases.
1031        
1032        *slices*
1033        
1034        A dictionary or iterable of dictionaries, each specifying a subset of *pattern_ds* to take before computing an index, with one index returned for each dictionary and for each variable. Subsetting is based on the *xr.Dataset.sel* method: e.g. *slices*=[dict(lag=0,index_val=1)] will produce 1 set of indices based on pattern_ds.sel(lag=0,index_val=1). If *slices*=None, no subsets are computed.
1035        
1036        *ix_means*
1037        
1038        If None, the mean of each index is calculated and subtracted, resulting in centred indices. Otherwise, *ix_means* should be a dictionary of index names and predefined mean values which are subtracted instead. Of most use for online computations, updating a precomputed index in a new dataset.
1039        
1040        *ix_stds*
1041        
1042        If None, the standard deviation of each index is calculated and is divided by, resulting in standardised indices. Otherwise, *ix_stds* should be a dictionary of index names and predefined std values which are divided by instead. Of most use for online computations, updating a precomputed index in a new dataset.
1043
1044        *drop_blank*
1045        
1046        A boolean. If True, drop indices where the corresponding pattern is entirely blank. If False, returns an all np.nan time series.
1047        *in_place*
1048        
1049        *strict_metadata*
1050        
1051        If False, indices will be merged into a common dataset regardless of metadata. If True, nonmatching metadata will raise a ValueError.
1052        
1053        **Returns
1054        
1055        An xarray.Dataset of indices with a single coordinate (*dim*).
1056        """
1057        #Parse inputs
1058        
1059        if slices is None:
1060            self.slices=[{}]
1061        elif type(slices) is dict:
1062            self.slices=[slices]
1063        else:
1064            self.slices=slices
1065            
1066        if ix_means is not None or ix_stds is not None:
1067            self.user_params=True
1068            self.means=ix_means
1069            self.stds=ix_stds
1070        else:
1071            self.user_params=False
1072            self.means={}
1073            self.stds={}
1074            
1075        self.indices=None
1076        
1077        #Compute indices
1078        indices=[self._generate_index(pattern_ds,series_ds,dim,sl)\
1079                for sl in self.slices]
1080        try:
1081            indices=xr.merge(indices)
1082        except Exception as e:
1083            if strict_metadata:
1084                print("Merge of indices failed. Consider 'strict_metadata=False'")
1085                raise e
1086            else:
1087                indices=xr.merge(indices,compat='override')
1088            
1089        #Optionally remove indices which are all nan    
1090        if drop_blank:
1091            drop=(~indices.isnull()).sum()==0
1092            drop=[k for k,d in drop.to_dict()['data_vars'].items() if d['data']]
1093            indices=indices.drop_vars(drop)
1094            _=[(self.means.pop(x),self.stds.pop(x)) for x in drop]
1095        if in_place:
1096            self.indices=indices
1097        return indices

Compute standardised indices from an xarray.Dataset of patterns and an xarray.Dataset of arbitrary dimension variables.

Arguments

pattern_ds

An xarray.Dataset of patterns to project onto with arbitrary dimensions.

series_ds

An xarray.Dataset of variables to project onto the patterns. Coordinates of series_ds once subsetted using slices must match the dimensions of pattern_ds + the extra coord dim.

Optional arguments

dim:

A string specifying the remaining coord of the scalar indices. Defaults to 'time', which should be the choice for most use cases.

slices

A dictionary or iterable of dictionaries, each specifying a subset of pattern_ds to take before computing an index, with one index returned for each dictionary and for each variable. Subsetting is based on the xr.Dataset.sel method: e.g. slices=[dict(lag=0,index_val=1)] will produce 1 set of indices based on pattern_ds.sel(lag=0,index_val=1). If slices=None, no subsets are computed.

ix_means

If None, the mean of each index is calculated and subtracted, resulting in centred indices. Otherwise, ix_means should be a dictionary of index names and predefined mean values which are subtracted instead. Of most use for online computations, updating a precomputed index in a new dataset.

ix_stds

If None, the standard deviation of each index is calculated and is divided by, resulting in standardised indices. Otherwise, ix_stds should be a dictionary of index names and predefined std values which are divided by instead. Of most use for online computations, updating a precomputed index in a new dataset.

drop_blank

A boolean. If True, drop indices where the corresponding pattern is entirely blank. If False, returns an all np.nan time series. in_place

strict_metadata

If False, indices will be merged into a common dataset regardless of metadata. If True, nonmatching metadata will raise a ValueError.

**Returns

An xarray.Dataset of indices with a single coordinate (dim).

def get_standardisation_params(self, as_dict=False):
1142    def get_standardisation_params(self,as_dict=False):
1143        
1144        """ Retrieve index means and stds for computed indices, for use as future inputs into index_means or index_stds in *IndexGenerator.Generate*
1145        """
1146        if as_dict:
1147            return self.means,self.stds
1148        else:
1149            params=[xr.Dataset(self.means),xr.Dataset(self.stds)]
1150            return xr.concat(params,'param').assign_coords({'param':['mean','std']})

Retrieve index means and stds for computed indices, for use as future inputs into index_means or index_stds in IndexGenerator.Generate