function.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <atomic>
20 #include <initializer_list>
21 #include <list>
22 #include <memory>
23 #include <string>
24 #include <vector>
25 
26 #include "ngraph/ngraph_visibility.hpp"
27 #include "ngraph/node.hpp"
28 #include "ngraph/op/parameter.hpp"
29 #include "ngraph/op/result.hpp"
30 #include "ngraph/op/sink.hpp"
31 
32 namespace ngraph
33 {
34  /// A user-defined function.
35  class NGRAPH_API Function
36  {
37  public:
38  static constexpr DiscreteTypeInfo type_info{"Function", 0};
39  const DiscreteTypeInfo& get_type_info() const { return type_info; }
40  Function(const NodeVector& results,
41  const ParameterVector& parameters,
42  const std::string& name = "");
43 
44  Function(const OutputVector& results,
45  const ParameterVector& parameters,
46  const std::string& name = "");
47 
48  Function(const std::shared_ptr<Node>& result,
49  const ParameterVector& parameters,
50  const std::string& name = "");
51 
52  Function(const ResultVector& results,
53  const ParameterVector& parameters,
54  const std::string& name = "");
55 
56  Function(const ResultVector& results,
57  const SinkVector& sinks,
58  const ParameterVector& parameters,
59  const std::string& name = "");
60 
61  Function(const OutputVector& results,
62  const SinkVector& sinks,
63  const ParameterVector& parameters,
64  const std::string& name = "");
65 
66  virtual ~Function() {}
67  /// Return the number of outputs for this function.
68  size_t get_output_size() const;
69 
70  /// Return the op that generates output i
71  std::shared_ptr<Node> get_output_op(size_t i) const;
72 
73  Output<Node> output(size_t i) const;
74 
75  /// Return the element type of output i
76  const element::Type& get_output_element_type(size_t i) const;
77 
78  /// Return the shape of element i
79  const Shape& get_output_shape(size_t i) const;
80 
81  /// Return the partial shape of element i
82  const PartialShape& get_output_partial_shape(size_t i) const;
83 
84  /// Check that there is a single result and return it.
85  std::shared_ptr<Node> get_result() const;
86 
87  /// \brief Get the unique name of the function.
88  /// \returns A const reference to the function's unique name.
89  const std::string& get_name() const;
90 
91  /// \brief Sets a friendly name for a function. This does not overwrite the unique name
92  /// of the function and is retrieved via get_friendly_name(). Used mainly for
93  /// debugging.
94  /// \param name is the friendly name to set
95  void set_friendly_name(const std::string& name);
96 
97  /// \brief Gets the friendly name for a function. If no friendly name has been set via
98  /// set_friendly_name then the function's unique name is returned.
99  /// \returns A const reference to the function's friendly name.
100  const std::string& get_friendly_name() const;
101 
102  std::vector<std::shared_ptr<Node>> get_ops() const;
103  std::vector<std::shared_ptr<Node>> get_ordered_ops() const;
104  void map_unordered_ops(std::function<void(Node*)> f) const;
105 
106  friend std::ostream& operator<<(std::ostream&, const Function&);
107  // updates graph and m_results list
108  void replace_node(std::shared_ptr<Node> old, std::shared_ptr<Node> repl);
109 
110  void validate_nodes_and_infer_types() const;
111 
112  /// \brief Returns the sum of the size of all nodes in the graph plus the size of
113  /// all constant data. This has little value beyond comparing the relative size of
114  /// graphs and should not be considered the actual memory consumption of a graph.
115  size_t get_graph_size() const;
116 
117  /// \brief Returns true if any of the op's defined in the function contains partial shape
118  bool is_dynamic() const;
119 
120  /// \brief Replace the `parameter_index`th parameter of the function with `parameter`.
121  ///
122  /// All users of the `parameter_index`th parameter are redirected to `parameter`, and the
123  /// `parameter_index`th entry in the function parameter list is replaced with `parameter`.
124  ///
125  /// \param parameter_index The index of the parameter to replace.
126  /// \param parameter The parameter to substitute for the `parameter_index`th parameter.
127  void replace_parameter(size_t parameter_index,
128  const std::shared_ptr<op::Parameter>& parameter);
129 
130  using topological_sort_t = std::function<std::vector<std::shared_ptr<Node>>(
131  const std::vector<std::shared_ptr<Node>>& root_nodes)>;
132  void set_topological_sort(topological_sort_t);
133 
134  virtual bool visit_attributes(AttributeVisitor& visitor);
135 
136  /// Return the function parameters
137  const ParameterVector& get_parameters() const { return m_parameters; };
138  /// Return a list of function's outputs
139  const ResultVector& get_results() const { return m_results; };
140  /// Index for parameter, or -1
141  int64_t get_parameter_index(const std::shared_ptr<op::Parameter>& parameter) const;
142 
143  /// Index for value or result referencing it, or -1
144  int64_t get_result_index(const Output<Node>& value) const;
145 
146  /// \brief Evaluate the function on inputs, putting results in outputs.
147  /// \param outputs Tensors for the outputs to compute. One for each result
148  /// \param inputs Tensors for the inputs. One for each inputs.
149  bool evaluate(const HostTensorVector& output_tensors,
150  const HostTensorVector& input_tensors) const;
151 
152  /// \brief Return a list of function's sinks.
153  const SinkVector& get_sinks() const { return m_sinks; }
154  /// \brief Add new sink nodes to the list. Method doesn't validate graph, it should be done
155  /// manually after all changes.
156  /// \param sinks new sink nodes
157  void add_sinks(const SinkVector& sinks);
158 
159  /// \brief Delete sink node from the list of sinks. Method doesn't delete node from graph.
160  /// \param sink Sink to delete
161  void remove_sink(const std::shared_ptr<op::Sink>& sink);
162 
163  /// \brief Add new Result nodes to the list. Method doesn't validate graph, it should be
164  /// done manually after all changes.
165  /// \param results new Result nodes
166  void add_results(const ResultVector& results);
167 
168  /// \brief Delete Result node from the list of results. Method will not delete node from
169  /// graph.
170  /// \param result Result node to delete
171  void remove_result(const std::shared_ptr<op::Result>& result);
172 
173  /// \brief Add new Parameter nodes to the list.
174  ///
175  /// Method doesn't change or validate graph, it should be done manually.
176  /// For example, if you want to replace `ReadValue` node by `Parameter`, you should do the
177  /// following steps:
178  /// * replace node `ReadValue` by `Parameter` in graph
179  /// * call add_parameter() to add new input to the list
180  /// * call graph validation to check correctness of changes
181  ///
182  /// \param params new Parameter nodes
183  void add_parameters(const ParameterVector& params);
184 
185  /// \brief Delete Parameter node from the list of parameters. Method will not delete node
186  /// from graph. You need to replace Parameter with other operation manually.
187  /// Attention: Indexing of parameters can be changed.
188  ///
189  /// Possible use of method is to replace input by variable. For it the following steps
190  /// should be done:
191  /// * `Parameter` node should be replaced by `ReadValue`
192  /// * call remove_parameter(param) to remove input from the list
193  /// * check if any parameter indexes are saved/used somewhere, update it for all inputs
194  /// because indexes can be changed
195  /// * call graph validation to check all changes
196  ///
197  /// \param param Parameter node to delete
198  void remove_parameter(const std::shared_ptr<op::Parameter>& param);
199 
200  private:
201  Function(const Function&) = delete;
202  Function(const Function&&) = delete;
203  Function& operator=(const Function&) = delete;
204  /// \brief Checks all the Parameter nodes are registered in the list of Function parameters
205  void check_all_parameters_registered() const;
206 
207  static std::atomic<size_t> m_next_instance_id;
208  std::string m_name;
209  const std::string m_unique_name;
210  size_t m_placement{0};
211  topological_sort_t m_topological_sorter;
212 
213  ResultVector m_results;
214 
215  // List of the nodes with side effect in graph.
216  // These nodes are not outputs of graph but should not be removed even if have no children.
217  SinkVector m_sinks;
218  ParameterVector m_parameters;
219  };
220 
221  template <>
222  class NGRAPH_API AttributeAdapter<std::shared_ptr<Function>>
223  : public DirectValueAccessor<std::shared_ptr<Function>>
224  {
225  public:
226  AttributeAdapter(std::shared_ptr<Function>& value)
228  {
229  }
230 
231  static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<std::shared_ptr<Function>>",
232  0};
233  const DiscreteTypeInfo& get_type_info() const override { return type_info; }
234  };
235 } // namespace ngraph
An AttributeAdapter "captures" an attribute as an AT& and makes it available as a ValueAccessor<VAT>.
Definition: attribute_adapter.hpp:171
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:71
Definition: attribute_adapter.hpp:79
A user-defined function.
Definition: function.hpp:36
void remove_sink(const std::shared_ptr< op::Sink > &sink)
Delete sink node from the list of sinks. Method doesn't delete node from graph.
void set_friendly_name(const std::string &name)
Sets a friendly name for a function. This does not overwrite the unique name of the function and is r...
const SinkVector & get_sinks() const
Return a list of function's sinks.
Definition: function.hpp:153
bool is_dynamic() const
Returns true if any of the op's defined in the function contains partial shape.
size_t get_graph_size() const
Returns the sum of the size of all nodes in the graph plus the size of all constant data....
void add_sinks(const SinkVector &sinks)
Add new sink nodes to the list. Method doesn't validate graph, it should be done manually after all c...
size_t get_output_size() const
Return the number of outputs for this function.
const std::string & get_name() const
Get the unique name of the function.
void add_results(const ResultVector &results)
Add new Result nodes to the list. Method doesn't validate graph, it should be done manually after all...
const ParameterVector & get_parameters() const
Return the function parameters.
Definition: function.hpp:137
bool evaluate(const HostTensorVector &output_tensors, const HostTensorVector &input_tensors) const
Evaluate the function on inputs, putting results in outputs.
int64_t get_parameter_index(const std::shared_ptr< op::Parameter > &parameter) const
Index for parameter, or -1.
std::shared_ptr< Node > get_output_op(size_t i) const
Return the op that generates output i.
const element::Type & get_output_element_type(size_t i) const
Return the element type of output i.
std::shared_ptr< Node > get_result() const
Check that there is a single result and return it.
void add_parameters(const ParameterVector &params)
Add new Parameter nodes to the list.
const std::string & get_friendly_name() const
Gets the friendly name for a function. If no friendly name has been set via set_friendly_name then th...
const PartialShape & get_output_partial_shape(size_t i) const
Return the partial shape of element i.
const Shape & get_output_shape(size_t i) const
Return the shape of element i.
int64_t get_result_index(const Output< Node > &value) const
Index for value or result referencing it, or -1.
void remove_parameter(const std::shared_ptr< op::Parameter > &param)
Delete Parameter node from the list of parameters. Method will not delete node from graph....
void replace_parameter(size_t parameter_index, const std::shared_ptr< op::Parameter > &parameter)
Replace the parameter_indexth parameter of the function with parameter.
void remove_result(const std::shared_ptr< op::Result > &result)
Delete Result node from the list of results. Method will not delete node from graph.
const ResultVector & get_results() const
Return a list of function's outputs.
Definition: function.hpp:139
Definition: node.hpp:132
A handle for one of a node's outputs.
Definition: node_output.hpp:42
Class representing a shape that may be partially or totally dynamic.
Definition: partial_shape.hpp:46
Shape for a tensor.
Definition: shape.hpp:31
Definition: element_type.hpp:61
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
NGRAPH_API void replace_node(std::shared_ptr< Node > target, std::shared_ptr< Node > replacement, const std::vector< int64_t > &output_order)
Replace the node target with the node replacement, i.e., redirect all users and control dependencies ...
Definition: type.hpp:39