Project

General

Profile

Statistics
| Branch: | Tag: | Revision:

pycama / src / pycama / WorldPlot.py @ 834:d989df597b80

History | View | Annotate | Download (39.9 KB)

1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3

    
4
# Copyright 2016-2017 Maarten Sneep, KNMI
5
#
6
# Redistribution and use in source and binary forms, with or without
7
# modification, are permitted provided that the following conditions are met:
8
#
9
# 1. Redistributions of source code must retain the above copyright notice,
10
#    this list of conditions and the following disclaimer.
11
#
12
# 2. Redistributions in binary form must reproduce the above copyright notice,
13
#    this list of conditions and the following disclaimer in the documentation
14
#    and/or other materials provided with the distribution.
15
#
16
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
20
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
23
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26

    
27
## @file WorldPlot.py
28
#  @author Maarten Sneep
29
#
30
#  This file defines a subclass of the pycama.AnalysisAndPlot.AnalysisAndPlot class.
31
#  This subclass implements the level 3 gridder for PyCAMA.
32

    
33

    
34
import math
35
import logging
36

    
37
import numpy as np
38
import netCDF4
39
import h5py
40

    
41
from .AnalysisAndPlot import AnalysisAndPlot
42

    
43
from .utilities import *
44

    
45
## Grid & plot onto a worldmap
46
#
47
#  This is a subclass of the pycama.AnalysisAndPlot.AnalysisAndPlot class intended
48
#  for gridding onto a level 3 grid and plotting the result. As for all
49
#  pycama.AnalysisAndPlot.AnalysisAndPlot subclasses the extracted data can be
50
#  stored into/read from a netCDF4 file.
51
#
52
#  The level 3 grid used here is a reduced grid, similar to the data structure used by
53
#  ECMWF for its reduced Gaussian grid storage. Note that our grid is not a Gaussian grid,
54
#  we simply scale the number of longitudes for each latitude with \f$\cos(\delta_{\mbox{geo}})\f$.
55
#
56
#  The __init__ method is inherited.
57
class WorldPlot(AnalysisAndPlot):
58
    ## Class specific preparation.
59
    #
60
    #  @param resolution Keyword parameter, float. Set the spatial resolution \f$\Delta\delta\f$ in degrees. Defaults to 0.5.
61
    #
62
    def setup(self, **kwargs):
63
        if 'resolution' in kwargs:
64
            self.resolution = math.ceil(360.0*kwargs['resolution'])/360.0
65
        else:
66
            self.resolution = 0.5
67

    
68
        ## The latitude grid is linear (equidistant) from south to north, with pixel centers at
69
        #  \f$\delta_c = -90.0 + \Delta\delta/2 + i\Delta\delta\f$, with \f$i=0, 1, \ldots, (180/\Delta\delta)-1\f$.
70
        #  This instance variable is an array with just the pixel centers.
71
        self.latitude_centers = None
72
        ## The number of latitudes in the grid, a simple integer, with value the
73
        #  lenght of the `latitude_centers` array.
74
        self.n_latitude = -1
75
        ## Array of length `n_latitude`, with each element containing the number of
76
        #  longitudes in that latitude band.
77
        #  the value in each element \f$i\f$ is:
78
        #  \f[
79
        #     N[i] = 2 \left\lceil \left\lceil \frac{180}{\Delta\delta} \right\rceil \cos\left(\frac{\pi \delta_c[i]}{180})\right\rceil
80
        #  \f]
81
        self.grid_length = None
82
        ## The cummulative sum of grid_length, starting from 0. This
83
        #  gives the starting point of each latitude band in the overall data set.
84
        self.grid_length_cummulative = None
85
        ## The sum of grid_length, the total number of points in the overall grid.
86
        self.n_total = -1
87
        ## For each latitude the longitude grid is linear (equidistant) from -180 degrees longitude going east, with pixel centers at
88
        #  \f$\vartheta_c = -180 + \Delta\vartheta/2 + j\Delta\vartheta \f$,
89
        #  with \f$j=0, 1, \ldots, (360/\Delta\vartheta)\f$ and \f$\Delta\vartheta = 360/(N[i]+1)\f$.
90
        self.longitude_centers = None
91
        ## latitude boundaries, an [n_latitude, 2] array.
92
        self.latitude_bounds = None
93
        ## longitude boundaries, an [n_total, 2] array.
94
        self.longitude_bounds = None
95
        ## Dict with reconstructed graphical representation
96
        self.graph_rep = {'name':None, 'data':None}
97

    
98
        self.define_grid()
99

    
100
    ## Define the level 3 grid
101
    #
102
    #  This creates the grid following the definitions for each of the instance variables.
103
    def define_grid(self):
104
        lat_c = np.arange(-90.0+self.resolution/2, 90.0, self.resolution)
105
        self.latitude_centers = lat_c
106
        self.n_latitude = len(lat_c)
107
        lat_b = np.arange(-90.0, 90.0+self.resolution/2, self.resolution)
108
        self.latitude_bounds = np.transpose(np.asarray([lat_b[0:-1], lat_b[1:]]))
109
        n_lon_max = int(math.ceil(360.0/self.resolution))
110

    
111
        self.grid_length = np.asarray(2*np.ceil((n_lon_max/2)*np.cos(np.pi*lat_c/180.0)), dtype=np.int32)
