tesseract  v4.0.0-17-g361f3264
Open Source OCR Engine
weightmatrix.h
1 // File: weightmatrix.h
3 // Description: Hides distinction between float/int implementations.
4 // Author: Ray Smith
5 // Created: Tue Jun 17 09:05:39 PST 2014
6 //
7 // (C) Copyright 2014, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #ifndef TESSERACT_LSTM_WEIGHTMATRIX_H_
20 #define TESSERACT_LSTM_WEIGHTMATRIX_H_
21 
22 #include <memory>
23 #include "genericvector.h"
24 #include "intsimdmatrix.h"
25 #include "matrix.h"
26 #include "tprintf.h"
27 
28 namespace tesseract {
29 
30 // Convenience instantiation of GENERIC_2D_ARRAY<double> with additional
31 // operations to write a strided vector, so the transposed form of the input
32 // is memory-contiguous.
33 class TransposedArray : public GENERIC_2D_ARRAY<double> {
34  public:
35  // Copies the whole input transposed, converted to double, into *this.
36  void Transpose(const GENERIC_2D_ARRAY<double>& input);
37  // Writes a vector of data representing a timestep (gradients or sources).
38  // The data is assumed to be of size1 in size (the strided dimension).
39  virtual ~TransposedArray();
40  void WriteStrided(int t, const float* data) {
41  int size1 = dim1();
42  for (int i = 0; i < size1; ++i) put(i, t, data[i]);
43  }
44  void WriteStrided(int t, const double* data) {
45  int size1 = dim1();
46  for (int i = 0; i < size1; ++i) put(i, t, data[i]);
47  }
48  // Prints the first and last num elements of the un-transposed array.
49  void PrintUnTransposed(int num) {
50  int num_features = dim1();
51  int width = dim2();
52  for (int y = 0; y < num_features; ++y) {
53  for (int t = 0; t < width; ++t) {
54  if (num == 0 || t < num || t + num >= width) {
55  tprintf(" %g", (*this)(y, t));
56  }
57  }
58  tprintf("\n");
59  }
60  }
61 }; // class TransposedArray
62 
63 // Generic weight matrix for network layers. Can store the matrix as either
64 // an array of floats or int8_t. Provides functions to compute the forward and
65 // backward steps with the matrix and updates to the weights.
66 class WeightMatrix {
67  public:
68  WeightMatrix() : int_mode_(false), use_adam_(false) {}
69  // Sets up the network for training. Initializes weights using weights of
70  // scale `range` picked according to the random number generator `randomizer`.
71  // Note the order is outputs, inputs, as this is the order of indices to
72  // the matrix, so the adjacent elements are multiplied by the input during
73  // a forward operation.
74  int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range,
75  TRand* randomizer);
76  // Changes the number of outputs to the size of the given code_map, copying
77  // the old weight matrix entries for each output from code_map[output] where
78  // non-negative, and uses the mean (over all outputs) of the existing weights
79  // for all outputs with negative code_map entries. Returns the new number of
80  // weights.
81  int RemapOutputs(const std::vector<int>& code_map);
82 
83  // Converts a float network to an int network. Each set of input weights that
84  // corresponds to a single output weight is converted independently:
85  // Compute the max absolute value of the weight set.
86  // Scale so the max absolute value becomes INT8_MAX.
87  // Round to integer.
88  // Store a multiplicative scale factor (as a float) that will reproduce
89  // the original value, subject to rounding errors.
90  void ConvertToInt();
91  // Returns the size rounded up to an internal factor used by the SIMD
92  // implementation for its input.
93  int RoundInputs(int size) const {
94  if (multiplier_ == nullptr) return size;
95  return multiplier_->RoundInputs(size);
96  }
97 
98  // Accessors.
99  bool is_int_mode() const {
100  return int_mode_;
101  }
102  int NumOutputs() const { return int_mode_ ? wi_.dim1() : wf_.dim1(); }
103  // Provides one set of weights. Only used by peep weight maxpool.
104  const double* GetWeights(int index) const { return wf_[index]; }
105  // Provides access to the deltas (dw_).
106  double GetDW(int i, int j) const { return dw_(i, j); }
107 
108  // Allocates any needed memory for running Backward, and zeroes the deltas,
109  // thus eliminating any existing momentum.
110  void InitBackward();
111 
112  // Writes to the given file. Returns false in case of error.
113  bool Serialize(bool training, TFile* fp) const;
114  // Reads from the given file. Returns false in case of error.
115  bool DeSerialize(bool training, TFile* fp);
116  // As DeSerialize, but reads an old (float) format WeightMatrix for
117  // backward compatibility.
118  bool DeSerializeOld(bool training, TFile* fp);
119 
120  // Computes matrix.vector v = Wu.
121  // u is of size W.dim2() - 1 and the output v is of size W.dim1().
122  // u is imagined to have an extra element at the end with value 1, to
123  // implement the bias, but it doesn't actually have it.
124  // Asserts that the call matches what we have.
125  void MatrixDotVector(const double* u, double* v) const;
126  void MatrixDotVector(const int8_t* u, double* v) const;
127  // MatrixDotVector for peep weights, MultiplyAccumulate adds the
128  // component-wise products of *this[0] and v to inout.
129  void MultiplyAccumulate(const double* v, double* inout);
130  // Computes vector.matrix v = uW.
131  // u is of size W.dim1() and the output v is of size W.dim2() - 1.
132  // The last result is discarded, as v is assumed to have an imaginary
133  // last value of 1, as with MatrixDotVector.
134  void VectorDotMatrix(const double* u, double* v) const;
135  // Fills dw_[i][j] with the dot product u[i][] . v[j][], using elements
136  // from u and v, starting with u[i][offset] and v[j][offset].
137  // Note that (matching MatrixDotVector) v[last][] is missing, presumed 1.0.
138  // Runs parallel if requested. Note that inputs must be transposed.
139  void SumOuterTransposed(const TransposedArray& u, const TransposedArray& v,
140  bool parallel);
141  // Updates the weights using the given learning rate, momentum and adam_beta.
142  // num_samples is used in the Adam correction factor.
143  void Update(double learning_rate, double momentum, double adam_beta,
144  int num_samples);
145  // Adds the dw_ in other to the dw_ is *this.
146  void AddDeltas(const WeightMatrix& other);
147  // Sums the products of weight updates in *this and other, splitting into
148  // positive (same direction) in *same and negative (different direction) in
149  // *changed.
150  void CountAlternators(const WeightMatrix& other, double* same,
151  double* changed) const;
152 
153  void Debug2D(const char* msg);
154 
155  // Computes and returns the dot product of the two n-vectors u and v.
156  static double DotProduct(const double* u, const double* v, int n);
157  // Utility function converts an array of float to the corresponding array
158  // of double.
159  static void FloatToDouble(const GENERIC_2D_ARRAY<float>& wf,
161 
162  private:
163  // Computes matrix.vector v = Wu.
164  // u is of size starts.back()+extents.back() and the output v is of size
165  // starts.size().
166  // The weight matrix w, is of size starts.size()xMAX(extents)+add_bias_fwd.
167  // If add_bias_fwd, an extra element at the end of w[i] is the bias weight
168  // and is added to v[i].
169  static void MatrixDotVectorInternal(const GENERIC_2D_ARRAY<double>& w,
170  bool add_bias_fwd, bool skip_bias_back,
171  const double* u, double* v);
172 
173  private:
174  // Choice between float and 8 bit int implementations.
177  // Transposed copy of wf_, used only for Backward, and set with each Update.
179  // Which of wf_ and wi_ are we actually using.
180  bool int_mode_;
181  // True if we are running adam in this weight matrix.
182  bool use_adam_;
183  // If we are using wi_, then scales_ is a factor to restore the row product
184  // with a vector to the correct range.
186  // Weight deltas. dw_ is the new delta, and updates_ the momentum-decaying
187  // amount to be added to wf_/wi_.
190  // Iff use_adam_, the sum of squares of dw_. The number of samples is
191  // given to Update(). Serialized iff use_adam_.
193  // Holds the optimal integer multiplier for this machine.
194  std::unique_ptr<IntSimdMatrix> multiplier_;
195 };
196 
197 } // namespace tesseract.
198 
199 #endif // TESSERACT_LSTM_WEIGHTMATRIX_H_
void MultiplyAccumulate(int n, const double *u, const double *v, double *out)
Definition: functions.h:201
GENERIC_2D_ARRAY< double > dw_
Definition: weightmatrix.h:188
GENERIC_2D_ARRAY< double > updates_
Definition: weightmatrix.h:189
int RoundInputs(int size) const
Definition: weightmatrix.h:93
Definition: helpers.h:42
std::unique_ptr< IntSimdMatrix > multiplier_
Definition: weightmatrix.h:194
void put(ICOORD pos, const double &thing)
Definition: matrix.h:220
GENERIC_2D_ARRAY< int8_t > wi_
Definition: weightmatrix.h:176
GENERIC_2D_ARRAY< double > wf_
Definition: weightmatrix.h:175
Definition: intsimdmatrix.h:25
TransposedArray wf_t_
Definition: weightmatrix.h:178
GenericVector< double > scales_
Definition: weightmatrix.h:185
Definition: serialis.h:77
int dim1() const
Definition: matrix.h:206
bool int_mode_
Definition: weightmatrix.h:180
Definition: baseapi.cpp:94
WeightMatrix()
Definition: weightmatrix.h:68
int dim2() const
Definition: matrix.h:207
bool Serialize(FILE *fp) const
Definition: matrix.h:144
Definition: weightmatrix.h:33
void WriteStrided(int t, const float *data)
Definition: weightmatrix.h:40
Definition: weightmatrix.h:66
bool is_int_mode() const
Definition: weightmatrix.h:99
double GetDW(int i, int j) const
Definition: weightmatrix.h:106
bool DeSerialize(bool swap, FILE *fp)
Definition: matrix.h:161
void Transpose(const GENERIC_2D_ARRAY< double > &input)
Definition: weightmatrix.cpp:42
GENERIC_2D_ARRAY< double > dw_sq_sum_
Definition: weightmatrix.h:192
int NumOutputs() const
Definition: weightmatrix.h:102
void WriteStrided(int t, const double *data)
Definition: weightmatrix.h:44
virtual int index(int column, int row) const
Definition: matrix.h:215
bool use_adam_
Definition: weightmatrix.h:182
void PrintUnTransposed(int num)
Definition: weightmatrix.h:49
const double * GetWeights(int index) const
Definition: weightmatrix.h:104