emb 和 statis,那么最基本要实现的功能就是增、删、改、查,其中增对应的就是注册和初始化变量,删就是进行特征过滤,改就是进行更新,查就是从ps上把最新的值取下来,

并且这些功能需要保证一定的顺序,只有先注册才能初始化,之后才能删改查。

但是emb和statis都是函数,返回的分别是embedding_op 对应的op,加入到计算图中,那么对于变量的注册、初始化、更新就需要想其它的办法加入计算,

这里使用的是hook的机制,并且使用hook 的priority优先级机制来保证 op 调用的顺序,优先级数值越小的越先被调用,hook基类的默认优先级为2000

1.注册与初始化(增)

注册与初始化两个OP,但注册只需要在client端完成 variableinfo 的设置即可,所以每个worker都需要做,就放在workerHook中完成

初始化择需要各个ps端完成,只需要一个worker向ps请求即可,所以只在workermaster 的 ChiefHook 中完成,并且worker master 的注册也是在 chiefHook 中完成(worker master 没有 workerHook)

chiehook 的优先级是1001 ,workerhook 的优先级是1002

class SimpleSession(object):
    def __init__(self, hooks=None):
        ...
        if self._is_chief:
            self._hooks = self._hooks + [ChiefHook()]
        else:
            self._hooks = self._hooks + [WorkerHook()]
        def take_priority(elem):
            return elem._priority
        self._hooks.sort(key=take_priority)
        self._session = Session(self._hooks)
        ...
 
 
class ChiefHook(Hook):
    def __init__(self):
        super(ChiefHook, self).__init__(priority=1001)
 
    def create_session(self):
        scopes = list(get_model_scopes())
        if global_variables(scopes) is None or\
                len(global_variables(scopes)) == 0:
            return
        execute_with_retry(variable_registers(scopes))
        execute_with_retry(global_initializers(scopes))

注册最后调用的是 ps_register_variable_op,初始化使用的是 用户指定的初始化方式 如 Ones,Zeros,xdl.TruncatedNormal 等等

注册在client端就结束了 ,只是在每个worker记一下variable的信息

Status RawClient::RegisterVariable(const std::string& name, const VariableInfo& info) {
  std::lock_guard<std::mutex> lock(variable_info_mutex_);
  auto iter = args_.variable_info.find(name);
  if (iter != args_.variable_info.end()) {
    std::string arg_name = "slots";
    auto it2 = info.args.find(arg_name);
    if (it2 != info.args.end()) {
      auto it = args_.variable_info[name].args.find(arg_name);
      if (it == args_.variable_info[name].args.end()) {
        args_.variable_info[name].args[arg_name] = it2->second;
      } else {
        args_.variable_info[name].args[arg_name] += "|" + it2->second;
      }
    }
    return Status::Ok();
  }
  args_.variable_info[name] = info;
  init_variable_info_ = false;
  return Status::Ok();
}

初始化会通过partitioner 将data 的shape split 给各个ps,然后请求各个ps按照 udf 来执行simplerun, 完成 data 和 slot 的初始化

worker0 初始化的时候会调用 initop, 在调用process的时候会把本地的variableinfo发送给scheduler,在从头训练(没有checkpoint)的情况下,scheduler会调用 placement

2. filter (删)

就是各种用户指定的filter, 会在variable 带上对应的 slot, 这里所谓的删除是删除key对应的data行和slot行

需要注意的是,虽然fea_statis 或fea_score 是一起使用的,但在featurescorefilter中只计算fea_score, 而在HashSlotsUpdateHook中只更新 fea_statis,两者的计算是分开的
还有一些slot不是通过filter添加的,比如__dacay_rate 和accumulate,这些没有删除功能,只能更新,将在下一部分介绍。

下面以 featurescorefilter 为例,介绍filter的工作。

featurescorefilter调用的是PsHashFeatureScoreFilterOp,

Broadcast这样的spliter是指向所有ps广播clk_weight,non_clk_weight和train_threshold的所有内容
然后在client向各个ps发送“HashFeatureScoreFilter”这样的udf

