tesseract  v4.0.0-17-g361f3264
Open Source OCR Engine
lstmrecognizer.h
1 // File: lstmrecognizer.h
3 // Description: Top-level line recognizer class for LSTM-based networks.
4 // Author: Ray Smith
5 // Created: Thu May 02 08:57:06 PST 2013
6 //
7 // (C) Copyright 2013, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #ifndef TESSERACT_LSTM_LSTMRECOGNIZER_H_
20 #define TESSERACT_LSTM_LSTMRECOGNIZER_H_
21 
22 #include "ccutil.h"
23 #include "helpers.h"
24 #include "imagedata.h"
25 #include "matrix.h"
26 #include "network.h"
27 #include "networkscratch.h"
28 #include "recodebeam.h"
29 #include "series.h"
30 #include "strngs.h"
31 #include "unicharcompress.h"
32 
33 class BLOB_CHOICE_IT;
34 struct Pix;
35 class ROW_RES;
36 class ScrollView;
37 class TBOX;
38 class WERD_RES;
39 
40 namespace tesseract {
41 
42 class Dict;
43 class ImageData;
44 
45 // Enum indicating training mode control flags.
49 };
50 
51 // Top-level line recognizer class for LSTM-based networks.
52 // Note that a sub-class, LSTMTrainer is used for training.
54  public:
57 
58  int NumOutputs() const {
59  return network_->NumOutputs();
60  }
61  int training_iteration() const {
62  return training_iteration_;
63  }
64  int sample_iteration() const {
65  return sample_iteration_;
66  }
67  double learning_rate() const {
68  return learning_rate_;
69  }
71  if (network_ == nullptr) return LT_NONE;
72  StaticShape shape;
73  shape = network_->OutputShape(shape);
74  return shape.loss_type();
75  }
76  bool SimpleTextOutput() const { return OutputLossType() == LT_SOFTMAX; }
77  bool IsIntMode() const { return (training_flags_ & TF_INT_MODE) != 0; }
78  // True if recoder_ is active to re-encode text to a smaller space.
79  bool IsRecoding() const {
81  }
82  // Returns true if the network is a TensorFlow network.
83  bool IsTensorFlow() const { return network_->type() == NT_TENSORFLOW; }
84  // Returns a vector of layer ids that can be passed to other layer functions
85  // to access a specific layer.
87  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
88  Series* series = static_cast<Series*>(network_);
89  GenericVector<STRING> layers;
90  series->EnumerateLayers(nullptr, &layers);
91  return layers;
92  }
93  // Returns a specific layer from its id (from EnumerateLayers).
94  Network* GetLayer(const STRING& id) const {
95  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
96  ASSERT_HOST(id.length() > 1 && id[0] == ':');
97  Series* series = static_cast<Series*>(network_);
98  return series->GetLayer(&id[1]);
99  }
100  // Returns the learning rate of the layer from its id.
101  float GetLayerLearningRate(const STRING& id) const {
102  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
104  ASSERT_HOST(id.length() > 1 && id[0] == ':');
105  Series* series = static_cast<Series*>(network_);
106  return series->LayerLearningRate(&id[1]);
107  } else {
108  return learning_rate_;
109  }
110  }
111  // Multiplies the all the learning rate(s) by the given factor.
112  void ScaleLearningRate(double factor) {
113  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
114  learning_rate_ *= factor;
117  for (int i = 0; i < layers.size(); ++i) {
118  ScaleLayerLearningRate(layers[i], factor);
119  }
120  }
121  }
122  // Multiplies the learning rate of the layer with id, by the given factor.
123  void ScaleLayerLearningRate(const STRING& id, double factor) {
124  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
125  ASSERT_HOST(id.length() > 1 && id[0] == ':');
126  Series* series = static_cast<Series*>(network_);
127  series->ScaleLayerLearningRate(&id[1], factor);
128  }
129 
130  // Converts the network to int if not already.
131  void ConvertToInt() {
132  if ((training_flags_ & TF_INT_MODE) == 0) {
135  }
136  }
137 
138  // Provides access to the UNICHARSET that this classifier works with.
139  const UNICHARSET& GetUnicharset() const { return ccutil_.unicharset; }
140  // Provides access to the UnicharCompress that this classifier works with.
141  const UnicharCompress& GetRecoder() const { return recoder_; }
142  // Provides access to the Dict that this classifier works with.
143  const Dict* GetDict() const { return dict_; }
144  // Sets the sample iteration to the given value. The sample_iteration_
145  // determines the seed for the random number generator. The training
146  // iteration is incremented only by a successful training iteration.
147  void SetIteration(int iteration) {
148  sample_iteration_ = iteration;
149  }
150  // Accessors for textline image normalization.
151  int NumInputs() const {
152  return network_->NumInputs();
153  }
154  int null_char() const { return null_char_; }
155 
156  // Loads a model from mgr, including the dictionary only if lang is not null.
157  bool Load(const char* lang, TessdataManager* mgr);
158 
159  // Writes to the given file. Returns false in case of error.
160  // If mgr contains a unicharset and recoder, then they are not encoded to fp.
161  bool Serialize(const TessdataManager* mgr, TFile* fp) const;
162  // Reads from the given file. Returns false in case of error.
163  // If mgr contains a unicharset and recoder, then they are taken from there,
164  // otherwise, they are part of the serialization in fp.
165  bool DeSerialize(const TessdataManager* mgr, TFile* fp);
166  // Loads the charsets from mgr.
167  bool LoadCharsets(const TessdataManager* mgr);
168  // Loads the Recoder.
169  bool LoadRecoder(TFile* fp);
170  // Loads the dictionary if possible from the traineddata file.
171  // Prints a warning message, and returns false but otherwise fails silently
172  // and continues to work without it if loading fails.
173  // Note that dictionary load is independent from DeSerialize, but dependent
174  // on the unicharset matching. This enables training to deserialize a model
175  // from checkpoint or restore without having to go back and reload the
176  // dictionary.
177  bool LoadDictionary(const char* lang, TessdataManager* mgr);
178 
179  // Recognizes the line image, contained within image_data, returning the
180  // recognized tesseract WERD_RES for the words.
181  // If invert, tries inverted as well if the normal interpretation doesn't
182  // produce a good enough result. The line_box is used for computing the
183  // box_word in the output words. worst_dict_cert is the worst certainty that
184  // will be used in a dictionary word.
185  void RecognizeLine(const ImageData& image_data, bool invert, bool debug,
186  double worst_dict_cert, const TBOX& line_box,
187  PointerVector<WERD_RES>* words, int lstm_choice_mode = 0);
188 
189  // Helper computes min and mean best results in the output.
190  void OutputStats(const NetworkIO& outputs,
191  float* min_output, float* mean_output, float* sd);
192  // Recognizes the image_data, returning the labels,
193  // scores, and corresponding pairs of start, end x-coords in coords.
194  // Returned in scale_factor is the reduction factor
195  // between the image and the output coords, for computing bounding boxes.
196  // If re_invert is true, the input is inverted back to its original
197  // photometric interpretation if inversion is attempted but fails to
198  // improve the results. This ensures that outputs contains the correct
199  // forward outputs for the best photometric interpretation.
200  // inputs is filled with the used inputs to the network.
201  bool RecognizeLine(const ImageData& image_data, bool invert, bool debug,
202  bool re_invert, bool upside_down, float* scale_factor,
203  NetworkIO* inputs, NetworkIO* outputs);
204 
205  // Converts an array of labels to utf-8, whether or not the labels are
206  // augmented with character boundaries.
207  STRING DecodeLabels(const GenericVector<int>& labels);
208 
209  // Displays the forward results in a window with the characters and
210  // boundaries as determined by the labels and label_coords.
211  void DisplayForward(const NetworkIO& inputs,
212  const GenericVector<int>& labels,
213  const GenericVector<int>& label_coords,
214  const char* window_name,
215  ScrollView** window);
216  // Converts the network output to a sequence of labels. Outputs labels, scores
217  // and start xcoords of each char, and each null_char_, with an additional
218  // final xcoord for the end of the output.
219  // The conversion method is determined by internal state.
220  void LabelsFromOutputs(const NetworkIO& outputs, GenericVector<int>* labels,
221  GenericVector<int>* xcoords);
222 
223  protected:
224  // Sets the random seed from the sample_iteration_;
225  void SetRandomSeed() {
226  int64_t seed = static_cast<int64_t>(sample_iteration_) * 0x10000001;
227  randomizer_.set_seed(seed);
229  }
230 
231  // Displays the labels and cuts at the corresponding xcoords.
232  // Size of labels should match xcoords.
233  void DisplayLSTMOutput(const GenericVector<int>& labels,
234  const GenericVector<int>& xcoords,
235  int height, ScrollView* window);
236 
237  // Prints debug output detailing the activation path that is implied by the
238  // xcoords.
239  void DebugActivationPath(const NetworkIO& outputs,
240  const GenericVector<int>& labels,
241  const GenericVector<int>& xcoords);
242 
243  // Prints debug output detailing activations and 2nd choice over a range
244  // of positions.
245  void DebugActivationRange(const NetworkIO& outputs, const char* label,
246  int best_choice, int x_start, int x_end);
247 
248  // As LabelsViaCTC except that this function constructs the best path that
249  // contains only legal sequences of subcodes for recoder_.
250  void LabelsViaReEncode(const NetworkIO& output, GenericVector<int>* labels,
251  GenericVector<int>* xcoords);
252  // Converts the network output to a sequence of labels, with scores, using
253  // the simple character model (each position is a char, and the null_char_ is
254  // mainly intended for tail padding.)
255  void LabelsViaSimpleText(const NetworkIO& output,
256  GenericVector<int>* labels,
257  GenericVector<int>* xcoords);
258 
259  // Returns a string corresponding to the label starting at start. Sets *end
260  // to the next start and if non-null, *decoded to the unichar id.
261  const char* DecodeLabel(const GenericVector<int>& labels, int start, int* end,
262  int* decoded);
263 
264  // Returns a string corresponding to a given single label id, falling back to
265  // a default of ".." for part of a multi-label unichar-id.
266  const char* DecodeSingleLabel(int label);
267 
268  protected:
269  // The network hierarchy.
271  // The unicharset. Only the unicharset element is serialized.
272  // Has to be a CCUtil, so Dict can point to it.
274  // For backward compatibility, recoder_ is serialized iff
275  // training_flags_ & TF_COMPRESS_UNICHARSET.
276  // Further encode/decode ccutil_.unicharset's ids to simplify the unicharset.
278 
279  // ==Training parameters that are serialized to provide a record of them.==
281  // Flags used to determine the training method of the network.
282  // See enum TrainingFlags above.
284  // Number of actual backward training steps used.
286  // Index into training sample set. sample_iteration >= training_iteration_.
288  // Index in softmax of null character. May take the value UNICHAR_BROKEN or
289  // ccutil_.unicharset.size().
290  int32_t null_char_;
291  // Learning rate and momentum multipliers of deltas in backprop.
293  float momentum_;
294  // Smoothing factor for 2nd moment of gradients.
295  float adam_beta_;
296 
297  // === NOT SERIALIZED.
300  // Language model (optional) to use with the beam search.
302  // Beam search held between uses to optimize memory allocation/use.
304 
305  // == Debugging parameters.==
306  // Recognition debug display window.
308 };
309 
310 } // namespace tesseract.
311 
312 #endif // TESSERACT_LSTM_LSTMRECOGNIZER_H_
bool SimpleTextOutput() const
Definition: lstmrecognizer.h:76
ScrollView * debug_win_
Definition: lstmrecognizer.h:307
void LabelsFromOutputs(const NetworkIO &outputs, GenericVector< int > *labels, GenericVector< int > *xcoords)
Definition: lstmrecognizer.cpp:417
void ScaleLayerLearningRate(const STRING &id, double factor)
Definition: lstmrecognizer.h:123
LossType
Definition: static_shape.h:29
int null_char() const
Definition: lstmrecognizer.h:154
TRand randomizer_
Definition: lstmrecognizer.h:298
bool LoadRecoder(TFile *fp)
Definition: lstmrecognizer.cpp:134
void DisplayLSTMOutput(const GenericVector< int > &labels, const GenericVector< int > &xcoords, int height, ScrollView *window)
Definition: lstmrecognizer.cpp:320
void LabelsViaSimpleText(const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)
Definition: lstmrecognizer.cpp:443
bool IsRecoding() const
Definition: lstmrecognizer.h:79
void ScaleLayerLearningRate(const char *id, double factor)
Definition: plumbing.h:111
bool LoadCharsets(const TessdataManager *mgr)
Definition: lstmrecognizer.cpp:124
Definition: helpers.h:42
int32_t training_flags_
Definition: lstmrecognizer.h:283
virtual void ConvertToInt()
Definition: network.h:191
STRING DecodeLabels(const GenericVector< int > &labels)
Definition: lstmrecognizer.cpp:289
float learning_rate_
Definition: lstmrecognizer.h:292
bool Load(const char *lang, TessdataManager *mgr)
Definition: lstmrecognizer.cpp:69
LSTMRecognizer()
Definition: lstmrecognizer.cpp:49
const char * DecodeSingleLabel(int label)
Definition: lstmrecognizer.cpp:504
Definition: static_shape.h:38
RecodeBeamSearch * search_
Definition: lstmrecognizer.h:303
int NumOutputs() const
Definition: network.h:123
void LabelsViaReEncode(const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)
Definition: lstmrecognizer.cpp:429
Definition: lstmrecognizer.h:47
Definition: imagedata.h:105
void set_seed(uint64_t seed)
Definition: helpers.h:46
int training_iteration() const
Definition: lstmrecognizer.h:61
LossType OutputLossType() const
Definition: lstmrecognizer.h:70
const UnicharCompress & GetRecoder() const
Definition: lstmrecognizer.h:141
Definition: networkscratch.h:36
Definition: rect.h:34
Definition: unicharset.h:146
virtual StaticShape OutputShape(const StaticShape &input_shape) const
Definition: network.h:133
int32_t training_iteration_
Definition: lstmrecognizer.h:285
Definition: static_shape.h:30
Definition: recodebeam.h:179
bool LoadDictionary(const char *lang, TessdataManager *mgr)
Definition: lstmrecognizer.cpp:157
LossType loss_type() const
Definition: static_shape.h:50
Definition: serialis.h:77
CCUtil ccutil_
Definition: lstmrecognizer.h:273
Definition: baseapi.cpp:94
GenericVector< STRING > EnumerateLayers() const
Definition: lstmrecognizer.h:86
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
Definition: lstmrecognizer.cpp:99
Definition: network.h:87
Dict * dict_
Definition: lstmrecognizer.h:301
float momentum_
Definition: lstmrecognizer.h:293
void DebugActivationRange(const NetworkIO &outputs, const char *label, int best_choice, int x_start, int x_end)
Definition: lstmrecognizer.cpp:374
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
double learning_rate() const
Definition: lstmrecognizer.h:67
float GetLayerLearningRate(const STRING &id) const
Definition: lstmrecognizer.h:101
int32_t null_char_
Definition: lstmrecognizer.h:290
float adam_beta_
Definition: lstmrecognizer.h:295
Definition: unicharcompress.h:128
Network * GetLayer(const char *id) const
Definition: plumbing.cpp:155
int sample_iteration() const
Definition: lstmrecognizer.h:64
void DisplayForward(const NetworkIO &inputs, const GenericVector< int > &labels, const GenericVector< int > &label_coords, const char *window_name, ScrollView **window)
Definition: lstmrecognizer.cpp:304
bool IsIntMode() const
Definition: lstmrecognizer.h:77
Definition: ccutil.h:51
int32_t IntRand()
Definition: helpers.h:56
void ScaleLearningRate(double factor)
Definition: lstmrecognizer.h:112
Definition: dict.h:88
Definition: scrollview.h:102
TrainingFlags
Definition: lstmrecognizer.h:46
void DebugActivationPath(const NetworkIO &outputs, const GenericVector< int > &labels, const GenericVector< int > &xcoords)
Definition: lstmrecognizer.cpp:347
Definition: network.h:105
NetworkType type() const
Definition: network.h:112
Definition: tessdatamanager.h:126
const Dict * GetDict() const
Definition: lstmrecognizer.h:143
Definition: network.h:78
Definition: strngs.h:45
Definition: static_shape.h:32
float LayerLearningRate(const char *id) const
Definition: plumbing.h:105
bool IsTensorFlow() const
Definition: lstmrecognizer.h:83
~LSTMRecognizer()
Definition: lstmrecognizer.cpp:62
const UNICHARSET & GetUnicharset() const
Definition: lstmrecognizer.h:139
bool Serialize(const TessdataManager *mgr, TFile *fp) const
Definition: lstmrecognizer.cpp:80
Definition: pageres.h:169
int size() const
Definition: genericvector.h:71
void OutputStats(const NetworkIO &outputs, float *min_output, float *mean_output, float *sd)
Definition: lstmrecognizer.cpp:194
Definition: network.h:54
Network * network_
Definition: lstmrecognizer.h:270
UnicharCompress recoder_
Definition: lstmrecognizer.h:277
Definition: series.h:27
void RecognizeLine(const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words, int lstm_choice_mode=0)
Definition: lstmrecognizer.cpp:172
int NumInputs() const
Definition: lstmrecognizer.h:151
void ConvertToInt()
Definition: lstmrecognizer.h:131
int32_t sample_iteration_
Definition: lstmrecognizer.h:287
Definition: lstmrecognizer.h:48
Definition: lstmrecognizer.h:53
NetworkScratch scratch_space_
Definition: lstmrecognizer.h:299
Definition: networkio.h:39
int NumInputs() const
Definition: network.h:120
STRING network_str_
Definition: lstmrecognizer.h:280
int NumOutputs() const
Definition: lstmrecognizer.h:58
UNICHARSET unicharset
Definition: ccutil.h:68
void SetIteration(int iteration)
Definition: lstmrecognizer.h:147
Network * GetLayer(const STRING &id) const
Definition: lstmrecognizer.h:94
void SetRandomSeed()
Definition: lstmrecognizer.h:225
const char * DecodeLabel(const GenericVector< int > &labels, int start, int *end, int *decoded)
Definition: lstmrecognizer.cpp:462
Definition: pageres.h:141