#include #include namespace at::native { using fused_adagrad_fn = void (*)( const at::Tensor& param, const at::Tensor& grad, const at::Tensor& state_sum, const at::Tensor& state_step, const double lr, const double lr_decay, const double weight_decay, const double eps, const bool maximize, const float* grad_scale_ptr); DECLARE_DISPATCH(fused_adagrad_fn, fused_adagrad_stub); } // namespace at::native