class PsHashFeatureScoreFilterOp : public xdl::OpKernelAsync {
 public:
  ...
  void Compute(OpKernelContext* ctx, Callback done) override {
    ...
    ps::client::UdfData udf("HashFeatureScoreFilter",
                            ps::client::UdfData(0),
                            ps::client::UdfData(1),
                            ps::client::UdfData(2));

    std::vector<ps::client::Partitioner*> spliters{
      new ps::client::partitioner::Broadcast,
      new ps::client::partitioner::Broadcast,
      new ps::client::partitioner::Broadcast};

    client->Process(udf, var_name_, client->Args(nonclk_weight_, clk_weight_,
                                                 train_threshold_),
                                                 spliters, {}, outputs, cb);
  }

在ps端调用 HashFeatureScoreFilter 的simplerun

class HashFeatureScoreFilter : public SimpleUdf<float, float, float> {
 public:
  virtual Status SimpleRun(UdfContext* ctx,
                           const float& nonclk_weight,
                           const float& clk_weight,
                           const float& score_threshold) const {
    Variable* variable = ctx->GetVariable();
    ...
    //2. compute fea score
    auto score = show_vector * nonclk_weight + clk_vector * (clk_weight - nonclk_weight);
    printf("HashFeatureScoreFilter for %s fea score min %f max %f\n",
           var_name.c_str(), score.minCoeff(), score.maxCoeff());

    //3. select keys and store fea score
    std::vector<size_t> ids;
    for (size_t i = 0; i < items.size(); ++i) {
      *fea_scores->Raw<float>(items[i].id) = score(i);
      if (score(i) < score_threshold) {
        ids.push_back(items[i].id);
      }
    }

    //4. delete
    ctx->GetServerLocker()->ChangeType(QRWLocker::kWrite);
    tbb::concurrent_vector<size_t> unfiltered_ids;
    size_t del_size = hashmap->EraseById(ctx->GetVariableName(), ids, &unfiltered_ids);
    ...
    return Status::Ok();
  }
};

hashmap 完成删除工作,把id放到free_list_里面,供特征准入时重用

virtual size_t EraseById(const std::string& variable_name, std::vector<size_t>& ids, tbb::concurrent_vector<size_t>* unfiltered_ids) {
    std::atomic<size_t> size(0);
    tbb::concurrent_vector<KeyType> keys;
    std::sort(ids.begin(), ids.end());
    tbb::parallel_for_each(begin(table_), end(table_), [&](const std::pair<KeyType, size_t>& pr) {
      auto iter = std::lower_bound(ids.begin(), ids.end(), pr.second);
      if (iter != ids.end() && *iter == pr.second) {
        keys.push_back(pr.first);
        free_list_.push(pr.second);
        size++;
      } else {
        unfiltered_ids->push_back(pr.second);
      }
    });
    for (auto&& key : keys) {
      table_.unsafe_erase(key);
    }
    LOG(INFO) << "Filter for " + variable_name + ", clear=" + std::to_string(keys.size()) + ", left=" + std::to_string(table_.size());
    return size;
  }

3. 更新 (改)

这一部分我们介绍filter 的slot 和 statis 的slot的计算更新方法

decay_rate 是statis自带的,维度是label_len+1,通过label累加计算出来,通过HashSlotUpdaterHook 进行更新slot值,通过HashFeatureDecayHook对slot进行decay

accumulate 是adagrade优化器的参数,优化器负责data和自身参数的更新,我们这部分主要介绍 slot 的更新。

可以发现带有衰减率的两个slot,需要两个hook来分别完成slot的更新和decay_rate的更新,并且是先更新slot,再更新decay_rate,并且hook的优先级能够保证,先更新slot,然后才考虑各个filter。

下面我们以 __dacay_rate_xx 这个slot的更新为例,来介绍slot的更新,顺便介绍统计特征。

__dacay_rate_xx 的更新分为两部分,一是计算统计增量更新slot,二是应用decat_rate更新slot

1) 计算统计增量更新slot

先通过 FeaStatsCpuOp 计算要更新的增量

void FeaStatsFunctor<CpuDevice, T, I>::operator()(CpuDevice* d, const Tensor& sindex, const Tensor& ssegment,
                                                  const Tensor& fea_stat_input, Tensor* stat_delta) {
  ...
  size_t out_fea_stat_dim = in_fea_stat_dim + 1;
  TensorShape out_shape({ssegment.Shape()[0], out_fea_stat_dim});
  *stat_delta = Tensor(d, out_shape, DataTypeToEnum<T>::v());
  T* pout = stat_delta->Raw<T>();
  std::memset(pout, 0, sizeof(T) * out_shape.NumElements());
  ...

  #pragma omp parallel for
  for (size_t i = 0; i < id_num; ++i) {
    size_t sseg_idx = std::lower_bound(psseg, psseg + sseg_size, i + 1) - psseg;
    T* src = pclick + psindex[i] * in_fea_stat_dim;
    T* dst = pout + sseg_idx * out_fea_stat_dim;

    // show
    common::cpu_atomic_add<T>(1, dst);
    // others
    for (size_t j = 0; j < in_fea_stat_dim; j++) {
      common::cpu_atomic_add<T>(src[j], dst + j + 1);
    }
  }

};

