Blob Blame Raw
#ifndef HANDLE_H
#define HANDLE_H


#include <cassert>

#include <atomic>
#include <utility>



class Shared;
template<typename T> class THandle;
template<typename T> class TWeakHandle;



class Counter {
private:
	friend class Shared;
	friend class HandleBase;
	friend class WeakHandleBase;
	
	typedef unsigned int Type;
	std::atomic<Type> refCount;
	std::atomic<Type> ptrRefCount;
	Shared* pointer;
	
	inline explicit Counter(Shared &shared): refCount(1), ptrRefCount(0), pointer(&shared) { }
	inline ~Counter() { assert(!refCount && !ptrRefCount); }
	Counter& operator=(const Counter&) = delete;
};



class Shared {
public:
	typedef THandle<Shared> Handle;
	
private:
	friend class HandleBase;
	friend class WeakHandleBase;
	Counter * const counter;
	
	Shared(const Shared&) = delete;
	Shared& operator=(const Shared&) = delete;
public:
	inline Shared(): counter(new Counter(*this)) { }
	virtual ~Shared();
};



class WeakHandleBase {
private:
	friend class HandleBase;
	Counter* counter;
	
	inline void set(Counter* counter) {
		if (this->counter == counter) return;
		if (counter) counter->refCount++;
		if (this->counter && !--this->counter->refCount) delete this->counter;
		this->counter = counter;
	}
	
protected:
	inline WeakHandleBase(): counter() { }
	inline void set(Shared* pointer) { set(pointer ? pointer->counter : nullptr); }
	inline void set(const WeakHandleBase &other) { set(other.counter); }
	inline void swap(WeakHandleBase &other) { std::swap(counter, other.counter); }
	
public:
	inline ~WeakHandleBase()
		{ reset(); }
	inline void reset()
		{ return set((Counter*)nullptr); }
	inline bool operator<(const WeakHandleBase &other) const
		{ return counter < other.counter; }
};



class HandleBase {
private:
	Shared *pointer;
	
protected:
	inline HandleBase(): pointer() { }
	
	inline void set(Shared *pointer) {
		if (this->pointer == pointer) return;
		if (pointer) pointer->counter->ptrRefCount++;
		if (this->pointer && !--this->pointer->counter->ptrRefCount) delete this->pointer;
		this->pointer = pointer;
	}
	
	inline void set(const WeakHandleBase &weak) {
		Counter *counter = weak.counter;
		if (!counter) { set((Shared*)nullptr); return; }
		if (pointer && pointer->counter == counter) return;
	
		Shared *pointer = nullptr;
		Counter::Type cnt = counter->refCount;
		while(cnt)
			if (counter->ptrRefCount.compare_exchange_weak(cnt, cnt+1))
				{ pointer = counter->pointer; break; }
		
		if (this->pointer && !--this->pointer->counter->ptrRefCount) delete this->pointer;
		this->pointer = pointer;
	}
	
	inline void swap(HandleBase &other) { std::swap(pointer, other.pointer); }
	
	inline Shared* get() const
		{ return pointer; }
	
public:
	inline ~HandleBase()
		{ reset(); }
	
	inline void reset()
		{ return set((Shared*)nullptr); }
	inline operator bool() const
		{ return pointer; }
	inline bool operator<(const HandleBase &other) const
		{ return pointer < other.pointer; }
	
	inline bool operator==(const HandleBase &other) const
		{ return pointer == other.pointer; }
	inline bool operator==(const Shared *pointer) const
		{ return this->pointer == pointer; }
	inline friend bool operator==(const Shared *pointer, const HandleBase &handle)
		{ return pointer == handle.pointer; }
	
	inline bool operator!=(const HandleBase &other) const
		{ return pointer != other.pointer; }
	inline bool operator!=(const Shared *pointer) const
		{ return this->pointer != pointer; }
	inline friend bool operator!=(const Shared *pointer, const HandleBase &handle)
		{ return pointer != handle.pointer; }
};



template<typename T>
class TWeakHandle: public WeakHandleBase {
public:
	typedef T Type;
	typedef THandle<Type> Handle;
	
private:
	inline void typeChecker(const Type*) { }
	
public:
	inline TWeakHandle() { }
	inline TWeakHandle(const std::nullptr_t) { }
	inline TWeakHandle(TWeakHandle &&other) { WeakHandleBase::swap(other); }
	inline TWeakHandle(const TWeakHandle &other) { WeakHandleBase::set(other); }
	inline explicit TWeakHandle(const Handle &handle) { WeakHandleBase::set(handle.pointer()); }
	
	inline TWeakHandle& operator=(TWeakHandle &&other)
		{ WeakHandleBase::swap(other); return *this; }
	inline TWeakHandle& operator=(const TWeakHandle &other)
		{ WeakHandleBase::set(other); return *this; }
	inline TWeakHandle& operator=(const Handle &handle)
		{ WeakHandleBase::set(handle.pointer()); return *this; }
	
	inline void swap(TWeakHandle &other)
		{ WeakHandleBase::swap(other); return *this; }
		
	template<typename TT>
	inline operator TWeakHandle<TT>&()
		{ typeChecker((TT*)nullptr); return *reinterpret_cast<TWeakHandle<TT>*>(this); }
	template<typename TT>
	inline operator const TWeakHandle<TT>&() const
		{ typeChecker((TT*)nullptr); return *reinterpret_cast<const TWeakHandle<TT>*>(this); }
};



template<typename T>
class THandle: public HandleBase {
public:
	typedef T Type;
	typedef TWeakHandle<Type> Weak;
	
private:
	inline void typeChecker(const Type*) { }
	
public:
	inline THandle() { }
	inline THandle(const std::nullptr_t) { }
	inline THandle(THandle &&other) { HandleBase::swap(other); }
	inline THandle(const THandle &other) { HandleBase::set(other.get()); }
	inline explicit THandle(const Weak &weak) { HandleBase::set(weak); }
	inline explicit THandle(Type *pointer) { HandleBase::set(pointer); }
	
	inline THandle& operator=(THandle &&other)
		{ HandleBase::swap(other); return *this; }
	inline THandle& operator=(const THandle &other)
		{ HandleBase::set(other.get()); return *this; }
	inline THandle& operator=(const Weak &weak)
		{ HandleBase::set(weak); return *this; }
	inline THandle& operator=(Type *pointer)
		{ HandleBase::set(pointer); return *this; }
	
	inline void swap(THandle &other)
		{ HandleBase::swap(other); return *this; }
		
	inline Type* pointer() const { return (Type*)HandleBase::get(); }
	inline Type* operator->() const { assert(pointer()); return pointer(); }
	inline Type& operator*() const { assert(pointer()); return *pointer(); }
	
	template<typename TT>
	inline operator THandle<TT>&()
		{ typeChecker((TT*)nullptr); return *reinterpret_cast<THandle<TT>*>(this); }
	template<typename TT>
	inline operator const THandle<TT>&() const
		{ typeChecker((TT*)nullptr); return *reinterpret_cast<const THandle<TT>*>(this); }
};



#endif