import matplotlib.pyplot as plt
import numpy as np
from skued import gaussian, spectrum_colors
from skued import baseline_dwt

s, intensity = np.load('data/powder.npy')

# Double exponential inelastic background and substrate effects
diffuse = 75 * np.exp(-7 * s) + 55 * np.exp(-2 * s)
substrate1 = 0.8 * gaussian(s, center = s.mean(), fwhm = s.mean()/4)
substrate2 = 0.9 * gaussian(s, center = s.mean()/2.5, fwhm = s.mean()/4)

signal = intensity + diffuse + substrate1 + substrate2
levels = list(range(1,7))
colors = spectrum_colors(levels)

fig, (ax1, ax2) = plt.subplots(nrows = 1, ncols = 2, figsize = (9,3))

ax1.plot(s, signal, 'k-', label = 'Diffraction')
ax1.plot(s, diffuse + substrate1 + substrate2, 'r', label = 'Background')
ax1.set_title('Diffraction pattern of rutile VO$_2$')

for l, c in zip(levels, colors):
        baseline = baseline_dwt(signal, level = l, max_iter = 150, wavelet = 'sym6')
        ax2.plot(s, baseline, color = c, label = f'Level {l}')
ax2.set_title('Baseline examples (DWT)')

for ax in (ax1, ax2):
        ax.set_xlabel('s ($4 \pi / \AA$)')
        ax.set_ylabel('Diffracted intensity (counts)')
        ax.set_xlim([0.2, 0.5])
        ax.set_ylim([30, 80])
        ax.legend()
plt.show()