tesseract  v4.0.0-17-g361f3264
Open Source OCR Engine
lstmtrainer.h
1 // File: lstmtrainer.h
3 // Description: Top-level line trainer class for LSTM-based networks.
4 // Author: Ray Smith
5 // Created: Fri May 03 09:07:06 PST 2013
6 //
7 // (C) Copyright 2013, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #ifndef TESSERACT_LSTM_LSTMTRAINER_H_
20 #define TESSERACT_LSTM_LSTMTRAINER_H_
21 
22 #include "imagedata.h"
23 #include "lstmrecognizer.h"
24 #include "rect.h"
25 #include "tesscallback.h"
26 
27 namespace tesseract {
28 
29 class LSTM;
30 class LSTMTrainer;
31 class Parallel;
32 class Reversed;
33 class Softmax;
34 class Series;
35 
36 // Enum for the types of errors that are counted.
37 enum ErrorTypes {
38  ET_RMS, // RMS activation error.
39  ET_DELTA, // Number of big errors in deltas.
40  ET_WORD_RECERR, // Output text string word recall error.
41  ET_CHAR_ERROR, // Output text string total char error.
42  ET_SKIP_RATIO, // Fraction of samples skipped.
43  ET_COUNT // For array sizing.
44 };
45 
46 // Enum for the trainability_ flags.
48  TRAINABLE, // Non-zero delta error.
49  PERFECT, // Zero delta error.
50  UNENCODABLE, // Not trainable due to coding/alignment trouble.
51  HI_PRECISION_ERR, // Hi confidence disagreement.
52  NOT_BOXED, // Early in training and has no character boxes.
53 };
54 
55 // Enum to define the amount of data to get serialized.
57  LIGHT, // Minimal data for remote training.
58  NO_BEST_TRAINER, // Save an empty vector in place of best_trainer_.
59  FULL, // All data including best_trainer_.
60 };
61 
62 // Enum to indicate how the sub_trainer_ training went.
64  STR_NONE, // Did nothing as not good enough.
65  STR_UPDATED, // Subtrainer was updated, but didn't replace *this.
66  STR_REPLACED // Subtrainer replaced *this.
67 };
68 
70 // Function to restore the trainer state from a given checkpoint.
71 // Returns false on failure.
74 // Function to save a checkpoint of the current trainer state.
75 // Returns false on failure. SerializeAmount determines the amount of the
76 // trainer to serialize, typically used for saving the best state.
77 typedef TessResultCallback3<bool, SerializeAmount, const LSTMTrainer*,
79 // Function to compute and record error rates on some external test set(s).
80 // Args are: iteration, mean errors, model, training stage.
81 // Returns a STRING containing logging information about the tests.
82 typedef TessResultCallback4<STRING, int, const double*, const TessdataManager&,
83  int>* TestCallback;
84 
85 // Trainer class for LSTM networks. Most of the effort is in creating the
86 // ideal target outputs from the transcription. A box file is used if it is
87 // available, otherwise estimates of the char widths from the unicharset are
88 // used to guide a DP search for the best fit to the transcription.
89 class LSTMTrainer : public LSTMRecognizer {
90  public:
91  LSTMTrainer();
92  // Callbacks may be null, in which case defaults are used.
93  LSTMTrainer(FileReader file_reader, FileWriter file_writer,
94  CheckPointReader checkpoint_reader,
95  CheckPointWriter checkpoint_writer,
96  const char* model_base, const char* checkpoint_name,
97  int debug_interval, int64_t max_memory);
98  virtual ~LSTMTrainer();
99 
100  // Tries to deserialize a trainer from the given file and silently returns
101  // false in case of failure. If old_traineddata is not null, then it is
102  // assumed that the character set is to be re-mapped from old_traineddata to
103  // the new, with consequent change in weight matrices etc.
104  bool TryLoadingCheckpoint(const char* filename, const char* old_traineddata);
105 
106  // Initializes the character set encode/decode mechanism directly from a
107  // previously setup traineddata containing dawgs, UNICHARSET and
108  // UnicharCompress. Note: Call before InitNetwork!
109  void InitCharSet(const std::string& traineddata_path) {
110  ASSERT_HOST(mgr_.Init(traineddata_path.c_str()));
111  InitCharSet();
112  }
113  void InitCharSet(const TessdataManager& mgr) {
114  mgr_ = mgr;
115  InitCharSet();
116  }
117 
118  // Initializes the trainer with a network_spec in the network description
119  // net_flags control network behavior according to the NetworkFlags enum.
120  // There isn't really much difference between them - only where the effects
121  // are implemented.
122  // For other args see NetworkBuilder::InitNetwork.
123  // Note: Be sure to call InitCharSet before InitNetwork!
124  bool InitNetwork(const STRING& network_spec, int append_index, int net_flags,
125  float weight_range, float learning_rate, float momentum,
126  float adam_beta);
127  // Initializes a trainer from a serialized TFNetworkModel proto.
128  // Returns the global step of TensorFlow graph or 0 if failed.
129  // Building a compatible TF graph: See tfnetwork.proto.
130  int InitTensorFlowNetwork(const std::string& tf_proto);
131  // Resets all the iteration counters for fine tuning or training a head,
132  // where we want the error reporting to reset.
133  void InitIterations();
134 
135  // Accessors.
136  double ActivationError() const {
137  return error_rates_[ET_DELTA];
138  }
139  double CharError() const { return error_rates_[ET_CHAR_ERROR]; }
140  const double* error_rates() const {
141  return error_rates_;
142  }
143  double best_error_rate() const {
144  return best_error_rate_;
145  }
146  int best_iteration() const {
147  return best_iteration_;
148  }
149  int learning_iteration() const { return learning_iteration_; }
150  int32_t improvement_steps() const { return improvement_steps_; }
151  void set_perfect_delay(int delay) { perfect_delay_ = delay; }
152  const GenericVector<char>& best_trainer() const { return best_trainer_; }
153  // Returns the error that was just calculated by PrepareForBackward.
154  double NewSingleError(ErrorTypes type) const {
156  }
157  // Returns the error that was just calculated by TrainOnLine. Since
158  // TrainOnLine rolls the error buffers, this is one further back than
159  // NewSingleError.
160  double LastSingleError(ErrorTypes type) const {
161  return error_buffers_[type]
164  }
165  const DocumentCache& training_data() const {
166  return training_data_;
167  }
169 
170  // If the training sample is usable, grid searches for the optimal
171  // dict_ratio/cert_offset, and returns the results in a string of space-
172  // separated triplets of ratio,offset=worderr.
174  const ImageData* trainingdata, int iteration, double min_dict_ratio,
175  double dict_ratio_step, double max_dict_ratio, double min_cert_offset,
176  double cert_offset_step, double max_cert_offset, STRING* results);
177 
178  // Provides output on the distribution of weight values.
179  void DebugNetwork();
180 
181  // Loads a set of lstmf files that were created using the lstm.train config to
182  // tesseract into memory ready for training. Returns false if nothing was
183  // loaded.
184  bool LoadAllTrainingData(const GenericVector<STRING>& filenames,
185  CachingStrategy cache_strategy,
186  bool randomly_rotate);
187 
188  // Keeps track of best and locally worst error rate, using internally computed
189  // values. See MaintainCheckpointsSpecific for more detail.
190  bool MaintainCheckpoints(TestCallback tester, STRING* log_msg);
191  // Keeps track of best and locally worst error_rate (whatever it is) and
192  // launches tests using rec_model, when a new min or max is reached.
193  // Writes checkpoints using train_model at appropriate times and builds and
194  // returns a log message to indicate progress. Returns false if nothing
195  // interesting happened.
196  bool MaintainCheckpointsSpecific(int iteration,
197  const GenericVector<char>* train_model,
198  const GenericVector<char>* rec_model,
199  TestCallback tester, STRING* log_msg);
200  // Builds a string containing a progress message with current error rates.
201  void PrepareLogMsg(STRING* log_msg) const;
202  // Appends <intro_str> iteration learning_iteration()/training_iteration()/
203  // sample_iteration() to the log_msg.
204  void LogIterations(const char* intro_str, STRING* log_msg) const;
205 
206  // TODO(rays) Add curriculum learning.
207  // Returns true and increments the training_stage_ if the error rate has just
208  // passed through the given threshold for the first time.
209  bool TransitionTrainingStage(float error_threshold);
210  // Returns the current training stage.
211  int CurrentTrainingStage() const { return training_stage_; }
212 
213  // Writes to the given file. Returns false in case of error.
214  bool Serialize(SerializeAmount serialize_amount,
215  const TessdataManager* mgr, TFile* fp) const;
216  // Reads from the given file. Returns false in case of error.
217  bool DeSerialize(const TessdataManager* mgr, TFile* fp);
218 
219  // De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the
220  // learning rates (by scaling reduction, or layer specific, according to
221  // NF_LAYER_SPECIFIC_LR).
222  void StartSubtrainer(STRING* log_msg);
223  // While the sub_trainer_ is behind the current training iteration and its
224  // training error is at least kSubTrainerMarginFraction better than the
225  // current training error, trains the sub_trainer_, and returns STR_UPDATED if
226  // it did anything. If it catches up, and has a better error rate than the
227  // current best, as well as a margin over the current error rate, then the
228  // trainer in *this is replaced with sub_trainer_, and STR_REPLACED is
229  // returned. STR_NONE is returned if the subtrainer wasn't good enough to
230  // receive any training iterations.
231  SubTrainerResult UpdateSubtrainer(STRING* log_msg);
232  // Reduces network learning rates, either for everything, or for layers
233  // independently, according to NF_LAYER_SPECIFIC_LR.
234  void ReduceLearningRates(LSTMTrainer* samples_trainer, STRING* log_msg);
235  // Considers reducing the learning rate independently for each layer down by
236  // factor(<1), or leaving it the same, by double-training the given number of
237  // samples and minimizing the amount of changing of sign of weight updates.
238  // Even if it looks like all weights should remain the same, an adjustment
239  // will be made to guarantee a different result when reverting to an old best.
240  // Returns the number of layer learning rates that were reduced.
241  int ReduceLayerLearningRates(double factor, int num_samples,
242  LSTMTrainer* samples_trainer);
243 
244  // Converts the string to integer class labels, with appropriate null_char_s
245  // in between if not in SimpleTextOutput mode. Returns false on failure.
246  bool EncodeString(const STRING& str, GenericVector<int>* labels) const {
247  return EncodeString(str, GetUnicharset(), IsRecoding() ? &recoder_ : nullptr,
248  SimpleTextOutput(), null_char_, labels);
249  }
250  // Static version operates on supplied unicharset, encoder, simple_text.
251  static bool EncodeString(const STRING& str, const UNICHARSET& unicharset,
252  const UnicharCompress* recoder, bool simple_text,
253  int null_char, GenericVector<int>* labels);
254 
255  // Performs forward-backward on the given trainingdata.
256  // Returns the sample that was used or nullptr if the next sample was deemed
257  // unusable. samples_trainer could be this or an alternative trainer that
258  // holds the training samples.
259  const ImageData* TrainOnLine(LSTMTrainer* samples_trainer, bool batch) {
260  int sample_index = sample_iteration();
261  const ImageData* image =
262  samples_trainer->training_data_.GetPageBySerial(sample_index);
263  if (image != nullptr) {
264  Trainability trainable = TrainOnLine(image, batch);
265  if (trainable == UNENCODABLE || trainable == NOT_BOXED) {
266  return nullptr; // Sample was unusable.
267  }
268  } else {
270  }
271  return image;
272  }
273  Trainability TrainOnLine(const ImageData* trainingdata, bool batch);
274 
275  // Prepares the ground truth, runs forward, and prepares the targets.
276  // Returns a Trainability enum to indicate the suitability of the sample.
277  Trainability PrepareForBackward(const ImageData* trainingdata,
278  NetworkIO* fwd_outputs, NetworkIO* targets);
279 
280  // Writes the trainer to memory, so that the current training state can be
281  // restored. *this must always be the master trainer that retains the only
282  // copy of the training data and language model. trainer is the model that is
283  // actually serialized.
284  bool SaveTrainingDump(SerializeAmount serialize_amount,
285  const LSTMTrainer* trainer,
286  GenericVector<char>* data) const;
287 
288  // Reads previously saved trainer from memory. *this must always be the
289  // master trainer that retains the only copy of the training data and
290  // language model. trainer is the model that is restored.
292  LSTMTrainer* trainer) const {
293  if (data.empty()) return false;
294  return ReadSizedTrainingDump(&data[0], data.size(), trainer);
295  }
296  bool ReadSizedTrainingDump(const char* data, int size,
297  LSTMTrainer* trainer) const {
298  return trainer->ReadLocalTrainingDump(&mgr_, data, size);
299  }
300  // Restores the model to *this.
301  bool ReadLocalTrainingDump(const TessdataManager* mgr, const char* data,
302  int size);
303 
304  // Sets up the data for MaintainCheckpoints from a light ReadTrainingDump.
305  void SetupCheckpointInfo();
306 
307  // Writes the full recognition traineddata to the given filename.
308  bool SaveTraineddata(const STRING& filename);
309 
310  // Writes the recognizer to memory, so that it can be used for testing later.
311  void SaveRecognitionDump(GenericVector<char>* data) const;
312 
313  // Returns a suitable filename for a training dump, based on the model_base_,
314  // the iteration and the error rates.
315  STRING DumpFilename() const;
316 
317  // Fills the whole error buffer of the given type with the given value.
318  void FillErrorBuffer(double new_error, ErrorTypes type);
319  // Helper generates a map from each current recoder_ code (ie softmax index)
320  // to the corresponding old_recoder code, or -1 if there isn't one.
321  std::vector<int> MapRecoder(const UNICHARSET& old_chset,
322  const UnicharCompress& old_recoder) const;
323 
324  protected:
325  // Private version of InitCharSet above finishes the job after initializing
326  // the mgr_ data member.
327  void InitCharSet();
328  // Helper computes and sets the null_char_.
329  void SetNullChar();
330 
331  // Factored sub-constructor sets up reasonable default values.
332  void EmptyConstructor();
333 
334  // Outputs the string and periodically displays the given network inputs
335  // as an image in the given window, and the corresponding labels at the
336  // corresponding x_starts.
337  // Returns false if the truth string is empty.
338  bool DebugLSTMTraining(const NetworkIO& inputs,
339  const ImageData& trainingdata,
340  const NetworkIO& fwd_outputs,
341  const GenericVector<int>& truth_labels,
342  const NetworkIO& outputs);
343  // Displays the network targets as line a line graph.
344  void DisplayTargets(const NetworkIO& targets, const char* window_name,
345  ScrollView** window);
346 
347  // Builds a no-compromises target where the first positions should be the
348  // truth labels and the rest is padded with the null_char_.
349  bool ComputeTextTargets(const NetworkIO& outputs,
350  const GenericVector<int>& truth_labels,
351  NetworkIO* targets);
352 
353  // Builds a target using standard CTC. truth_labels should be pre-padded with
354  // nulls wherever desired. They don't have to be between all labels.
355  // outputs is input-output, as it gets clipped to minimum probability.
356  bool ComputeCTCTargets(const GenericVector<int>& truth_labels,
357  NetworkIO* outputs, NetworkIO* targets);
358 
359  // Computes network errors, and stores the results in the rolling buffers,
360  // along with the supplied text_error.
361  // Returns the delta error of the current sample (not running average.)
362  double ComputeErrorRates(const NetworkIO& deltas, double char_error,
363  double word_error);
364 
365  // Computes the network activation RMS error rate.
366  double ComputeRMSError(const NetworkIO& deltas);
367 
368  // Computes network activation winner error rate. (Number of values that are
369  // in error by >= 0.5 divided by number of time-steps.) More closely related
370  // to final character error than RMS, but still directly calculable from
371  // just the deltas. Because of the binary nature of the targets, zero winner
372  // error is a sufficient but not necessary condition for zero char error.
373  double ComputeWinnerError(const NetworkIO& deltas);
374 
375  // Computes a very simple bag of chars char error rate.
376  double ComputeCharError(const GenericVector<int>& truth_str,
377  const GenericVector<int>& ocr_str);
378  // Computes a very simple bag of words word recall error rate.
379  // NOTE that this is destructive on both input strings.
380  double ComputeWordError(STRING* truth_str, STRING* ocr_str);
381 
382  // Updates the error buffer and corresponding mean of the given type with
383  // the new_error.
384  void UpdateErrorBuffer(double new_error, ErrorTypes type);
385 
386  // Rolls error buffers and reports the current means.
387  void RollErrorBuffers();
388 
389  // Given that error_rate is either a new min or max, updates the best/worst
390  // error rates, and record of progress.
391  STRING UpdateErrorGraph(int iteration, double error_rate,
392  const GenericVector<char>& model_data,
393  TestCallback tester);
394 
395  protected:
396  // Alignment display window.
398  // CTC target display window.
400  // CTC output display window.
402  // Reconstructed image window.
404  // How often to display a debug image.
406  // Iteration at which the last checkpoint was dumped.
408  // Basename of files to save best models to.
409  STRING model_base_;
410  // Checkpoint filename.
412  // Training data.
415  // Name to use when saving best_trainer_.
417  // Number of available training stages.
419  // Checkpointing callbacks.
422  // TODO(rays) These are pointers, and must be deleted. Switch to unique_ptr
423  // when we can commit to c++11.
424  CheckPointReader checkpoint_reader_;
425  CheckPointWriter checkpoint_writer_;
426 
427  // ===Serialized data to ensure that a restart produces the same results.===
428  // These members are only serialized when serialize_amount != LIGHT.
429  // Best error rate so far.
431  // Snapshot of all error rates at best_iteration_.
433  // Iteration of best_error_rate_.
435  // Worst error rate since best_error_rate_.
437  // Snapshot of all error rates at worst_iteration_.
439  // Iteration of worst_error_rate_.
441  // Iteration at which the process will be thought stalled.
443  // Saved recognition models for computing test error for graph points.
446  // Saved trainer for reverting back to last known best.
448  // A subsidiary trainer running with a different learning rate until either
449  // *this or sub_trainer_ hits a new best.
450  LSTMTrainer* sub_trainer_;
451  // Error rate at which last best model was dumped.
453  // Current stage of training.
455  // History of best error rate against iteration. Used for computing the
456  // number of steps to each 2% improvement.
459  // Number of iterations since the best_error_rate_ was 2% more than it is now.
461  // Number of iterations that yielded a non-zero delta error and thus provided
462  // significant learning. learning_iteration_ <= training_iteration_.
463  // learning_iteration_ is used to measure rate of learning progress.
465  // Saved value of sample_iteration_ before looking for the the next sample.
467  // How often to include a PERFECT training sample in backprop.
468  // A PERFECT training sample is used if the current
469  // training_iteration_ > last_perfect_training_iteration_ + perfect_delay_,
470  // so with perfect_delay_ == 0, all samples are used, and with
471  // perfect_delay_ == 4, at most 1 in 5 samples will be perfect.
473  // Value of training_iteration_ at which the last PERFECT training sample
474  // was used in back prop.
476  // Rolling buffers storing recent training errors are indexed by
477  // training_iteration % kRollingBufferSize_.
478  static const int kRollingBufferSize_ = 1000;
480  // Rounded mean percent trailing training errors in the buffers.
481  double error_rates_[ET_COUNT]; // RMS training error.
482  // Traineddata file with optional dawgs + UNICHARSET and recoder.
483  TessdataManager mgr_;
484 };
485 
486 } // namespace tesseract.
487 
488 #endif // TESSERACT_LSTM_LSTMTRAINER_H_
bool SimpleTextOutput() const
Definition: lstmrecognizer.h:76
bool InitNetwork(const STRING &network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum, float adam_beta)
Definition: lstmtrainer.cpp:171
const DocumentCache & training_data() const
Definition: lstmtrainer.h:165
SubTrainerResult UpdateSubtrainer(STRING *log_msg)
Definition: lstmtrainer.cpp:547
TessResultCallback4< STRING, int, const double *, const TessdataManager &, int > * TestCallback
Definition: lstmtrainer.h:83
int null_char() const
Definition: lstmrecognizer.h:154
void FillErrorBuffer(double new_error, ErrorTypes type)
Definition: lstmtrainer.cpp:949
const double * error_rates() const
Definition: lstmtrainer.h:140
bool ComputeCTCTargets(const GenericVector< int > &truth_labels, NetworkIO *outputs, NetworkIO *targets)
Definition: lstmtrainer.cpp:1119
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
Definition: lstmtrainer.cpp:468
bool ComputeTextTargets(const NetworkIO &outputs, const GenericVector< int > &truth_labels, NetworkIO *targets)
Definition: lstmtrainer.cpp:1099
bool IsRecoding() const
Definition: lstmrecognizer.h:79
int learning_iteration_
Definition: lstmtrainer.h:464
GenericVector< char > worst_model_data_
Definition: lstmtrainer.h:445
void InitCharSet()
Definition: lstmtrainer.cpp:992
bool TransitionTrainingStage(float error_threshold)
Definition: lstmtrainer.cpp:421
double ActivationError() const
Definition: lstmtrainer.h:136
Definition: lstmtrainer.h:50
Definition: lstmtrainer.h:57
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
Definition: lstmtrainer.h:259
Definition: lstmtrainer.h:52
bool Init(const char *data_file_name)
Definition: tessdatamanager.cpp:55
int stall_iteration_
Definition: lstmtrainer.h:442
FileReader file_reader_
Definition: lstmtrainer.h:420
bool LoadAllTrainingData(const GenericVector< STRING > &filenames, CachingStrategy cache_strategy, bool randomly_rotate)
Definition: lstmtrainer.cpp:300
bool TryLoadingCheckpoint(const char *filename, const char *old_traineddata)
Definition: lstmtrainer.cpp:128
CachingStrategy
Definition: imagedata.h:42
double worst_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:438
Definition: imagedata.h:105
void UpdateErrorBuffer(double new_error, ErrorTypes type)
Definition: lstmtrainer.cpp:1248
int training_iteration() const
Definition: lstmrecognizer.h:61
GenericVector< int > best_error_iterations_
Definition: lstmtrainer.h:458
void RollErrorBuffers()
Definition: lstmtrainer.cpp:1261
int checkpoint_iteration_
Definition: lstmtrainer.h:407
SerializeAmount
Definition: lstmtrainer.h:56
int32_t improvement_steps_
Definition: lstmtrainer.h:460
TessResultCallback2< bool, const GenericVector< char > &, LSTMTrainer * > * CheckPointReader
Definition: lstmtrainer.h:69
Definition: unicharset.h:146
bool ReadLocalTrainingDump(const TessdataManager *mgr, const char *data, int size)
Definition: lstmtrainer.cpp:909
Definition: imagedata.h:314
virtual ~LSTMTrainer()
Definition: lstmtrainer.cpp:116
int learning_iteration() const
Definition: lstmtrainer.h:149
Definition: lstmtrainer.h:58
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:481
const GenericVector< char > & best_trainer() const
Definition: lstmtrainer.h:152
Trainability GridSearchDictParams(const ImageData *trainingdata, int iteration, double min_dict_ratio, double dict_ratio_step, double max_dict_ratio, double min_cert_offset, double cert_offset_step, double max_cert_offset, STRING *results)
Definition: lstmtrainer.cpp:243
bool(* FileReader)(const STRING &filename, GenericVector< char > *data)
Definition: genericvector.h:360
double ComputeErrorRates(const NetworkIO &deltas, double char_error, double word_error)
Definition: lstmtrainer.cpp:1130
Definition: serialis.h:77
double best_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:432
double ComputeRMSError(const NetworkIO &deltas)
Definition: lstmtrainer.cpp:1150
Definition: baseapi.cpp:94
void SetNullChar()
Definition: lstmtrainer.cpp:1005
Definition: lstmtrainer.h:48
int debug_interval_
Definition: lstmtrainer.h:405
SubTrainerResult
Definition: lstmtrainer.h:63
int perfect_delay_
Definition: lstmtrainer.h:472
bool Serialize(SerializeAmount serialize_amount, const TessdataManager *mgr, TFile *fp) const
Definition: lstmtrainer.cpp:431
int ReduceLayerLearningRates(double factor, int num_samples, LSTMTrainer *samples_trainer)
Definition: lstmtrainer.cpp:609
void StartSubtrainer(STRING *log_msg)
Definition: lstmtrainer.cpp:517
double learning_rate() const
Definition: lstmrecognizer.h:67
bool MaintainCheckpointsSpecific(int iteration, const GenericVector< char > *train_model, const GenericVector< char > *rec_model, TestCallback tester, STRING *log_msg)
std::vector< int > MapRecoder(const UNICHARSET &old_chset, const UnicharCompress &old_recoder) const
Definition: lstmtrainer.cpp:957
int32_t null_char_
Definition: lstmrecognizer.h:290
int prev_sample_iteration_
Definition: lstmtrainer.h:466
STRING DumpFilename() const
Definition: lstmtrainer.cpp:940
Definition: unicharcompress.h:128
double best_error_rate() const
Definition: lstmtrainer.h:143
int sample_iteration() const
Definition: lstmrecognizer.h:64
void InitIterations()
Definition: lstmtrainer.cpp:218
double ComputeWinnerError(const NetworkIO &deltas)
Definition: lstmtrainer.cpp:1169
bool(* FileWriter)(const GenericVector< char > &data, const STRING &filename)
Definition: genericvector.h:363
Definition: tesscallback.h:1702
bool ReadSizedTrainingDump(const char *data, int size, LSTMTrainer *trainer) const
Definition: lstmtrainer.h:296
int worst_iteration_
Definition: lstmtrainer.h:440
Trainability PrepareForBackward(const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)
Definition: lstmtrainer.cpp:798
int CurrentTrainingStage() const
Definition: lstmtrainer.h:211
int32_t improvement_steps() const
Definition: lstmtrainer.h:150
FileWriter file_writer_
Definition: lstmtrainer.h:421
GenericVector< double > best_error_history_
Definition: lstmtrainer.h:457
Definition: scrollview.h:102
double ComputeWordError(STRING *truth_str, STRING *ocr_str)
Definition: lstmtrainer.cpp:1215
int InitTensorFlowNetwork(const std::string &tf_proto)
Definition: lstmtrainer.cpp:198
double NewSingleError(ErrorTypes type) const
Definition: lstmtrainer.h:154
TessResultCallback3< bool, SerializeAmount, const LSTMTrainer *, GenericVector< char > * > * CheckPointWriter
Definition: lstmtrainer.h:78
double LastSingleError(ErrorTypes type) const
Definition: lstmtrainer.h:160
double CharError() const
Definition: lstmtrainer.h:139
void DebugNetwork()
Definition: lstmtrainer.cpp:293
CheckPointReader checkpoint_reader_
Definition: lstmtrainer.h:424
Definition: lstmtrainer.h:66
ScrollView * ctc_win_
Definition: lstmtrainer.h:401
void PrepareLogMsg(STRING *log_msg) const
Definition: lstmtrainer.cpp:400
STRING UpdateErrorGraph(int iteration, double error_rate, const GenericVector< char > &model_data, TestCallback tester)
Definition: lstmtrainer.cpp:1280
void set_perfect_delay(int delay)
Definition: lstmtrainer.h:151
Definition: lstmtrainer.h:64
Definition: tessdatamanager.h:126
int training_stage_
Definition: lstmtrainer.h:454
Definition: lstmtrainer.h:49
int best_iteration() const
Definition: lstmtrainer.h:146
bool SaveTraineddata(const STRING &filename)
Definition: lstmtrainer.cpp:921
int num_training_stages_
Definition: lstmtrainer.h:418
STRING best_model_name_
Definition: lstmtrainer.h:416
void SaveRecognitionDump(GenericVector< char > *data) const
Definition: lstmtrainer.cpp:930
Definition: lstmtrainer.h:59
Definition: lstmtrainer.h:40
LSTMTrainer()
Definition: lstmtrainer.cpp:73
void ReduceLearningRates(LSTMTrainer *samples_trainer, STRING *log_msg)
Definition: lstmtrainer.cpp:590
Definition: strngs.h:45
Definition: lstmtrainer.h:41
bool EncodeString(const STRING &str, GenericVector< int > *labels) const
Definition: lstmtrainer.h:246
CheckPointWriter checkpoint_writer_
Definition: lstmtrainer.h:425
double ComputeCharError(const GenericVector< int > &truth_str, const GenericVector< int > &ocr_str)
Definition: lstmtrainer.cpp:1187
LSTMTrainer * sub_trainer_
Definition: lstmtrainer.h:450
Definition: lstmtrainer.h:51
Definition: tesscallback.h:1716
static const int kRollingBufferSize_
Definition: lstmtrainer.h:478
DocumentCache * mutable_training_data()
Definition: lstmtrainer.h:168
GenericVector< double > error_buffers_[ET_COUNT]
Definition: lstmtrainer.h:479
Definition: lstmtrainer.h:65
double worst_error_rate_
Definition: lstmtrainer.h:436
const UNICHARSET & GetUnicharset() const
Definition: lstmrecognizer.h:139
Definition: lstmtrainer.h:39
STRING checkpoint_name_
Definition: lstmtrainer.h:411
int size() const
Definition: genericvector.h:71
ScrollView * recon_win_
Definition: lstmtrainer.h:403
void InitCharSet(const TessdataManager &mgr)
Definition: lstmtrainer.h:113
Definition: lstmtrainer.h:43
UnicharCompress recoder_
Definition: lstmrecognizer.h:277
int32_t sample_iteration_
Definition: lstmrecognizer.h:287
void DisplayTargets(const NetworkIO &targets, const char *window_name, ScrollView **window)
Definition: lstmtrainer.cpp:1062
const ImageData * GetPageBySerial(int serial)
Definition: imagedata.h:337
Definition: lstmtrainer.h:38
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
Definition: lstmtrainer.cpp:900
Definition: lstmrecognizer.h:53
ErrorTypes
Definition: lstmtrainer.h:37
Definition: networkio.h:39
bool randomly_rotate_
Definition: lstmtrainer.h:413
GenericVector< char > best_model_data_
Definition: lstmtrainer.h:444
Definition: blamer.h:43
Trainability
Definition: lstmtrainer.h:47
void LogIterations(const char *intro_str, STRING *log_msg) const
Definition: lstmtrainer.cpp:412
GenericVector< char > best_trainer_
Definition: lstmtrainer.h:447
ScrollView * align_win_
Definition: lstmtrainer.h:397
STRING model_base_
Definition: lstmtrainer.h:409
TessdataManager mgr_
Definition: lstmtrainer.h:483
Definition: lstmtrainer.h:89
void EmptyConstructor()
Definition: lstmtrainer.cpp:1014
float error_rate_of_last_saved_best_
Definition: lstmtrainer.h:452
ScrollView * target_win_
Definition: lstmtrainer.h:399
DocumentCache training_data_
Definition: lstmtrainer.h:414
bool empty() const
Definition: genericvector.h:90
bool DebugLSTMTraining(const NetworkIO &inputs, const ImageData &trainingdata, const NetworkIO &fwd_outputs, const GenericVector< int > &truth_labels, const NetworkIO &outputs)
Definition: lstmtrainer.cpp:1029
int best_iteration_
Definition: lstmtrainer.h:434
Definition: lstmtrainer.h:42
double best_error_rate_
Definition: lstmtrainer.h:430
int last_perfect_training_iteration_
Definition: lstmtrainer.h:475
void InitCharSet(const std::string &traineddata_path)
Definition: lstmtrainer.h:109
bool ReadTrainingDump(const GenericVector< char > &data, LSTMTrainer *trainer) const
Definition: lstmtrainer.h:291
bool MaintainCheckpoints(TestCallback tester, STRING *log_msg)
Definition: lstmtrainer.cpp:312