Project

General

Profile

Statistics
| Branch: | Tag: | Revision:

pycama / src / pycama / ScatterPlot.py @ 834:d989df597b80

History | View | Annotate | Download (37.1 KB)

1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3

    
4
# Copyright 2016-2017 Maarten Sneep, KNMI
5
#
6
# Redistribution and use in source and binary forms, with or without
7
# modification, are permitted provided that the following conditions are met:
8
#
9
# 1. Redistributions of source code must retain the above copyright notice,
10
#    this list of conditions and the following disclaimer.
11
#
12
# 2. Redistributions in binary form must reproduce the above copyright notice,
13
#    this list of conditions and the following disclaimer in the documentation
14
#    and/or other materials provided with the distribution.
15
#
16
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
20
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
23
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26

    
27
## @file ScatterPlot.py
28
#
29
# This file defines a subclass of the pycama.AnalysisAndPlot.AnalysisAndPlot class.
30
# This subclass extracts the scatter density data from the input data.
31
#
32
# @author Maarten Sneep
33

    
34
import math
35
import logging
36
import warnings
37
warnings.filterwarnings("ignore", category=FutureWarning)
38

    
39
import numpy as np
40
import netCDF4
41
import h5py
42

    
43
from .AnalysisAndPlot import AnalysisAndPlot
44

    
45
from .utilities import *
46

    
47
## Extract scatter density data (2-parameter histograms) and the covariances (and correlations)
48
#
49
#  Also includes netCDF write & read routines. Plotting should be possible from
50
#  a restored object.
51
#
52
class ScatterPlot(AnalysisAndPlot):
53
    ## Define a few variables to collect the number of bins in the histograms
54
    #
55
    #  There is one default (via the `number_bins` keyword argument, 100 if that value isn't given)
56
    #  @param number_bins keyword argument, the default number of bind in the 2D histograms.
57
    def setup(self, **kwargs):
58
        if 'number_bins' in kwargs:
59
            self.number_bins = kwargs['number_bins']
60
        else:
61
            self.number_bins = 100
62

    
63
    ## Return the keys (tuples with both variable names included). This is a property.
64
    @property
65
    def key_pairs(self):
66
        return list(self.variables.keys())
67

    
68
    ## The actual underlying variable names.
69
    #
70
    #  differs from superclass, because we have tuples as keys in self.variables
71
    @property
72
    def variable_names(self):
73
        return [v1 for v1 in list(self.variables_meta.keys()) if v1 != 'longitude']
74

    
75
    ## Return True if the variable must be included here.
76
    #
77
    #  Flags, internal variables, variables without scanline and variables explicitly excluded from the scatter plots
78
    #  are thrown out here.
79
    def include_var(self, variable):
80
        return not (variable.flag or variable.internal_only or variable.noscanline or (not variable.include_scatter))
81

    
82
    ## Add the input variables if they should be plotted.
83
    #
84
    #  This method does not perform the extraction, but prepares all histogram axes and generates the key-pairs.
85
    def add_raw_variables(self):
86
        number_bins_collection = set()
87
        for var in list(self.input_variables.variables.values()):
88
            if not self.include_var(var):
89
                continue
90
            self.add_variable(var)
91

    
92
            if var.data_range is None:
93
                var.data_range = [np.nanmin(var.aggregate_data), np.nanmax(var.aggregate_data)]
94

    
95
            if var.number_bins is None:
96
                var.number_bins = self.number_bins
97

    
98
            number_bins_collection.add(var.number_bins)
99

    
100
            if var.log_range:
101
                b = np.logspace(math.log10(var.data_range[0]), math.log10(var.data_range[1]), num=var.number_bins+1, endpoint=True, base=10.0, dtype=np.float32)
102
            else:
103
                b = np.linspace(var.data_range[0], var.data_range[1], num=var.number_bins+1, endpoint=True, dtype=np.float32)
104
            self.variables_meta[var.name]['bins'] = b
105
            self.variables_meta[var.name]['number_bins'] = var.number_bins
106

    
107
        self.number_bins_collection = sorted(list(number_bins_collection))
108
        N = len(self.variable_names)
109
        self.correlation_matrix = np.zeros((N, N), dtype=np.float32)
110
        self.covariance_matrix = np.zeros((N, N), dtype=np.float32)
111
        self.index_list = []
112
        self.pair_list = []
113
        for i1, n1 in enumerate(self.variable_names[:-1]):
114
            for i, n2 in enumerate(self.variable_names[i1+1:]):
115
                key = (n1, n2)
116
                var1 = self.input_variables.variables[n1]
117
                var2 = self.input_variables.variables[n2]
118
                self.variables[key] = np.zeros((var1.number_bins, var2.number_bins), dtype=np.uint32)
119
                self.pair_list.append(key)
120
                self.index_list.append((i1, i1+1+i))
121

    
122
    ## Extract the required information from the input data.
123
    #
124
    #  A straight loop over the prepared pairs is used. The double loop takes place in add_raw_variables().