HashSlotsUpdateHook 通过调用 client 的 ps_sparse_push_slots_op 完成对slot的更新,调用链路如下:

(client)ps_sparse_push_slots_op → (client)HashPushSlots → (ps)BuildHashSlice → (ps)AssignAddSlotUpdater

client 的 HashPushSlots 代码如下

void Client::HashPushSlots(...) {
  ...
  std::vector<Data*> inputs = Args(ids_vec, name_vec, save_ratio_vec, true, insertable);
  UdfData slice_udf("BuildHashSlice", UdfData(0), UdfData(1), UdfData(2), UdfData(3), UdfData(4));
  std::vector<std::unique_ptr<Data>>* outputs = 
    new std::vector<std::unique_ptr<Data>>;
  std::vector<Partitioner*> splitter = {
    new partitioner::HashId,
    ...
  };
  std::vector<UdfData> udf_chain;
  for (size_t i = 0; i < slot_names.size(); i++) {
    UdfData one_updater = UdfData(updaters[i], slice_udf, UdfData(5+i*2), UdfData(6+i*2));
    std::string slot_name = slot_names[i];
    inputs.push_back(Args(slot_name)[0]);
    std::vector<ps::Tensor> grad = {grads[i]};
    inputs.push_back(Args(grad)[0]);
    splitter.push_back(new partitioner::Broadcast);
    splitter.push_back(new partitioner::HashData);
    udf_chain.push_back(one_updater);
  }
  ...

  Process(udf_chain, variable_name, inputs, splitter, 
          combiner, outputs, realcb);
}  

splitter 中的 HashId 和 HashData 会分离出不同的数据发给不同的ps

Status SparseData::Split(PartitionerContext* ctx, Data* src, std::vector<Data*>* dst) {
    ...
    dst->clear();
    for (size_t i = 0; i < info->parts.size(); ++i) {
      WrapperData<std::vector<Tensor>>* result = new WrapperData<std::vector<Tensor>>();
      dst->emplace_back(result);
      ctx->AddDeleter(result);
    }
    for (size_t i = 0; i < data_vec.size(); ++i) {
      std::vector<Data*> one_dst;
      Status one_status = SplitOneSparseData(ctx, data_vec[i], &one_dst, i);
      if (!one_status.IsOk()) {
        return one_status;
      }
      for(size_t j = 0; j < one_dst.size(); ++j) {
        dynamic_cast<WrapperData<std::vector<Tensor>>*>((*dst)[j])->Internal().push_back(dynamic_cast<WrapperData<Tensor>*>(one_dst[j])->Internal());
      }
    }
    ...
}

buildhashslice 的代码上面分析过,就是会做特征准入

AssignAddSlotUpdater 完成slot的更新

