网络广告策划流程/成都网站建设seo
一般来说,从FATE框架中获得数据使用get_component('name').get_output_data()
。
但是这样子在目前的1.x的FATE中,只能以分类、回归的格式输出才能获得。
如果是图片、文本、token embedding等,用这种方式根本拿不到模型的输出。
经过跟FATE社区人员交涉,社区肯定了这种方法拿不出。并且给了个方法,在自定义的trainer中的predict
函数,直接保存输出。不在通过上述方法获得。
只能说现在只能先这样用了。
如何自定义trainer,在官方文档有。
trainer中的predict部分部分原代码如下,直接在这里面添加save model prediction就行:
def _predict(self, dataset: Dataset):pred_result = []# switch eval modedataset.eval()self.model.eval()labels = []# 直接在这里save predictionpred = self.model(images)torch.save('./xxxx',pred)length=len(dataset.get_sample_ids())ret_rs = torch.rand(length,1)ret_label = torch.rand(length, 1).int()return dataset.get_sample_ids(), ret_rs, ret_labeldef predict(self, dataset: Dataset):ids, ret_rs, ret_label=self._predict(dataset)if self.fed_mode:return self.format_predict_result(ids, ret_rs, ret_label, task_type=self.task_type)else:return ret_rs, ret_label
在上述代码我返回了一些假的数据,因为如果返回数据的格式不符合,Fateboard会直接报错,无法进入到下一步。所以放在那里,没用。