from __future__ import absolute_import
#
# Copyright (C) 2011, 2016 Smithsonian Astrophysical Observatory
#
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
from six import iteritems
from .parameter import Parameter
from .model import ArithmeticModel, modelCacher1d
from .basic import TableModel
import numpy, operator
from sherpa.utils.err import ModelErr
__all__ = ('create_template_model', 'TemplateModel', 'KNNInterpolator', 'Template')
def create_template_model(modelname, names, parvals, templates, template_interpolator_name='default'):
"""
Create a TemplateModel model class from template input
`modelname` - name of the template model.
`names` - list of strings that define the order of the
named parameters.
`parvals` - 2-D ndarray of parameter vectors, index corresponds
to the spectrum in `templates`. The parameter grid.
`templates` - list of TableModel objects that contain a spectrum
at a specific parameter vector (corresponds to a row
in `parvals`).
`template_interpolator_name` - name of the template interpolator, or None
for disabling interpolation *between* templates.
See load_template_model for more information.
"""
# Create a list of parameters from input
pars = []
for ii, name in enumerate(names):
minimum = min(parvals[:,ii])
maximum = max(parvals[:,ii])
initial = parvals[:,ii][0]
# Initial parameter value is always first parameter value listed
par = Parameter(modelname, name, initial,
minimum, maximum,
minimum, maximum)
pars.append(par)
# Create the templates table from input
tm = TemplateModel(modelname, pars, parvals, templates)
if template_interpolator_name is not None:
if template_interpolator_name in interpolators:
interp = interpolators[template_interpolator_name]
args = interp[1]
args['template_model'] = tm
args['name'] = modelname
return interp[0](**args)
else:
return tm
[docs]class InterpolatingTemplateModel(ArithmeticModel):
def __init__(self, name, template_model):
self.template_model = template_model
for par in template_model.pars:
self.__dict__[par.name] = par
self.parvals = template_model.parvals
ArithmeticModel.__init__(self, name, template_model.pars)
[docs] def fold(self, data):
for template in self.template_model.templates:
template.fold(data)
@modelCacher1d
def calc(self, p, x0, x1=None, *args, **kwargs):
interpolated_template = self.interpolate(p, x0)
return interpolated_template(x0, x1, *args, **kwargs)
[docs]class KNNInterpolator(InterpolatingTemplateModel):
def __init__(self, name, template_model, k=None, order=2):
self._distances = {}
if k is None:
self.k = 2*template_model.parvals[0].size
else:
self.k = k
self.order = order
InterpolatingTemplateModel.__init__(self, name, template_model)
def _calc_distances(self, point):
self._distances = {}
for i, t_point in enumerate(self.template_model.parvals):
self._distances[i] = numpy.linalg.norm(point - t_point, self.order)
self._distances = sorted(iteritems(self._distances), key=operator.itemgetter(1))
[docs] def interpolate(self, point, x_out):
self._calc_distances(point)
if self._distances[0][1]==0:
return self.template_model.templates[self._distances[0][0]]
k_distances = self._distances[:self.k]
weights = [(idx, 1/numpy.array(distance)) for idx, distance in k_distances]
sum_weights = sum([1/weight for idx, weight in k_distances])
y_out = numpy.zeros(len(x_out))
for idx, weight in weights:
y_out += self.template_model.templates[idx].calc((weight,), x_out)
y_out /= sum_weights
tm = TableModel('interpolated')
tm.load(x_out, y_out)
return tm
[docs]class Template(KNNInterpolator):
def __init__(self, *args, **kwargs):
KNNInterpolator.__init__(self, *args, **kwargs)
[docs]class TemplateModel(ArithmeticModel):
def __init__(self, name='templatemodel', pars=(), parvals=[], templates=[]):
self.parvals = parvals
self.templates = templates
self.index = {}
for par in pars:
self.__dict__[par.name] = par
for ii, parval in enumerate(parvals):
self.index[tuple(parval)] = templates[ii]
ArithmeticModel.__init__(self, name, pars)
self.is_discrete = True
[docs] def fold(self, data):
for template in self.templates:
template.fold(data)
[docs] def get_x(self):
p = tuple(par.val for par in self.pars)
template = self.query(p)
return template.get_x()
[docs] def get_y(self):
p = tuple(par.val for par in self.pars)
template = self.query(p)
return template.get_y()
[docs] def query(self, p):
try:
return self.index[tuple(p)]
except:
raise ModelErr("Interpolation of template parameters was disabled for this model, but parameter values not in the template library have been requested. Please use gridsearch method and make sure the sequence option is consistent with the template library")
@modelCacher1d
def calc(self, p, x0, x1=None, *args, **kwargs):
table_model = self.query(p)
# return interpolated the spectrum according to the input grid (x0, [x1])
return table_model(x0, x1, *args, **kwargs)
interpolators = {
'default' : (Template, {'k':2, 'order':2})
}