tesseract  v4.0.0-17-g361f3264
Open Source OCR Engine
networkscratch.h
1 // File: networkscratch.h
3 // Description: Scratch space for Network layers that hides distinction
4 // between float/int implementations.
5 // Author: Ray Smith
6 // Created: Thu Jun 19 10:50:29 PST 2014
7 //
8 // (C) Copyright 2014, Google Inc.
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 // http://www.apache.org/licenses/LICENSE-2.0
13 // Unless required by applicable law or agreed to in writing, software
14 // distributed under the License is distributed on an "AS IS" BASIS,
15 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 // See the License for the specific language governing permissions and
17 // limitations under the License.
19 
20 #ifndef TESSERACT_LSTM_NETWORKSCRATCH_H_
21 #define TESSERACT_LSTM_NETWORKSCRATCH_H_
22 
23 #include "genericvector.h"
24 #include "matrix.h"
25 #include "networkio.h"
26 #include "svutil.h"
27 #include "tprintf.h"
28 
29 namespace tesseract {
30 
31 // Generic scratch space for network layers. Provides NetworkIO that can store
32 // a complete set (over time) of intermediates, and GenericVector<float>
33 // scratch space that auto-frees after use. The aim here is to provide a set
34 // of temporary buffers to network layers that can be reused between layers
35 // and don't have to be reallocated on each call.
37  public:
38  NetworkScratch() : int_mode_(false) {}
39  ~NetworkScratch() = default;
40 
41  // Sets the network representation. If the representation is integer, then
42  // default (integer) NetworkIOs are separated from the always-float variety.
43  // This saves memory by having separate int-specific and float-specific
44  // stacks. If the network representation is float, then all NetworkIOs go
45  // to the float stack.
46  void set_int_mode(bool int_mode) {
47  int_mode_ = int_mode;
48  }
49 
50  // Class that acts like a NetworkIO (by having an implicit cast operator),
51  // yet actually holds a pointer to NetworkIOs in the source NetworkScratch,
52  // and knows how to unstack the borrowed pointers on destruction.
53  class IO {
54  public:
55  // The NetworkIO should be sized after construction.
56  IO(const NetworkIO& src, NetworkScratch* scratch)
57  : int_mode_(scratch->int_mode_ && src.int_mode()),
58  scratch_space_(scratch) {
60  : scratch_space_->float_stack_.Borrow();
61  }
62  // Default constructor for arrays. Use one of the Resize functions
63  // below to initialize and size.
64  IO() : int_mode_(false), network_io_(nullptr), scratch_space_(nullptr) {}
65 
66  ~IO() {
67  if (scratch_space_ == nullptr) {
68  ASSERT_HOST(network_io_ == nullptr);
69  } else if (int_mode_) {
71  } else {
73  }
74  }
75  // Resizes the array (and stride), avoiding realloc if possible, to the
76  // size from various size specs:
77  // Same time size, given number of features.
78  void Resize(const NetworkIO& src, int num_features,
79  NetworkScratch* scratch) {
80  if (scratch_space_ == nullptr) {
81  int_mode_ = scratch->int_mode_ && src.int_mode();
82  scratch_space_ = scratch;
84  : scratch_space_->float_stack_.Borrow();
85  }
86  network_io_->Resize(src, num_features);
87  }
88  // Resizes to a specific size as a temp buffer. No batches, no y-dim.
89  void Resize2d(bool int_mode, int width, int num_features,
90  NetworkScratch* scratch) {
91  if (scratch_space_ == nullptr) {
92  int_mode_ = scratch->int_mode_ && int_mode;
93  scratch_space_ = scratch;
95  : scratch_space_->float_stack_.Borrow();
96  }
97  network_io_->Resize2d(int_mode, width, num_features);
98  }
99  // Resize forcing a float representation with the width of src and the given
100  // number of features.
101  void ResizeFloat(const NetworkIO& src, int num_features,
102  NetworkScratch* scratch) {
103  if (scratch_space_ == nullptr) {
104  int_mode_ = false;
105  scratch_space_ = scratch;
107  }
108  network_io_->ResizeFloat(src, num_features);
109  }
110 
111  // Returns a ref to a NetworkIO that enables *this to be treated as if
112  // it were just a NetworkIO*.
114  return *network_io_;
115  }
117  return network_io_;
118  }
119  operator NetworkIO*() {
120  return network_io_;
121  }
122 
123  private:
124  // True if this is from the always-float stack, otherwise the default stack.
125  bool int_mode_;
126  // The NetworkIO that we have borrowed from the scratch_space_.
128  // The source scratch_space_. Borrowed pointer, used to free the
129  // NetworkIO. Don't delete!
131  }; // class IO.
132 
133  // Class that acts like a fixed array of float, yet actually uses space
134  // from a GenericVector<float> in the source NetworkScratch, and knows how
135  // to unstack the borrowed vector on destruction.
136  class FloatVec {
137  public:
138  // The array will have size elements in it, uninitialized.
139  FloatVec(int size, NetworkScratch* scratch)
140  : vec_(nullptr), scratch_space_(scratch) {
141  Init(size, scratch);
142  }
143  // Default constructor is for arrays. Use Init to setup.
144  FloatVec() : vec_(nullptr), data_(nullptr), scratch_space_(nullptr) {}
146  if (scratch_space_ != nullptr) scratch_space_->vec_stack_.Return(vec_);
147  }
148 
149  void Init(int size, NetworkScratch* scratch) {
150  if (scratch_space_ != nullptr && vec_ != nullptr)
152  scratch_space_ = scratch;
153  vec_ = scratch_space_->vec_stack_.Borrow();
154  vec_->resize_no_init(size);
155  data_ = &(*vec_)[0];
156  }
157 
158  // Use the cast operator instead of operator[] so the FloatVec can be used
159  // as a double* argument to a function call.
160  operator double*() const { return data_; }
161  double* get() { return data_; }
162 
163  private:
164  // Vector borrowed from the scratch space. Use Return to free it.
166  // Short-cut pointer to the underlying array.
167  double* data_;
168  // The source scratch_space_. Borrowed pointer, used to free the
169  // vector. Don't delete!
171  }; // class FloatVec
172 
173  // Class that acts like a 2-D array of double, yet actually uses space
174  // from the source NetworkScratch, and knows how to unstack the borrowed
175  // array on destruction.
177  public:
178  // Default constructor is for arrays. Use Init to setup.
179  GradientStore() : array_(nullptr), scratch_space_(nullptr) {}
181  if (scratch_space_ != nullptr) scratch_space_->array_stack_.Return(array_);
182  }
183 
184  void Init(int size1, int size2, NetworkScratch* scratch) {
185  if (scratch_space_ != nullptr && array_ != nullptr)
186  scratch_space_->array_stack_.Return(array_);
187  scratch_space_ = scratch;
188  array_ = scratch_space_->array_stack_.Borrow();
189  array_->Resize(size1, size2, 0.0);
190  }
191 
192  // Accessors to get to the underlying TransposedArray.
193  TransposedArray* get() const { return array_; }
194  const TransposedArray& operator*() const { return *array_; }
195 
196  private:
197  // Array borrowed from the scratch space. Use Return to free it.
199  // The source scratch_space_. Borrowed pointer, used to free the
200  // vector. Don't delete!
202  }; // class GradientStore
203 
204  // Class that does the work of holding a stack of objects, a stack pointer
205  // and a vector of in-use flags, so objects can be returned out of order.
206  // It is safe to attempt to Borrow/Return in multiple threads.
207  template<typename T> class Stack {
208  public:
209  Stack() : stack_top_(0) {
210  }
211 
212  // Lends out the next free item, creating one if none available, sets
213  // the used flags and increments the stack top.
214  T* Borrow() {
215  SVAutoLock lock(&mutex_);
216  if (stack_top_ == stack_.size()) {
217  stack_.push_back(new T);
218  flags_.push_back(false);
219  }
220  flags_[stack_top_] = true;
221  return stack_[stack_top_++];
222  }
223  // Takes back the given item, and marks it free. Item does not have to be
224  // the most recently lent out, but free slots don't get re-used until the
225  // blocking item is returned. The assumption is that there will only be
226  // small, temporary variations from true stack use. (Determined by the order
227  // of destructors within a local scope.)
228  void Return(T* item) {
229  SVAutoLock lock(&mutex_);
230  // Linear search will do.
231  int index = stack_top_ - 1;
232  while (index >= 0 && stack_[index] != item) --index;
233  if (index >= 0) flags_[index] = false;
234  while (stack_top_ > 0 && !flags_[stack_top_ - 1]) --stack_top_;
235  }
236 
237  private:
242  }; // class Stack.
243 
244  private:
245  // If true, the network weights are int8_t, if false, float.
246  bool int_mode_;
247  // Stacks of NetworkIO and GenericVector<float>. Once allocated, they are not
248  // deleted until the NetworkScratch is deleted.
253 };
254 
255 } // namespace tesseract.
256 
257 #endif // TESSERACT_LSTM_NETWORKSCRATCH_H_
Stack()
Definition: networkscratch.h:209
GenericVector< double > * vec_
Definition: networkscratch.h:165
Definition: networkscratch.h:53
FloatVec()
Definition: networkscratch.h:144
Stack< NetworkIO > int_stack_
Definition: networkscratch.h:249
void Resize(const NetworkIO &src, int num_features, NetworkScratch *scratch)
Definition: networkscratch.h:78
bool int_mode_
Definition: networkscratch.h:246
~FloatVec()
Definition: networkscratch.h:145
Definition: networkscratch.h:36
TransposedArray * array_
Definition: networkscratch.h:198
IO()
Definition: networkscratch.h:64
Definition: baseapi.cpp:94
double * data_
Definition: networkscratch.h:167
Definition: networkscratch.h:207
NetworkScratch * scratch_space_
Definition: networkscratch.h:130
bool int_mode() const
Definition: networkio.h:127
Stack< NetworkIO > float_stack_
Definition: networkscratch.h:250
NetworkScratch()
Definition: networkscratch.h:38
NetworkIO * operator->()
Definition: networkscratch.h:116
Definition: weightmatrix.h:33
GradientStore()
Definition: networkscratch.h:179
void Return(T *item)
Definition: networkscratch.h:228
FloatVec(int size, NetworkScratch *scratch)
Definition: networkscratch.h:139
NetworkScratch * scratch_space_
Definition: networkscratch.h:170
void ResizeFloat(const NetworkIO &src, int num_features, NetworkScratch *scratch)
Definition: networkscratch.h:101
void Init(int size1, int size2, NetworkScratch *scratch)
Definition: networkscratch.h:184
int stack_top_
Definition: networkscratch.h:240
T * Borrow()
Definition: networkscratch.h:214
GenericVector< bool > flags_
Definition: networkscratch.h:239
Definition: svutil.h:78
void ResizeFloat(const NetworkIO &src, int num_features)
Definition: networkio.h:52
void set_int_mode(bool int_mode)
Definition: networkscratch.h:46
void Resize2d(bool int_mode, int width, int num_features)
Definition: networkio.cpp:40
void Resize(const NetworkIO &src, int num_features)
Definition: networkio.h:45
NetworkIO * network_io_
Definition: networkscratch.h:127
bool int_mode_
Definition: networkscratch.h:125
NetworkScratch * scratch_space_
Definition: networkscratch.h:201
Definition: networkscratch.h:136
Definition: genericvector.h:457
Definition: svutil.h:96
PointerVector< T > stack_
Definition: networkscratch.h:238
SVMutex mutex_
Definition: networkscratch.h:241
~IO()
Definition: networkscratch.h:66
void Resize2d(bool int_mode, int width, int num_features, NetworkScratch *scratch)
Definition: networkscratch.h:89
const TransposedArray & operator*() const
Definition: networkscratch.h:194
Definition: networkio.h:39
void Init(int size, NetworkScratch *scratch)
Definition: networkscratch.h:149
Stack< TransposedArray > array_stack_
Definition: networkscratch.h:252
void resize_no_init(int size)
Definition: genericvector.h:65
IO(const NetworkIO &src, NetworkScratch *scratch)
Definition: networkscratch.h:56
Stack< GenericVector< double > > vec_stack_
Definition: networkscratch.h:251
~GradientStore()
Definition: networkscratch.h:180
NetworkIO & operator*()
Definition: networkscratch.h:113
Definition: networkscratch.h:176