125
    #  The covariance and correlation elements are calculated piecewise.
126
    #  This should work with non-synchornized variables, but there are other issues with those.
127
    #
128
    #
129
    def process(self):
130
        n_total = len(self.pair_list)
131
        if n_total == 0:
132
            self.logger.warning("No valid data at all for scatter plots")
133
            return
134
        i = 0
135
        for (n1, n2), (i1, i2) in zip(self.pair_list, self.index_list):
136
            if i > 0 and i % (n_total//min(30, n_total)) == 0:
137
                self.progress(100*i/n_total)
138
            i += 1
139
            v1 = self.input_variables.variables[n1].aggregate_data
140
            v2 = self.input_variables.variables[n2].aggregate_data
141
            if len(v1) != len(v2):
142
                self.logger.warning("%s: No valid data for the combination '%s' and '%s' (length differs).", self.__class__.__name__, n1, n2)
143
                continue
144

    
145
            # include mask for non-synchronized variables.
146
            try:
147
                cidx = np.logical_and(np.logical_and(np.isfinite(v1), np.isfinite(v2)), np.logical_and(np.logical_not(v1.mask), np.logical_not(v2.mask)))
148
            except:
149
                cidx = np.logical_and(np.isfinite(v1), np.isfinite(v2))
150

    
151
            if not np.any(cidx):
152
                self.logger.warning("%s: No valid data for the combination '%s' and '%s'.", self.__class__.__name__, n1, n2)
153
                continue
154
            b1 = self.variables_meta[n1]['bins']
155
            b2 = self.variables_meta[n2]['bins']
156
            self.variables[(n1, n2)][:, :] += np.asarray(np.histogram2d(v1[cidx], v2[cidx], bins=[b1, b2])[0], dtype=np.uint32)
157
            self.correlation_matrix[i1, i2] = np.corrcoef(v1[cidx], v2[cidx])[0,1]
158
            self.covariance_matrix[i1, i2] = np.cov(v1[cidx], v2[cidx])[0,1]
159
            self.correlation_matrix[i2, i1] = self.correlation_matrix[i1, i2]
160
            self.covariance_matrix[i2, i1] = self.covariance_matrix[i1, i2]
161
            self.logger.debug("correlation[{0}, {1}] ({2}, {3}) = {4}".format(n1, n2, i1, i2, self.correlation_matrix[i1, i2]))
162
        # add diagonal.
163
        for i, name in enumerate(self.variable_names):
164
            self.correlation_matrix[i, i] = 1.0
165
            self.covariance_matrix[i, i] = np.var(self.input_variables.variables[name].aggregate_data, ddof=1)
166

    
167
    ## Merge data into a combined dataset.
168
    #
169
    #  @param other The object to be added to self,
170
    #               also an instance of the pycama.ScatterPlot.ScatterPlot class.
171
    def __iadd__(self, other):
172
        for (n1, n2), (i1, i2) in zip(self.pair_list, self.index_list):
173
            count = np.sum(self.variables[(n1, n2)][:, :])
174
            ocount = np.sum(other.variables[(n1,n2)][:, :])
175
            self.variables[(n1, n2)][:, :] += other.variables[(n1, n2)][:, :]
176
            if count + ocount > 0:
177
                self.covariance_matrix[i1, i2] = (count**2*self.covariance_matrix[i1, i2] + ocount**2*other.covariance_matrix[i1, i2])/(count**2+ocount**2)
178
                self.covariance_matrix[i2, i1] = self.covariance_matrix[i1, i2]
179
        for (i1, i2) in self.index_list:
180
            self.correlation_matrix[i1, i2] = self.covariance_matrix[i2, i1]/(math.sqrt(self.covariance_matrix[i1,i1]*self.covariance_matrix[i2,i2]))
181
            self.correlation_matrix[i1, i2] = self.correlation_matrix[i2, i1]
182

    
183
    ## Write processed data to output file.
184
    #
185
    #  @param fname File to write to
186
    #  @param mode  Writing mode, defaults to append.
187
    #
188
    #  Write data (including extraction specific dimensions) to the group with
189
    #  the name given in the `storage_group_name` property ("`scatterplot_data`" for this class).
190
    #
191
    #  Storage is like that of the pycama.HistogramPlot.HistogramPlot class, except that
192
    #  the histograms themselves are now 2D arrays.
193
    def dump(self, fname, mode='a'):
194
        compress={'compression':'gzip', 'compression_opts':3, 'shuffle':True, 'fletcher32':True}
195
        with h5py.File(fname, 'a') as ref:
196
            # dimensions
197
            time_step = self.time_index_in_output # set in __init__.
198
            try:
199
                grp = ref.create_group(self.storage_group_name)
200
                grp.attrs['comment'] = 'Coincidence plots, two dimensional histograms'
201
            except:
202
                grp = ref[self.storage_group_name]
203

    
204
            number_bins_collection = sorted(list(self.number_bins_collection))
205
            for number_bins in number_bins_collection:
206
                try:
207
                    dim_name = 'scatter_bins_{0}'.format(number_bins)
208
                    var = grp.create_dataset(dim_name, (number_bins,), dtype=np.int32, **compress)
209
                    var[:] = np.arange(number_bins, dtype=np.int32)
210
                    var.attrs['long_name'] = "scatterplot indices for length {0}".format(number_bins)
211
                except:
212
                    self.logger.debug("'{0}' not created".format(dim_name))
213
                    var = grp[dim_name]
214
                    var[:] = np.arange(number_bins, dtype=np.int32)
215

    
216
            try:
217
                dimset = grp.create_dataset('variable_scatter_index',
218
                                            (len(self.variable_names),), dtype=np.int32)
219
                dimset[:] = np.arange(len(self.variable_names), dtype=np.int32)
220
                dimset.attrs['long_name'] = 'index of variables in dataset for scatter'
221
            except:
222
                self.logger.debug("variable_scatter_index not created")
223

    
224
            try:
225
                dt = h5py.special_dtype(vlen=str) # implicit VLEN type for strings.
226
                var = grp.create_dataset('variable_names_scatter', (len(self.variable_names),),
227
                                         dtype=dt, **compress)
228
                var.dims.create_scale(grp['variable_scatter_index'], 'variable_scatter_index')
229
                var.dims[0].attach_scale(grp['variable_scatter_index'])
230

    
231
                for i, name in enumerate(self.variable_names):
232
                    self.logger.debug("Setting name {0} in index {1}".format(name, i))
233
                    var[i] = name
234
                var.attrs['long_name'] = "variable names in dataset for scatter"
235
            except:
236
                self.logger.debug("variable_names_scatter not created")
237

    
238
            self.logger.debug("{0}.dump(): number of pairs: {1}".format(self.__class__.__name__, len(self.pair_list)))
239

    
240
            for n1, n2 in self.pair_list:
241
                self.logger.debug("{0}.dump(): storing pair ({1}, {2})".format(self.__class__.__name__, n1, n2))
242
                try:
243
                    var = grp["{0}_{1}_2d_histogram".format(n1, n2)]
244
                    var.resize(self.time_index_in_output+1, axis=0)
245
                except:
246
                    if self.time_index_in_output > 0:
247
                        continue
248
                    scatter_bin_dim1 = 'scatter_bins_{0}'.format(self.variables_meta[n1]['number_bins'])
249
                    scatter_bin_dim2 = 'scatter_bins_{0}'.format(self.variables_meta[n2]['number_bins'])
250
                    n_scatter_bin_dim1 = self.variables_meta[n1]['number_bins']
251
                    n_scatter_bin_dim2 = self.variables_meta[n2]['number_bins']
252
                    var = grp.create_dataset("{0}_{1}_2d_histogram".format(n1, n2),
253
                                             (1, n_scatter_bin_dim1, n_scatter_bin_dim2),
254
                                             dtype=self.variables[(n1, n2)].dtype,
255
                                             maxshape=(None, n_scatter_bin_dim1, n_scatter_bin_dim2),
256
                                             **compress)
257
                    try:
258
                        var.dims.create_scale(ref['time'], 'time')
259
                        var.dims[0].attach_scale(ref['time'])
260
                    except RuntimeError:
261
                        pass
262
                    var.dims.create_scale(grp[scatter_bin_dim1], scatter_bin_dim1)
263
                    var.dims[1].attach_scale(grp[scatter_bin_dim1])
264
                    var.dims.create_scale(grp[scatter_bin_dim2], scatter_bin_dim2)
265
                    var.dims[2].attach_scale(grp[scatter_bin_dim2])
266
                    var.attrs['bins'] = "{0}_scatter_bins {1}_scatter_bins".format(n1, n2)
267
                    var.attrs['variables'] = "{0} {1}".format(n1, n2)
268

    
269
                var[time_step, :, :] = self.variables[(n1, n2)][:, :]
270

    
271
                for k in (n1, n2):
272
                    scatter_bin_dim = 'scatter_bins_{0}'.format(self.variables_meta[k]['number_bins'])
273
                    try:
274
                        var = grp.create_dataset(k+'_scatter_bins',
275
                                                 (self.variables_meta[k]['number_bins'],),
276
                                                 dtype=np.float32,
277
                                                 **compress)
278
                        var.dims.create_scale(grp[scatter_bin_dim], scatter_bin_dim)
279
                        var.dims[0].attach_scale(grp[scatter_bin_dim])
280
                        var.attrs['bounds'] = k+'_bounds'
281
                        var[:] = (self.variables_meta[k]['bins'][0:-1] + self.variables_meta[k]['bins'][1:])/2.
282
                        var.attrs['long_name'] = self.variables_meta[k]['title']
283
                        var.attrs['range'] = [float(v) for v in self.variables_meta[k]['data_range']]
284
                        var.attrs['log_range'] = str(self.variables_meta[k]['log_range'])
285
                        var.attrs['units'] = self.variables_meta[k]['units']
286
                    except:
287
                        pass
288

    
289
                    try:
290
                        var = grp.create_dataset(k+'_bounds',
291
                                                 (self.variables_meta[k]['number_bins'],2),
292
                                                 dtype=np.float32, **compress)
293
                        var[:, 0] = self.variables_meta[k]['bins'][0:-1]
294
                        var[:, 1] = self.variables_meta[k]['bins'][1:]
295
                        var.dims.create_scale(grp[scatter_bin_dim], scatter_bin_dim)
296
                        var.dims[0].attach_scale(grp[scatter_bin_dim])
297
                        var.dims.create_scale(ref['nv'], 'nv')
298
                        var.dims[1].attach_scale(ref['nv'])
299
                    except:
300
                        pass
301

    
302
            N = len(self.variable_names)
303
            try:
304
                var = grp.create_dataset('correlation_matrix',
305
                                         (1, N, N),
306
                                         maxshape=(None, N, N),
307
                                         dtype=np.float32,
308
                                        **compress)
309
                var.dims.create_scale(ref['time'], 'time')
310
                var.dims[0].attach_scale(ref['time'])
311
                var.dims.create_scale(grp['variable_scatter_index'], 'variable_scatter_index')
312
                var.dims[1].attach_scale(grp['variable_scatter_index'])
313
                var.dims.create_scale(grp['variable_scatter_index'], 'variable_scatter_index')
314
                var.dims[2].attach_scale(grp['variable_scatter_index'])
315
            except:
316
                var = grp['correlation_matrix']
317
                var.resize(self.time_index_in_output+1, axis=0)
318

    
319
            try:
320
                var[time_step, :, :] = self.correlation_matrix
321
            except:
322
                self.logger.warning("Not adding correlation matrix because dimensions have changed.")
323

    
324
            try:
325
                var = grp.create_dataset('covariance_matrix',
326
                                         (1, N, N),
327
                                         maxshape=(None, N, N),
328
                                         dtype=np.float32,
329
                                         **compress)
330
                var.dims.create_scale(ref['time'], 'time')
331
                var.dims[0].attach_scale(ref['time'])
332
                var.dims.create_scale(grp['variable_scatter_index'], 'variable_scatter_index')
333
                var.dims[1].attach_scale(grp['variable_scatter_index'])
334
                var.dims.create_scale(grp['variable_scatter_index'], 'variable_scatter_index')
335
                var.dims[2].attach_scale(grp['variable_scatter_index'])
336
            except:
337
                var = grp['covariance_matrix']
338
                var.resize(self.time_index_in_output+1, axis=0)
339
            try:
340
                var[time_step, :, :] = self.covariance_matrix
341
            except:
342
                self.logger.warning("Not adding covariance matrix because dimensions have changed.")
343

    
344
    ## Read processed data from the input file, for specified time index.
345
    #
346
    #  @param fname NetCDF file with input data.
347
    #  @param time_index Time slice to read.
348
    def ingest(self, fname, time_index, exclude=None):
349
        self.time_index_in_output = time_index
350

    
351
        with netCDF4.Dataset(fname, 'r') as ref:
352
            try:
353
                grp = ref.groups[self.storage_group_name]
354
            except:
355
                self.logger.error("scatter data not in '%s'.", os.path.basename(fname))
356
                return False
357
            dimension_lengths = [len(grp.dimensions[k]) for k in grp.dimensions.keys() if k.startswith('scatter_bins_')]
358
            self.number_bins_collection = dimension_lengths
359
            self.pair_list = []
360
            self.index_list = []
361
            variable_names = grp.variables['variable_names_scatter'][:]
362
            if exclude is not None:
363
                variable_names = [v for v in variable_names if v not in exclude]
364
                
365
            self.logger.debug("{0}.ingest(): variable names [{1}]".format(self.__class__.__name__, ", ".join(variable_names)))
366
            for i1, n1 in enumerate(variable_names[:-1]):
367
                for i, n2 in enumerate(variable_names[i1+1:]):
368
                    key = (n1, n2)
369
                    self.pair_list.append(key)
370
                    self.index_list.append((i1, i1+1+i))
371

    
372
        with h5py.File(fname, 'r') as ref:
373
            try:
374
                grp = ref[self.storage_group_name]
375
            except:
376
                self.logger.error("scatter data not in '%s'.", os.path.basename(fname))
377
                return False
378

    
379
            vlist = [v for v in list(grp.keys()) if v.endswith('_2d_histogram')]
380

    
381
            for k in vlist:
382
                var = grp[k]
383
                try:
384
                    n1, n2 = [n.replace('_scatter_bins', '') for n in var.attrs['bins'].split()]
385
                except KeyError as err:
386
                    self.logger.warning("Attribute 'bins' missing from variable '{0}'".format(k))
387
                    continue
388
                self.variables[(n1, n2)] = var[time_index, :, :]
389

    
390
            vlist = grp['variable_names_scatter'][:] # [v for v in list(grp.keys()) if v.endswith('_bounds')]
391
            for n in vlist:
392
                k = n+'_bounds'
393
                if n not in self.variables_meta:
394
                    self.variables_meta[n] = {}
395
                try:
396
                    var = grp[k]
397
                    self.variables_meta[n]['bins'] = np.concatenate([var[:, 0], [var[-1, 1]]])
398
                except KeyError:
399
                    continue
400

    
401
            vlist = grp['variable_names_scatter'][:] #[v for v in list(grp.keys()) if v.endswith('_scatter_bins')]
402
            for k in vlist:
403
                var = grp[k+'_scatter_bins']
404
                if n not in self.variables_meta:
405
                    self.variables_meta[k] = {}
406
                try:
407
                    self.variables_meta[k]['title'] = var.attrs['long_name']
408
                    if isinstance(self.variables_meta[k]['title'], bytes):
409
                        self.variables_meta[k]['title'] = str(self.variables_meta[k]['title'], 'utf-8')
410
                except (AttributeError, KeyError):
411
                    self.variables_meta[k]['title'] = k
412
                try:
413
                    self.variables_meta[k]['data_range'] = var.attrs['range']
414
                except (AttributeError, KeyError):
415
                    self.variables_meta[k]['data_range'] = [self.variables_meta[k]['bins'].min(), self.variables_meta[k]['bins'].max()]
416
                try:
417
                    self.variables_meta[k]['log_range'] = (var.attrs['log_range'].lower() == b'true'
418
                                                           if isinstance(var.attrs['log_range'], bytes) else
419
                                                           var.attrs['log_range'].lower() == 'true')
420
                except (AttributeError, KeyError):
421
                    self.variables_meta[k]['log_range'] = False
422
                try:
423
                    self.variables_meta[k]['units'] = var.attrs['units']
424
                    if isinstance(self.variables_meta[k]['units'], bytes):
425
                        self.variables_meta[k]['units'] = str(self.variables_meta[k]['units'], 'utf-8')
426
                except:
427
                    self.variables_meta[k]['units'] = "1"
428

    
429
                self.variables_meta[k]['number_bins'] = grp[k+'_scatter_bins'].shape[0]
430

    
431
            var = grp['correlation_matrix']
432
            self.correlation_matrix = var[time_index, :, :]
433
            var = grp['covariance_matrix']
434
            self.covariance_matrix = var[time_index, :, :]
435
        return True
436

    
437
    ## Make a plot of a specified variable
438
    #
439
    #  @param varname The name of the variable to plot, here a list or tuple with two variable names.
440
    #  @param figure  The matplotlib.figure.Figure instance to plot to.
441
    #  @param ax      The matplotlib.axes.Axes object to use. Default is None, in that case `ax = matplotlib.pyplot.gca()` shall be used.
442
    #  @param log     Keyword parameter, boolean. When `True` a logarithmic colour scale will be used.
443
    #  @param time    Keyword parameter, either a tuple/list of datetime.datetime objects or a single datetime.datetime object indicating the
444
    #                 date or date-range of the underlying data.
445
    #  @param colorscale Keyword parameter, a string with teh name of a matplotlib colour scale. Defaults to 'viridis'.
446
    #  @param hide_R  Keyword parameter, boolean. When `True` the correlation coefficient will _not_ be added to the top-left of the plot.
447
    #  @return `False` if no plot could be created, `True` otherwise.
448
    def plot(self, varnames, figure, ax=None, **kwargs):
449
        import matplotlib
450
        matplotlib.use('svg')
451
        matplotlib.rc('font', family='DejaVu Sans')
452

    
453
        import matplotlib.pyplot as plt
454
        import matplotlib.cm as cm
455
        from matplotlib.colors import Normalize, LogNorm
456
        import matplotlib.ticker as ticker
457

    
458
        self.logger.info("Plotting scatter density of '%s' and '%s' (%s)", varnames[0], varnames[1],
459
                         "log scale" if ('log' in kwargs and kwargs['log']) else "linear scale")
460

    
461
        if ax is None:
462
            ax = plt.gca()
463

    
464
        if tuple(varnames[0:2]) not in self.key_pairs:
465
            if tuple(varnames[0:2:-1]) in self.key_pairs:
466
                varnames = varnames[0:2:-1]
467
            else:
468
                self.logger.error("scatterplot data for pair '%s' and '%s' not found.",
469
                                  varnames[0], varnames[1])
470
                return False
471

    
472
        key = tuple(varnames[0:2])
473
        i1, i2 = [self.variable_names.index(n) for n in key]
474

    
475
        xdata = self.variables_meta[key[0]]
476
        ydata = self.variables_meta[key[1]]
477
        image = np.ma.asarray(self.variables[key])
478
        image.mask = (image == 0)
479
        im_max = np.nanmax(image)
480
        if im_max > 0:
481
            scale_divisor = 10**(math.floor(math.log10(im_max)))
482
        else:
483
            scale_divisor = im_max
484
            if scale_divisor == 0.0: scale_divisor = 1.0
485

    
486
        correlation_coefficient = self.correlation_matrix[i1, i2]
487

    
488
        try:
489
            if ('log' in kwargs and kwargs['log'] and im_max > 0):
490
                normalizer = LogNorm(vmin=1, vmax=math.ceil(im_max/scale_divisor)*scale_divisor)
491
            else:
492
                normalizer = Normalize(vmin=0, vmax=math.ceil(im_max/scale_divisor)*scale_divisor)
493
        except ValueError:
494
            self.logger.warning("Aborting scatter density plot of '%s' and '%s' (%s/%s)", varnames[0], varnames[1], str(im_max), str(scale_divisor))
495
            return False
496

    
497
        if 'time' in kwargs:
498
            if isinstance(kwargs['time'], (list, tuple)):
499
                tt = kwargs['time']
500
            else:
501
                tt = [kwargs['time']]
502
            timestr = "".join([t.strftime("%Y-%m-%d") for t in tt])
503
            ax.set_title(timestr)
504

    
505
        if 'colorscale' in kwargs:
506
            try:
507
                colormap = cm.get_cmap(kwargs['colorscale'])
508
            except:
509
                self.logger.warning("Colorscale '{0}' was not found, using 'viridis' instead".format(cmapname))
510
                colormap = cm.get_cmap("viridis")
511
        else:
512
            colormap = cm.get_cmap("gray_r")
513

    
514
        im = plt.pcolormesh(xdata['bins'], ydata['bins'], image.T, cmap=colormap, norm=normalizer, shading='flat', edgecolors='None', axes=ax)
515
        if ('log' in kwargs and kwargs['log'] and im_max > 0):
516
            fmt = ticker.LogFormatterMathtext(base=10.0, labelOnlyBase=False)
517
        else:
518
            fmt = ticker.ScalarFormatter(useMathText=True)
519
            fmt.set_powerlimits((-3,3))
520
        cbar = plt.colorbar(im, ax=ax, orientation='vertical', aspect=25, extend="neither", format=fmt)
521
        cbar.ax.xaxis.set_tick_params(which='major', width=1.0, length=6.0, direction='out')
522
        cbar.set_label("Number of observations")
523
        ax.set_xlim(xdata['data_range'])
524
        ax.set_ylim(ydata['data_range'])
525

    
526
        if xdata['units'] is not None and xdata['units'] != "1":
527
            plt.xlabel("{0} [{1}]".format(xdata['title'], xdata['units']))
528
        else:
529
            plt.xlabel(xdata['title'])
530

    
531
        if ydata['units'] is not None and ydata['units'] != "1":
532
            plt.ylabel("{0} [{1}]".format(ydata['title'], ydata['units']))
533
        else:
534
            plt.ylabel(ydata['title'])
535

    
536
        if not ('hide_R' in kwargs and kwargs['hide_R']):
537
            plt.text(0.02, 0.98, "$R = {0:.3f}$".format(correlation_coefficient),
538
                     alpha=1.0, axes=ax, size='small',
539
                     verticalalignment='top', horizontalalignment='left',
540
                     transform=ax.transAxes)
541
        return True
542

    
543
    ## Generate a LaTeX table with either the covariance matrix or the correlation matrix for inclusion in the report.
544
    #  @param covariance Keyword argument, boolean. When True the covariance is printed, otherwise the correlation is printed.
545
    #  @return String with table.
546
    #
547
    #  The table uses a LaTeX boxes called `\correlationtablebox`.
548
    #  (this box is already created in the latex template for the report),
549
    #  and determines (within LaTeX) which orientation of the table gives the best result.
550
    #
551
    def report(self, **kwargs):
552
        print_cov = ('covariance' in kwargs and kwargs['covariance'])
553
        rval = []
554
        nvar = len(self.variable_names)
555
        rval.append(r'\savebox{\correlationtablebox}{')
556
        rval.append(r'\begin{{tabular}}{{{0}}}'.format(nvar*"c"))
557
        fmt = "{0} \\\\".format(" & ".join(["${{values[{0}]}}$".format(i) for i in range(nvar)]))
558
        ## Replace some unicode strings with LaTeX equivalent because these particular
559
        # code points aren't included at this moment in the LaTeX unicode support.
560
        rval.append(r"{0} \\".format(" & ".join([r"\rotatebox{{270}}{{{0}}}".format(self.variables_meta[n]['title'].replace('χ²', r'$\chi^2$').replace('λ', r'$\lambda$')) for n in self.variable_names])))
561
        for i, name in enumerate(self.variable_names):
562
            rval.append(fmt.format(values=[self.number_pretty_printer(v) for v in (self.covariance_matrix[i, :] if print_cov else self.correlation_matrix[i, :])]))
563
        rval.append(r'''\end{tabular}%
564
}''')
565
        caption_text = '{c} matrix'.format(c="Covariance" if print_cov else "Correlation")
566
        rval.append(r"""
567
\setlength{\naturaltablewidth}{\widthof{\usebox{\correlationtablebox}}}
568
\setlength{\naturaltableheight}{\heightof{\usebox{\correlationtablebox}}}
569
\addtolength{\naturaltableheight}{\depthof{\usebox{\correlationtablebox}}}
570
\setlength{\naturaltextratio}{1pt*\ratio{0.92\textheight}{\textwidth}}
571
\setlength{\naturaltableratio}{1pt*\ratio{\naturaltableheight}{\naturaltablewidth}}
572
\setlength{\realtextheight}{\textheight}
573
\setlength{\realtextwidth}{\textwidth}
574
%%
575
\ifthenelse{\lengthtest{0.92\textheight>\naturaltableheight}%%
576
            \AND \lengthtest{\textwidth>\naturaltablewidth}}%%
577
{%%(then)
578
  \begin{table}[htbp]%%
579
  \caption{%s}%%
580
  \centering\usebox{\correlationtablebox}%%
581
  \end{table}%%
582
}%%
583
{%%
584
  \begin{landscape}%%
585
  \begin{table}%%
586
    \caption{%s}%%
587
    \ifthenelse{\lengthtest{0.92\realtextheight>\naturaltableheight}%%
588
                \AND \lengthtest{\realtextwidth>\naturaltablewidth}}%%
589
      {\centering\usebox{\correlationtablebox}}%%
590
      {%%
591
        \ifthenelse{\lengthtest{\naturaltableratio<\naturaltextratio}}%%
592
          {\resizebox*{\realtextheight}{!}{\usebox{\correlationtablebox}}}%%
593
          {\resizebox*{!}{0.92\realtextwidth}{\usebox{\correlationtablebox}}}%%
594
      }
595
  \end{table}%%
596
  \end{landscape}%%
597
}
598

599
""" % (caption_text, caption_text))
600
        return "\n".join(rval)
601

    
602
    ## Add a table to the report using matplotlib alone. The preferred output uses LaTeX.
603
    #
604
    #  @param ax The axis to use, defaults to `ax = matplotlib.pyplot.gca()`
605
    #  @param figure The figure instance to use, defaults to `figure = matplotlib.pyplot.gcf()`.
606
    #  @param covariance Keywor argument, boolean. When absent or `False` the correlation is written.
607
    def table(self, ax=None, figure=None, **kwargs):
608
        import matplotlib.pyplot as plt
609
        import matplotlib.cm as cm
610

    
611
        print_cov = ('covariance' in kwargs and kwargs['covariance'])
612
        self.logger.info("Adding table for %s", 'covariance' if print_cov else 'correlation')
613
        if ax is None:
614
            ax = plt.gca()
615
        if figure is None:
616
            figure = plt.gcf()
617

    
618
        colLabels = [self.variables_meta[n]['title'] for n in self.variable_names]
619
        ax.axis('off')
620
        nrows, ncols = len(colLabels)+1, len(colLabels)
621
        hcell, wcell = 0.2, 0.9
622
        hpad, wpad = 0.0, 0.0
623
        figure.set_size_inches(ncols*wcell+wpad, nrows*hcell+hpad)
624

    
625
        contents = []
626
        d = self.covariance_matrix if print_cov else self.correlation_matrix
627
        for i, name in enumerate(self.variable_names):
628
            line = []
629
            for item in d[i, :]:
630
                line.append("{0:.4g}".format(item))
631
            contents.append(line)
632

    
633
        #do the table
634
        the_table = ax.table(cellText=contents, colLabels=colLabels, loc='center')
635
        the_text = plt.text(0.0, 1.02, 'Covariance' if print_cov else 'Correlation',
636
                            alpha=1.0, axes=ax,
637
                            verticalalignment='bottom', horizontalalignment='left',
638
                            transform=ax.transAxes)
639
        return [the_table, the_text]
640

    
641
    ## Make a plot of the correlation matrix on a rectangular grid.
642
    #
643
    #  @param figure  The matplotlib.figure.Figure instance to plot to.
644
    #  @param ax      The matplotlib.axes.Axes object to use. Default is None, in that case `ax = matplotlib.pyplot.gca()` shall be used.
645
    #  @param colorscale Keyword argument, string: The color scale to use. Defaults to 'seismic' (blue-white-red).
646
    #  @param covariance Keyword argument, boolean: when `True` the covariance matrix is plotted. Defaults to `False` (correlation matrix).
647
    #
648
    #  @return `False` if no plot could be created, `True` otherwise.
649
    def correlation_graph(self, ax=None, figure=None, **kwargs):
650
        import matplotlib.pyplot as plt
651
        import matplotlib.cm as cm
652

    
653
        print_cov = ('covariance' in kwargs and kwargs['covariance'])
654
        self.logger.info("Adding plot for %s", 'covariance' if print_cov else 'correlation')
655
        if ax is None:
656
            ax = plt.gca()
657
        if figure is None:
658
            figure = plt.gcf()
659

    
660
        colLabels = [self.variables_meta[n]['title'] for n in self.variable_names]
661
        d = self.covariance_matrix if print_cov else self.correlation_matrix
662

    
663
        if 'colorscale' in kwargs:
664
            try:
665
                colormap = cm.get_cmap(kwargs['colorscale'])
666
            except:
667
                self.logger.warning("Colorscale '{0}' was not found, using 'seismic' instead".format(cmapname))
668
                colormap = cm.get_cmap("seismic")
669
        else:
670
            colormap = cm.get_cmap("seismic")
671

    
672
        im = plt.imshow(d, cmap=colormap, aspect='equal', interpolation='nearest',
673
                   vmin=(d.min() if print_cov else -1.0), vmax=(d.max() if print_cov else 1.0),
674
                   origin='upper')
675
        cbar = plt.colorbar(im, ax=ax, orientation='vertical', aspect=20, extend="neither")
676
        cbar.set_label("Covariance" if print_cov else "Correlation coefficient")
677

    
678
        plt.xticks(np.arange(d.shape[0]), colLabels, rotation='vertical', fontsize='small')
679
        plt.yticks(np.arange(d.shape[1]), colLabels, rotation='horizontal', fontsize='small')
680

    
681
        plt.margins(0.0)
682
        #plt.subplots_adjust(bottom=0.15, left=0.15)
683
        return True
684

    
685
    ## Make a plot of the correlation matrix on a circle, with the thickness of the connecting lines indicating the
686
    # strength of the correlation, and the color indicating the sign (red positive, blue negative).
687
    #
688
    #  @param figure  The matplotlib.figure.Figure instance to plot to.
689
    #  @param ax      The matplotlib.axes.Axes object to use. Default is None, in that case `ax = matplotlib.pyplot.gca()` shall be used.
690
    #  @param show_values Keyword argument, boolean: when `True` numerical values are printed on the connecting lines. Defaults to `False`.
691
    #  @param show_value_threshold Keyword argument, float: the absolute threshold of values to print, when `show_values` is true. Defaults to 0.25.
692
    #
693
    #  @return `False` if no plot could be created, `True` otherwise.
694
    def correlation_matrix_plot(self, ax=None, figure=None, **kwargs):
695
        import matplotlib.pyplot as plt
696
        self.logger.info("Adding graph of correlations")
697
        R = self.correlation_matrix
698
        names = [self.variables_meta[n]['title'] for n in self.variable_names]
699

    
700
        show_values = ('show_values' in kwargs and kwargs['show_values'])
701
        if show_values and 'show_value_threshold' in kwargs:
702
            show_value_threshold = kwargs['show_value_threshold']
703
        else:
704
            show_value_threshold = 0.25
705

    
706
        theta = np.linspace(0, 2*np.pi, num=R.shape[0], endpoint=False)
707
        x, y = np.cos(theta), np.sin(theta)
708
        ii, jj = np.triu_indices(R.shape[0], 1)
709
        if ax is None:
710
            ax = plt.gca()
711
        if figure is None:
712
            figure = plt.gcf()
713

    
714
        plt.axis([-1.25, 1.25, -1.25, 1.25])
715

    
716
        phi = np.linspace(0, 2*np.pi, num=720)
717
        plt.plot(np.cos(phi), np.sin(phi), '0.5', linewidth=0.5)
718

    
719
        for i, j in zip(ii, jj):
720
            c = ('r' if (R[i,j] > 0) else 'b')
721
            if not np.isfinite(R[i,j]):
722
                continue
723
            plt.plot([x[i], x[j]], [y[i], y[j]],
724
                     color=c, alpha=(np.fabs(R[i,j])/2 + 0.5),
725
                     linewidth=(4.0*(np.fabs(R[i,j])+0.125)),
726
                     solid_capstyle='round',
727
                     solid_joinstyle='round')
728
            if show_values and np.fabs(R[i,j]) > show_value_threshold:
729
                plt.text((x[i]+x[j])/2, (y[i]+y[j])/2,
730
                         "{0:.3f}".format(R[i,j]),
731
                         horizontalalignment='center',
732
                         verticalalignment='center',
733
                         size='small',
734
                         color='k')
735

    
736
        for i, name in enumerate(names):
737
            xx, yy = 1.1*np.cos(theta[i]), 1.1*np.sin(theta[i])
738
            if -0.5 < xx < 0.5:
739
                ha = 'center'
740
                if yy > 0:
741
                    va = 'bottom'
742
                    yy += 0.2*(1-2*np.fabs(xx))
743
                else:
744
                    va = 'top'
745
                    yy -= 0.2*(1-2*np.fabs(xx))
746
            elif xx < -0.5:
747
                ha = 'right'
748
                va = 'center'
749
            else:
750
                ha = 'left'
751
                va = 'center'
752
            plt.text(xx, yy, name,
753
                     horizontalalignment=ha,
754
                     verticalalignment=va)
755
        plt.plot(x, y, 'ko')
756

    
757
        axes = figure.get_axes()
758

    
759
        for ax in axes:
760
            ax.set_axis_off()
761

    
762
        axes[0].set_aspect('equal')
763
        plt.tight_layout()
764
        return True