function parNew = egm(par,R,Rnext,wage,wagenext,ppinext,ppinextnext,labormarketstatus,betashock)
%par = egm_c(par,R,Rnext,wage,wagenext,ppinext,labormarketstatus)
%iterate once on consumption and labor supply using the endog grid method

global Params;


npp = Params.npp;


if ~exist('betashock','var')
    betashock = 1;
end

%% Section 1: solve for savings rule using EGM

% asset grid
xthis = [0; Params.knotXi];


nassets = length(xthis);
nc = Params.nc;
assert(nc == nassets);

%compute consumption and marg util for each asset level and income level
MU = NaN(nassets,npp);
margTax = NaN(nassets,npp);
nthis = NaN(nassets,npp);

for jp=1:npp
    
    
    S = savingspline(par(Params.par_sind,jp));
    N = nspline(par(Params.par_nind,jp));
    
    sthis = interp_savspline(S,xthis);
    
    %we have end of period assets in last period and this period,
    %and we have to figure out this period's consumption:
    [cthis, nthis(:,jp), margTax(:,jp)] = get_cnt(xthis,sthis*ppinextnext,N,Rnext,wagenext,jp);
    
    
    if ~ all(cthis>0)
        disp('negative c')
        jp
        I = cthis <=0;
        find(I)
        cthis(I)
        nthis(I,jp)
        xthis(I)
        sthis(I)
    end
    assert(all(cthis>0));
    
    MU(:,jp) = margutilC(cthis);
end


%compute expected marg util
MUexp = zeros(nassets,npp);
for ip=1:npp  %loop over previous income states
    ppsum = 0;
    for jp = 1:npp %loop over this period income state
        pp = transProb_nonLinear(ip,jp,labormarketstatus);
        if(pp>0)
            MUexp(:,ip) = MUexp(:,ip) + pp .* MU(:,jp) .* (1+(1-margTax(:,jp))*(Rnext-1))/ppinext;
            ppsum = ppsum  + pp;
        end
    end
    MUexp(:,ip) = MUexp(:,ip)/ppsum;
    assert(abs(ppsum-1) < 1e-9)
end


clear MU;

Cprev = invmargutilC(Params.betah*betashock*MUexp);
assert(all(Cprev(:) > 0))


%at this stage we know the previous c and n on the b' grid, but we
%don't know n or b
assert(all(size(Cprev) == [Params.nc, Params.npp]))
bprev = egm_bn(repmat(xthis,1,Params.npp),nthis,Cprev,repmat(xthis,1,Params.npp),R-1,ppinext,wage);



%pack results back into par
parNew =  NaN(size(par));

parNew(Params.par_sind(1),:) = bprev(1,:);
for ip = 1:npp
    
    XX = [bprev(:,ip); 1e8];
    tmp = (xthis(end)-xthis(end-1))/(bprev(end,ip)-bprev(end-1,ip))*(1e8-bprev(end,ip)) + xthis(end);
    YY = [xthis; tmp];
    
    parNew(Params.par_sind(2:end),ip) = interp1(XX,YY,bprev(1,ip) + Params.knotXi,'linear');
    
    %if ip == 3
    %    [bprev(end-8:end,ip)'; xthis(end-8:end)';bprev(1,ip) + Params.knotXi(end-8:end)'; parNew(Params.par_sind(end-8:end),ip)']
    %end
end


% consistency check
xthis_chv = Params.vgrid;
xthis = exp(xthis_chv)-1;  %change of variables
for ip = 1:npp
    
    S = savingspline(parNew(Params.par_sind,ip));
    N = nspline(par(Params.par_nind,ip));
    
    sthis = interp_savspline(S,xthis);
    [cthis, nthis, margTax, taxpaid, taxableIncome] = get_cnt(xthis,sthis*ppinext,N,R,wage,ip);
    if ~all(cthis > 0)
        [bprev(end,ip) xthis(end) bprev(1,ip)+Params.knotXi(end) parNew(Params.par_sind(end),ip)]
        ip
        I = cthis <= 0;
        find(I)
        xthis(I)
        nthis(I)
        sthis(I)
        R
        wage
        taxableIncome(I)
    end
    assert(all(cthis > 0), 'error in egm: cthis is zero or negative.')
    end

%% Section 2 --  Solve for n rule using static first order condition
% we have already done this before, but we may not have solved for the
% decision rule in all parts of the state space

xthis = Params.ngrid;
savings = zeros(Params.nn,npp);

nnew = zeros(Params.nn,npp);
for ip=1:npp
    if Params.employed(ip)
        S = savingspline(parNew(Params.par_sind,ip));
        savings(:,ip) = interp_savspline(S,xthis);
        
        N = nspline(par(Params.par_nind,ip));  %for an initial guess
        nnew(:,ip) = interp_nspline(N,xthis,true);
    end
end


nemp = sum(Params.employed);
nnew(:,Params.employed) = solveN(nnew(:,Params.employed), repmat(xthis,1,nemp),savings(:,Params.employed)*ppinext,R-1,wage);

parNew(Params.par_nind,:) = nnew;

%% Section 3 -- Solve for V using Bellman equation
vnew = NaN(Params.nv,npp);


xthis_chv = Params.vgrid;
xthis = exp(xthis_chv)-1;  %change of variables

