#if !defined(mixture_model_h)
#define mixture_model_h

#include <string>
#include <iostream>
#include "newmatap.h"
#include "newmatio.h"
#include "newimage/newimageall.h"
#include "utils/tracer_plus.h"
#include "mmoptions.h"
#include "connected_offset.h"
#include "miscmaths/sparse_matrix.h"
#include "miscmaths/minimize.h"
#include "libprob.h"

using namespace NEWMAT;
using namespace NEWIMAGE;
using namespace MISCMATHS;

namespace Mm{

  inline double boundexp(double x)
  {
    // OUT(x);
    double bound=700;
    if(x>bound)
      x=bound;
    else if(x<-bound)
      x=-bound;

    double ret = std::exp(x);
    // OUT(ret);
    return ret;
  }

  class Distribution
  {
  public:
    Distribution() : useprop(false){}
    virtual float pdf(float val) const = 0;
    virtual float dpdfdmn(float val) const = 0;
    virtual float dpdfdvar(float val) const = 0;
    virtual ~Distribution(){}

    float getmean() const {return mn;}
    float getvar() const {return var;}
    float getprop() const {return prop;}

    virtual bool setparams(float pmn, float pvar, float pprop){mn=pmn;var=pvar;prop=pprop;return true;}

    void setuseprop(bool puseprop) {useprop = puseprop;}

