parameter.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include "ngraph/op/op.hpp"
8 
9 namespace ngraph
10 {
11  class Function;
12  namespace op
13  {
14  namespace v0
15  {
16  /// \brief A function parameter.
17  ///
18  /// Parameters are nodes that represent the arguments that will be passed to
19  /// user-defined functions. Function creation requires a sequence of parameters.
20  /// Basic graph operations do not need parameters attached to a function.
21  class NGRAPH_API Parameter : public op::Op
22  {
23  public:
24  static constexpr NodeTypeInfo type_info{"Parameter", 0};
25  const NodeTypeInfo& get_type_info() const override { return type_info; }
26  /// \brief Constructions a tensor-typed parameter node.
27  Parameter() = default;
28  /// \brief Constructions a tensor-typed parameter node.
29  ///
30  /// \param element_type The element type of the parameter.
31  /// \param pshape The partial shape of the parameter.
32  Parameter(const ngraph::element::Type& element_type, const PartialShape& pshape);
33 
34  bool visit_attributes(AttributeVisitor& visitor) override;
35 
36  void validate_and_infer_types() override;
37 
38  virtual std::shared_ptr<Node>
39  clone_with_new_inputs(const OutputVector& new_args) const override;
40 
41  bool is_relevant_to_shapes() const;
42  void set_is_relevant_to_shapes(bool is_relevant);
43 
44  const PartialShape& get_partial_shape() const { return m_partial_shape; }
45  PartialShape& get_partial_shape() { return m_partial_shape; }
46  void set_partial_shape(const PartialShape& partial_shape)
47  {
48  m_partial_shape = partial_shape;
49  }
50  const element::Type& get_element_type() const { return m_element_type; }
51  void set_element_type(const element::Type& element_type)
52  {
53  m_element_type = element_type;
54  }
55 
56  protected:
57  PartialShape m_partial_shape;
58  element::Type m_element_type;
59  bool m_is_relevant_to_shapes;
60  };
61  } // namespace v0
62  using v0::Parameter;
63  } // namespace op
64  using ParameterVector = std::vector<std::shared_ptr<op::Parameter>>;
65 
66  template <>
67  class NGRAPH_API AttributeAdapter<ParameterVector> : public VisitorAdapter
68  {
69  public:
70  AttributeAdapter(ParameterVector& ref);
71 
72  bool visit_attributes(AttributeVisitor& visitor) override;
73 
74  static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<ParameterVector>", 0};
75  const DiscreteTypeInfo& get_type_info() const override { return type_info; }
76 
77  protected:
78  ParameterVector& m_ref;
79  };
80 } // namespace ngraph
const DiscreteTypeInfo & get_type_info() const override
type info enables identification of the value accessor, as well as is_type and as_type.
Definition: parameter.hpp:75
An AttributeAdapter "captures" an attribute as an AT& and makes it available as a ValueAccessor<VAT>.
Definition: attribute_adapter.hpp:161
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
Class representing a shape that may be partially or totally dynamic.
Definition: partial_shape.hpp:34
Adapters will see visitor.
Definition: attribute_adapter.hpp:185
Definition: element_type.hpp:51
Root of all actual ops.
Definition: op.hpp:17
A function parameter.
Definition: parameter.hpp:22
Parameter()=default
Constructions a tensor-typed parameter node.
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
Parameter(const ngraph::element::Type &element_type, const PartialShape &pshape)
Constructions a tensor-typed parameter node.
const NodeTypeInfo & get_type_info() const override
Definition: parameter.hpp:25
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
Definition: type.hpp:27