%loop over this period's income state
for ip = 1:npp
    
    S = savingspline(parNew(Params.par_sind,ip));
    N = nspline(parNew(Params.par_nind,ip));
    
    sthis = interp_savspline(S,xthis);
    [cthis, nthis, margTax, taxpaid, taxableIncome] = get_cnt(xthis,sthis*ppinext,N,R,wage,ip);
    if ~all(cthis > 0)
        ip
        I = cthis <= 0;
        find(I)
        xthis(I)
        nthis(I)
        sthis(I)
        R
        wage
        taxableIncome(I)
    end
    assert(all(cthis > 0), 'error in egm: cthis is zero or negative.')
    
    
    %build the expected continuatiuon value
    assets = sthis;
    Vexp = 0;
    for jp=1:npp
        
        pp = transProb_nonLinear(ip,jp,labormarketstatus);
        
        if any(pp>0)
            Vn = vspline(par(Params.par_vind,jp));
            Vnext = interp_vspline(Vn,assets);
            Vexp = Vexp + pp.*Vnext;
        end
    end
    
    
    vnew(:,ip) = log(cthis) - (Params.psi1/(1+Params.psi2))*(nthis.^(1+Params.psi2)) + Params.betah*betashock*Vexp;
    
    
    
end  %close loop over this period's income state


%pack results back into par
parNew(Params.par_vind,:) = vnew;




end


function [b , n  ] = egm_bn(b0,n0,c,bprime,ii,ppi_prime,wage)
% solves non-linear equation for previous assets and labor given previous c


global Params;

na = size(b0,1);

skill = repmat(Params.skill',na,1);
Tu = repmat(Params.Tu',na,1);
To = repmat(Params.To',na,1);

b = b0;
n = n0;


for it = 1:100
    
    x = ii * b + skill.*wage.*n + Tu;
    [margTax, taxpaid,dmarg] = interp_tax(x, Params.incometax);
    f1 = b + x - taxpaid + To - (1+Params.tauC)*c - bprime*ppi_prime;
    
    
    f2 = (1-margTax(:,Params.employed)).*wage.*skill(:,Params.employed).*margutilC(c(:,Params.employed))/(1+Params.tauC) - Params.psi1*n(:,Params.employed).^(Params.psi2);
    
    if all(abs([f1(:);f2(:)]) < 1e-10)
        break
    end
    
    J1b = 1-margTax*ii;
    J1n = -margTax.*skill*wage;
    
    J2b = -dmarg(:,Params.employed).*ii.*wage.*skill(:,Params.employed).*margutilC(c(:,Params.employed))/(1+Params.tauC);
    J2n = -dmarg(:,Params.employed).*(wage.*skill(:,Params.employed)).^2.*margutilC(c(:,Params.employed))/(1+Params.tauC) - Params.psi1*Params.psi2.*n(:,Params.employed).^(Params.psi2-1);
    
    D = NaN(na,Params.npp,2);
    for ip = 1:Params.npp
        if Params.employed(ip)
            for ia = 1:na
                J = [J1b(ia,ip) J1n(ia,ip); J2b(ia,ip) J2n(ia,ip)];
                if cond(J) > 1e10
                    J
                    n(ia,ip)
                    skill(ia,ip)
                    pause;
                end
                D(ia,ip,:) = -[J1b(ia,ip) J1n(ia,ip); J2b(ia,ip) J2n(ia,ip)]\[f1(ia,ip);f2(ia,ip)];
            end
            n(:,ip) = n(:,ip) + D(:,ip,2);
        else
            for ia = 1:na
                D(ia,ip,1) = -J1b(ia,ip)\f1(ia,ip);
            end
        end
    end
    
    b = b + D(:,:,1);
    
    assert(all(all(abs(n(:,~Params.employed)) < 1e-10)))
    
    
    
end

if ~all(abs([f1(:);f2(:)]) < 1e-10)
   max(abs(f1))
   max(abs(f2))
end
assert(all(abs([f1(:);f2(:)]) < 1e-10))
assert(it < 100)

end

function [n] = solveN(n0, b,savings,ii,wage)
% solves non-linear equation for labor given assets and savings


global Params;

na = size(n0,1);

skill = repmat(Params.skill(Params.employed)',na,1);
Tu = repmat(Params.Tu(Params.employed)',na,1);
To = repmat(Params.To(Params.employed)',na,1);

n = n0;


for it = 1:100
    
    x = ii * b + skill.*wage.*n + Tu;
    [margTax, taxpaid,dmarg] = interp_tax(x, Params.incometax);
    c = (b + x - taxpaid + To  - savings)/ (1+Params.tauC);
    
    
    f2 = (1-margTax).*wage.*skill.*margutilC(c)/(1+Params.tauC) - Params.psi1*n.^(Params.psi2);
    
    if all(abs(f2(:)) < 1e-14)
        break
    end
    
    
    
    J2n = margutilC(c,2) .* (skill.*wage.*(1-margTax)./(1+Params.tauC)).^2 ...
        -dmarg.*(wage.*skill).^2.*margutilC(c)/(1+Params.tauC)...
        - Params.psi1*Params.psi2.*n.^(Params.psi2-1);
    
    D = -f2./J2n;
    
    n = n + D;
    
    
    
    
    
end

end

