tesseract  v4.0.0-17-g361f3264
Open Source OCR Engine
series.h
1 // File: series.h
3 // Description: Runs networks in series on the same input.
4 // Author: Ray Smith
5 // Created: Thu May 02 08:20:06 PST 2013
6 //
7 // (C) Copyright 2013, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #ifndef TESSERACT_LSTM_SERIES_H_
20 #define TESSERACT_LSTM_SERIES_H_
21 
22 #include "plumbing.h"
23 
24 namespace tesseract {
25 
26 // Runs two or more networks in series (layers) on the same input.
27 class Series : public Plumbing {
28  public:
29  // ni_ and no_ will be set by AddToStack.
30  explicit Series(const STRING& name);
31  virtual ~Series() = default;
32 
33  // Returns the shape output from the network given an input shape (which may
34  // be partially unknown ie zero).
35  StaticShape OutputShape(const StaticShape& input_shape) const override;
36 
37  STRING spec() const override {
38  STRING spec("[");
39  for (int i = 0; i < stack_.size(); ++i)
40  spec += stack_[i]->spec();
41  spec += "]";
42  return spec;
43  }
44 
45  // Sets up the network for training. Initializes weights using weights of
46  // scale `range` picked according to the random number generator `randomizer`.
47  // Returns the number of weights initialized.
48  int InitWeights(float range, TRand* randomizer) override;
49  // Recursively searches the network for softmaxes with old_no outputs,
50  // and remaps their outputs according to code_map. See network.h for details.
51  int RemapOutputs(int old_no, const std::vector<int>& code_map) override;
52 
53  // Sets needs_to_backprop_ to needs_backprop and returns true if
54  // needs_backprop || any weights in this network so the next layer forward
55  // can be told to produce backprop for this layer if needed.
56  bool SetupNeedsBackprop(bool needs_backprop) override;
57 
58  // Returns an integer reduction factor that the network applies to the
59  // time sequence. Assumes that any 2-d is already eliminated. Used for
60  // scaling bounding boxes of truth data.
61  // WARNING: if GlobalMinimax is used to vary the scale, this will return
62  // the last used scale factor. Call it before any forward, and it will return
63  // the minimum scale factor of the paths through the GlobalMinimax.
64  int XScaleFactor() const override;
65 
66  // Provides the (minimum) x scale factor to the network (of interest only to
67  // input units) so they can determine how to scale bounding boxes.
68  void CacheXScaleFactor(int factor) override;
69 
70  // Runs forward propagation of activations on the input line.
71  // See Network for a detailed discussion of the arguments.
72  void Forward(bool debug, const NetworkIO& input,
73  const TransposedArray* input_transpose, NetworkScratch* scratch,
74  NetworkIO* output) override;
75 
76  // Runs backward propagation of errors on the deltas line.
77  // See Network for a detailed discussion of the arguments.
78  bool Backward(bool debug, const NetworkIO& fwd_deltas,
79  NetworkScratch* scratch, NetworkIO* back_deltas) override;
80 
81  // Splits the series after the given index, returning the two parts and
82  // deletes itself. The first part, up to network with index last_start, goes
83  // into start, and the rest goes into end.
84  void SplitAt(int last_start, Series** start, Series** end);
85 
86  // Appends the elements of the src series to this, removing from src and
87  // deleting it.
88  void AppendSeries(Network* src);
89 };
90 
91 } // namespace tesseract.
92 
93 #endif // TESSERACT_LSTM_SERIES_H_
void AppendSeries(Network *src)
Definition: series.cpp:190
StaticShape OutputShape(const StaticShape &input_shape) const override
Definition: series.cpp:35
Definition: helpers.h:42
int InitWeights(float range, TRand *randomizer) override
Definition: series.cpp:47
PointerVector< Network > stack_
Definition: plumbing.h:136
Definition: static_shape.h:38
Definition: plumbing.h:30
Definition: networkscratch.h:36
void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
Definition: series.cpp:107
Definition: baseapi.cpp:94
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
Definition: series.cpp:62
const STRING & name() const
Definition: network.h:138
void SplitAt(int last_start, Series **start, Series **end)
Definition: series.cpp:160
Definition: weightmatrix.h:33
bool SetupNeedsBackprop(bool needs_backprop) override
Definition: series.cpp:79
Definition: network.h:105
void CacheXScaleFactor(int factor) override
Definition: series.cpp:101
Definition: strngs.h:45
Series(const STRING &name)
Definition: series.cpp:29
bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
Definition: series.cpp:129
virtual ~Series()=default
Definition: series.h:27
Definition: networkio.h:39
STRING spec() const override
Definition: series.h:37
int XScaleFactor() const override
Definition: series.cpp:92