  protected:
    float mn; float var; float prop; float useprop; }; class GaussianDistribution : public Distribution { public: GaussianDistribution() : Distribution() { } virtual float pdf(float val) const { float ret = premult*std::exp(-0.5/var*Sqr(val-mn)); if(useprop) ret *= prop; return ret; } virtual float dpdfdmn(float val) const { float ret = premult*(val-mn)/var*std::exp(-0.5/var*Sqr(val-mn)); return ret; } virtual float dpdfdvar(float val) const { float ret = premult*0.5*(Sqr(val-mn)-var)/std::pow(var,2)*std::exp(-0.5/var*Sqr(val-mn)); return ret; } virtual ~GaussianDistribution(){} virtual bool setparams(float pmean, float pvar, float pprop) { Distribution::setparams(pmean, pvar, pprop); if(pvar<=0) return false; useprop ? premult=pprop/std::sqrt(2*M_PI*pvar) : premult=1.0/std::sqrt(2*M_PI*pvar); return true; } private: float premult; }; class GammaDistribution : public Distribution { public: GammaDistribution(float pminmode = 0) : Distribution(), minmode(pminmode) {} virtual float pdf(float val) const { float ret = 1e-32; if(val > 0){ ret = boundexp(preadd + (a-1) * std::log(val) - b*val); if(useprop) ret *= prop; } return ret; } virtual float dpdfdmn(float val) const { float ret = 0; if(val > 0) ret = dpdfda(val)*(2*mn/var)+dpdfdb(val)*(1/var); return ret; } virtual float dpdfdvar(float val) const { float ret = 0; if(val > 0) ret = dpdfda(val)*(-Sqr(mn)/Sqr(var))+dpdfdb(val)*(-mn/Sqr(var)); return ret; } virtual ~GammaDistribution(){} virtual bool setparams(float pmn, float pvar, float pprop) { Distribution::setparams(pmn, pvar, pprop); bool ret = validate(); a = std::pow(mn,2)/var; b = mn/var; useprop ? preadd= log(prop)+a*std::log(b)-lgam(a) : preadd=a*std::log(b)-lgam(a); useprop ? premult=pprop/std::sqrt(2*M_PI*pvar) : premult=1.0/std::sqrt(2*M_PI*pvar); digama = digamma(a); if(!ret) OUT("invalid gamma"); return ret; } void setminmode(float pminmode) {minmode=pminmode; validate();} private: float dpdfdb(float val) const { return std::pow(b,a-1)*std::pow(val,a-1)/std::exp(lgam(a))*std::exp(-b*val)*(a-b*val); } float dpdfda(float val) const { return std::pow(b,a)*std::pow(val,a-1)/std::exp(lgam(a))*std::exp(-b*val)*(std::log(b)+std::log(val)-digama); } bool validate(); float digama; float preadd; float a; float b; float minmode; float premult; }; class FlippedGammaDistribution : public Distribution { public: FlippedGammaDistribution(float pminmode = 0) : Distribution(), minmode(pminmode) {} virtual float pdf(float val) const { float ret = 1e-32; val = -val; if(val > 0) { ret = boundexp(preadd + (a-1) * std::log(val) - b*val); //ret = boundexp(preadd + (a-1) * std::log(std::abs(val)) - b*std::abs(val)); if(useprop) ret *= prop; } return ret; } virtual float dpdfdmn(float val) const { // flip val val = -val; float pmn = -mn; float ret = 0; if(val > 0) ret = dpdfda(val)*(2*pmn/var)+dpdfdb(val)*(1/var); return -ret; } virtual float dpdfdvar(float val) const { // flip val val = -val; float pmn = -mn; float ret = 0; if(val > 0) ret = dpdfda(val)*(-Sqr(pmn)/Sqr(var))+dpdfdb(val)*(-pmn/Sqr(var)); return ret; } virtual ~FlippedGammaDistribution(){} virtual bool setparams(float pmn, float pvar, float pprop) { Distribution::setparams(pmn, pvar, pprop); bool ret = validate(); a = std::pow(mn,2)/var; b = -(mn)/var; // OUT(a); // OUT(b); // OUT(prop); useprop ? preadd= log(prop)+a*std::log(b)-lgam(a) : preadd=a*std::log(b)-lgam(a); useprop ? premult=pprop/std::sqrt(2*M_PI*pvar) : premult=1.0/std::sqrt(2*M_PI*pvar); digama = digamma(a); if(!ret) OUT("invalid gamma"); return ret; } void setminmode(float pminmode) {minmode=pminmode; validate();} private: bool validate(); float dpdfdb(float val) const { return std::pow(b,a-1)*std::pow(val,a-1)/std::exp(lgam(a))*std::exp(-b*val)*(a-b*val); } float dpdfda(float val) const { return std::pow(b,a)*std::pow(val,a-1)/std::exp(lgam(a))*std::exp(-b*val)*(std::log(b)+std::log(val)-digama); } float digama; float preadd; float a; float b; float minmode; float premult; }; class SmmVoxelFunction : public EvalFunction { public: SmmVoxelFunction(float pdata, vector& pdists, float plambda, float plog_bound) : EvalFunction(), data(pdata), dists(pdists), nclasses(pdists.size()), lambda(plambda), log_bound(plog_bound) {} float evaluate(const ColumnVector& x) const; //evaluate the function virtual ~SmmVoxelFunction(){}; private: SmmVoxelFunction(); const SmmVoxelFunction& operator=(SmmVoxelFunction& par); SmmVoxelFunction(const SmmVoxelFunction&); float data; vector& dists; int nclasses; float lambda; float log_bound; }; class SmmFunction : public gEvalFunction { public: SmmFunction(const ColumnVector& pdata, vector& pdists, const float& pmrf_precision, const volume& pmask, const vector& pconnected_offsets, const volume& pindices, const SparseMatrix& pD, float plambda, float plog_bound); float evaluate(const ColumnVector& x) const; //evaluate the function ReturnMatrix g_evaluate(const ColumnVector& x) const; //evaluate the gradient function virtual ~SmmFunction(){}; private: SmmFunction(); const SmmFunction& operator=(SmmFunction& par); SmmFunction(const SmmFunction&); const ColumnVector& data; vector& dists; const float& mrf_precision; const volume& mask; const vector& connected_offsets; const volume& indices; const SparseMatrix& D; int num_superthreshold; int nclasses; float lambda; float log_bound; }; class SmmFunctionDists : public gEvalFunction //class SmmFunctionDists : public EvalFunction { public: SmmFunctionDists(const ColumnVector& pdata, vector& pdists, const float& pmrf_precision, const volume& pmask, const vector& pconnected_offsets, const volume& pindices, float plambda, float plog_bound, const ColumnVector& m_tildew); float evaluate(const ColumnVector& x) const; //evaluate the function ReturnMatrix g_evaluate(const ColumnVector& x) const; //evaluate the gradient function virtual ~SmmFunctionDists(){}; private: SmmFunctionDists(); const SmmFunctionDists& operator=(SmmFunctionDists& par); SmmFunctionDists(const SmmFunctionDists&); const ColumnVector& data; vector& dists; const float& mrf_precision; const volume& mask; const vector& connected_offsets; const volume& indices; vector w; int num_superthreshold; int nclasses; float lambda; float log_bound; const ColumnVector& m_tildew; }; class Mixture_Model { public: // Constructor Mixture_Model(const volume& pspatial_data, const volume& pmask, const volume& pepi_example_data, float pepibt, vector& pdists, vector >& pw_means, ColumnVector& pY, MmOptions& popts); Mixture_Model(const volume& pspatial_data, const volume& pmask, const volume& pepi_example_data, float pepibt, vector& pdists, vector >& pw_means, ColumnVector& pY, bool pnonspatial=false, int pniters=10, bool pupdatetheta=true, int pdebuglevel=0, float pphi=0.015, float pmrfprecstart=10.0, int pntracesamps=10, float pmrfprecmultiplier=10.0, float pinitmultiplier=6.0, bool pfixmrfprec=false); // setup void setup(); // run void run(); // save data to logger dir void save() ; // Destructor virtual ~Mixture_Model(){} private: Mixture_Model(); const Mixture_Model& operator=(Mixture_Model&); Mixture_Model(Mixture_Model&); void update_theta(); void update_mrf_precision(); void update_tildew_scg(); void update_voxel_tildew_vb(); void calculate_taylor_lik(); void calculate_trace_tildew_D(); void get_weights(vector& weights, const ColumnVector& pmtildew); void get_weights2(vector& weights, vector > >& weights_samps, vector > >& tildew_samps, int nsamps, const ColumnVector& pmtildew); void save_weights(const ColumnVector& pmtildew, const char* affix, bool usesamples = true); int xsize; int ysize; int zsize; int num_superthreshold; int nclasses; const volume& spatial_data; const volume& mask; const volume& epi_example_data; float epibt; volume4D localweights; vector connected_offsets; volume indices; ColumnVector& Y; SparseMatrix D; ColumnVector m_tildew; vector prec_tildew; vector cov_tildew; SparseMatrix precision_lik; ColumnVector derivative_lik; float mrf_precision; // float mrf_precision_old; bool nonspatial; int niters; bool stopearly; bool updatetheta; int debuglevel; // logistic transform params: float lambda; float log_bound; float trace_covariance_tildew_D; int it; vector& dists; vector >& w_means; int ntracesamps; float mrfprecmultiplier; float initmultiplier; bool fixmrfprec; float trace_tol; float scg_tol; vector meanhist; vector mrf_precision_hist; }; ReturnMatrix sum_transform(const RowVector& wtilde, float log_bound); ReturnMatrix logistic_transform(const RowVector& wtilde,float lambda,float log_bound); ReturnMatrix inv_transform(const RowVector& w,float lambda,float log_bound,float initmultiplier); void ggmfit(const RowVector& data, vector& pdists, bool useprops); void plot_ggm(const vector >& w_means, const vector& dists, const volume& mask, const ColumnVector& Y); void make_ggmreport(const vector >& w_means, const vector& dists, const volume& mask, const volume& spatial_data, bool zfstatmode, bool overlay, const volume& epivol, float thresh, bool nonspatial, bool updatetheta, const string& data_name); void calculate_props(const vector >& w_means, vector& dists, const volume& mask); } #endif