112
        self.grid_length_cummulative = np.concatenate([[0], np.cumsum(self.grid_length)])
113
        self.n_total = np.sum(self.grid_length)
114

    
115
        self.longitude_centers = np.zeros((self.n_total,), dtype=np.float32)
116
        self.longitude_bounds = np.zeros((self.n_total,2), dtype=np.float32)
117

    
118
        for i, nlon in enumerate(self.grid_length):
119
            lon_bounds, step = np.linspace(-180.0, 180.0, num=nlon+1, endpoint=True, retstep=True, dtype=np.float32)
120
            lons = np.linspace(-180.0+step/2, 180.0-step/2, num=nlon, endpoint=True, dtype=np.float32)
121

    
122
            self.longitude_centers[self.grid_length_cummulative[i]:self.grid_length_cummulative[i+1]] = lons
123
            self.longitude_bounds[self.grid_length_cummulative[i]:self.grid_length_cummulative[i+1], 0] = lon_bounds[0:-1]
124
            self.longitude_bounds[self.grid_length_cummulative[i]:self.grid_length_cummulative[i+1], 1] = lon_bounds[1:]
125

    
126
    ## get or add data from/to a variable.
127
    #
128
    #  if add_data is None, then data for latitude band i is returned,
129
    #  otherwise the data in add_data is added to that latitude band
130
    #
131
    #  @param i latitude index
132
    #  @param name Variable to extract or write to.
133
    #  @param add_data When `None`, return the data for variable `name`,
134
    #         otherwise add the incoming data to the storage. The data is already averaged.
135
    #  @param count The number of observations in each bin, required to combine different fields later.
136
    #         The count is kept per variable because variables do not have to be synchronized.
137
    #         If `add_data` is `None` and `count` resolves to `True`, then return the `count`.
138
    # @return `None` when writing, `data` or `count` when reading.
139
    def data_for_latitude_band(self, i, name, add_data=None, count=None):
140
        var = self.variables[name]
141
        count_var = self.variables[name + '_count']
142
        if add_data is None:
143
            if count is None or not count:
144
                return var[self.grid_length_cummulative[i]:self.grid_length_cummulative[i+1]]
145
            else:
146
                return count_var[self.grid_length_cummulative[i]:self.grid_length_cummulative[i+1]]
147
        else:
148
            try:
149
                var[self.grid_length_cummulative[i]:self.grid_length_cummulative[i+1]] += add_data
150
                if count is not None:
151
                    count_var[self.grid_length_cummulative[i]:self.grid_length_cummulative[i+1]] += count
152
            except ValueError as err:
153
                self.logger.error("New data for variable {0} has wrong length (received {1}, expected {2})".format(name, len(add_data), self.grid_length[i]))
154
            return None
155

    
156
    ## A an input variable.
157
    #
158
    #  @param var a pycama.Variable.Variable instance.
159
    #
160
    #  This method reserves memory for the gridded level 3 data, both the averaged data and the observation counter.
161
    def add_variable(self, var):
162
        if not var.level3:
163
            return
164
        self.variables[var.name] = np.zeros((self.n_total,), dtype=np.float32)
165
        self.variables[var.name + '_count'] = np.zeros((self.n_total,), dtype=np.int32)
166
        super(self.__class__, self).add_variable(var)
167
        self.variables_meta[var.name]['noscanline'] = var.noscanline
168
        if hasattr(var, 'map_range'):
169
            self.variables_meta[var.name]['data_range'] = var.map_range
170

    
171
    ## Extract the required information from the input data.
172
    #
173
    #  Using the latitude bounds we find for each observation into which latitude band it should be sorted.
174
    #  In a loop over the latitude bands we select the observations for each band. If there are none we skip to the next latitude band.
175
    #  The longitudes are brought into the [-180, 180) range, and longitude bin boundaries are extracted for the latitude band.
176
    #
177
    #  The inner loop is over the variables. For each variable the valid data is selected (taking unsynchronized variable into account - I think),
178
    #  and the `np.histogram()` function is used to calculate the value for each longitude bin.
179
    #  The longitudes are used as primary histogram variable, with the data we're actually interested
180
    #  in the weights parameter. A separate call to `np.histogram()` is used to count the number of
181
    #  observations in each grid cell. The data_for_latitude_band() method is used to store the extracted data.
182
    #
183
    #  After extraction of the data the mean is taken by dividing the sum of the observations
184
    #  by the number of observations (where the latter is not 0).
185
    def process(self):
186
        lat = self.input_variables.variables['latitude'].aggregate_data
187
        lon = self.input_variables.variables['longitude'].aggregate_data
188
        lon = np.where(lon > 180.0, lon-360.0, lon)
189
        latidx = np.searchsorted(self.latitude_bounds[:, 0], lat) - 1
190
        for i in range(self.grid_length.shape[0]):
