-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathLSA.py
More file actions
58 lines (51 loc) · 1.98 KB
/
Copy pathLSA.py
File metadata and controls
58 lines (51 loc) · 1.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#####
# Least Square Approximation
#####
import numpy,math
import scipy.optimize as optimization
import matplotlib.pyplot as plt
# Chose a model that will create bimodality.
def func(x, a, b):
return a + b*b*x # Term b*b will create bimodality.
# Create toy data for curve_fit.
xdata = numpy.array([0.0,1.0,2.0,3.0,4.0,5.0])
ydata = numpy.array([0.1,0.9,2.2,2.8,3.9,5.1])
sigma = numpy.array([1.0,1.0,1.0,1.0,1.0,1.0])
# Compute chi-square manifold.
Steps = 101 # grid size
Chi2Manifold = numpy.zeros([Steps,Steps]) # allocate grid
amin = -7.0 # minimal value of a covered by grid
amax = +5.0 # maximal value of a covered by grid
bmin = -4.0 # minimal value of b covered by grid
bmax = +4.0 # maximal value of b covered by grid
for s1 in range(Steps):
for s2 in range(Steps):
# Current values of (a,b) at grid position (s1,s2).
a = amin + (amax - amin)*float(s1)/(Steps-1)
b = bmin + (bmax - bmin)*float(s2)/(Steps-1)
# Evaluate chi-squared.
chi2 = 0.0
for n in range(len(xdata)):
residual = (ydata[n] - func(xdata[n], a, b))/sigma[n]
chi2 = chi2 + residual*residual
Chi2Manifold[Steps-1-s2,s1] = chi2 # write result to grid.
# Plot grid.
plt.figure(1, figsize=(8,4.5))
plt.subplots_adjust(left=0.09, bottom=0.09, top=0.97, right=0.99)
# Plot chi-square manifold.
image = plt.imshow(Chi2Manifold, vmax=50.0,
extent=[amin, amax, bmin, bmax])
# Plot where curve-fit is going to for a couple of initial guesses.
for a_initial in -6.0, -4.0, -2.0, 0.0, 2.0, 4.0:
# Initial guess.
x0 = numpy.array([a_initial, -3.5])
xFit = optimization.curve_fit(func, xdata, ydata, x0, sigma)[0]
plt.plot([x0[0], xFit[0]], [x0[1], xFit[1]], 'o-', ms=4,
markeredgewidth=0, lw=2, color='orange')
plt.colorbar(image) # make colorbar
plt.xlim(amin, amax)
plt.ylim(bmin, bmax)
plt.xlabel(r'$a$', fontsize=24)
plt.ylabel(r'$b$', fontsize=24)
plt.savefig('demo-robustness-curve-fit.png')
plt.show()