% MATLAB script for visualization

DATASET = 'tag';
%DATASET = 'spec';
DEBUG = false;
OUTPUT = true;

if strcmp(DATASET, 'tag')
    IS_TAG = true;
elseif strcmp(DATASET, 'spec')
    IS_TAG = false;
else
    error('Unknown dataset type.');
end

if OUTPUT  % require exact Fig. No.
    close all;
    DEBUG = false;
end

data = load_data(DATASET);

% use empirical pmax to compute e and f
pmaxes = [];
num_mods = unique(data(:,3))';  % sorted
for lambda = num_mods
    subset = data( data(:,3) == lambda ,:);
    prob = subset(:,4);
    pmaxes = [pmaxes, max(prob)];
end
P = polyfit(num_mods, pmaxes, 1);
x = min(num_mods):0.01:max(num_mods);
y = P(1) * x + P(2);
figure;
hold on;
grid on;
scatter(num_mods, pmaxes);
xlabel('\lambda');
ylabel('p_{max}');
set(gca, 'FontSize', 20);
e = P(1);
f = P(2);
plot(x, y);
legend('obs', 'fit');

% unified fit
mu = data(:,1);
log_rho = log10(data(:,2));
lambda = data(:,3);
prob = data(:,4);
rs = 1 - mu;
log_rn = -log_rho;

if true  % IS_TAG
    fun = @(r) (e * lambda + f) ./ (1 + exp(-(r(1) + r(2) * rs))) ./ (1 + exp(-(r(3) + r(4) * log_rn))) - prob;
    [x, resnorm, residual, exitflag, output] = lsqnonlin(fun,[0,0,0,0]);
% else
%     fun = @(r) (e * lambda + f) ./ (1 + exp(-(r(1) + r(2) * rs))) ./ (1 + exp(-(r(3) + r(4) * log_rn))) - prob;
%     [x, resnorm, residual, exitflag, output] = lsqnonlin(fun,[0,0,0,0]);
end

a = x(1);
b = x(2);
c = x(3);
d = x(4);
% if ~IS_TAG
%     %g = x(5);
%     %h = x(6);
% end

global pred;
if true  % IS_TAG
    pred = @(rs, logrn, lambda) (e * lambda + f) ./ (1 + exp(-(a + b * rs))) ./ (1 + exp(-(c + d * logrn)));
% else
%     pred = @(rs, logrn, lambda) (e * lambda + f) ./ (1 + exp(-(a + b * rs))) ./ (1 + exp(-(c + d * logrn)));
end
loop_lambda(data, @plot_separate, DEBUG);
plot_residual(residual, true);

% SSE comparison
SSE = resnorm;
RMSE = sqrt(SSE / length(prob));
disp(['Unified Model: RMSE = ' num2str(RMSE) ', SSE = ' num2str(SSE)]);

% load testing datasets
t_dataset = [DATASET '.test'];
t_data = load_data(t_dataset);
t_rs = 1 - t_data(:,1);
t_logrn = -log10(t_data(:,2));
t_lambda = t_data(:,3);
t_prob = t_data(:,4);
t_pred = pred(t_rs, t_logrn, t_lambda);
t_residual = t_pred - t_prob;
plot_residual(t_residual, true);

% testing outside simulation range
t_dataset = [DATASET '.test.noise'];
t_data = load_data(t_dataset);
t_rs = 1 - t_data(:,1);
t_logrn = -log10(t_data(:,2));
t_lambda = t_data(:,3);
t_prob = t_data(:,4);
t_pred = pred(t_rs, t_logrn, t_lambda);
t_residual = t_pred - t_prob;
plot_residual(t_residual, true);

if OUTPUT
    figure(1);
    print('pmax_vs_lambda', '-depsc');
    for i = 1:16
        figure(1 + i);
        title('');  % clear title
        print(['pst_output_' num2str(i)], '-depsc');
    end
    figure(18);
    print('train_abs_res', '-depsc');
    figure(19);
    print('test_abs_res', '-depsc');
    figure(20);
    print('out_abs_res', '-depsc');
    a
    b
    c
    d
    e
    f
end