191
            if i > 0 and i % (self.grid_length.shape[0]//min(20, self.grid_length.shape[0])) == 0:
192
                self.progress(100*self.grid_length_cummulative[i]/self.grid_length_cummulative[-1])
193
            selected = (latidx == i)
194
            bins = np.concatenate([self.longitude_bounds[self.grid_length_cummulative[i]:self.grid_length_cummulative[i+1], 0],
195
                                   [self.longitude_bounds[self.grid_length_cummulative[i+1]-1, 1]] ] )
196
            if not np.any(selected):
197
                continue
198
            for v in self.input_variables.variables.values():
199
                if not v.show or not v.level3 or v.noscanline or len(v.aggregate_data) != len(latidx):
200
                    continue
201
                try:
202
                    vselected = np.logical_and(np.logical_and(np.logical_and(latidx == i, np.isfinite(v.aggregate_data)), np.logical_not(v.aggregate_data.mask)), v.aggregate_data < float.fromhex('1.ep+122'))
203
                except (AttributeError, IndexError):
204
                    vselected = np.logical_and(np.logical_and(latidx == i, np.isfinite(v.aggregate_data)), v.aggregate_data < float.fromhex('1.ep+122'))
205
                if np.any(vselected):
206
                    h = np.histogram(lon[vselected], bins=bins, weights=v.aggregate_data[vselected])[0]
207
                    c = np.histogram(lon[vselected], bins=bins)[0]
208
                    self.data_for_latitude_band(i, v.name, add_data=h, count=c)
209

    
210
        # divide by number of points for mean
211
        for name in [n for n in self.variables.keys() if not n.endswith('_count')]:
212
            cidx = (self.variables[name + '_count'] != 0)
213
            self.variables[name][cidx] /= self.variables[name + '_count'][cidx]
214

    
215
    ## Merge data into a combined dataset.
216
    #
217
    #  @param other The object to be added to self,
218
    #               also an instance of the pycama.WorldPlot.WorldPlot class.
219
    def __iadd__(self, other):
220
        for name in [n for n in self.variables.keys() if not n.endswith('_count')]:
221
            self.logger.debug("{0}.__iadd__(): processing '{1}'".format(self.__class__.__name__, name))
222

    
223
            S = self.variables[name]
224
            C = self.variables[name + '_count']
225
            oS = other.variables[name]
226
            oC = other.variables[name + '_count']
227
            idx = C+oC > 0
228
            self.variables[name][idx] = (S[idx]*C[idx] + oS[idx]*oC[idx])/(C[idx]+oC[idx])
229
            self.variables[name + '_count'] = C + oC
230

    
231
    ## Read processed data from the input file, for specified time index.
232
    #
233
    #  @param fname NetCDF file with input data.
234
    #  @param time_index Time slice to read.
235
    def ingest(self, fname, time_index, exclude=None):
236
        self.logger.debug("{0}.ingest(): reading {1}".format(self.__class__.__name__, fname))
237
        self.time_index_in_output = time_index
238
        with h5py.File(fname, 'r') as ref:
239
            if self.storage_group_name not in ref.keys():
240
                self.logger.error("Worldplot data not in '%s'", os.path.basename(fname))
241
                return False
242
            grp = ref[self.storage_group_name]
243
            self.resolution = grp.attrs['resolution']
244
            self.n_latitude = len(grp['latitude'])
245
            self.n_total = len(grp['rgrid'])
246
            self.latitude_centers = grp['latitude'][:]
247
            self.latitude_bounds = grp['latitude_bounds'][...]
248
            self.grid_length = grp['grid_length'][:]
249
            self.grid_length_cummulative = np.concatenate([[0], np.cumsum(self.grid_length)])
250
            self.longitude_centers = grp['longitude'][:]
251
            self.longitude_bounds = grp['longitude_bounds'][...]
252

    
253
            self.variables = {}
254
            self.variables_meta = {}
255
            for k in [name.decode("ASCII") for name in ref['variable_names'][:] if name.decode("ASCII") not in ('latitude', 'longitude')]:
256
                if k not in grp.keys():
257
                    self.logger.debug("Variable name {0} requested but not available".format(k))
258
                    continue
259
                if exclude is not None and k in exclude:
260
                    self.logger.debug("Variable name {0} is excluded".format(k))
261
                    continue
262
                    
263
                self.variables_meta[k] = {}
264
                self.variables_meta[k]['noscanline'] = False
265
                var = grp[k]
266
                self.variables[k] = var[self.time_index_in_output, :]
267
                count_var = grp[k+'_count']
268
                self.variables[k+'_count'] = count_var[self.time_index_in_output, :]
269

    
270
                self.logger.debug("{0}.ingest(): processing {1}".format(self.__class__.__name__, k))
271

    
272
                try:
273
                    self.variables_meta[k]['units'] = var.attrs['units']
274
                    if isinstance(self.variables_meta[k]['units'], bytes):
275
                        self.variables_meta[k]['units'] = str(self.variables_meta[k]['units'], 'utf-8')
276
                except (AttributeError, KeyError):
277
                    self.variables_meta[k]['units'] = "1"
278
                try:
279
                    self.variables_meta[k]['title'] = var.attrs['long_name']
280
                    if isinstance(self.variables_meta[k]['title'], bytes):
281
                        self.variables_meta[k]['title'] = str(self.variables_meta[k]['title'], 'utf-8')
282
                except (AttributeError, KeyError):
283
                    self.variables_meta[k]['title'] = k
284

    
285
                try:
286
                    if isinstance(var.attrs['range'], (bytes, str)) and var.attrs['range'].lower() in (b'false', 'false'):
287
                        self.variables_meta[k]['data_range'] = False
288
                    else:
289
                        self.variables_meta[k]['data_range'] = var.attrs['range']
290
                except (AttributeError, KeyError):
291
                    self.variables_meta[k]['data_range'] = None
292

    
293
                try:
294
                    self.variables_meta[k]['color_scale'] = var.attrs['color_scale']
295
                    if isinstance(self.variables_meta[k]['color_scale'], bytes):
296
                        self.variables_meta[k]['color_scale'] = str(self.variables_meta[k]['color_scale'], 'utf-8')
297
                except (AttributeError, KeyError):
298
                    self.variables_meta[k]['color_scale'] = 'nipy_spectral'
299
                try:
300
                    self.variables_meta[k]['log_range'] = (var.attrs['log_range'].lower() == b'true'
301
                                                           if isinstance(var.attrs['log_range'], bytes) else
302
                                                           var.attrs['log_range'].lower() == 'true')
303
                except (AttributeError, KeyError):
304
                    self.variables_meta[k]['log_range'] = False
305
        self.graph_rep = {'name':None, 'data':None}
306
        return True
307

    
308
    ## Write processed data to output netcdf file.
309
    #
310
    #  @param fname File to write to
311
    #  @param mode  Writing mode, defaults to append.
312
    #
313
    #  Write data (including extraction specific dimensions) to the group with
314
    #  the name given in the `storage_group_name` property ("`worldplot_data`" for this class).
315
    def dump(self, fname, mode='a'):
316
        compress={'compression':'gzip', 'compression_opts':3, 'shuffle':True, 'fletcher32':True}
317
        with h5py.File(fname, 'a') as ref:
318
            # group
319
            try:
320
                grp = ref.create_group(self.storage_group_name)
321
                grp.attrs['comment'] = 'Gridded (level 3) data'
322
            except:
323
                grp = ref[self.storage_group_name]
324

    
325
            # dimensions
326
            try:
327
                lat_var = grp.create_dataset('latitude', (self.n_latitude,), dtype=np.float32, **compress)
328
            except:
329
                lat_var = grp['latitude']
330
            lat_var[:] = self.latitude_centers
331
            lat_var.attrs['units'] = 'degrees_north'
332
            lat_var.attrs['standard_name'] = 'latitude'
333

    
334
            # This is a pure dimension.
335
            try:
336
                rgrid_var = grp.create_dataset('rgrid', (self.n_total,), dtype=np.int32, **compress)
337
                rgrid_var[:] = np.arange(self.n_total, dtype=np.int32)
338
                rgrid_var.attrs['long_name'] = 'reduced grid index'
339
            except:
340
                pass
341

    
342
            try:
343
                res = grp.attrs['resolution']
344
            except (AttributeError, KeyError):
345
                res = None
346
                grp.attrs['resolution'] = self.resolution
347

    
348
            if res is not None and res != self.resolution:
349
                self.logger.error("Spatial resolution in run and spatial resolution in file do not match.")
350
                return
351

    
352
            try:
353
                lat_bounds_var = grp.create_dataset('latitude_bounds', (self.n_latitude,2), dtype=np.float32, **compress)
354
                lat_bounds_var.dims.create_scale(grp['latitude'], 'latitude')
355
                lat_bounds_var.dims[0].attach_scale(grp['latitude'])
356
                lat_bounds_var.dims.create_scale(ref['nv'], 'nv')
357
                lat_bounds_var.dims[1].attach_scale(ref['nv'])
358
            except:
359
                lat_bounds_var = grp['latitude_bounds']
360
            lat_bounds_var[...] = self.latitude_bounds
361

    
362
            try:
363
                lon_var = grp.create_dataset('longitude', (self.n_total,), dtype=np.float32, **compress)
364
                lon_var.dims.create_scale(grp['rgrid'], 'rgrid')
365
                lon_var.dims[0].attach_scale(grp['rgrid'])
366
            except:
367
                lon_var = grp['longitude']
368

    
369
            lon_var[:] = self.longitude_centers
370
            lon_var.attrs['units'] = 'degrees_east'
371
            lon_var.attrs['standard_name'] = 'longitude'
372

    
373
            try:
374
                lon_bounds_var = grp.create_dataset('longitude_bounds', (self.n_total,2), dtype=np.float32, **compress)
375
                lon_bounds_var.dims.create_scale(grp['rgrid'], 'rgrid')
376
                lon_bounds_var.dims[0].attach_scale(grp['rgrid'])
377
                lon_bounds_var.dims.create_scale(ref['nv'], 'nv')
378
                lon_bounds_var.dims[1].attach_scale(ref['nv'])
379
            except:
380
                lon_bounds_var = grp['longitude_bounds']
381
            lon_bounds_var[...] = self.longitude_bounds
382

    
383
            try:
384
                grd_len = grp.create_dataset('grid_length', (self.n_latitude,), dtype=np.int32, **compress)
385
                grd_len.dims.create_scale(grp['latitude'], 'latitude')
386
                grd_len.dims[0].attach_scale(grp['latitude'])
387
            except:
388
                grd_len = grp['grid_length']
389
            grd_len[:] = self.grid_length
390
            grd_len.attrs['long_name'] = 'Number of points in each latitude band'
391

    
392
            try:
393
                grd_start = grp.create_dataset('grid_start', (self.n_latitude,), dtype=np.int32, **compress)
394
                grd_start.dims.create_scale(grp['latitude'], 'latitude')
395
                grd_start.dims[0].attach_scale(grp['latitude'])
396
            except:
397
                grd_start = grp['grid_start']
398
            grd_start[:] = self.grid_length_cummulative[:-1]
399
            grd_start.attrs['long_name'] = 'Start index for each latitude band'
400

    
401
            time_step = self.time_index_in_output
402

    
403
            # if self.count is not None:
404
            #     try:
405
            #         count_var = grp.create_dataset('count', (1, self.n_total), dtype=np.float32, maxshape=(None, self.n_total), **compress)
406
            #         count_var.dims.create_scale(ref['time'], 'time')
407
            #         count_var.dims[0].attach_scale(ref['time'])
408
            #         count_var.dims.create_scale(grp['rgrid'], 'rgrid')
409
            #         count_var.dims[1].attach_scale(grp['rgrid'])
410
            #     except:
411
            #         count_var = grp['count']
412
            #         count_var.resize(self.time_index_in_output+1, axis=0)
413
            #
414
            #     count_var[time_step, :] = self.count
415

    
416
            for k in [name for name in self.variable_names if not name.endswith('_count')]:
417
                if self.variables_meta[k]['noscanline']:
418
                    continue
419

    
420
                try:
421
                    var = grp[k]
422
                    var.resize(self.time_index_in_output+1, axis=0)
423
                except:
424
                    if self.time_index_in_output > 0:
425
                        continue
426
                    var = grp.create_dataset(k, (1, self.n_total), dtype=np.float32, maxshape=(None, self.n_total), **compress)
427
                    var.dims.create_scale(ref['time'], 'time')
428
                    var.dims[0].attach_scale(ref['time'])
429
                    var.dims.create_scale(grp['rgrid'], 'rgrid')
430
                    var.dims[1].attach_scale(grp['rgrid'])
431

    
432
                try:
433
                    var.attrs['units'] = self.variables_meta[k]['units']
434
                except (AttributeError, TypeError):
435
                    var.attrs['units'] = "1"
436

    
437
                try:
438
                    var.attrs['long_name'] = self.variables_meta[k]['title']
439
                except (AttributeError, TypeError):
440
                    var.attrs['long_name'] = k
441
                try:
442
                    var.attrs['range'] = [float(v) for v in self.variables_meta[k]['data_range']]
443
                except TypeError:
444
                    var.attrs['range'] = [np.min(self.variables[k]), np.max(self.variables[k])] if self.variables_meta[k]['data_range'] else 'false'
445
                except AttributeError:
446
                    var.attrs['range'] = [np.min(self.variables[k]), np.max(self.variables[k])]
447
                try:
448
                    var.attrs['color_scale'] = self.variables_meta[k]['color_scale']
449
                except (AttributeError, TypeError):
450
                    var.attrs['color_scale'] = "nipy_spectral"
451
                try:
452
                    var.attrs['log_range'] = str(self.variables_meta[k]['log_range'])
453
                except:
454
                    var.attrs['log_range'] = "false"
455
                var.attrs['comment'] = 'Mean of variable {0}'.format(k)
456

    
457
                var[time_step, :] = self.variables[k]
458

    
459
                try:
460
                    cvar = grp[k+'_count']
461
                    cvar.resize(self.time_index_in_output+1, axis=0)
462
                except:
463
                    if self.time_index_in_output > 0:
464
                        continue
465
                    cvar = grp.create_dataset(k+'_count', (1, self.n_total), dtype=np.int32, maxshape=(None, self.n_total), **compress)
466
                    cvar.dims.create_scale(ref['time'], 'time')
467
                    cvar.dims[0].attach_scale(ref['time'])
468
                    cvar.dims.create_scale(grp['rgrid'], 'rgrid')
469
                    cvar.dims[1].attach_scale(grp['rgrid'])
470

    
471
                cvar[time_step, :] = self.variables[k+'_count']
472

    
473
    ## Make a plot of a specified variable
474
    #
475
    #  @param varname The name of the variable to plot.
476
    #  @param figure  The matplotlib.figure.Figure instance to plot to.
477
    #  @param ax      The matplotlib.axes.Axes object to use. Default is None, in that case `ax = matplotlib.pyplot.gca()` shall be used.
478
    #  @param projection     Keyword parameter, boolean. Which part of the world to plot. One of 'north', 'south' or other.
479
    #         'north' maps to `npstere`, 'south' maps to `spstere` and the rest is mapped to `mbtfpq` * a worldwide projection.
480
    #  @param resolution     Keyword parameter, string. One of 'c' (crude), 'l' (low), 'i' (intermediate), 'h' (high), 'f' (full). Defaults to 'l'.
481
    #  @param add_colorbar   Keyword parameter, boolean. Add a colorbar to this plot.
482
    #  @param linewidth      Keyword parameter, float. Width in points of the lines used for the coastlines and other map-related line art.
483
    #  @param color          Keyword parameter, string. Color used for the coastlines and other map-related line art.
484
    #  @param time    Keyword parameter, either a tuple/list of datetime.datetime objects or a single datetime.datetime object indicating the
485
    #                 date or date-range of the underlying data.
486
    #
487
    #  @return `False` if no plot could be created, `True` otherwise.
488
    #
489
    #  The world plot is made by looping over the latitude bands, with a call to `pcolor()` for each latitude band.
490
    def plot(self, varname, figure, ax=None, ax_idx=None, projection='world',
491
            resolution='110m', add_colorbar=True, **kwargs):
492
        import matplotlib as mpl
493
        mpl.use('svg') # to allow running without X11
494
        mpl.rc('font', family='DejaVu Sans')
495
        import cartopy.crs as ccrs
496
        import matplotlib.pyplot as plt
497
        import matplotlib.path as mpath
498
        import matplotlib.cm as cm
499
        from matplotlib.colors import Normalize, LogNorm
500
        import matplotlib.ticker as ticker
501
        import temis_color_tables
502

    
503
        if 'zonal' in kwargs and kwargs['zonal']:
504
            return self.plot_zonal_mean(varname, figure, ax=ax, **kwargs)
505

    
506
        plot_count =  'count' in kwargs and kwargs['count']
507
        preliminary_unvalidated_data = ('preliminary_unvalidated_data' in kwargs and kwargs['preliminary_unvalidated_data'])
508

    
509
        if varname not in self.variables:
510
            self.logger.debug("Worldplot data for '%s' not found.", varname)
511
            return False
512

    
513
        if ax is not None:
514
            self.logger.error("For geolocation plots the axis object should not be supplied. Axis control can use the ax_idx parameter.")
515
            return False
516

    
517
        if ax_idx is None or len(ax_idx) != 3:
518
            ax_idx = (1,1,1)
519

    
520
        if 'linewidth' in kwargs:
521
            linewidth = kwargs['linewidth']
522
        else:
523
            linewidth = 0.5
524

    
525
        if 'color' in kwargs:
526
            linecolor = kwargs['color']
527
        else:
528
            linecolor = 'black'
529

    
530
        pcar = ccrs.PlateCarree()
531
        if 'north' in projection.lower():
532
            m = ccrs.NorthPolarStereo(central_longitude=0.0)
533
            ax = figure.add_subplot(ax_idx[0], ax_idx[1], ax_idx[2], projection=m)
534
            ax.set_extent([0, 359.99999, 35, 90], crs=pcar)
535
            alpha=1.0
536
        elif 'south' in projection.lower():
537
            m = ccrs.SouthPolarStereo(central_longitude=0.0)
538
            ax = figure.add_subplot(ax_idx[0], ax_idx[1], ax_idx[2], projection=m)
539
            ax.set_extent([0, 359.99999, -90, -35], crs=pcar)
540
            alpha=1.0
541
        else:
542
            if plot_count:
543
                self.logger.info("Plotting map of number of observations")
544
            else:
545
                self.logger.info("Plotting map of '%s'", varname)
546
            # moll -> Mollweide (global, elliptical, equal-area projection), lon_0=0.0
547
            # hammer -> Hammer (global, elliptical, equal-area projection), lon_0=0.0 (not in cartopy)
548
            # robin -> Robinson (global projection once used by the National Geographic Society for world maps), lon_0=0.0
549
            # eck4 -> Eckert IV (global equal-area projection), lon_0=0.0 (not in cartopy)
550
            # kav7 -> Kavrayskiy VII (global projection similar to Robinson, used widely in the former Soviet Union), lon_0=0.0 (not in cartopy)
551
            # mbtfpq -> McBryde-Thomas Flat Polar Quartic (global equal-area projection), lon_0=0.0 (not in cartopy)
552
            # cyl -> Equidistant Cylindrical (simplest projection, just displays the world in latitude/longitude coordinates), llcrnrlat=-90,urcrnrlat=90,llcrnrlon=-180,urcrnrlon=180
553

    
554
            # m = ccrs.Mollweide(central_longitude=0)
555
            # m = ccrs.EqualEarth(central_longitude=0)
556
            # m = ccrs.Robinson(central_longitude=0)
557
            # Emergency work-around for issues in cartopy.
558
            m = pcar
559
            
560
            alpha=1.0
561
            ax = figure.add_subplot(ax_idx[0], ax_idx[1], ax_idx[2], projection=m)
562

    
563
        if 'time' in kwargs:
564
            if isinstance(kwargs['time'], (list, tuple)):
565
                tt = kwargs['time']
566
            else:
567
                tt = [kwargs['time']]
568
            timestr = "".join([t.strftime("%Y-%m-%d") for t in tt])
569
            ax.set_title(timestr)
570

    
571

    
572
        if 'north' in projection.lower() or 'south' in projection.lower():
573
            circle_path = mpath.Path.circle(radius=6.111e6) # circle of 6111 km radius, which is approximately down to 35 degrees.
574
            ax.set_boundary(circle_path, transform=m)
575
        else:
576
            # make path that corresponds with extent of the imshow call below.
577
            bbox_path = mpath.Path(np.asarray([[-180.0, -90], [180.0, -90.0], [180.0, 90.0], [-180.0, 90.0], [-180.0, -90]]),
578
                            codes=np.asarray([1,  2,  2,  2, 79], dtype=np.uint8), closed=True)
579
            ax.set_boundary(bbox_path, transform=pcar)
580

    
581
        # ax.outline_patch.set_visible(False)
582

    
583
        if preliminary_unvalidated_data:
584
            plt.text(0.5, 0.5, "Preliminary and unvalidated data\nNot for publication",
585
                     horizontalalignment='center', verticalalignment='center',
586
                     fontsize="large", color="red", rotation=30, stretch="condensed",
587
                     transform=plt.gca().transAxes, family='sans-serif',
588
                     style="italic", weight="semibold", zorder=4)
589
        try:
590
            cmapname = self.variables_meta[varname]['color_scale']
591
        except KeyError:
592
            self.logger.warning("Colorscale for '{0}' was not found, using 'nipy_spectral' instead".format(varname))
593
            cmapname = "nipy_spectral"
594
        try:
595
            colormap = cm.get_cmap(cmapname)
596
        except:
597
            self.logger.warning("Colorscale '{0}' was not found, using 'nipy_spectral' instead".format(cmapname))
598
            colormap = cm.get_cmap("nipy_spectral")
599

    
600
        if plot_count:
601
            colormap = cm.get_cmap("nipy_spectral")
602

    
603
        if plot_count:
604
            data_range = [0, np.max(self.variables[varname + '_count'])]
605
            if data_range[1] == 0:
606
                return False
607
            scale = 10**(math.floor(math.log10(data_range[1])))
608
            i = 1
609
            while scale*i < data_range[1]: i+=1
610
            data_range[1] = i*scale
611
            if data_range[1] == 0: data_range[1] = 1
612
        else:
613
            try:
614
                data_range = self.variables_meta[varname]['data_range']
615
            except KeyError:
616
                data_range = [np.min(self.variables[varname]), np.max(self.variables[varname])]
617

    
618
        if isinstance(data_range, bool) and not data_range:
619
            data_range = [np.min(self.variables[varname]), np.max(self.variables[varname])]
620

    
621
        colormap.set_over(color='0.25')
622
        colormap.set_under(color='0.75')
623
        colormap.set_bad(color='0.5')
624

    
625
        try:
626
            is_log = self.variables_meta[varname]['log_range']
627
        except KeyError:
628
            is_log = False
629
        if plot_count:
630
            is_log = False
631

    
632
        if is_log:
633
            self.logger.debug("{0}.plot(): Logarithmic range for '{1}' [{2[0]}, {2[1]}]".format(self.__class__.__name__, varname, data_range))
634
            normalizer = LogNorm(vmin=data_range[0], vmax=data_range[1])
635
        else:
636
            self.logger.debug("{0}.plot(): Linear range for '{1}' [{2[0]}, {2[1]}]".format(self.__class__.__name__, varname, data_range))
637
            normalizer = Normalize(vmin=data_range[0], vmax=data_range[1])
638

    
639
        mapper = cm.ScalarMappable(norm=normalizer, cmap=colormap)
640

    
641
        if varname == self.graph_rep['name']:
642
            self.logger.debug("{0}.plot(): Reusing platecarree gridded data for '{1}'".format(self.__class__.__name__, varname))
643
            img = self.graph_rep['data']
644
        else:
645
            self.logger.debug("{0}.plot(): Mapping to platecarree gridded data for '{1}'".format(self.__class__.__name__, varname))
646

    
647
            n_target = self.n_latitude*2*36
648
            full_data_array = np.zeros((self.n_latitude, n_target), dtype=np.float64)
649

    
650
            for i in range(self.latitude_centers.shape[0]):
651
                n = self.grid_length[i]
652
                ratio = n_target/n
653

    
654
                if plot_count:
655
                    try:
656
                        data_array = np.ma.asarray(self.data_for_latitude_band(i, varname, count=True))
657
                        data_array.mask = (self.data_for_latitude_band(i, varname, count=True) == 0)
658
                    except KeyError:
659
                        self.logger.warning("Variable '{0}' could not be plotted".format(varname))
660
                        return False
661
                else:
662
                    try:
663
                        data_array = np.ma.asarray(self.data_for_latitude_band(i, varname))
664
                        data_array.mask = (self.data_for_latitude_band(i, varname, count=True) == 0)
665
                    except KeyError:
666
                        self.logger.warning("Variable '{0}' could not be plotted".format(varname))
667
                        return False
668

    
669
                if np.all(data_array.mask):
670
                    full_data_array[i, :] = np.nan
671
                    continue
672

    
673
                for j in range(n):
674
                    start_idx = int(math.floor(ratio*j+0.5))
675
                    end_idx = int(math.floor(ratio*(j+1)+0.5))
676
                    if data_array.mask[j]:
677
                        full_data_array[i, start_idx:end_idx] = np.nan
678
                    else:
679
                        full_data_array[i, start_idx:end_idx] = data_array[j]
680

    
681
            img = mapper.to_rgba(full_data_array)
682

    
683
            self.logger.debug("{0}.plot(): data array shape '{1}' ({2} bytes)".format(self.__class__.__name__, full_data_array.shape, np.prod(full_data_array.shape)*full_data_array.dtype.itemsize))
684
            self.logger.debug("{0}.plot(): img array shape '{1}' ({2} bytes)".format(self.__class__.__name__, full_data_array.shape, np.prod(img.shape)*img.dtype.itemsize))
685

    
686
            self.graph_rep['data'] = img
687
            self.graph_rep['name'] = varname
688

    
689
        extent = [-180.0, 180.0, -90.0, 90.0]
690
        mappable = ax.imshow(img[::2, ::2], origin='lower', extent=extent, transform=pcar)
691

    
692
        if 'magic' not in projection.lower():
693
            ax.coastlines(linewidth=linewidth, color=linecolor, resolution=('110m' if resolution == 'l' else '50m'), zorder=3)
694

    
695
            if 'north' in projection.lower():
696
                yticks = np.arange(45, 90, 15)
697
                xticks = np.arange(-180, 181, 60)
698
            elif 'south' in projection.lower():
699
                yticks = np.arange(-75, -44, 15)
700
                xticks = np.arange(-180, 181, 60)
701
            else:
702
                yticks = np.arange(-90, 91, 30)
703
                xticks = np.arange(-180, 181, 30)
704

    
705
            ax.gridlines(xlocs=xticks, ylocs=yticks, crs=pcar, linewidth=linewidth, color=linecolor, zorder=4)
706

    
707
        if add_colorbar:
708
            if self.variables_meta[varname]['log_range']:
709
                fmt = ticker.LogFormatterMathtext(base=10.0, labelOnlyBase=False)
710
            else:
711
                fmt = ticker.ScalarFormatter(useMathText=True)
712
                fmt.set_powerlimits((-3,3))
713

    
714
            cbar = figure.colorbar(mappable, ax=ax, cmap=colormap,
715
                                   norm=normalizer, extend='both',
716
                                   orientation='horizontal', shrink=0.6, aspect=30,
717
                                   format=fmt)
718
            ax1 = cbar.ax
719
            ax1.clear()
720
            cbar = mpl.colorbar.ColorbarBase(ax1, cmap=colormap, norm=normalizer, extend='both',
721
                                             orientation='horizontal', format=fmt)
722
            cbar.ax.xaxis.set_tick_params(which='major', width=1.0, length=6.0, direction='out')
723

    
724
            if plot_count:
725
                cbar.set_label("Number of observations per cell")
726
            elif self.variables_meta[varname]['units'] != "1":
727
                cbar.set_label("{0} [{1}]".format(self.variables_meta[varname]['title'], self.variables_meta[varname]['units']))
728
            else:
729
                cbar.set_label("{0}".format(self.variables_meta[varname]['title']))
730
        return True
731

    
732
    def zonal_mean(self, varname):
733
        if varname not in self.variables:
734
            self.logger.debug("Worldplot data for '%s' not found.", varname)
735
            return None
736
        zonal = np.ma.zeros((self.grid_length.shape[0],), dtype=np.float64)
737
        zonal_mask = np.zeros((self.grid_length.shape[0],), dtype=np.bool)
738

    
739
        for i in range(self.grid_length.shape[0]):
740
            data_array = self.data_for_latitude_band(i, varname)
741
            data_count = self.data_for_latitude_band(i, varname, count=True)
742
            if data_count.sum() > 0:
743
                zonal[i] = np.sum(data_array*data_count)/np.sum(data_count)
744
            else:
745
                zonal_mask[i] = True
746

    
747
        zonal.mask = zonal_mask
748

    
749
        if np.all(zonal_mask):
750
            return None
751

    
752
        rval = {'zonal_mean':zonal,
753
                'lat':self.latitude_centers,
754
                'lat_bounds':self.latitude_bounds}
755
        try:
756
            rval['title'] = self.variables_meta[varname]['title']
757
        except KeyError:
758
            rval['title'] = varname
759

    
760
        try:
761
            rval['units'] = self.variables_meta[varname]['units']
762
        except KeyError:
763
            rval['units'] = "1"
764

    
765
        try:
766
            rval['range'] = self.variables_meta[varname]['data_range']
767
        except KeyError:
768
            rval['range'] = False
769

    
770
        try:
771
            rval['log_range'] = self.variables_meta[varname]['log_range']
772
        except KeyError:
773
            rval['log_range'] = False
774

    
775
        try:
776
            rval['colorscale'] = self.variables_meta[varname]['color_scale']
777
        except KeyError:
778
            rval['colorscale'] = "nipy_spectral"
779

    
780
        return rval
781

    
782
    def plot_zonal_mean(self, varname, figure, ax=None, autorange=False, **kwargs):
783
        import matplotlib as mpl
784
        mpl.use('svg') # to allow running without X11
785
        mpl.rc('font', family='DejaVu Sans')
786
        import matplotlib.pyplot as plt
787
        import matplotlib.cm as cm
788
        from matplotlib.colors import Normalize, LogNorm
789
        import matplotlib.ticker as ticker
790
        import temis_color_tables
791

    
792
        data = self.zonal_mean(varname)
793
        if data is None:
794
            return False
795
        self.logger.info("Plotting zonal average of %s", varname)
796

    
797
        if ax is None:
798
            ax = figure.add_subplot(1, 1, 1)
799

    
800
        ax.plot(data['zonal_mean'], data['lat'], 'k-')
801
        ax.set_xlabel(data['title'] if data['units'] == "1" else "{0} [{1}]".format(data['title'], data['units']))
802
        ax.set_ylabel("Latitude [degrees]")
803
        if data['log_range']:
804
            if np.any(data['zonal_mean'] <= 0):
805
                self.logger.warning("Log scale requested on negative data in zonal mean plot of '%s' (ignoring request)", varname)
806
            else:
807
                ax.set_xscale('log')
808
        if not autorange:
809
            ax.set_xlim(data['range'])
810
        ax.set_ylim(-90.0, 90.0)
811

    
812
        return True