tesseract  v4.0.0-17-g361f3264
Open Source OCR Engine
tesseract::LSTM Class Reference

#include <lstm.h>

Inheritance diagram for tesseract::LSTM:
Collaboration diagram for tesseract::LSTM:

Public Types

enum  WeightType {
  CI, GI, GF1, GO,
  GFS, WT_COUNT
}
 

Public Member Functions

 LSTM (const STRING &name, int num_inputs, int num_states, int num_outputs, bool two_dimensional, NetworkType type)
 
virtual ~LSTM ()
 
StaticShape OutputShape (const StaticShape &input_shape) const override
 
STRING spec () const override
 
void SetEnableTraining (TrainingState state) override
 
int InitWeights (float range, TRand *randomizer) override
 
int RemapOutputs (int old_no, const std::vector< int > &code_map) override
 
void ConvertToInt () override
 
void DebugWeights () override
 
bool Serialize (TFile *fp) const override
 
bool DeSerialize (TFile *fp) override
 
void Forward (bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
 
bool Backward (bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
 
void Update (float learning_rate, float momentum, float adam_beta, int num_samples) override
 
void CountAlternators (const Network &other, double *same, double *changed) const override
 
void PrintW ()
 
void PrintDW ()
 
bool Is2D () const
 
- Public Member Functions inherited from tesseract::Network
 Network ()
 
 Network (NetworkType type, const STRING &name, int ni, int no)
 
virtual ~Network ()=default
 
NetworkType type () const
 
bool IsTraining () const
 
bool needs_to_backprop () const
 
int num_weights () const
 
int NumInputs () const
 
int NumOutputs () const
 
virtual StaticShape InputShape () const
 
const STRINGname () const
 
bool TestFlag (NetworkFlags flag) const
 
virtual bool IsPlumbingType () const
 
virtual void SetNetworkFlags (uint32_t flags)
 
virtual void SetRandomizer (TRand *randomizer)
 
virtual bool SetupNeedsBackprop (bool needs_backprop)
 
virtual int XScaleFactor () const
 
virtual void CacheXScaleFactor (int factor)
 
void DisplayForward (const NetworkIO &matrix)
 
void DisplayBackward (const NetworkIO &matrix)
 

Private Member Functions

void ResizeForward (const NetworkIO &input)
 

Private Attributes

int32_t na_
 
int32_t ns_
 
int32_t nf_
 
bool is_2d_
 
WeightMatrix gate_weights_ [WT_COUNT]
 
FullyConnectedsoftmax_
 
NetworkIO source_
 
NetworkIO state_
 
GENERIC_2D_ARRAY< int8_t > which_fg_
 
NetworkIO node_values_ [WT_COUNT]
 
StrideMap input_map_
 
int input_width_
 

Additional Inherited Members

- Static Public Member Functions inherited from tesseract::Network
static NetworkCreateFromFile (TFile *fp)
 
static void ClearWindow (bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
 
static int DisplayImage (Pix *pix, ScrollView *window)
 
- Protected Member Functions inherited from tesseract::Network
double Random (double range)
 
- Protected Attributes inherited from tesseract::Network
NetworkType type_
 
TrainingState training_
 
bool needs_to_backprop_
 
int32_t network_flags_
 
int32_t ni_
 
int32_t no_
 
int32_t num_weights_
 
STRING name_
 
ScrollViewforward_win_
 
ScrollViewbackward_win_
 
TRandrandomizer_
 
- Static Protected Attributes inherited from tesseract::Network
static char const *const kTypeNames [NT_COUNT]
 

Member Enumeration Documentation

◆ WeightType

Enumerator
CI 
GI 
GF1 
GO 
GFS 
WT_COUNT 

Constructor & Destructor Documentation

◆ LSTM()

tesseract::LSTM::LSTM ( const STRING name,
int  num_inputs,
int  num_states,
int  num_outputs,
bool  two_dimensional,
NetworkType  type 
)

◆ ~LSTM()

tesseract::LSTM::~LSTM ( )
virtual

Member Function Documentation

◆ Backward()

bool tesseract::LSTM::Backward ( bool  debug,
const NetworkIO fwd_deltas,
NetworkScratch scratch,
NetworkIO back_deltas 
)
overridevirtual

Reimplemented from tesseract::Network.

◆ ConvertToInt()

void tesseract::LSTM::ConvertToInt ( )
overridevirtual

Reimplemented from tesseract::Network.

◆ CountAlternators()

void tesseract::LSTM::CountAlternators ( const Network other,
double *  same,
double *  changed 
) const
overridevirtual

Reimplemented from tesseract::Network.

◆ DebugWeights()

void tesseract::LSTM::DebugWeights ( )
overridevirtual

Reimplemented from tesseract::Network.

◆ DeSerialize()

bool tesseract::LSTM::DeSerialize ( TFile fp)
overridevirtual

Reimplemented from tesseract::Network.

◆ Forward()

void tesseract::LSTM::Forward ( bool  debug,
const NetworkIO input,
const TransposedArray input_transpose,
NetworkScratch scratch,
NetworkIO output 
)
overridevirtual

Reimplemented from tesseract::Network.

◆ InitWeights()

int tesseract::LSTM::InitWeights ( float  range,
TRand randomizer 
)
overridevirtual

Reimplemented from tesseract::Network.

◆ Is2D()

bool tesseract::LSTM::Is2D ( ) const
inline

◆ OutputShape()

StaticShape tesseract::LSTM::OutputShape ( const StaticShape input_shape) const
overridevirtual

Reimplemented from tesseract::Network.

◆ PrintDW()

void tesseract::LSTM::PrintDW ( )

◆ PrintW()

void tesseract::LSTM::PrintW ( )

◆ RemapOutputs()

int tesseract::LSTM::RemapOutputs ( int  old_no,
const std::vector< int > &  code_map 
)
overridevirtual

Reimplemented from tesseract::Network.

◆ ResizeForward()

void tesseract::LSTM::ResizeForward ( const NetworkIO input)
private

◆ Serialize()

bool tesseract::LSTM::Serialize ( TFile fp) const
overridevirtual

Reimplemented from tesseract::Network.

◆ SetEnableTraining()

void tesseract::LSTM::SetEnableTraining ( TrainingState  state)
overridevirtual

Reimplemented from tesseract::Network.

◆ spec()

STRING tesseract::LSTM::spec ( ) const
inlineoverridevirtual

Reimplemented from tesseract::Network.

◆ Update()

void tesseract::LSTM::Update ( float  learning_rate,
float  momentum,
float  adam_beta,
int  num_samples 
)
overridevirtual

Reimplemented from tesseract::Network.

Member Data Documentation

◆ gate_weights_

WeightMatrix tesseract::LSTM::gate_weights_[WT_COUNT]
private

◆ input_map_

StrideMap tesseract::LSTM::input_map_
private

◆ input_width_

int tesseract::LSTM::input_width_
private

◆ is_2d_

bool tesseract::LSTM::is_2d_
private

◆ na_

int32_t tesseract::LSTM::na_
private

◆ nf_

int32_t tesseract::LSTM::nf_
private

◆ node_values_

NetworkIO tesseract::LSTM::node_values_[WT_COUNT]
private

◆ ns_

int32_t tesseract::LSTM::ns_
private

◆ softmax_

FullyConnected* tesseract::LSTM::softmax_
private

◆ source_

NetworkIO tesseract::LSTM::source_
private

◆ state_

NetworkIO tesseract::LSTM::state_
private

◆ which_fg_

GENERIC_2D_ARRAY<int8_t> tesseract::LSTM::which_fg_
private

The documentation for this class was generated from the following files: