定义新的装配工类#
本节介绍如何向该包中添加新的非线性拟合算法或编写用户定义的拟合器。简而言之,我们需要定义一个错误函数和一个 __call__
方法并定义与此装配工一起使用的约束类型(如果有)。
下面以scipy的SLSQP算法为例进行详细说明。所有装配工的基类是 Fitter
::
class SLSQPFitter(Fitter):
supported_constraints = ['bounds', 'eqcons', 'ineqcons', 'fixed',
'tied']
def __init__(self):
# Most currently defined fitters take no arguments in their
# __init__, but the option certainly exists for custom fitters
super().__init__()
所有的装配工都要取一个模型(他们的 __call__
方法修改模型的参数)作为它们的第一个参数。
接下来,error函数获取拟合算法和输入坐标迭代返回的参数列表,用它们评估模型并返回拟合的某种类型的度量。在本例中,残差平方和用作拟合度量:
def objective_function(self, fps, *args):
model = args[0]
meas = args[-1]
model.fitparams(fps)
res = self.model(*args[1:-1]) - meas
return np.sum(res**2)
这个 __call__
方法执行拟合。作为最低要求,它将所有坐标作为单独的参数。根据需要传递其他参数:
def __call__(self, model, x, y , maxiter=MAXITER, epsilon=EPS):
if model.linear:
raise ModelLinearityException(
'Model is linear in parameters; '
'non-linear fitting methods should not be used.')
model_copy = model.copy()
init_values, _ = model_to_fit_params(model_copy)
self.fitparams = optimize.fmin_slsqp(self.errorfunc, p0=init_values,
args=(y, x),
bounds=self.bounds,
eqcons=self.eqcons,
ineqcons=self.ineqcons)
return model_copy
定义插件装配工#
astropy.modeling
包括一个插件机制,允许在astropy内核外部定义的装配工插入到 astropy.modeling.fitting
namespace through the use of entry points. Entry points are references to importable objects. A tutorial on defining entry points can be found in setuptools' documentation . 插件装配工必须从 Fitter
基类。以便发现并插入装配工 astropy.modeling.fitting
入口点必须插入 astropy.modeling
入口点组
setup(
# ...
entry_points = {'astropy.modeling': 'PluginFitterName = fitter_module:PlugFitterClass'}
)
这将允许用户导入 PlugFitterName
通过 astropy.modeling.fitting
通过
from astropy.modeling.fitting import PlugFitterName
使用此功能的一个项目是 Saba 可作为参考。
使用自定义统计函数#
本节介绍如何使用用户定义的统计函数编写新的装配工。下面的例子显示了一个专门的类,它与两个变量中的不确定性拟合一条直线。
需要以下导入语句::
import numpy as np
from astropy.modeling.fitting import (_validate_model,
fitter_to_model_params,
model_to_fit_params, Fitter,
_convert_input)
from astropy.modeling.optimizers import Simplex
首先需要定义一个统计数据。这可以是函数或可调用类。::
def chi_line(measured_vals, updated_model, x_sigma, y_sigma, x):
"""
Chi^2 statistic for fitting a straight line with uncertainties in x and
y.
Parameters
----------
measured_vals : array
updated_model : `~astropy.modeling.ParametricModel`
model with parameters set by the current iteration of the optimizer
x_sigma : array
uncertainties in x
y_sigma : array
uncertainties in y
"""
model_vals = updated_model(x)
if x_sigma is None and y_sigma is None:
return np.sum((model_vals - measured_vals) ** 2)
elif x_sigma is not None and y_sigma is not None:
weights = 1 / (y_sigma ** 2 + updated_model.parameters[1] ** 2 *
x_sigma ** 2)
return np.sum((weights * (model_vals - measured_vals)) ** 2)
else:
if x_sigma is not None:
weights = 1 / x_sigma ** 2
else:
weights = 1 / y_sigma ** 2
return np.sum((weights * (model_vals - measured_vals)) ** 2)
一般来说,要定义一个新的装配工,只需提供一个统计函数和一个优化器。在本例中,我们将让优化器作为fitter的可选参数,并将统计信息设置为 chi_line
以上:
class LineFitter(Fitter):
"""
Fit a straight line with uncertainties in both variables
Parameters
----------
optimizer : class or callable
one of the classes in optimizers.py (default: Simplex)
"""
def __init__(self, optimizer=Simplex):
self.statistic = chi_line
super().__init__(optimizer, statistic=self.statistic)
最后要定义的是 __call__
方法:
def __call__(self, model, x, y, x_sigma=None, y_sigma=None, **kwargs):
"""
Fit data to this model.
Parameters
----------
model : `~astropy.modeling.core.ParametricModel`
model to fit to x, y
x : array
input coordinates
y : array
input coordinates
x_sigma : array
uncertainties in x
y_sigma : array
uncertainties in y
kwargs : dict
optional keyword arguments to be passed to the optimizer
Returns
------
model_copy : `~astropy.modeling.core.ParametricModel`
a copy of the input model with parameters set by the fitter
"""
model_copy = _validate_model(model,
self._opt_method.supported_constraints)
farg = _convert_input(x, y)
farg = (model_copy, x_sigma, y_sigma) + farg
p0, _, _ = model_to_fit_params(model_copy)
fitparams, self.fit_info = self._opt_method(
self.objective_function, p0, farg, **kwargs)
fitter_to_model_params(model_copy, fitparams)
return model_copy