class AssignAddSlotUpdater : public SimpleUdf<vector<Slices>, std::string, vector<Tensor> > {
 public:
  virtual Status SimpleRun(
      UdfContext* ctx, const vector<Slices>& sslices, const std::string& slot_name, const vector<Tensor>& grad_tensors) const {
      ...
      Tensor* assigned_tensor;
      ...
      PS_CHECK_BOOL(assigned_slice_size == slice_size, Status::ArgumentError("var " + var_name + " slot " + slot_name + " AssignAddSlotUpdater shape mismatch."));
      CASES(grad_tensor.Type(), MultiThreadDo(slices.slice_id.size(), [&](const Range& r) {
                for (size_t i = r.begin; i < r.end; i++) {
                  int64_t slice = slices.slice_id[i];
                  if ((int64_t)slice == ps::HashMap::NOT_ADD_ID) {
                    continue;
                  }
                  T* grad = grad_tensor.Raw<T>(i);
                  T* data = assigned_tensor->Raw<T>(slice);
                  for (size_t j = 0; j < slice_size; j++) {
                    *data += *grad;
                    data++;grad++;
                  }
                ...
};

2) 应用decay 更新slot (仅针对fea_statis 和 __decay_rate_xx)

每个interval 会执行一次

调用链路为(client)HashFeatureDecayHook → (client)ps_hash_feature_decay_op → (ps) HashFeatureDecay


    ...
    ps::client::UdfData udf("HashFeatureDecay",
                            ps::client::UdfData(0),
                            ps::client::UdfData(1),
                            ps::client::UdfData(2),                            
                            ps::client::UdfData(3));


    std::vector<ps::client::Partitioner*> spliters{
      new ps::client::partitioner::Broadcast,
      new ps::client::partitioner::Broadcast,
      new ps::client::partitioner::Broadcast,        
      new ps::client::partitioner::Broadcast};

    client->Process(udf, var_name_,
                    client->Args(slot_names_, decay_rates, decay_intervals, decay_mark),
                    spliters, {}, outputs, cb);
   ...

HashFeatureDecay 计算衰减


    ...
      PS_CHECK_STATUS(variable->GetExistSlot(slots[i], &tensor));
      auto& shape = tensor->Shape();
      ...
      decay_rate = pow(decay_rate, ((current_decay_mark - *last_decay_mark) * 1.0 / decay_intervals[i]));

      CASES(tensor->Type(), MultiThreadDo(items.size(), [&](const Range& r) {
            for (size_t j = r.begin; j < r.end; ++j) {
              T* dst = tensor->Raw<T>(items[j].id);
              for (size_t k = 0; k < shape[1]; ++k) {
                *(dst + k) *= decay_rate;
              }
            }
            return Status::Ok();
          }));
   ...

4. 拉取最新数据(查)

emb 和 statis 中需要拉取新数据对就是 var.gather 和 var.gather_slots 两个操作, 分别为特定的ids 从ps 上获取 data 和 slot

gather 的调用链路是 : (client)ps_sparse_pull_op→ (client)HashPull→ (ps) BuildHashSlice → (ps)TransSlice → (client) HashData.Combine

使用 HashId 对id 按照ps server 进行分隔,如果上面的 HashData, BuildHashSlice 的代码上面也分析过了,就是进行特征准入,如果这个hashkey 在hashmap里没有就加进去并初始化,TransSlice 就是获取data

Status SparseData::Combine(PartitionerContext* ctx, Data* src, size_t server_id, std::unique_ptr<Data>* output) {
  ...
  char* res_ptr = result->Raw<char>();
  shape.Set(0, 1);
  size_t single_size = shape.NumElements() * SizeOfType(type);
  MultiThreadDo(slices.ids[server_id].size(), [&](const Range& r) {
        char* src_ptr = data.Raw<char>() + r.begin * single_size;
        for (size_t i = r.begin; i < r.end; ++i) {
          size_t item = slices.ids[server_id][i];
          memcpy(res_ptr + item * single_size, src_ptr, single_size);
          src_ptr += single_size;
        }
        return Status::Ok();
      }, 1000);
  return Status::Ok();
}

gather_slot 的调用链路是 (client)ps_sparse_pull_slots_op→ (client)HashPullSlots→ (ps) BuildHashSlice → (ps) TransSlotsSlice → (ps)TransTensorSliceOffVector → (client)HashAuxiliaryData.Combine

BuildHashSlice 就是进行特征准入,如果这个hashkey 在hashmap里没有就加进去并初始化,TransSlotsSlice 就是获取 slot, combine 是把相同id 的slot的值放在一行里

std::vector<Partitioner*> combiner = {
    new partitioner::HashAuxiliaryData({slots_dim_sum})
  };


class HashAuxiliaryData : public SparseData {
 public:
  HashAuxiliaryData(const std::vector<size_t>& dims, size_t id = 0) : SparseData(id), dims_(dims) {}
  virtual Status CombineInit(PartitionerContext* ctx, std::unique_ptr<Data>* output) override;
 protected:
  std::vector<size_t> dims_;
};


Status HashAuxiliaryData::CombineInit(PartitionerContext* ctx, std::unique_ptr<Data>* output) {
  ...
  std::vector<size_t> dims;
  dims.push_back(slices.id_size);
  for (size_t i = 0; i < dims_.size(); ++i) {
    dims.push_back(dims_[i]);
  }
  TensorShape shape(dims);
  DataType type = info->datatype;
  output->reset(new WrapperData<Tensor>(type, shape, new initializer::NoneInitializer));
  return Status::Ok();
}