Program Listing for File abstract_data_structure.h

Return to documentation for file (include/shad/data_structures/abstract_data_structure.h)

//===------------------------------------------------------------*- C++ -*-===//
//
//                                     SHAD
//
//      The Scalable High-performance Algorithms and Data Structure Library
//
//===----------------------------------------------------------------------===//
//
// Copyright 2018 Battelle Memorial Institute
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
// use this file except in compliance with the License. You may obtain a copy
// of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
//
//===----------------------------------------------------------------------===//

#ifndef INCLUDE_SHAD_DATA_STRUCTURES_ABSTRACT_DATA_STRUCTURE_H_
#define INCLUDE_SHAD_DATA_STRUCTURES_ABSTRACT_DATA_STRUCTURE_H_

#include <deque>
#include <limits>
#include <memory>
#include <tuple>
#include <utility>
#include <vector>

#include "shad/data_structures/object_identifier.h"
#include "shad/runtime/runtime.h"

namespace shad {

template <typename DataStructure>
class AbstractDataStructure {
 public:
  using ObjectID = ObjectIdentifier<DataStructure>;
  using SharedPtr = std::shared_ptr<DataStructure>;

  AbstractDataStructure() = default;

  template <typename... Args>
  static SharedPtr Create(Args... args) {
    auto catalogRef = Catalog::Instance();
    ObjectID id = catalogRef->GetNextID();
    std::tuple<ObjectID, Args...> tuple(id, args...);
    rt::executeOnAll(CreateFunWrapper<ObjectID, Args...>, tuple);
    return catalogRef->GetPtr(id);
  }

  static void Destroy(const ObjectID &oid) {
    auto catalogRef = Catalog::Instance();
    auto destroyLambda = [](const ObjectID &oid) {
      auto catalogRef = Catalog::Instance();
      catalogRef->Erase(oid);
    };

    rt::executeOnAll(destroyLambda, oid);
  }

  static SharedPtr GetPtr(ObjectID oid) {
    return Catalog::Instance()->GetPtr(oid);
  }

  virtual ObjectID GetGlobalID() const = 0;

 protected:
  template <typename... Args>
  static void UpdateCatalogAndConstruct(const ObjectID &oid, Args &&... args) {
    // Get a local instance on the remote node.
    auto catalogRef = Catalog::Instance();
    std::shared_ptr<DataStructure> ptr(
        new DataStructure(oid, std::forward<Args>(args)...));
    catalogRef->Insert(oid, ptr);
  }

  template <typename... Args, std::size_t... is>
  static void CreateFunInnerWrapper(const std::tuple<Args...> &&tuple,
                                    std::index_sequence<is...>) {
    UpdateCatalogAndConstruct(std::get<is>(tuple)...);
  }

  template <typename... Args>
  static void CreateFunWrapper(const std::tuple<Args...> &args) {
    CreateFunInnerWrapper(std::move(args), std::index_sequence_for<Args...>());
  }

  class Catalog {
   public:
    void Insert(const ObjectID &oid, const SharedPtr ce) {
      uint32_t locality = static_cast<uint32_t>(oid.GetOwnerLocality());
      std::lock_guard<rt::Lock> _(registerLock_);
      if (register_[locality].size() <= oid.GetLocalID()) {
        register_[locality].resize(oid.GetLocalID() + 1);
      }
      register_[locality][oid.GetLocalID()] = ce;
    }

    void Erase(const ObjectID &oid) {
      uint32_t locality = static_cast<uint32_t>(oid.GetOwnerLocality());
      if (rt::thisLocality() == oid.GetOwnerLocality()) {
        std::lock_guard<rt::Lock> _(registerLock_);
        oidCache_.push_back(oid);
        register_[locality][oid.GetLocalID()] = nullptr;
      } else {
        std::lock_guard<rt::Lock> _(registerLock_);
        register_[locality][oid.GetLocalID()] = nullptr;
      }
    }

    SharedPtr GetPtr(const ObjectID &oid) {
      uint32_t locality = static_cast<uint32_t>(oid.GetOwnerLocality());
      return register_[locality][oid.GetLocalID()];
    }

    static Catalog *Instance() {
      static Catalog instance;
      return &instance;
    }

    ObjectID GetNextID() {
      {
        std::lock_guard<rt::Lock> _(registerLock_);
        if (!oidCache_.empty()) {
          auto result = oidCache_.back();
          oidCache_.pop_back();
          return result;
        }
      }
      return ObjectIDCounter::Instance()++;
    }

   private:
    Catalog() : register_(rt::numLocalities()), oidCache_(), registerLock_() {}

    using ObjectIDCounter = ObjectIdentifierCounter<DataStructure>;

    std::vector<std::deque<SharedPtr>> register_;
    std::vector<ObjectID> oidCache_;
    rt::Lock registerLock_;
  };
};

}  // namespace shad

#endif  // INCLUDE_SHAD_DATA_STRUCTURES_ABSTRACT_DATA_STRUCTURE_H_