binary_convolution.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
9 //
10 // Unless required by applicable law or agreed to in writing, software
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #pragma once
18
19 #include "ngraph/coordinate_diff.hpp"
20 #include "ngraph/op/op.hpp"
21 #include "ngraph/op/util/attr_types.hpp"
22
23 namespace ngraph
24 {
25  namespace op
26  {
27  namespace v1
28  {
29  class NGRAPH_API BinaryConvolution : public Op
30  {
31  public:
32  static constexpr NodeTypeInfo type_info{"BinaryConvolution", 1};
33  const NodeTypeInfo& get_type_info() const override { return type_info; }
34  enum class BinaryConvolutionMode
35  {
36  // Interpret input data and kernel values: 0 as -1, 1 as 1
37  XNOR_POPCOUNT
38  };
39
40  /// \brief Constructs a binary convolution operation.
41  BinaryConvolution() = default;
42  /// \brief Constructs a binary convolution operation.
43  /// \param data The node producing the input data batch tensor.
44  /// \param kernel The node producing the filters tensor.
45  /// \param strides The strides.
48  /// \param dilations The dilations.
49  /// \param mode Defines how input tensor 0/1 values and weights 0/1 are interpreted.
52  ///
53  /// Output [N, C_OUT, R1, ... Rf]
55  const Output<Node>& kernel,
56  const Strides& strides,
59  const Strides& dilations,
60  BinaryConvolutionMode mode,
63
64  BinaryConvolution(const Output<Node>& data,
65  const Output<Node>& kernel,
66  const Strides& strides,
69  const Strides& dilations,
70  const std::string& mode,
73
74  size_t get_version() const override { return 1; }
75  void validate_and_infer_types() override;
76
77  bool visit_attributes(AttributeVisitor& visitor) override;
78
79  std::shared_ptr<Node>
80  clone_with_new_inputs(const OutputVector& new_args) const override;
81
82  /// \return The strides.
83  const Strides& get_strides() const { return m_strides; }
84  void set_strides(const Strides& strides) { m_strides = strides; }
85  /// \return The dilations.
86  const Strides& get_dilations() const { return m_dilations; }
87  void set_dilations(const Strides& dilations) { m_dilations = dilations; }
88  /// \return The padding-below sizes (possibly negative).
91  /// \return The padding-above sizes (possibly negative).
94  /// \return The pad type for convolution.
97  /// \return The mode of convolution.
98  const BinaryConvolutionMode& get_mode() const { return m_mode; }
99  void set_mode(const BinaryConvolutionMode& mode) { m_mode = mode; }
100  /// \return The pad value.
103  protected:
104  BinaryConvolutionMode mode_from_string(const std::string& mode) const;
105  Strides m_strides;
106  Strides m_dilations;
109  BinaryConvolutionMode m_mode;
112  };
113  }
114  } // namespace op
115
116  NGRAPH_API
117  std::ostream& operator<<(std::ostream& s,
118  const op::v1::BinaryConvolution::BinaryConvolutionMode& type);
119
120  template <>
123  {
124  public:
127  {
128  }
129
130  static constexpr DiscreteTypeInfo type_info{
132  const DiscreteTypeInfo& get_type_info() const override { return type_info; }
133  };
134
135 } // namespace ngraph
An AttributeAdapter "captures" an attribute as an AT& and makes it available as a ValueAccessor<VAT>.
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:71
A difference (signed) of tensor element coordinates.
Definition: coordinate_diff.hpp:30
Access an enum via a string.
A handle for one of a node's outputs.
Definition: node_output.hpp:42
Strides for a tensor.
Definition: strides.hpp:30
Root of all actual ops.
Definition: op.hpp:29
Definition: binary_convolution.hpp:30
const BinaryConvolutionMode & get_mode() const
Definition: binary_convolution.hpp:98
size_t get_version() const override
Definition: binary_convolution.hpp:74
Definition: binary_convolution.hpp:92
Definition: binary_convolution.hpp:89
BinaryConvolution()=default
Constructs a binary convolution operation.
const Strides & get_strides() const
Definition: binary_convolution.hpp:83
Definition: binary_convolution.hpp:95
const NodeTypeInfo & get_type_info() const override
Definition: binary_convolution.hpp:33
const Strides & get_dilations() const
Definition: binary_convolution.hpp:86