Program Listing for File set.h¶
↰ Return to documentation for file (include/shad/data_structures/set.h
)
//===------------------------------------------------------------*- C++ -*-===//
//
// SHAD
//
// The Scalable High-performance Algorithms and Data Structure Library
//
//===----------------------------------------------------------------------===//
//
// Copyright 2018 Battelle Memorial Institute
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
// use this file except in compliance with the License. You may obtain a copy
// of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
//
//===----------------------------------------------------------------------===//
#ifndef INCLUDE_SHAD_DATA_STRUCTURES_SET_H_
#define INCLUDE_SHAD_DATA_STRUCTURES_SET_H_
#include <algorithm>
#include <functional>
#include <tuple>
#include <utility>
#include <vector>
#include "shad/data_structures/abstract_data_structure.h"
#include "shad/data_structures/buffer.h"
#include "shad/data_structures/compare_and_hash_utils.h"
#include "shad/data_structures/local_set.h"
#include "shad/distributed_iterator_traits.h"
#include "shad/runtime/runtime.h"
namespace shad {
template <typename Set, typename T, typename NonConstT>
class set_iterator;
template <typename T, typename ELEM_COMPARE = MemCmp<T>>
class Set : public AbstractDataStructure<Set<T, ELEM_COMPARE>> {
template <typename>
friend class AbstractDataStructure;
friend class set_iterator<Set<T, ELEM_COMPARE>, const T, T>;
public:
using value_type = T;
using SetT = Set<T, ELEM_COMPARE>;
using LSetT = LocalSet<T, ELEM_COMPARE>;
using ObjectID = typename AbstractDataStructure<SetT>::ObjectID;
using ShadSetPtr = typename AbstractDataStructure<SetT>::SharedPtr;
using BuffersVector = typename impl::BuffersVector<T, SetT>;
using iterator = set_iterator<Set<T, ELEM_COMPARE>, const T, T>;
using const_iterator = set_iterator<Set<T, ELEM_COMPARE>, const T, T>;
using local_iterator = lset_iterator<LocalSet<T, ELEM_COMPARE>, const T>;
using const_local_iterator =
lset_iterator<LocalSet<T, ELEM_COMPARE>, const T>;
#ifdef DOXYGEN_IS_RUNNING
static ShadSetPtr Create(const size_t numEntries);
#endif
ObjectID GetGlobalID() const { return oid_; }
size_t Size() const;
std::pair<iterator, bool> Insert(const T& element);
void AsyncInsert(rt::Handle& handle, const T& element);
void BufferedInsert(const T& element);
void BufferedAsyncInsert(rt::Handle& handle, const T& element);
void WaitForBufferedInsert() { buffers_.FlushAll(); }
void Erase(const T& element);
void AsyncErase(rt::Handle& handle, const T& element);
void Clear() {
auto clearLambda = [](const ObjectID& oid) {
auto setPtr = SetT::GetPtr(oid);
setPtr->localSet_.Clear();
};
rt::executeOnAll(clearLambda, oid_);
}
void Reset(size_t numElements) {
auto resetLambda = [](const std::tuple<ObjectID, size_t>& t) {
auto setPtr = SetT::GetPtr(std::get<0>(t));
setPtr->localSet_.Reset(std::get<1>(t));
};
rt::executeOnAll(resetLambda, std::make_tuple(oid_, numElements));
}
bool Find(const T& element);
void AsyncFind(rt::Handle& handle, const T& element, bool* found);
template <typename ApplyFunT, typename... Args>
void ForEachElement(ApplyFunT&& function, Args&... args);
template <typename ApplyFunT, typename... Args>
void AsyncForEachElement(rt::Handle& handle, ApplyFunT&& function,
Args&... args);
void PrintAllElements() {
auto printLambda = [](const ObjectID& oid) {
auto setPtr = SetT::GetPtr(oid);
std::cout << "---- Locality: " << rt::thisLocality() << std::endl;
setPtr->localSet_.PrintAllElements();
};
rt::executeOnAll(printLambda, oid_);
}
// FIXME it should be protected
void BufferEntryInsert(const T& element) { localSet_.Insert(element); }
iterator begin() { return iterator::set_begin(this); }
iterator end() { return iterator::set_end(this); }
const_iterator cbegin() const { return const_iterator::set_begin(this); }
const_iterator cend() const { return const_iterator::set_end(this); }
const_iterator begin() const { return cbegin(); }
const_iterator end() const { return cend(); }
local_iterator local_begin() {
return local_iterator::lset_begin(&localSet_);
}
local_iterator local_end() { return local_iterator::lset_end(&localSet_); }
const_local_iterator clocal_begin() {
return const_local_iterator::lset_begin(&localSet_);
}
const_local_iterator clocal_end() {
return const_local_iterator::lset_end(&localSet_);
}
std::pair<iterator, bool> insert(const value_type& value) {
return Insert(value);
}
std::pair<iterator, bool> insert(const_iterator, const value_type& value) {
return insert(value);
}
void buffered_async_insert(rt::Handle& h, const value_type& value) {
BufferedAsyncInsert(h, value);
}
void buffered_async_wait(rt::Handle& h) { rt::waitForCompletion(h); }
void buffered_async_flush() { WaitForBufferedInsert(); }
private:
ObjectID oid_;
LocalSet<T, ELEM_COMPARE> localSet_;
BuffersVector buffers_;
struct ExeAtArgs {
ObjectID oid;
T element;
};
protected:
Set(ObjectID oid, const size_t numEntries)
: oid_(oid),
localSet_(
std::max(numEntries / (constants::kSetDefaultNumEntriesPerBucket *
rt::numLocalities()),
1lu)),
buffers_(oid) {}
};
template <typename T, typename ELEM_COMPARE>
inline size_t Set<T, ELEM_COMPARE>::Size() const {
size_t size = localSet_.size_;
size_t remoteSize(0);
auto sizeLambda = [](const ObjectID& oid, size_t* res) {
auto setPtr = SetT::GetPtr(oid);
*res = setPtr->localSet_.size_;
};
for (auto tgtLoc : rt::allLocalities()) {
if (tgtLoc != rt::thisLocality()) {
rt::executeAtWithRet(tgtLoc, sizeLambda, oid_, &remoteSize);
size += remoteSize;
}
}
return size;
}
template <typename T, typename ELEM_COMPARE>
inline std::pair<typename Set<T, ELEM_COMPARE>::iterator, bool>
Set<T, ELEM_COMPARE>::Insert(const T& element) {
size_t targetId = shad::hash<T>{}(element) % rt::numLocalities();
rt::Locality targetLocality(targetId);
using itr_traits = distributed_iterator_traits<iterator>;
if (targetLocality == rt::thisLocality()) {
auto lres = localSet_.Insert(element);
auto git = itr_traits::iterator_from_local(begin(), end(), lres.first);
return std::make_pair(git, lres.second);
}
std::pair<iterator, bool> res;
auto insertLambda =
[](const std::tuple<iterator, iterator, ObjectID, T>& args,
std::pair<iterator, bool>* res_ptr) {
auto setPtr = SetT::GetPtr(std::get<2>(args));
auto lres = setPtr->localSet_.Insert(std::get<3>(args));
auto git = itr_traits::iterator_from_local(
std::get<0>(args), std::get<1>(args), lres.first);
*res_ptr = std::make_pair(git, lres.second);
};
rt::executeAtWithRet(targetLocality, insertLambda,
std::make_tuple(begin(), end(), oid_, element), &res);
return res;
}
template <typename T, typename ELEM_COMPARE>
inline void Set<T, ELEM_COMPARE>::AsyncInsert(rt::Handle& handle,
const T& element) {
size_t targetId = shad::hash<T>{}(element) % rt::numLocalities();
rt::Locality targetLocality(targetId);
if (targetLocality == rt::thisLocality()) {
localSet_.AsyncInsert(handle, element);
} else {
auto insertLambda = [](rt::Handle& handle, const ExeAtArgs& args) {
auto setPtr = SetT::GetPtr(args.oid);
setPtr->localSet_.AsyncInsert(handle, args.element);
};
ExeAtArgs args = {oid_, element};
rt::asyncExecuteAt(handle, targetLocality, insertLambda, args);
}
}
template <typename T, typename ELEM_COMPARE>
inline void Set<T, ELEM_COMPARE>::BufferedInsert(const T& element) {
size_t targetId = shad::hash<T>{}(element) % rt::numLocalities();
rt::Locality targetLocality(targetId);
buffers_.Insert(element, targetLocality);
}
template <typename T, typename ELEM_COMPARE>
inline void Set<T, ELEM_COMPARE>::BufferedAsyncInsert(rt::Handle& handle,
const T& element) {
size_t targetId = shad::hash<T>{}(element) % rt::numLocalities();
rt::Locality targetLocality(targetId);
buffers_.AsyncInsert(handle, element, targetLocality);
}
template <typename T, typename ELEM_COMPARE>
inline void Set<T, ELEM_COMPARE>::Erase(const T& element) {
size_t targetId = shad::hash<T>{}(element) % rt::numLocalities();
rt::Locality targetLocality(targetId);
if (targetLocality == rt::thisLocality()) {
localSet_.Erase(element);
} else {
auto eraseLambda = [](const ExeAtArgs& args) {
auto setPtr = SetT::GetPtr(args.oid);
setPtr->localSet_.Erase(args.element);
};
ExeAtArgs args = {oid_, element};
rt::executeAt(targetLocality, eraseLambda, args);
}
}
template <typename T, typename ELEM_COMPARE>
inline void Set<T, ELEM_COMPARE>::AsyncErase(rt::Handle& handle,
const T& element) {
size_t targetId = shad::hash<T>{}(element) % rt::numLocalities();
rt::Locality targetLocality(targetId);
if (targetLocality == rt::thisLocality()) {
localSet_.AsyncErase(handle, element);
} else {
auto eraseLambda = [](rt::Handle& handle, const ExeAtArgs& args) {
auto setPtr = SetT::GetPtr(args.oid);
setPtr->localSet_.AsyncErase(handle, args.element);
};
ExeAtArgs args = {oid_, element};
rt::asyncExecuteAt(handle, targetLocality, eraseLambda, args);
}
}
template <typename T, typename ELEM_COMPARE>
inline bool Set<T, ELEM_COMPARE>::Find(const T& element) {
size_t targetId = shad::hash<T>{}(element) % rt::numLocalities();
rt::Locality targetLocality(targetId);
if (targetLocality == rt::thisLocality()) {
return localSet_.Find(element);
} else {
auto findLambda = [](const ExeAtArgs& args, bool* res) {
auto setPtr = SetT::GetPtr(args.oid);
*res = setPtr->localSet_.Find(args.element);
};
ExeAtArgs args = {oid_, element};
bool found;
rt::executeAtWithRet(targetLocality, findLambda, args, &found);
return found;
}
return false;
}
template <typename T, typename ELEM_COMPARE>
inline void Set<T, ELEM_COMPARE>::AsyncFind(rt::Handle& handle,
const T& element, bool* found) {
size_t targetId = shad::hash<T>{}(element) % rt::numLocalities();
rt::Locality targetLocality(targetId);
if (targetLocality == rt::thisLocality()) {
localSet_.AsyncFind(handle, element, found);
} else {
auto findLambda = [](rt::Handle&, const ExeAtArgs& args, bool* res) {
auto setPtr = SetT::GetPtr(args.oid);
*res = setPtr->localSet_.Find(args.element);
};
ExeAtArgs args = {oid_, element};
rt::asyncExecuteAtWithRet(handle, targetLocality, findLambda, args, found);
}
}
template <typename T, typename ELEM_COMPARE>
template <typename ApplyFunT, typename... Args>
void Set<T, ELEM_COMPARE>::ForEachElement(ApplyFunT&& function, Args&... args) {
using FunctionTy = void (*)(const T&, Args&...);
FunctionTy fn = std::forward<decltype(function)>(function);
using feArgs = std::tuple<ObjectID, FunctionTy, std::tuple<Args...>>;
using ArgsTuple = std::tuple<LSetT*, FunctionTy, std::tuple<Args...>>;
feArgs arguments(oid_, fn, std::tuple<Args...>(args...));
auto feLambda = [](const feArgs& args) {
auto setPtr = SetT::GetPtr(std::get<0>(args));
ArgsTuple argsTuple(&setPtr->localSet_, std::get<1>(args),
std::get<2>(args));
rt::forEachAt(rt::thisLocality(),
LSetT::template ForEachElementFunWrapper<ArgsTuple, Args...>,
argsTuple, setPtr->localSet_.numBuckets_);
};
rt::executeOnAll(feLambda, arguments);
}
template <typename T, typename ELEM_COMPARE>
template <typename ApplyFunT, typename... Args>
void Set<T, ELEM_COMPARE>::AsyncForEachElement(rt::Handle& handle,
ApplyFunT&& function,
Args&... args) {
using FunctionTy = void (*)(rt::Handle&, const T&, Args&...);
FunctionTy fn = std::forward<decltype(function)>(function);
using feArgs = std::tuple<ObjectID, FunctionTy, std::tuple<Args...>>;
using ArgsTuple = std::tuple<LSetT*, FunctionTy, std::tuple<Args...>>;
feArgs arguments{oid_, fn, std::tuple<Args...>(args...)};
auto feLambda = [](rt::Handle& handle, const feArgs& args) {
auto setPtr = SetT::GetPtr(std::get<0>(args));
ArgsTuple argsTuple = std::make_tuple(&setPtr->localSet_, std::get<1>(args),
std::get<2>(args));
rt::asyncForEachAt(
handle, rt::thisLocality(),
LSetT::template AsyncForEachElementFunWrapper<ArgsTuple, Args...>,
argsTuple, setPtr->localSet_.numBuckets_);
};
rt::asyncExecuteOnAll(handle, feLambda, arguments);
}
template <typename SetT, typename T, typename NonConstT>
class set_iterator : public std::iterator<std::forward_iterator_tag, T> {
public:
using value_type = NonConstT;
using OIDT = typename SetT::ObjectID;
using LSet = typename SetT::LSetT;
using local_iterator_type = lset_iterator<LSet, T>;
set_iterator() {}
set_iterator(uint32_t locID, const OIDT setOID, local_iterator_type& lit,
T element) {
data_ = {locID, setOID, lit, element};
}
set_iterator(uint32_t locID, const OIDT setOID, local_iterator_type& lit) {
auto setPtr = SetT::GetPtr(setOID);
const LSet* lsetPtr = &(setPtr->localSet_);
if (lit != local_iterator_type::lset_end(lsetPtr))
data_ = itData(locID, setOID, lit, *lit);
else
*this = set_end(setPtr.get());
}
static set_iterator set_begin(const SetT* setPtr) {
const LSet* lsetPtr = &(setPtr->localSet_);
auto localEnd = local_iterator_type::lset_end(lsetPtr);
if (static_cast<uint32_t>(rt::thisLocality()) == 0) {
auto localBegin = local_iterator_type::lset_begin(lsetPtr);
if (localBegin != localEnd) {
return set_iterator(0, setPtr->oid_, localBegin);
}
set_iterator beg(0, setPtr->oid_, localEnd, T());
return ++beg;
}
auto getItLambda = [](const OIDT& setOID, set_iterator* res) {
auto setPtr = SetT::GetPtr(setOID);
const LSet* lsetPtr = &(setPtr->localSet_);
auto localEnd = local_iterator_type::lset_end(lsetPtr);
auto localBegin = local_iterator_type::lset_begin(lsetPtr);
if (localBegin != localEnd) {
*res = set_iterator(0, setOID, localBegin);
} else {
set_iterator beg(0, setOID, localEnd, T());
*res = ++beg;
}
};
set_iterator beg(0, setPtr->oid_, localEnd, T());
rt::executeAtWithRet(rt::Locality(0), getItLambda, setPtr->oid_, &beg);
return beg;
}
static set_iterator set_end(const SetT* setPtr) {
local_iterator_type lend =
local_iterator_type::lset_end(&(setPtr->localSet_));
set_iterator end(rt::numLocalities(), OIDT(0), lend, T());
return end;
}
bool operator==(const set_iterator& other) const {
return (data_ == other.data_);
}
bool operator!=(const set_iterator& other) const { return !(*this == other); }
T operator*() const { return data_.element_; }
set_iterator& operator++() {
auto setPtr = SetT::GetPtr(data_.oid_);
if (static_cast<uint32_t>(rt::thisLocality()) == data_.locId_) {
const LSet* lsetPtr = &(setPtr->localSet_);
auto lend = local_iterator_type::lset_end(lsetPtr);
if (data_.lsetIt_ != lend) {
++(data_.lsetIt_);
}
if (data_.lsetIt_ != lend) {
data_.element_ = *(data_.lsetIt_);
return *this;
} else {
// find the local begin on next localities
itData itd;
for (uint32_t i = data_.locId_ + 1; i < rt::numLocalities(); ++i) {
rt::executeAtWithRet(rt::Locality(i), getLocBeginIt, data_.oid_,
&itd);
if (itd.locId_ != rt::numLocalities()) {
// It Data is valid
data_ = itd;
return *this;
}
}
data_ = itData(rt::numLocalities(), OIDT(0), lend, T());
return *this;
}
}
itData itd;
rt::executeAtWithRet(rt::Locality(data_.locId_), getRemoteIt, data_, &itd);
data_ = itd;
return *this;
}
set_iterator operator++(int) {
set_iterator tmp = *this;
operator++();
return tmp;
}
class local_iterator_range {
public:
local_iterator_range(local_iterator_type B, local_iterator_type E)
: begin_(B), end_(E) {}
local_iterator_type begin() { return begin_; }
local_iterator_type end() { return end_; }
private:
local_iterator_type begin_;
local_iterator_type end_;
};
static local_iterator_range local_range(set_iterator& B, set_iterator& E) {
auto setPtr = SetT::GetPtr(B.data_.oid_);
local_iterator_type lbeg, lend;
uint32_t thisLocId = static_cast<uint32_t>(rt::thisLocality());
if (B.data_.locId_ == thisLocId) {
lbeg = B.data_.lsetIt_;
} else {
lbeg = local_iterator_type::lset_begin(&(setPtr->localSet_));
}
if (E.data_.locId_ == thisLocId) {
lend = E.data_.lsetIt_;
} else {
lend = local_iterator_type::lset_end(&(setPtr->localSet_));
}
return local_iterator_range(lbeg, lend);
}
static rt::localities_range localities(set_iterator& B, set_iterator& E) {
return rt::localities_range(rt::Locality(B.data_.locId_),
rt::Locality(std::min<uint32_t>(
rt::numLocalities(), E.data_.locId_ + 1)));
}
static set_iterator iterator_from_local(set_iterator& B, set_iterator& E,
local_iterator_type itr) {
return set_iterator(static_cast<uint32_t>(rt::thisLocality()), B.data_.oid_,
itr);
}
private:
struct itData {
itData() : oid_(0), lsetIt_(nullptr, 0, 0, nullptr, nullptr) {}
itData(uint32_t locId, OIDT oid, local_iterator_type lsetIt, T element)
: locId_(locId), oid_(oid), lsetIt_(lsetIt), element_(element) {}
bool operator==(const itData& other) const {
return (locId_ == other.locId_) && (lsetIt_ == other.lsetIt_);
}
bool operator!=(itData& other) const { return !(*this == other); }
uint32_t locId_;
OIDT oid_;
local_iterator_type lsetIt_;
NonConstT element_;
};
itData data_;
static void getLocBeginIt(const OIDT& setOID, itData* res) {
auto setPtr = SetT::GetPtr(setOID);
auto lsetPtr = &(setPtr->localSet_);
auto localEnd = local_iterator_type::lset_end(lsetPtr);
auto localBegin = local_iterator_type::lset_begin(lsetPtr);
if (localBegin != localEnd) {
*res = itData(static_cast<uint32_t>(rt::thisLocality()), setOID,
localBegin, *localBegin);
} else {
*res = itData(rt::numLocalities(), OIDT(0), localEnd, T());
}
}
static void getRemoteIt(const itData& itd, itData* res) {
auto setPtr = SetT::GetPtr(itd.oid_);
auto lsetPtr = &(setPtr->localSet_);
auto localEnd = local_iterator_type::lset_end(lsetPtr);
local_iterator_type cit = itd.lsetIt_;
++cit;
if (cit != localEnd) {
*res = itData(static_cast<uint32_t>(rt::thisLocality()), itd.oid_, cit,
*cit);
return;
} else {
itData outitd;
for (uint32_t i = itd.locId_ + 1; i < rt::numLocalities(); ++i) {
rt::executeAtWithRet(rt::Locality(i), getLocBeginIt, itd.oid_, &outitd);
if (outitd.locId_ != rt::numLocalities()) {
// It Data is valid
*res = outitd;
return;
}
}
*res = itData(rt::numLocalities(), OIDT(0), localEnd, T());
}
}
};
} // namespace shad
#endif // INCLUDE_SHAD_DATA_STRUCTURES_SET_H_