topk.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
9 //
10 // Unless required by applicable law or agreed to in writing, software
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16
17 #pragma once
18
19 #include <memory>
20
21 #include "ngraph/axis_set.hpp"
22 #include "ngraph/op/constant.hpp"
23 #include "ngraph/op/op.hpp"
24
25 namespace ngraph
26 {
27  namespace op
28  {
29  namespace v1
30  {
31  /// \brief Computes indices and values of the k maximum/minimum values
32  /// for each slice along specified axis.
33  class NGRAPH_API TopK : public Op
34  {
35  public:
36  using SortType = TopKSortType;
37  using Mode = TopKMode;
38
39  static constexpr NodeTypeInfo type_info{"TopK", 1};
40  const NodeTypeInfo& get_type_info() const override { return type_info; }
41  /// \brief Constructs a TopK operation
42  TopK() = default;
43  /// \brief Constructs a TopK operation with two outputs: values and indices.
44  /// By default the indices output is described by i32 data type.
45  ///
46  /// \param data The input tensor
47  /// \param k Specifies how many maximum/minimum elements should be computed
48  /// (note: scalar input tensor)
49  /// \param axis The axis along which to compute top k indices
50  /// \param mode Specifies which operation (min or max) is used to select
51  /// the biggest element of two.
52  /// \param sort Specifies order of output elements and/or indices
53  /// Accepted values: none, index, value
54  /// \param index_element_type Specyfies type of produced indices
55  TopK(const Output<Node>& data,
56  const Output<Node>& k,
57  const int64_t axis,
58  const std::string& mode,
59  const std::string& sort,
60  const element::Type& index_element_type = element::i32);
61
62  TopK(const Output<Node>& data,
63  const Output<Node>& k,
64  const int64_t axis,
65  const Mode mode,
66  const SortType sort,
67  const element::Type& index_element_type = element::i32);
68
69  bool visit_attributes(AttributeVisitor& visitor) override;
70  void validate_and_infer_types() override;
71
72  virtual std::shared_ptr<Node>
73  clone_with_new_inputs(const OutputVector& new_args) const override;
74
75  virtual size_t get_version() const override { return 1; }
76  /// \brief Returns axis value after normalization
77  /// \note If input rank required to normalization is dynamic, the exception is
78  /// thrown
79  uint64_t get_axis() const;
80  /// \brief Returns axis value before normalization
81  int64_t get_provided_axis() const { return m_axis; }
82  void set_axis(const int64_t axis);
83  Mode get_mode() const { return m_mode; }
84  void set_mode(const Mode mode) { m_mode = mode; }
85  SortType get_sort_type() const { return m_sort; }
86  void set_sort_type(const SortType sort) { m_sort = sort; }
87  element::Type get_index_element_type() const { return m_index_element_type; }
88  void set_index_element_type(const element::Type& index_element_type)
89  {
90  m_index_element_type = index_element_type;
91  }
92  /// \brief Returns the value of K, if available
93  ///
94  /// \note If the second input to this op is a constant, the value is retrieved
95  /// and returned. If the input is not constant(dynamic) this method returns 0
96  size_t get_k() const;
97  void set_k(size_t k);
98  size_t get_default_output_index() const override { return no_default_index(); }
99  bool evaluate(const HostTensorVector& outputs,
100  const HostTensorVector& inputs) const override;
101
102  protected:
103  int64_t m_axis;
104  uint64_t m_normalized_axis;
105  Mode m_mode;
106  SortType m_sort;
107  element::Type m_index_element_type{element::i32};
108
109  virtual size_t read_k_from_constant_node(const std::shared_ptr<Node>& node,
110  const element::Type& k_element_type) const;
111
112  template <typename T>
113  size_t validate_and_get_k(const std::shared_ptr<op::Constant>& k_constant) const;
114  Shape compute_output_shape(const std::string& node_description,
115  const PartialShape input_partial_shape,
116  const int64_t k) const;
117  void set_axis(const Rank input_rank, const int64_t axis);
118  };
119  } // namespace v1
120
121  namespace v3
122  {
123  /// \brief Computes indices and values of the k maximum/minimum values
124  /// for each slice along specified axis.
125  class NGRAPH_API TopK : public v1::TopK
126  {
127  public:
128  static constexpr NodeTypeInfo type_info{"TopK", 3};
129  const NodeTypeInfo& get_type_info() const override { return type_info; }
130  /// \brief Constructs a TopK operation
131  TopK() = default;
132  /// \brief Constructs a TopK operation with two outputs: values and indices.
133  /// By default the indices output is described by i32 data type.
134  ///
135  /// \param data The input tensor
136  /// \param k Specifies how many maximum/minimum elements should be computed
137  /// (note: scalar input tensor)
138  /// \param axis The axis along which to compute top k indices
139  /// \param mode Specifies which operation (min or max) is used to select
140  /// the biggest element of two.
141  /// \param sort Specifies order of output elements and/or indices
142  /// Accepted values: none, index, value
143  /// \param index_element_type Specyfies type of produced indices
144  TopK(const Output<Node>& data,
145  const Output<Node>& k,
146  const int64_t axis,
147  const std::string& mode,
148  const std::string& sort,
149  const element::Type& index_element_type = element::i32);
150
151  TopK(const Output<Node>& data,
152  const Output<Node>& k,
153  const int64_t axis,
154  const Mode mode,
155  const SortType sort,
156  const element::Type& index_element_type = element::i32);
157  bool visit_attributes(AttributeVisitor& visitor) override;
158  void validate_and_infer_types() override;
159  virtual std::shared_ptr<Node>
160  clone_with_new_inputs(const OutputVector& new_args) const override;
161
162  bool evaluate(const HostTensorVector& outputs,
163  const HostTensorVector& inputs) const override;
164
165  protected:
166  virtual size_t
168  const element::Type& k_element_type) const override;
169  };
170  } // namespace v3
171  } // op
172 } // ngraph
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:71
Class representing a dimension, which may be dynamic (undetermined until runtime),...
Definition: dimension.hpp:35
A handle for one of a node's outputs.
Definition: node_output.hpp:42
Class representing a shape that may be partially or totally dynamic.
Definition: partial_shape.hpp:46
Shape for a tensor.
Definition: shape.hpp:31
Definition: element_type.hpp:61
Root of all actual ops.
Definition: op.hpp:29
Computes indices and values of the k maximum/minimum values for each slice along specified axis.
Definition: topk.hpp:34
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
virtual size_t get_version() const override
Definition: topk.hpp:75
bool evaluate(const HostTensorVector &outputs, const HostTensorVector &inputs) const override
Evaluates the op on input_values putting results in output_values.
TopK()=default
Constructs a TopK operation.
TopK(const Output< Node > &data, const Output< Node > &k, const int64_t axis, const std::string &mode, const std::string &sort, const element::Type &index_element_type=element::i32)
Constructs a TopK operation with two outputs: values and indices. By default the indices output is de...
const NodeTypeInfo & get_type_info() const override
Definition: topk.hpp:40
size_t get_k() const
Returns the value of K, if available.
size_t get_default_output_index() const override
Returns the output of the default output, or throws if there is none.
Definition: topk.hpp:98
int64_t get_provided_axis() const
Returns axis value before normalization.
Definition: topk.hpp:81
uint64_t get_axis() const
Returns axis value after normalization.
Computes indices and values of the k maximum/minimum values for each slice along specified axis.
Definition: topk.hpp:126
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
bool evaluate(const HostTensorVector &outputs, const HostTensorVector &inputs) const override
Evaluates the op on input_values putting results in output_values.
const NodeTypeInfo & get_type_info() const override
Definition: topk.hpp:129
TopK(const Output< Node > &data, const Output< Node > &k, const int64_t axis, const std::string &mode, const std::string &sort, const element::Type &index_element_type=element::i32)
Constructs a TopK operation with two outputs: values and indices. By default the indices output is de...
TopK()=default
Constructs a TopK operation.
The Intel nGraph C++ API.