tesseract  v4.0.0-17-g361f3264
Open Source OCR Engine
tesseract::WeightMatrix Class Reference

#include <weightmatrix.h>

Collaboration diagram for tesseract::WeightMatrix:

Public Member Functions

 WeightMatrix ()
 
int InitWeightsFloat (int no, int ni, bool use_adam, float weight_range, TRand *randomizer)
 
int RemapOutputs (const std::vector< int > &code_map)
 
void ConvertToInt ()
 
int RoundInputs (int size) const
 
bool is_int_mode () const
 
int NumOutputs () const
 
const double * GetWeights (int index) const
 
double GetDW (int i, int j) const
 
void InitBackward ()
 
bool Serialize (bool training, TFile *fp) const
 
bool DeSerialize (bool training, TFile *fp)
 
bool DeSerializeOld (bool training, TFile *fp)
 
void MatrixDotVector (const double *u, double *v) const
 
void MatrixDotVector (const int8_t *u, double *v) const
 
void MultiplyAccumulate (const double *v, double *inout)
 
void VectorDotMatrix (const double *u, double *v) const
 
void SumOuterTransposed (const TransposedArray &u, const TransposedArray &v, bool parallel)
 
void Update (double learning_rate, double momentum, double adam_beta, int num_samples)
 
void AddDeltas (const WeightMatrix &other)
 
void CountAlternators (const WeightMatrix &other, double *same, double *changed) const
 
void Debug2D (const char *msg)
 

Static Public Member Functions

static double DotProduct (const double *u, const double *v, int n)
 
static void FloatToDouble (const GENERIC_2D_ARRAY< float > &wf, GENERIC_2D_ARRAY< double > *wd)
 

Static Private Member Functions

static void MatrixDotVectorInternal (const GENERIC_2D_ARRAY< double > &w, bool add_bias_fwd, bool skip_bias_back, const double *u, double *v)
 

Private Attributes

GENERIC_2D_ARRAY< double > wf_
 
GENERIC_2D_ARRAY< int8_t > wi_
 
TransposedArray wf_t_
 
bool int_mode_
 
bool use_adam_
 
GenericVector< double > scales_
 
GENERIC_2D_ARRAY< double > dw_
 
GENERIC_2D_ARRAY< double > updates_
 
GENERIC_2D_ARRAY< double > dw_sq_sum_
 
std::unique_ptr< IntSimdMatrixmultiplier_
 

Constructor & Destructor Documentation

◆ WeightMatrix()

tesseract::WeightMatrix::WeightMatrix ( )
inline

Member Function Documentation

◆ AddDeltas()

void tesseract::WeightMatrix::AddDeltas ( const WeightMatrix other)

◆ ConvertToInt()

void tesseract::WeightMatrix::ConvertToInt ( )

◆ CountAlternators()

void tesseract::WeightMatrix::CountAlternators ( const WeightMatrix other,
double *  same,
double *  changed 
) const

◆ Debug2D()

void tesseract::WeightMatrix::Debug2D ( const char *  msg)

◆ DeSerialize()

bool tesseract::WeightMatrix::DeSerialize ( bool  training,
TFile fp 
)

◆ DeSerializeOld()

bool tesseract::WeightMatrix::DeSerializeOld ( bool  training,
TFile fp 
)

◆ DotProduct()

double tesseract::WeightMatrix::DotProduct ( const double *  u,
const double *  v,
int  n 
)
static

◆ FloatToDouble()

void tesseract::WeightMatrix::FloatToDouble ( const GENERIC_2D_ARRAY< float > &  wf,
GENERIC_2D_ARRAY< double > *  wd 
)
static

◆ GetDW()

double tesseract::WeightMatrix::GetDW ( int  i,
int  j 
) const
inline

◆ GetWeights()

const double* tesseract::WeightMatrix::GetWeights ( int  index) const
inline

◆ InitBackward()

void tesseract::WeightMatrix::InitBackward ( )

◆ InitWeightsFloat()

int tesseract::WeightMatrix::InitWeightsFloat ( int  no,
int  ni,
bool  use_adam,
float  weight_range,
TRand randomizer 
)

◆ is_int_mode()

bool tesseract::WeightMatrix::is_int_mode ( ) const
inline

◆ MatrixDotVector() [1/2]

void tesseract::WeightMatrix::MatrixDotVector ( const double *  u,
double *  v 
) const

◆ MatrixDotVector() [2/2]

void tesseract::WeightMatrix::MatrixDotVector ( const int8_t *  u,
double *  v 
) const

◆ MatrixDotVectorInternal()

void tesseract::WeightMatrix::MatrixDotVectorInternal ( const GENERIC_2D_ARRAY< double > &  w,
bool  add_bias_fwd,
bool  skip_bias_back,
const double *  u,
double *  v 
)
staticprivate

◆ MultiplyAccumulate()

void tesseract::WeightMatrix::MultiplyAccumulate ( const double *  v,
double *  inout 
)

◆ NumOutputs()

int tesseract::WeightMatrix::NumOutputs ( ) const
inline

◆ RemapOutputs()

int tesseract::WeightMatrix::RemapOutputs ( const std::vector< int > &  code_map)

◆ RoundInputs()

int tesseract::WeightMatrix::RoundInputs ( int  size) const
inline

◆ Serialize()

bool tesseract::WeightMatrix::Serialize ( bool  training,
TFile fp 
) const

◆ SumOuterTransposed()

void tesseract::WeightMatrix::SumOuterTransposed ( const TransposedArray u,
const TransposedArray v,
bool  parallel 
)

◆ Update()

void tesseract::WeightMatrix::Update ( double  learning_rate,
double  momentum,
double  adam_beta,
int  num_samples 
)

◆ VectorDotMatrix()

void tesseract::WeightMatrix::VectorDotMatrix ( const double *  u,
double *  v 
) const

Member Data Documentation

◆ dw_

GENERIC_2D_ARRAY<double> tesseract::WeightMatrix::dw_
private

◆ dw_sq_sum_

GENERIC_2D_ARRAY<double> tesseract::WeightMatrix::dw_sq_sum_
private

◆ int_mode_

bool tesseract::WeightMatrix::int_mode_
private

◆ multiplier_

std::unique_ptr<IntSimdMatrix> tesseract::WeightMatrix::multiplier_
private

◆ scales_

GenericVector<double> tesseract::WeightMatrix::scales_
private

◆ updates_

GENERIC_2D_ARRAY<double> tesseract::WeightMatrix::updates_
private

◆ use_adam_

bool tesseract::WeightMatrix::use_adam_
private

◆ wf_

GENERIC_2D_ARRAY<double> tesseract::WeightMatrix::wf_
private

◆ wf_t_

TransposedArray tesseract::WeightMatrix::wf_t_
private

◆ wi_

GENERIC_2D_ARRAY<int8_t> tesseract::WeightMatrix::wi_
private

The documentation for this class was generated from the following files: