diff --git a/configure.py b/configure.py index 8572fa7fdb249d1ae6f39202b871d17cd44755f8..6279c4261044e3f33519ece5f2ac19af2acb505d 100644 --- a/configure.py +++ b/configure.py @@ -25,10 +25,12 @@ import re import subprocess import sys +# pylint: disable=g-import-not-at-top try: from shutil import which except ImportError: from distutils.spawn import find_executable as which +# pylint: enable=g-import-not-at-top _TF_BAZELRC = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.tf_configure.bazelrc') @@ -485,7 +487,10 @@ def set_cc_opt_flags(environ_cp): cc_opt_flags = get_from_env_or_user_or_default(environ_cp, 'CC_OPT_FLAGS', question, default_cc_opt_flags) for opt in cc_opt_flags.split(): - write_to_bazelrc('build:opt --cxxopt=%s --copt=%s' % (opt, opt)) + host_opt = '-march=native' # It should be safe on the same build host. + write_to_bazelrc( + 'build:opt --cxxopt=%s --copt=%s' % (opt, opt) + + ' --host_cxxopt=%s --host_copt=%s' % (host_opt, host_opt)) def set_tf_cuda_clang(environ_cp): diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 893175373f670b38b03b3492ddffb16102708160..6ef4860f35835e59be3452b57204d42c82d0816b 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -130,7 +130,9 @@ Status CopySubgraph(const Graph& graph, const Frame* frame, stack.push_back(src); } Node* src_copy = (*node_map)[e->src()->id()]; - int src_output = squash_src_outputs[e->src()->id()] ? 0 : e->src_output(); + int src_output = squash_src_outputs[e->src()->id()] && !e->IsControlEdge() + ? 0 + : e->src_output(); Node* dst_copy = (*node_map)[e->dst()->id()]; output->AddEdge(src_copy, src_output, dst_copy, e->dst_input()); } diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 2c5d910d58d8e486deb1f644cbda11261b93a84e..e420f21ca33fe7de9b33f404ce04eae62d9c041e 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -77,18 +77,6 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( out_shape.dim_sizes()); } - // Degenerate case: single slice. - if (num_indices == 1) { - auto index = builder->Reshape(indices, {1}); - auto start_index = builder->Pad( - index, XlaHelpers::Zero(builder, index_type), - xla::MakeEdgePaddingConfig( - {{input_shape_pre_axis.dims(), input_shape_post_axis.dims()}})); - auto slice = - builder->DynamicSlice(input, start_index, slice_shape.dim_sizes()); - return builder->Reshape(slice, out_shape.dim_sizes()); - } - // Specify the shape of the loop-carried Tensor tuple. xla::PrimitiveType ptype; TF_CHECK_OK(DataTypeToPrimitiveType(dtype, &ptype)); diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 24774c4c2a385d9aabd22a550bd8be3acf409d85..763d94e94c2167f47b3f0777a31815f02791aa9e 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -1309,7 +1309,7 @@ Status ComputationBuilder::SetReturnValue( } StatusOr ComputationBuilder::IsConstant( - const ComputationDataHandle& operand) { + const ComputationDataHandle& operand, int64 num_parameters) { if (!first_error_.ok()) { return first_error_; } @@ -1317,6 +1317,7 @@ StatusOr ComputationBuilder::IsConstant( IsConstantRequest request; *request.mutable_computation() = computation_.handle(); *request.mutable_operand() = operand; + request.set_num_parameters(num_parameters); IsConstantResponse response; VLOG(2) << "making IsConstant request"; @@ -1330,7 +1331,8 @@ StatusOr ComputationBuilder::IsConstant( } StatusOr> ComputationBuilder::ComputeConstant( - const ComputationDataHandle& operand, const Layout* output_layout) { + const ComputationDataHandle& operand, const Layout* output_layout, + tensorflow::gtl::ArraySlice parameters) { if (!first_error_.ok()) { return first_error_; } @@ -1341,6 +1343,9 @@ StatusOr> ComputationBuilder::ComputeConstant( if (output_layout != nullptr) { *request.mutable_output_layout() = *output_layout; } + for (const auto& param : parameters) { + *request.add_parameters() = param.ToProto(); + } ComputeConstantResponse response; diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index bc7ad06a3fec4aecbe11312982542c9d615b4911..8e1b4be1f3ebf8e3f530b053447f86f7a2f56fa7 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -746,11 +746,12 @@ class ComputationBuilder { ComputationDataHandle Recv(const Shape& shape, const ChannelHandle& handle); // Returns true if 'operand' is a compile-time constant. A compile-time - // constant does not depend on parameters, or on stateful operators such - // as `RngNormal` or `Infeed`. Unlike `ComputeConstant`, `IsConstant` tests - // whether a computation is a compile-time constant without evaluating the - // computation. - StatusOr IsConstant(const ComputationDataHandle& operand); + // constant does not depend on parameters with higher index then + // `num_parameters`, or on stateful operators such as `RngNormal` or `Infeed`. + // Unlike `ComputeConstant`, `IsConstant` tests whether a computation is a + // compile-time constant without evaluating the computation. + StatusOr IsConstant(const ComputationDataHandle& operand, + int64 num_parameters = 0); // Normalizes operand across spatial and batch dimensions for each feature. // @@ -795,7 +796,7 @@ class ComputationBuilder { float epsilon, int64 feature_index); // Computes the value of a constant indicated by a - // ComputationDataHandle. + // ComputationDataHandle using a non-optimized interpreter on the host. // // The operand must be from the computation currently being built - // i.e., returned from this builder with no intervening call to @@ -803,8 +804,11 @@ class ComputationBuilder { // that may stop working at any time. // // The operand must represent a constant value, which in this case - // means that it must not statically depend on a parameter to the - // computation that is being built. + // means that it must not statically depend on any parameter of the + // computation that is being built other then the ones specified on the + // paramtere list. The parameters in the list will be indexed by their + // parameter id property so the number of parameters specified should be at + // least as many as the largest used parameter index. // // `IsConstant` can be used to test whether a computation is a compile-time // constant without evaluation it. `ComputeConstant` only succeeds for @@ -822,7 +826,8 @@ class ComputationBuilder { // will be stored using that layout. StatusOr> ComputeConstant( const ComputationDataHandle& operand, - const Layout* output_layout = nullptr); + const Layout* output_layout = nullptr, + tensorflow::gtl::ArraySlice parameters = {}); // Returns a new ComputationBuilder whose resultant Computation is used only // by this ComputationBuilder. The sub-ComputationBuilder has the same diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 8536429846f87fd5c4b073cc4b13b3f1c5eb2e5c..b422b22df9cfbefb6611fcb229ed42e67fe3a0d8 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -101,6 +101,11 @@ BufferAllocationProto BufferAllocation::ToProto() const { proto_assigned->set_offset(buffer_offset_size.second.offset); proto_assigned->set_size(buffer_offset_size.second.size); } + std::sort(proto.mutable_assigned()->begin(), proto.mutable_assigned()->end(), + [](const BufferAllocationProto::Assigned& assign1, + const BufferAllocationProto::Assigned& assign2) { + return assign1.logical_buffer_id() < assign2.logical_buffer_id(); + }); return proto; } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 487ea003be643a9c7d48dc2d0037ba6a0ae498dd..f46764cba0ad6ef174a89951c251613c69b4b083 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -322,7 +322,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // binary size (and most AOT applications are single-threaded). // TODO(29630486) Support multi-threaded AOT. pipeline.AddPass(max_parallelism, - ShapeSizeBytesFunction(), module); + ShapeSizeBytesFunction()); } // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc index b49047283119fb2f10b9f68eaa37a7bdc27f63a6..81c29e4726c7be53b433be896f558f502e43c885 100644 --- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc @@ -52,7 +52,7 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module, llvm::IRBuilder<> ir_builder(vector_tanh_body); llvm::FastMathFlags fast_math_flags; - fast_math_flags.setUnsafeAlgebra(); + fast_math_flags.setFast(); ir_builder.setFastMathFlags(fast_math_flags); llvm::Value* input = &*vector_tanh_function->arg_begin(); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index c2213c8f2ef592c537daf9abe2ffa10b83a8fa4c..4a62a80fac0c89d8e1cf66f16f07fca0ffbaa2d3 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -101,11 +101,9 @@ class DefaultCostModel : public ParallelCostModel { const std::unique_ptr cost_analysis_; }; - ParallelTaskAssignment::ParallelTaskAssignment( const int64 max_parallelism, - const HloCostAnalysis::ShapeSizeFunction& shape_size, - HloModule* module) { + const HloCostAnalysis::ShapeSizeFunction& shape_size, HloModule* module) { VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism; // Run cost analysis on 'module'. auto cost_analysis = MakeUnique(shape_size); @@ -153,7 +151,6 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( StatusOr ParallelTaskAssigner::Run(HloModule* module) { XLA_VLOG_LINES(2, "ParallelTaskAssigner ENTRY"); XLA_VLOG_LINES(3, module->ToString()); - // Compute target parallel task counts for all instructions in 'module'. HloToParallelTasks hlo_to_parallel_tasks; ComputeTargetParallelTasks(module, &hlo_to_parallel_tasks); @@ -230,6 +227,9 @@ bool ParallelTaskAssigner::AssignParallelTasksHelper( void ParallelTaskAssigner::ComputeTargetParallelTasks( HloModule* module, HloToParallelTasks* hlo_to_parallel_tasks) { + ParallelTaskAssignment parallel_task_assignment(max_parallelism_, + shape_size_function_, module); + // Compute parallel task counts for all instructions in 'module'. for (auto* computation : module->computations()) { if (computation->IsFusionComputation()) { @@ -238,7 +238,7 @@ void ParallelTaskAssigner::ComputeTargetParallelTasks( for (auto* instruction : computation->instructions()) { // Query ParallelTaskAssignment for target parallel task count. const int64 target_parallel_task_count = - parallel_task_assignment_.GetTargetParallelTaskCount(instruction); + parallel_task_assignment.GetTargetParallelTaskCount(instruction); if (target_parallel_task_count > 1) { hlo_to_parallel_tasks->insert( {instruction, target_parallel_task_count}); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h index e036da5784f6151eb3b01107ec7f3ab820071a60..5801ec8d270cdaed7f2f65c24987a9ea643edb02 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h @@ -37,10 +37,9 @@ class ParallelTaskAssignment { // 'shape_size': shape size function used by HloCostAnalysis during parallel // task assignment. // 'module': the containing HloModule. - ParallelTaskAssignment( - const int64 max_parallelism, - const HloCostAnalysis::ShapeSizeFunction& shape_size, - HloModule* module); + ParallelTaskAssignment(const int64 max_parallelism, + const HloCostAnalysis::ShapeSizeFunction& shape_size, + HloModule* module); ~ParallelTaskAssignment() {} // Computes and returns the target parallel task count for 'instruction'. @@ -63,11 +62,9 @@ class ParallelTaskAssigner : public HloPassInterface { // 'max_parallelism': the maximum parallel task count per instruction. // 'shape_size': shape size function used by HloCostAnalysis during parallel // task assignment. - // 'module': the containing HloModule. ParallelTaskAssigner(const int64 max_parallelism, - const HloCostAnalysis::ShapeSizeFunction& shape_size, - HloModule* module) - : parallel_task_assignment_(max_parallelism, shape_size, module) {} + const HloCostAnalysis::ShapeSizeFunction& shape_size) + : max_parallelism_(max_parallelism), shape_size_function_(shape_size) {} ~ParallelTaskAssigner() override {} tensorflow::StringPiece name() const override { @@ -95,7 +92,8 @@ class ParallelTaskAssigner : public HloPassInterface { void ComputeTargetParallelTasks(HloModule* module, HloToParallelTasks* hlo_to_parallel_tasks); - ParallelTaskAssignment parallel_task_assignment_; + int64 max_parallelism_; + HloCostAnalysis::ShapeSizeFunction shape_size_function_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index fd4c332cba94513ec5b4cd88a842189e716f35d5..a945657712aae46093cd016d23114f26b8a2d926 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -1289,6 +1289,15 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const int64 rank = ShapeUtil::Rank(input_hlo->shape()); llvm_ir::IrArray::Index slice_start_index(rank); llvm_ir::IrArray::Index slice_limit_index(rank); + // Slice starts at update[index - slice_start_index_adjusted], + // where adjusted value = slice_start_index when in bounds, and + // adjusted value = slice_start_index - input_dim, when wrapping. + llvm_ir::IrArray::Index slice_start_index_adjusted(rank); + + // Slice intersection gathers (ANDs) conditions on all ranks for which + // 'input' is set to 'update' + llvm::Value* slice_intersection = ir_builder_->getTrue(); + for (int64 i = 0; i < rank; ++i) { // Emit IR to read dynamic start indices from 'start_hlo'. llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); @@ -1298,38 +1307,97 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( AsStringRef(IrName(hlo, StrCat("start_idx", i)))); slice_start_index[i] = ir_builder_->CreateZExtOrBitCast( start_index_value, index[i]->getType()); - // Emit IR to compute: slice_limit_index = start_index + update_dim - // NOTE: Although 'start_indices' is dynamic and could be - // out-of-range, we do not compute 'slice_limit_index' mod input dim - // size here, because subsequent array index calculations will be - // computed mod input dim size for safety. + + llvm::Value* input_dim_size = llvm::ConstantInt::get( + index[i]->getType(), input_hlo->shape().dimensions(i)); llvm::Value* update_dim_size = llvm::ConstantInt::get( index[i]->getType(), update_hlo->shape().dimensions(i)); + + // Generate code to handle wrapping semantics: + // slice_start_index[i] = slice_start_index[i] % input_dim_size; + // slice_limit_index[i] = slice_start_index[i] + update_dim_size. + // slice_start_index[i] is updated in place and it will now be in + // range. slice_limit_index[i] may be out of range, and it's being + // URem-ed below if so. + slice_start_index[i] = + ir_builder_->CreateURem(slice_start_index[i], input_dim_size); slice_limit_index[i] = ir_builder_->CreateAdd(slice_start_index[i], update_dim_size); - } - // Check if 'index' intersects start/end indices. - llvm::Value* slice_intersection = - llvm::ConstantInt::get(ir_builder_->getInt1Ty(), 1); - - for (int64 i = 0; i < rank; ++i) { - // Check that index[i] >= slice_start_index[i]. - slice_intersection = ir_builder_->CreateAnd( + // Test if slice_limit_index[i] is in bounds + llvm::Value* in_bounds = + ir_builder_->CreateICmpULE(slice_limit_index[i], input_dim_size); + llvm_ir::LlvmIfData if_in_bounds = + llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); + + // Handle true BB (slice_limit_index[i] <= input_dim_size). + SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_); + // Check that index[i] >= slice_start_index[i] && + // index[i] < slice_limit_index[i] + llvm::Value* slice_intersection_in_bounds = ir_builder_->CreateAnd( slice_intersection, ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), - "slice_intersection"); - - // Check that index[i] < slice_limit_index[i]. - slice_intersection = ir_builder_->CreateAnd( - slice_intersection, + "slice_intersection_in"); + slice_intersection_in_bounds = ir_builder_->CreateAnd( + slice_intersection_in_bounds, ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]), - "slice_intersection"); + "slice_intersection_in"); + + // Handle false BB (slice_limit_index[i] > input_dim_size). + SetToFirstInsertPoint(if_in_bounds.false_block, ir_builder_); + // Check that index[i] >= slice_start_index[i] || + // index[i] < slice_limit_index[i]%input_dim_size. + llvm::Value* index_wraps = ir_builder_->CreateICmpSLT( + index[i], + ir_builder_->CreateURem(slice_limit_index[i], input_dim_size)); + llvm::Value* slice_intersection_or = ir_builder_->CreateOr( + ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), + index_wraps, "slice_intersection_out"); + llvm::Value* slice_intersection_out_of_bounds = + ir_builder_->CreateAnd(slice_intersection, slice_intersection_or, + "slice_intersection_out"); + // Create value for slice_start_index_adjusted[i] when out of bounds. + // If within out-of-bounds if. + llvm_ir::LlvmIfData if_start_needs_adjustment = + llvm_ir::EmitIfThenElse(index_wraps, "adjust_start", ir_builder_); + SetToFirstInsertPoint(if_start_needs_adjustment.true_block, + ir_builder_); + llvm::Value* slice_start_index_adjusted_oob = + ir_builder_->CreateSub(slice_start_index[i], input_dim_size); + SetToFirstInsertPoint(if_start_needs_adjustment.after_block, + ir_builder_); + llvm::PHINode* slice_start_index_adjusted_phi = + ir_builder_->CreatePHI(slice_start_index_adjusted_oob->getType(), + 2); + slice_start_index_adjusted_phi->addIncoming( + slice_start_index_adjusted_oob, + if_start_needs_adjustment.true_block); + slice_start_index_adjusted_phi->addIncoming( + slice_start_index[i], if_start_needs_adjustment.false_block); + // End of if within if. + + // After checking in/out of bounds. + SetToFirstInsertPoint(if_in_bounds.after_block, ir_builder_); + llvm::PHINode* phi_slice_intersection = + ir_builder_->CreatePHI(slice_intersection->getType(), 2); + phi_slice_intersection->addIncoming(slice_intersection_in_bounds, + if_in_bounds.true_block); + phi_slice_intersection->addIncoming( + slice_intersection_out_of_bounds, + if_start_needs_adjustment.after_block); + slice_intersection = phi_slice_intersection; + + llvm::PHINode* phi_index = + ir_builder_->CreatePHI(slice_start_index[i]->getType(), 2); + phi_index->addIncoming(slice_start_index[i], if_in_bounds.true_block); + phi_index->addIncoming(slice_start_index_adjusted_phi, + if_start_needs_adjustment.after_block); + slice_start_index_adjusted[i] = phi_index; } // Emit: // if (slice_intersection) -> return data from 'update'. - // else -> return data from 'index'. + // else -> return data from 'input'. llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), @@ -1337,7 +1405,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( slice_intersection, "slice_intersection", ir_builder_); - // Handle true BB. + // Handle true BB (return data from 'update') SetToFirstInsertPoint(if_data.true_block, ir_builder_); // Compute update index for intersection case. llvm_ir::IrArray::Index update_index(rank); @@ -1346,14 +1414,14 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( index[i]->getType(), update_hlo->shape().dimensions(i)); // NOTE: Subtraction will be positive due to bounds checking above. update_index[i] = ir_builder_->CreateURem( - ir_builder_->CreateSub(index[i], slice_start_index[i]), + ir_builder_->CreateSub(index[i], slice_start_index_adjusted[i]), update_dim_size); } TF_ASSIGN_OR_RETURN(llvm::Value * true_value, operand_to_generator.at(update_hlo)(update_index)); ir_builder_->CreateStore(true_value, ret_value_addr); - // Handle false BB. + // Handle false BB (return data from 'input') SetToFirstInsertPoint(if_data.false_block, ir_builder_); TF_ASSIGN_OR_RETURN(llvm::Value * false_value, operand_to_generator.at(input_hlo)(index)); diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 2d32e59d36c4e3026e0e151561db3076146fabe4..7e0d182b365c35788195e70dc35c3923ed8991bb 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -88,6 +88,16 @@ class Executable { tensorflow::gtl::ArraySlice> arguments); + // Populates `hlo_execution_profile` from `executor`. This is implicit in any + // Execute* API call that takes a hlo_execution_profile argument, but must be + // called explicitly for other (async, for example) variants after the stream + // has completed. + virtual Status PopulateExecutionProfile( + HloExecutionProfile* hlo_execution_profile, + perftools::gputools::StreamExecutor* executor) { + return Status::OK(); + } + // Convenience wrapper for calling Executable::ExecuteOnStream. Sets up a // timer for the execution, sets up HLO profiling if enabled, and fills in the // given ExecutionProfile if non-null. The ExecuteOnStream overloads have diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index d3c83ea72e33b959e21d0cc9c1706d92bd659a5c..b4fbed1562945adeb52a9471453ed4fee0e35180 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -197,7 +197,7 @@ Status GenericTransferManager::ResetDevices( "Device reset is not yet supported on this platform (b/30481585)"); } -int64 GenericTransferManager::GetByteSizeRequirement(const Shape& shape) { +int64 GenericTransferManager::GetByteSizeRequirement(const Shape& shape) const { return ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*)); } diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 26488d6ec651b75c753119a7ce818c692c6c03dd..ef9a50676a4171b56e8a77d2dedc05b1580e5ea5 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -78,7 +78,7 @@ class GenericTransferManager : public TransferManager { const Shape& shape, perftools::gputools::DeviceMemoryBase* region) override; - int64 GetByteSizeRequirement(const Shape& shape) override; + int64 GetByteSizeRequirement(const Shape& shape) const override; private: // The platform this transfer manager targets. diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index b9c4adce93a88cb48635993b6e9999528d78ec07..364b76b93c288f13f2bf447cebfc25f705d77826 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -470,6 +470,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:cuda_libdevice_path", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", "@llvm//:core", "@llvm//:support", diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 2caa8f60517c66c1708e52481f01727f0008afd9..ceb0e530c151219c7fef4dd6bfa36013cb53d63c 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -67,6 +67,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/cuda_libdevice_path.h" @@ -226,46 +227,63 @@ tensorflow::Status PrepareHloModuleForIrEmitting( return pipeline.Run(hlo_module).status(); } -// Invokes the ptxas tool on the given PTX string, and dumps its output. -void DumpPtxasInfo(const string& ptx, int cc_major, int cc_minor) { +// Compiles the given PTX string using ptxas and returns the resulting machine +// code (i.e. a cubin) as a byte array. +StatusOr> CompilePtx(const string& ptx, int cc_major, + int cc_minor) { const string ptxas_path = - tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin/ptxas"); - // Do not log PTX stats if ptxas is not found at the given path. - if (!tensorflow::Env::Default()->FileExists(ptxas_path).ok()) { - LOG(WARNING) - << "Failed to dump PTX stats because ptxas is not found at path \"" - << ptxas_path << "\"."; - return; + tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin", "ptxas"); + VLOG(2) << "Using ptxas at " << ptxas_path; + auto env = tensorflow::Env::Default(); + TF_RETURN_IF_ERROR(env->FileExists(ptxas_path)); + + // Write ptx into a temporary file. + string ptx_path; + if (!env->LocalTempFilename(&ptx_path)) { + return InternalError("couldn't get temp PTX file name"); } + auto ptx_cleaner = tensorflow::gtl::MakeCleanup([&ptx_path] { + TF_CHECK_OK(tensorflow::Env::Default()->DeleteFile(ptx_path)); + }); - // Write `ptx` into a temporary file. - char tempdir_template[] = "/tmp/ptxXXXXXX"; - char* tempdir_name = mkdtemp(tempdir_template); - CHECK_NOTNULL(tempdir_name); - string ptx_path = tensorflow::io::JoinPath(tempdir_name, "ptx"); - TF_CHECK_OK( - tensorflow::WriteStringToFile(tensorflow::Env::Default(), ptx_path, ptx)); - LOG(INFO) << "ptx file written to: " << ptx_path; + TF_RETURN_IF_ERROR(tensorflow::WriteStringToFile(env, ptx_path, ptx)); + VLOG(2) << "ptx written to: " << ptx_path; // Invoke ptxas and collect its output. + string cubin_path; + if (!env->LocalTempFilename(&cubin_path)) { + return InternalError("couldn't get temp CUBIN file name"); + } + auto cubin_cleaner = tensorflow::gtl::MakeCleanup([&cubin_path] { + TF_CHECK_OK(tensorflow::Env::Default()->DeleteFile(cubin_path)); + }); tensorflow::SubProcess ptxas_info_dumper; - ptxas_info_dumper.SetProgram(ptxas_path, - {ptxas_path, ptx_path, "-o", "/dev/null", "-v", - StrCat("-arch=sm_", cc_major, cc_minor)}); + std::vector ptxas_args = {ptxas_path, ptx_path, "-o", cubin_path, + StrCat("-arch=sm_", cc_major, cc_minor)}; + if (VLOG_IS_ON(2)) { + ptxas_args.push_back("-v"); + } + ptxas_info_dumper.SetProgram(ptxas_path, ptxas_args); ptxas_info_dumper.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE); if (!ptxas_info_dumper.Start()) { - LOG(ERROR) << "Failed to launch ptxas."; - return; + return InternalError("Failed to launch ptxas"); } string stderr_output; int exit_status = ptxas_info_dumper.Communicate( /*stdin_input=*/nullptr, /*stdout_output=*/nullptr, &stderr_output); XLA_LOG_LINES(tensorflow::INFO, stderr_output); if (exit_status != 0) { - LOG(ERROR) << "ptxas exited with non-zero error code " << exit_status - << "."; + return InternalError("ptxas exited with non-zero error code %d", + exit_status); } + + // Read in the result of compilation and return it as a byte vector. + string cubin; + TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(), + cubin_path, &cubin)); + std::vector cubin_vector(cubin.begin(), cubin.end()); + return cubin_vector; } } // namespace @@ -318,7 +336,7 @@ StatusOr> GpuCompiler::Compile( // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, buffer_assignment->ToString()); - + XLA_VLOG_LINES(2, module->ToString()); const string xla_dump_hlo_proto_to = module->config().debug_options().xla_dump_hlo_proto_to(); if (!xla_dump_hlo_proto_to.empty()) { @@ -359,15 +377,10 @@ StatusOr> GpuCompiler::Compile( /*optimized=*/false)); } - string* ptx; string libdevice_dir; { tensorflow::mutex_lock lock(mutex_); - // Reserve space for the PTX to be generated for this module. - generated_ptxes_.emplace_back(MakeUnique()); - ptx = generated_ptxes_.back().get(); - // Find the directory containing libdevice. To avoid searching for it every // time, we have a one-element cache, keyed on the module's config's // cuda_data_dir. @@ -389,8 +402,9 @@ StatusOr> GpuCompiler::Compile( cc_minor = 0; } - TF_ASSIGN_OR_RETURN(*ptx, CompileToPtx(&llvm_module, {cc_major, cc_minor}, - module->config(), libdevice_dir)); + TF_ASSIGN_OR_RETURN(string ptx, + CompileToPtx(&llvm_module, {cc_major, cc_minor}, + module->config(), libdevice_dir)); if (!ir_dump_directory.empty()) { TF_RETURN_IF_ERROR(llvm_ir::DumpIRToDirectory( @@ -405,10 +419,10 @@ StatusOr> GpuCompiler::Compile( VLOG(2) << "LLVM module after optimizations:"; XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(llvm_module)); VLOG(2) << "PTX:"; - XLA_VLOG_LINES(2, *ptx); - if (VLOG_IS_ON(2)) { - DumpPtxasInfo(*ptx, cc_major, cc_minor); - } + XLA_VLOG_LINES(2, ptx); + + const std::vector cubin = + CompilePtxOrGetCachedResult(ptx, cc_major, cc_minor); auto thunk_schedule = MakeUnique( ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment), @@ -417,7 +431,8 @@ StatusOr> GpuCompiler::Compile( XLA_VLOG_LINES(2, thunk_schedule->ToString()); auto* gpu_executable = - new GpuExecutable(*ptx, std::move(thunk_schedule), std::move(module), + new GpuExecutable(ptx, cubin, {cc_major, cc_minor}, + std::move(thunk_schedule), std::move(module), std::move(buffer_assignment), ShapeSizeBytesFunction()); if (embed_ir_in_executable) { DCHECK_NE("", ir_module_string_before_opt); @@ -426,6 +441,61 @@ StatusOr> GpuCompiler::Compile( return std::unique_ptr(gpu_executable); } +std::vector GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx, + int cc_major, + int cc_minor) { + bool inserted; + decltype(compilation_cache_.begin()) iter; + // Pointers into compilation_cache_ where the ptx and (optional) cubin are + // stored. + const string* cache_ptx = nullptr; + CompilationCacheValue* cache_value = nullptr; + + { + tensorflow::mutex_lock lock(mutex_); + std::tie(iter, inserted) = compilation_cache_.emplace( + std::piecewise_construct, + std::forward_as_tuple(ptx, cc_major, cc_minor), + std::forward_as_tuple()); + cache_ptx = &iter->first.ptx; + cache_value = &iter->second; + } + + // Compile the ptx if it wasn't in the cache before we called this function. + // Other threads asking for the same compilation key will block on + // cache_value->mutex_ until compilation is done. + { + tensorflow::mutex_lock lock(cache_value->mutex_); + if (inserted) { + CHECK(!cache_value->compilation_done); + if (!ptx.empty()) { + StatusOr> maybe_cubin = + CompilePtx(*cache_ptx, cc_major, cc_minor); + if (maybe_cubin.ok()) { + cache_value->cubin_data = std::move(maybe_cubin).ValueOrDie(); + VLOG(2) << "Compiled PTX size:" << ptx.size() + << " CUBIN size: " << cache_value->cubin_data.size(); + } else { + LOG(WARNING) + << "Failed to compile ptx to cubin. Will attempt to let " + "GPU driver compile the ptx. " + << maybe_cubin.status(); + } + } + cache_value->compilation_done = true; + cache_value->compilation_done_cv_.notify_all(); + } else { + while (!cache_value->compilation_done) { + cache_value->compilation_done_cv_.wait(lock); + } + } + } + + CHECK(cache_value != nullptr); + CHECK(cache_value->compilation_done); + return cache_value->cubin_data; +} + StatusOr>> GpuCompiler::Compile( std::vector> modules, std::vector> stream_execs) { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h index 7a4c4b00d9ad0d895d6b326d2e58f3becdac56d0..ee67e65caf2434fc74503d07c6fccb98de70d96c 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -26,6 +26,8 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -86,10 +88,57 @@ class GpuCompiler : public LLVMCompiler { string cached_cuda_data_dir_ GUARDED_BY(mutex_); string cached_libdevice_dir_ GUARDED_BY(mutex_); - // The list of PTX strings generated by this GpuCompiler. We let GpuCompiler - // to own them because they need to be alive across the life span of the - // StreamExecutor (b/24776264). - std::vector> generated_ptxes_ GUARDED_BY(mutex_); + // Tries to compile the given ptx string to cubin. Returns a vector with the + // compiled cubin. If compilation was unsuccessful, returns an empty vector. + std::vector CompilePtxOrGetCachedResult(const string& ptx, + int cc_major, int cc_minor); + + // The compilation_cache_ map is a cache from {ptx string, cc_major, cc_minor} + // -> cubin so we don't recompile the same ptx twice. This is important for + // some interactive workflows. (We also cache at the HLO level, but sometimes + // we can't realize that two modules are the same until we lower to ptx.) + // + // Compilation of distinct PTX happens in parallel. If more than one thread + // attempts to compile the same PTX, the fist thread to obtain + // cache_value_->mutex_ performs the compilation. The rest wait() on + // cache_value_->compilation_done_cv_ until the compilation is done. + // + // If compiling the ptx fails, we return an empty cubin, cross our fingers, + // and leave compilation up to the driver. + struct CompilationCacheKey { + CompilationCacheKey(std::string ptx, int cc_major, int cc_minor) + : ptx(std::move(ptx)), cc_major(cc_major), cc_minor(cc_minor) {} + string ptx; + int cc_major; + int cc_minor; + }; + struct CompilationCacheHash { + size_t operator()(const CompilationCacheKey& key) const { + return tensorflow::Hash64Combine( + tensorflow::Hash64Combine(tensorflow::Hash64(key.ptx), key.cc_major), + key.cc_minor); + } + }; + struct CompilationCacheEq { + size_t operator()(const CompilationCacheKey& a, + const CompilationCacheKey& b) const { + return a.cc_major == b.cc_major && a.cc_minor == b.cc_minor && + a.ptx == b.ptx; + } + }; + struct CompilationCacheValue { + bool compilation_done = false; + std::vector cubin_data; + // mutex and condition variable to serialize compilation completing. + tensorflow::mutex mutex_; + tensorflow::condition_variable compilation_done_cv_; + }; + + // Don't even think about switching this to FlatMap; iterator stability is + // critical here. + std::unordered_map + compilation_cache_ GUARDED_BY(mutex_); TF_DISALLOW_COPY_AND_ASSIGN(GpuCompiler); }; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 254d0d770560b32298533f04139ab2f6c9a167ce..c6f23f9b0506186c4f76a887e6a540dafdd79962 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -108,13 +108,16 @@ class HloExecutionProfiler { // Implementation note: HLO profiling is always enabled for GPU executables, // since we can use timers around thunks. GpuExecutable::GpuExecutable( - tensorflow::StringPiece ptx, + const string& ptx, const std::vector& cubin, + std::pair compute_capability, std::unique_ptr thunk_schedule, std::unique_ptr hlo_module, std::unique_ptr assignment, HloCostAnalysis::ShapeSizeFunction shape_size_function) : Executable(std::move(hlo_module)), ptx_(ptx), + cubin_(cubin), + compute_capability_(compute_capability), thunk_schedule_(std::move(thunk_schedule)), assignment_(std::move(assignment)), shape_size_function_(std::move(shape_size_function)) {} @@ -125,6 +128,16 @@ Status GpuExecutable::ExecuteThunks( HloExecutionProfile* hlo_execution_profile) { se::Stream* main_stream = run_options->stream(); + std::pair stream_compute_compatibility; + main_stream->parent()->GetDeviceDescription().cuda_compute_capability( + &stream_compute_compatibility.first, + &stream_compute_compatibility.second); + TF_RET_CHECK(stream_compute_compatibility == compute_capability_) + << "Compute capability mismatch; expected {" << compute_capability_.first + << ", " << compute_capability_.second << "}, but was {" + << stream_compute_compatibility.first << ", " + << stream_compute_compatibility.second << "}"; + bool do_profile = hlo_execution_profile != nullptr; if (do_profile) { LOG(WARNING) << "PROFILING: profiling is enabled"; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 748a8f521bc5293d58de19ab52f4bdecec6cb1e5..a3815370c19af1da612bc6d9663cc0f8896062f7 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -47,7 +47,10 @@ namespace gpu { // This is an immutable data type after initialization, and thus thread safe. class GpuExecutable : public Executable { public: - GpuExecutable(tensorflow::StringPiece ptx, + // cubin (i.e. the compiled ptx) may be empty, in which case we leave + // compilation up to the GPU driver. + GpuExecutable(const string& ptx, const std::vector& cubin, + std::pair compute_capability, std::unique_ptr thunk_schedule, std::unique_ptr hlo_module, std::unique_ptr assignment, @@ -64,6 +67,13 @@ class GpuExecutable : public Executable { // Returns the compiled PTX for the computation. tensorflow::StringPiece ptx() const { return ptx_; } + // Returns the cubin (compiled PTX) stored in this GpuExecutable. May be + // empty, in which case compilation is left up to the GPU driver. + const std::vector& cubin() const { return cubin_; } + + // Both overloads of ExecuteOnStream will fail if the compute capability of + // the stream doesn't match the compute capability passed to this object's + // constructor. StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice @@ -110,8 +120,17 @@ class GpuExecutable : public Executable { // This string should be modified only before ExecuteOnStream. string ir_module_string_; - // The reference to the compiled PTX for the computation. - const tensorflow::StringPiece ptx_; + // The PTX for the computation. + const string ptx_; + + // The GPU machine code for the computation, targeting GPUs at + // compute_capability_. + // + // May be empty, in which case we leave compilation up to the GPU driver. + const std::vector cubin_; + + // The compute capability of the GPU we're targeting with this GpuExecutable. + std::pair compute_capability_; // The thunks to be invoked by this GpuExecutable. They are generated by the // IrEmitter. diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index 69399e36c4c4faa7c6ed5c79a3f094490f022001..96606993696354f36e143b3b994bbe6afb902df3 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -48,6 +48,12 @@ tensorflow::Status KernelThunk::Initialize(const GpuExecutable& executable) { // StreamExecutor uses the latter. loader_spec_->AddCudaPtxInMemory( se::port::StringPiece(ptx.data(), ptx.size()), kernel_name_); + + if (!executable.cubin().empty()) { + loader_spec_->AddCudaCubinInMemory( + reinterpret_cast(executable.cubin().data()), kernel_name_); + } + return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index e09899e48d389716226661c5ae61defb066362aa..5107ac782d7c93dfa17969338bf97c9fd9bb1516 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1901,12 +1901,13 @@ std::vector HloInstruction::ExtraAttributesToString() const { if (has_sharding()) { extra.push_back(StrCat("sharding=", sharding().ToString())); } - if (!control_successors_.empty()) { - extra.push_back(StrCat( - "control-successors=", - Join(control_successors_, ", ", [](string* out, HloInstruction* succ) { - StrAppend(out, succ->name()); - }))); + if (!control_predecessors_.empty()) { + extra.push_back(StrCat("control-predecessors={", + Join(control_predecessors_, ", ", + [](string* out, HloInstruction* pre) { + StrAppend(out, pre->name()); + }), + "}")); } return extra; } diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index aaa4e3a2e3a7523a6779f86ba0c20cb48c23f600..f463e57d995c0f0549872a1a0bf20a3ead626dc8 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -41,11 +41,21 @@ namespace se = ::perftools::gputools; namespace xla { /*static*/ StatusOr> -HloRunner::ReadModuleFromHloProtoFile(const char* filename, +HloRunner::ReadModuleFromHloProtoFile(const std::string& filename, const DebugOptions& debug_options) { HloProto proto; - TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), - filename, &proto)); + + const Status s = + tensorflow::ReadBinaryProto(tensorflow::Env::Default(), filename, &proto); + + if (!s.ok()) { + const Status s2 = + tensorflow::ReadTextProto(tensorflow::Env::Default(), filename, &proto); + if (!s2.ok()) { + return Status(s2.code(), s.error_message() + "\n" + s2.error_message()); + } + } + TF_ASSIGN_OR_RETURN( HloModuleConfig config, HloModule::CreateModuleConfigFromProto(proto.hlo_module())); @@ -56,7 +66,7 @@ HloRunner::ReadModuleFromHloProtoFile(const char* filename, } /*static*/ StatusOr> -HloRunner::ReadModuleFromHloTextDumpFile(const char* filename, +HloRunner::ReadModuleFromHloTextDumpFile(const std::string& filename, const DebugOptions& debug_options) { string hlo_string; TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(), @@ -66,6 +76,19 @@ HloRunner::ReadModuleFromHloTextDumpFile(const char* filename, return tools::Parse(hlo_string, config); } +/*static*/ StatusOr> HloRunner::ReadModule( + const std::string& filename, const DebugOptions& debug_options) { + auto module = HloRunner::ReadModuleFromHloProtoFile(filename, debug_options); + if (module.ok()) { + return module; + } + const std::string e = module.status().error_message(); + module = HloRunner::ReadModuleFromHloTextDumpFile(filename, debug_options); + return module.ok() ? std::move(module) + : Status(module.status().code(), + e + "\n" + module.status().error_message()); +} + // Define this in .cc file to avoid having to include eigen or forward declare // these types in the header. struct HloRunner::EigenThreadPoolWrapper { diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index b0e2b980e22d3348f63a528cd49f0b814540d549..a5732848c6b4191faf8d7b07c749132ca8b14413 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -44,15 +44,23 @@ class HloRunner { ~HloRunner(); - // Reads the binary proto file in xla.HloProto format, creates and returns the - // HloModule. + // Reads the proto file in xla.HloProto format, creates and returns the + // HloModule. Will try to parse the filename as binary proto, then try as + // text proto if that fails. static StatusOr> ReadModuleFromHloProtoFile( - const char* filename, const DebugOptions& debug_options); + const std::string& filename, const DebugOptions& debug_options); // Reads the hlo text dump file in HloModule::ToString format, creates and // returns the HloModule. static StatusOr> ReadModuleFromHloTextDumpFile( - const char* filename, const DebugOptions& debug_options); + const std::string& filename, const DebugOptions& debug_options); + + // Tries to parse the filename specified first as binary proto format, then + // as a textual proto format, then textual IR, then gives up if both fail. + // ReadModuleFromHloProtoFile or ReadModuleFromHloTextDumpFile should be used + // explicitly when you know the format, this if you don't. + static StatusOr> ReadModule( + const std::string& filename, const DebugOptions& debug_options); // Executes the given module with given literals as input and returns the // result as a Literal. The LiteralPtr type accepts Literal* or diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 5dff4b5778970dd473c5f158b3828a850847d1ff..956c0d5f05288e32c626f247ce8356c60d17808d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -555,8 +555,9 @@ int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout) { llvm::FastMathFlags GetFastMathFlags(bool fast_math_enabled) { llvm::FastMathFlags flags; if (fast_math_enabled) { - // UnsafeAlgebra implies NoInfs, NoNaNs, NoSignedZeros, and AllowReciprocal. - flags.setUnsafeAlgebra(); + // Fast implies AllowReassoc, NoInfs, NoNaNs, NoSignedZeros, + // AllowReciprocal, AllowContract, and ApproxFunc. + flags.setFast(); } return flags; } diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index bac33d8102e07766531a4ce6eac77aff4971bfef..71afbee456b0f5eb67cb092d84f8e95ea1038c54 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -490,14 +490,20 @@ Service::ExecuteParallelAndRegisterResult( std::vector> arguments, Backend* backend, tensorflow::gtl::ArraySlice device_handles, - tensorflow::gtl::ArraySlice result_tags) { + tensorflow::gtl::ArraySlice result_tags, + ExecutionProfile* profile) { // Streams where the computation are launched, so we can wait on the streams // to complete. std::vector::SmartPtr> streams; + std::vector> timers; // Global data handles for the computation results, one for each computation. std::vector result_handles; + // Device ID to stream executor, populated only with devices that are being + // profiled. + std::map index_to_profiled_streams; + TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, backend->computation_placer()->AssignDevices( options_.number_of_replicas(), executables.size())); @@ -510,6 +516,21 @@ Service::ExecuteParallelAndRegisterResult( backend->BorrowStream(replicas[replica])); streams.push_back(std::move(stream)); + if (replica == 0 && profile != nullptr) { + timers.emplace_back( + new perftools::gputools::Timer(streams.back()->parent())); + streams.back() + ->InitTimer(timers.back().get()) + .ThenStartTimer(timers.back().get()); + CHECK(timers.front() != nullptr); + } + + if (replica == 0 && + executables[i]->module_config().debug_options().xla_hlo_profile() && + executables[i]->hlo_profiling_enabled()) { + index_to_profiled_streams[i] = streams.back().get(); + } + // Set up run options. ExecutableRunOptions options; options.set_stream(streams.back().get()); @@ -526,6 +547,10 @@ Service::ExecuteParallelAndRegisterResult( perftools::gputools::DeviceMemoryBase result, executables[i]->ExecuteAsyncOnStream(&run_options, arguments[i])); + if (replica == 0 && profile != nullptr) { + streams.back()->ThenStopTimer(timers.back().get()); + } + // All replicas share the same device address for the result allocation, // so only one of the replicas need to register the result handle. if (replica == 0) { @@ -543,6 +568,69 @@ Service::ExecuteParallelAndRegisterResult( } } + // For every stream that had profiling enabled, obtain and debug-dump the HLO + // profile. + for (auto& index_to_profiled_stream : index_to_profiled_streams) { + int64 device = index_to_profiled_stream.first; + se::Stream* stream = index_to_profiled_stream.second; + HloExecutionProfile hlo_profile; + TF_RETURN_IF_ERROR(executables[device]->PopulateExecutionProfile( + &hlo_profile, stream->parent())); + + std::unordered_set profiled_computations = + hlo_profile.profiled_computations(); + // To ensure we have print the profiles in a stable order, iterate over the + // computations in post order. + auto& module = executables[device]->module(); + std::list all_computations = + module.MakeComputationPostOrder(); + for (xla::HloComputation* computation : all_computations) { + if (profiled_computations.count(computation) > 0) { + string profile_string = hlo_profile.ToString( + *computation, streams[0]->parent()->GetDeviceDescription(), + executables[device]->CreateCostAnalysis().get()); + if (!profile_string.empty()) { + LOG(INFO) << "HLO profile for execution on device " << device + << ":\n"; + XLA_LOG_LINES(tensorflow::INFO, profile_string); + } + } + } + hlo_graph_dumper::MaybeDumpHloModule(module, "Service::Execute", + &hlo_profile); + } + + if (profile != nullptr) { + CHECK(!timers.empty()); + std::vector timer_nanoseconds; + timer_nanoseconds.reserve(timers.size()); + for (auto& timer : timers) { + timer_nanoseconds.push_back(timer->Nanoseconds()); + } + uint64 nanoseconds = + *std::max_element(timer_nanoseconds.begin(), timer_nanoseconds.end()); + + // Merge in run-time profile information from execution_profile on the + // zeroth device. + profile->MergeFrom(executables[0]->execution_profile()); + + // Overall execution time (in nanoseconds) from the executor timer. + profile->set_compute_and_transfer_time_ns(nanoseconds); + + // TODO(b/28123297): On GPU we end up including transfer time in + // the compute time this way. Instead, we should get the correct + // value by measuring it. Setting the field here at least lets + // benchmarks provide *some* value for GPU computations. + // + // TODO(b/28447609): The value in compute_and_transfer_time_ns is actually + // the compute time without the transfer time, so this way we get the + // correct compute time. We should instead have the correct value for + // compute_and_transfer_time and set compute_time to the compute time. + if (profile->compute_time_ns() == 0) { + profile->set_compute_time_ns(profile->compute_and_transfer_time_ns()); + } + } + return result_handles; } @@ -715,14 +803,16 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // Execute the generated executables in parallel and return the device // handles for each computation's output. + ExecutionProfile profile; TF_ASSIGN_OR_RETURN( std::vector outputs, ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments, execute_backend_.get(), device_handles, - computation_names)); + computation_names, &profile)); for (const GlobalDataHandle& output : outputs) { ExecuteResponse response; *response.mutable_output() = output; + *response.mutable_profile() = profile; *result->add_responses() = response; } @@ -1082,8 +1172,9 @@ tensorflow::Status Service::IsConstant(const IsConstantRequest* arg, return InvalidArgument("computations may not be empty"); } - TF_ASSIGN_OR_RETURN(bool is_constant, - user_computation->IsConstant(arg->operand())); + TF_ASSIGN_OR_RETURN( + bool is_constant, + user_computation->IsConstant(arg->operand(), arg->num_parameters())); result->set_is_constant(is_constant); return tensorflow::Status::OK(); @@ -1101,8 +1192,9 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, return InvalidArgument("computations may not be empty"); } - TF_ASSIGN_OR_RETURN(bool is_constant, - user_computation->IsConstant(arg->operand())); + TF_ASSIGN_OR_RETURN( + bool is_constant, + user_computation->IsConstant(arg->operand(), arg->parameters_size())); if (!is_constant) { return InvalidArgument("Operand to ComputeConstant depends on parameter."); } @@ -1141,8 +1233,18 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, /*include_unreachable_instructions=*/ false)); + std::vector parameters(arg->parameters_size()); + for (int64 i = 0; i < arg->parameters_size(); ++i) { + parameters[i] = Literal(arg->parameters(i)); + } + std::vector parameter_ptrs; + std::transform(parameters.begin(), parameters.end(), + std::back_inserter(parameter_ptrs), + [](const Literal& literal) { return &literal; }); + HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate(*module, {})); + TF_ASSIGN_OR_RETURN(auto result_literal, + evaluator.Evaluate(*module, parameter_ptrs)); // Since the shape_with_output_layout option in ExecutionOption is // non-effective to the Evaluator results, explicit relayout here. if (arg->has_output_layout()) { diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 2452259f736054b5bf1f03fc5103d65eded7f398..6646be2e9aa43763b93bcea7a1df9d10580f162c 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -327,7 +327,8 @@ class Service : public ServiceInterface { arguments, Backend* backend, tensorflow::gtl::ArraySlice device_handles, - tensorflow::gtl::ArraySlice result_tags); + tensorflow::gtl::ArraySlice result_tags, + ExecutionProfile* profile); // Convenience function for adding a function to a user computation. template diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index f63d91604cf40edfae98b56a8bacdbded697ffc3..057bdffe93164e9bb7271157556961575666359d 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -119,7 +119,7 @@ class TransferManager { // Determines the byte size requirement for the given shape on the underlying // architecture. This will be used to allocate an appropriately sized memory // region for a host-to-device transfer. - virtual int64 GetByteSizeRequirement(const Shape& shape) = 0; + virtual int64 GetByteSizeRequirement(const Shape& shape) const = 0; // Transfer a memory block of the given size from the device source into the // 'destination' buffer. diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 006c814996df9b209e6cd4d75bc04689c4e297c5..e9d182509b5356d32b667b7921e2843d30faeb9b 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -1482,14 +1482,15 @@ UserComputation::ComputeProgramShape( namespace { -// A visitor which checks whether an operation is a compile-time constant. That -// is, the operation does not depend on any parameter instructions. The visitor -// walks the computation starting at a given operation and sets is_constant to -// false iff a parameter or RNG operation is encountered. -void ConstantVisitor(const SessionComputation& session_computation, - const ComputationDataHandle& handle, - std::set* visited, bool* is_constant) { - if (visited->count(handle.handle()) != 0 || !*is_constant) { +// A visitor which checks whether an operation is pure functional meaning that +// it doesn't depend on any parameter with an index higher then num_parameters. +// The visitor walks the computation starting at a given operation and sets +// is_functional to false iff a parameter or RNG operation is encountered. +void PureFunctionalVisitor(const SessionComputation& session_computation, + const ComputationDataHandle& handle, + int64 num_parameters, std::set* visited, + bool* is_functional) { + if (visited->count(handle.handle()) != 0 || !*is_functional) { return; } @@ -1497,7 +1498,7 @@ void ConstantVisitor(const SessionComputation& session_computation, session_computation.requests().at(handle.handle()); switch (request.request().op_case()) { case OpRequest::kRngRequest: - *is_constant = false; + *is_functional = false; break; case OpRequest::kConstantRequest: @@ -1506,41 +1507,43 @@ void ConstantVisitor(const SessionComputation& session_computation, case OpRequest::kGetTupleElementRequest: { const GetTupleElementRequest& get_tuple_element_request = request.request().get_tuple_element_request(); - ConstantVisitor(session_computation, get_tuple_element_request.operand(), - visited, is_constant); + PureFunctionalVisitor(session_computation, + get_tuple_element_request.operand(), num_parameters, + visited, is_functional); break; } case OpRequest::kSliceRequest: { const SliceRequest& slice_request = request.request().slice_request(); - ConstantVisitor(session_computation, slice_request.operand(), visited, - is_constant); + PureFunctionalVisitor(session_computation, slice_request.operand(), + num_parameters, visited, is_functional); break; } case OpRequest::kDynamicSliceRequest: { const DynamicSliceRequest& dynamic_slice_request = request.request().dynamic_slice_request(); - ConstantVisitor(session_computation, dynamic_slice_request.operand(), - visited, is_constant); - ConstantVisitor(session_computation, - dynamic_slice_request.start_indices(), visited, - is_constant); + PureFunctionalVisitor(session_computation, + dynamic_slice_request.operand(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + dynamic_slice_request.start_indices(), + num_parameters, visited, is_functional); break; } case OpRequest::kDynamicUpdateSliceRequest: { const DynamicUpdateSliceRequest& dynamic_update_slice_request = request.request().dynamic_update_slice_request(); - ConstantVisitor(session_computation, - dynamic_update_slice_request.operand(), visited, - is_constant); - ConstantVisitor(session_computation, - dynamic_update_slice_request.update(), visited, - is_constant); - ConstantVisitor(session_computation, - dynamic_update_slice_request.start_indices(), visited, - is_constant); + PureFunctionalVisitor(session_computation, + dynamic_update_slice_request.operand(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + dynamic_update_slice_request.update(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + dynamic_update_slice_request.start_indices(), + num_parameters, visited, is_functional); break; } @@ -1549,7 +1552,8 @@ void ConstantVisitor(const SessionComputation& session_computation, request.request().concatenate_request(); for (const ComputationDataHandle& handle : concatenate_request.operands()) { - ConstantVisitor(session_computation, handle, visited, is_constant); + PureFunctionalVisitor(session_computation, handle, num_parameters, + visited, is_functional); } break; } @@ -1557,61 +1561,63 @@ void ConstantVisitor(const SessionComputation& session_computation, case OpRequest::kConvolveRequest: { const ConvolveRequest& convolve_request = request.request().convolve_request(); - ConstantVisitor(session_computation, convolve_request.lhs(), visited, - is_constant); - ConstantVisitor(session_computation, convolve_request.rhs(), visited, - is_constant); + PureFunctionalVisitor(session_computation, convolve_request.lhs(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, convolve_request.rhs(), + num_parameters, visited, is_functional); break; } case OpRequest::kCrossReplicaSumRequest: { // TODO(b/33009255): Implmement constant folding for cross replica sum. - *is_constant = false; + *is_functional = false; break; } case OpRequest::kInfeedRequest: { - *is_constant = false; + *is_functional = false; break; } case OpRequest::kOutfeedRequest: { - *is_constant = false; + *is_functional = false; break; } case OpRequest::kCallRequest: { const CallRequest& call_request = request.request().call_request(); for (const ComputationDataHandle& handle : call_request.operands()) { - ConstantVisitor(session_computation, handle, visited, is_constant); + PureFunctionalVisitor(session_computation, handle, num_parameters, + visited, is_functional); } // TODO(b/32495713): We aren't checking the to_apply computation itself, // so we conservatively say that computations containing the Call op - // cannot be constant. We cannot set is_constant=false in other similar + // cannot be constant. We cannot set is_functional=false in other similar // cases since we're already relying on IsConstant to return true. - *is_constant = false; + *is_functional = false; break; } case OpRequest::kCustomCallRequest: { - *is_constant = false; + *is_functional = false; break; } case OpRequest::kSendRequest: { - *is_constant = false; + *is_functional = false; break; } case OpRequest::kRecvRequest: { - *is_constant = false; + *is_functional = false; break; } case OpRequest::kMapRequest: { const MapRequest& map_request = request.request().map_request(); for (const ComputationDataHandle& handle : map_request.operands()) { - ConstantVisitor(session_computation, handle, visited, is_constant); + PureFunctionalVisitor(session_computation, handle, num_parameters, + visited, is_functional); } // TODO(b/32495713): We aren't checking the to_apply computation itself. break; @@ -1619,10 +1625,10 @@ void ConstantVisitor(const SessionComputation& session_computation, case OpRequest::kReduceRequest: { const ReduceRequest& reduce_request = request.request().reduce_request(); - ConstantVisitor(session_computation, reduce_request.operand(), visited, - is_constant); - ConstantVisitor(session_computation, reduce_request.init_value(), visited, - is_constant); + PureFunctionalVisitor(session_computation, reduce_request.operand(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, reduce_request.init_value(), + num_parameters, visited, is_functional); // TODO(b/32495713): We aren't checking the to_apply computation itself. break; } @@ -1630,10 +1636,12 @@ void ConstantVisitor(const SessionComputation& session_computation, case OpRequest::kReduceWindowRequest: { const ReduceWindowRequest& reduce_window_request = request.request().reduce_window_request(); - ConstantVisitor(session_computation, reduce_window_request.operand(), - visited, is_constant); - ConstantVisitor(session_computation, reduce_window_request.init_value(), - visited, is_constant); + PureFunctionalVisitor(session_computation, + reduce_window_request.operand(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + reduce_window_request.init_value(), num_parameters, + visited, is_functional); // TODO(b/32495713): We aren't checking the to_apply computation itself. break; } @@ -1641,13 +1649,15 @@ void ConstantVisitor(const SessionComputation& session_computation, case OpRequest::kSelectAndScatterRequest: { const SelectAndScatterRequest& select_and_scatter_request = request.request().select_and_scatter_request(); - ConstantVisitor(session_computation, select_and_scatter_request.operand(), - visited, is_constant); - ConstantVisitor(session_computation, select_and_scatter_request.source(), - visited, is_constant); - ConstantVisitor(session_computation, - select_and_scatter_request.init_value(), visited, - is_constant); + PureFunctionalVisitor(session_computation, + select_and_scatter_request.operand(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + select_and_scatter_request.source(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + select_and_scatter_request.init_value(), + num_parameters, visited, is_functional); // TODO(b/32495713): We aren't checking the select and scatter // computations themselves. break; @@ -1656,76 +1666,80 @@ void ConstantVisitor(const SessionComputation& session_computation, case OpRequest::kBroadcastRequest: { const BroadcastRequest& broadcast_request = request.request().broadcast_request(); - ConstantVisitor(session_computation, broadcast_request.operand(), visited, - is_constant); + PureFunctionalVisitor(session_computation, broadcast_request.operand(), + num_parameters, visited, is_functional); break; } case OpRequest::kReshapeRequest: { const ReshapeRequest& reshape_request = request.request().reshape_request(); - ConstantVisitor(session_computation, reshape_request.operand(), visited, - is_constant); + PureFunctionalVisitor(session_computation, reshape_request.operand(), + num_parameters, visited, is_functional); break; } case OpRequest::kReverseRequest: { const ReverseRequest& reverse_request = request.request().reverse_request(); - ConstantVisitor(session_computation, reverse_request.operand(), visited, - is_constant); + PureFunctionalVisitor(session_computation, reverse_request.operand(), + num_parameters, visited, is_functional); break; } case OpRequest::kPadRequest: { const PadRequest& pad_request = request.request().pad_request(); - ConstantVisitor(session_computation, pad_request.operand(), visited, - is_constant); - ConstantVisitor(session_computation, pad_request.padding_value(), visited, - is_constant); + PureFunctionalVisitor(session_computation, pad_request.operand(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, pad_request.padding_value(), + num_parameters, visited, is_functional); break; } case OpRequest::kParameterRequest: { - *is_constant = false; + const ParameterRequest& parameter_request = + request.request().parameter_request(); + if (parameter_request.parameter() >= num_parameters) { + *is_functional = false; + } break; } case OpRequest::kConvertRequest: { const ConvertRequest& convert_request = request.request().convert_request(); - ConstantVisitor(session_computation, convert_request.operand(), visited, - is_constant); + PureFunctionalVisitor(session_computation, convert_request.operand(), + num_parameters, visited, is_functional); break; } case OpRequest::kWhileRequest: { const WhileRequest& while_request = request.request().while_request(); - ConstantVisitor(session_computation, while_request.init(), visited, - is_constant); + PureFunctionalVisitor(session_computation, while_request.init(), + num_parameters, visited, is_functional); // TODO(b/32495713): We aren't checking the condition and body // computations themselves. - *is_constant = false; + *is_functional = false; break; } case OpRequest::kTernaryOpRequest: { const TernaryOpRequest& ternary_op_request = request.request().ternary_op_request(); - ConstantVisitor(session_computation, ternary_op_request.lhs(), visited, - is_constant); - ConstantVisitor(session_computation, ternary_op_request.rhs(), visited, - is_constant); - ConstantVisitor(session_computation, ternary_op_request.ehs(), visited, - is_constant); + PureFunctionalVisitor(session_computation, ternary_op_request.lhs(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, ternary_op_request.rhs(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, ternary_op_request.ehs(), + num_parameters, visited, is_functional); break; } case OpRequest::kTransposeRequest: { const TransposeRequest& transpose_request = request.request().transpose_request(); - ConstantVisitor(session_computation, transpose_request.operand(), visited, - is_constant); + PureFunctionalVisitor(session_computation, transpose_request.operand(), + num_parameters, visited, is_functional); break; } @@ -1734,7 +1748,8 @@ void ConstantVisitor(const SessionComputation& session_computation, request.request().variadic_op_request(); for (const ComputationDataHandle& handle : variadic_op_request.operands()) { - ConstantVisitor(session_computation, handle, visited, is_constant); + PureFunctionalVisitor(session_computation, handle, num_parameters, + visited, is_functional); } break; } @@ -1742,67 +1757,74 @@ void ConstantVisitor(const SessionComputation& session_computation, case OpRequest::kUnaryOpRequest: { const UnaryOpRequest& unary_op_request = request.request().unary_op_request(); - ConstantVisitor(session_computation, unary_op_request.operand(), visited, - is_constant); + PureFunctionalVisitor(session_computation, unary_op_request.operand(), + num_parameters, visited, is_functional); break; } case OpRequest::kBatchNormTrainingRequest: { const BatchNormTrainingRequest& batch_norm_training_request = request.request().batch_norm_training_request(); - ConstantVisitor(session_computation, - batch_norm_training_request.operand(), visited, - is_constant); - ConstantVisitor(session_computation, batch_norm_training_request.scale(), - visited, is_constant); - ConstantVisitor(session_computation, batch_norm_training_request.offset(), - visited, is_constant); + PureFunctionalVisitor(session_computation, + batch_norm_training_request.operand(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_training_request.scale(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_training_request.offset(), + num_parameters, visited, is_functional); break; } case OpRequest::kBatchNormInferenceRequest: { const BatchNormInferenceRequest& batch_norm_inference_request = request.request().batch_norm_inference_request(); - ConstantVisitor(session_computation, - batch_norm_inference_request.operand(), visited, - is_constant); - ConstantVisitor(session_computation, batch_norm_inference_request.scale(), - visited, is_constant); - ConstantVisitor(session_computation, - batch_norm_inference_request.offset(), visited, - is_constant); - ConstantVisitor(session_computation, batch_norm_inference_request.mean(), - visited, is_constant); - ConstantVisitor(session_computation, - batch_norm_inference_request.variance(), visited, - is_constant); + PureFunctionalVisitor(session_computation, + batch_norm_inference_request.operand(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_inference_request.scale(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_inference_request.offset(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_inference_request.mean(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_inference_request.variance(), + num_parameters, visited, is_functional); break; } case OpRequest::kBatchNormGradRequest: { const BatchNormGradRequest& batch_norm_grad_request = request.request().batch_norm_grad_request(); - ConstantVisitor(session_computation, batch_norm_grad_request.operand(), - visited, is_constant); - ConstantVisitor(session_computation, batch_norm_grad_request.scale(), - visited, is_constant); - ConstantVisitor(session_computation, batch_norm_grad_request.mean(), - visited, is_constant); - ConstantVisitor(session_computation, batch_norm_grad_request.variance(), - visited, is_constant); - ConstantVisitor(session_computation, - batch_norm_grad_request.grad_output(), visited, - is_constant); + PureFunctionalVisitor(session_computation, + batch_norm_grad_request.operand(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_grad_request.scale(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, batch_norm_grad_request.mean(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_grad_request.variance(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_grad_request.grad_output(), + num_parameters, visited, is_functional); break; } case OpRequest::kBinaryOpRequest: { const BinaryOpRequest& binary_op_request = request.request().binary_op_request(); - ConstantVisitor(session_computation, binary_op_request.lhs(), visited, - is_constant); - ConstantVisitor(session_computation, binary_op_request.rhs(), visited, - is_constant); + PureFunctionalVisitor(session_computation, binary_op_request.lhs(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, binary_op_request.rhs(), + num_parameters, visited, is_functional); break; } @@ -1817,8 +1839,8 @@ void ConstantVisitor(const SessionComputation& session_computation, } // namespace -StatusOr UserComputation::IsConstant( - const ComputationDataHandle& handle) { +StatusOr UserComputation::IsConstant(const ComputationDataHandle& handle, + int64 num_parameters) { tensorflow::mutex_lock lock(mutex_); // Verify that the handle is valid. @@ -1829,7 +1851,8 @@ StatusOr UserComputation::IsConstant( bool is_constant = true; std::set visited; - ConstantVisitor(session_computation_, handle, &visited, &is_constant); + PureFunctionalVisitor(session_computation_, handle, num_parameters, &visited, + &is_constant); return is_constant; } diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h index dabf68e298ed2600d5248b7b8c7b1e014efedb14..ac879ce55a75f6241a39f935b79017be46c1816b 100644 --- a/tensorflow/compiler/xla/service/user_computation.h +++ b/tensorflow/compiler/xla/service/user_computation.h @@ -250,9 +250,11 @@ class UserComputation { StatusOr> ComputeProgramShape( VersionedComputationHandle::Version version) const; - // Returns true if the given data handle does not depend on any - // parameters. That is, the value can be computed at compile time. - StatusOr IsConstant(const ComputationDataHandle& handle); + // Returns true if the given data handle does not depend on any parameter with + // index higher then num_parameters. That is, the value can be computed at + // compile time if we know the first num_parameters arguments. + StatusOr IsConstant(const ComputationDataHandle& handle, + int64 num_parameters); // Returns the output shape of the operation indicated by the given handle. StatusOr GetShape(const ComputationDataHandle& handle); diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index b2e9743af79d0e4658451e7a9522c338036851ba..d423c78476dde18d209b5efac9e8f77da41bfeb4 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -71,24 +71,27 @@ class ComputeConstantTest : public ::testing::Test { StatusOr> ComputeConstantLiteral( Client* client, const ComputationDataHandle& operand, - ComputationBuilder* builder, Layout* output_layout = nullptr) { - TF_ASSIGN_OR_RETURN(auto computed, - builder->ComputeConstant(operand, output_layout)); + ComputationBuilder* builder, Layout* output_layout = nullptr, + tensorflow::gtl::ArraySlice parameters = {}) { + TF_ASSIGN_OR_RETURN(auto computed, builder->ComputeConstant( + operand, output_layout, parameters)); return std::move(computed); } template - StatusOr ComputeConstantScalar(Client* client, - const ComputationDataHandle& operand, - ComputationBuilder* builder) { - TF_ASSIGN_OR_RETURN(auto literal, - ComputeConstantLiteral(client, operand, builder)); + StatusOr ComputeConstantScalar( + Client* client, const ComputationDataHandle& operand, + ComputationBuilder* builder, + tensorflow::gtl::ArraySlice parameters = {}) { + TF_ASSIGN_OR_RETURN( + auto literal, + ComputeConstantLiteral(client, operand, builder, nullptr, parameters)); return literal->Get({}); } bool IsConstant(const ComputationDataHandle& operand, - ComputationBuilder* builder) { - StatusOr result = builder->IsConstant(operand); + ComputationBuilder* builder, int64 num_parameters = 0) { + StatusOr result = builder->IsConstant(operand, num_parameters); EXPECT_TRUE(result.ok()) << result.status(); return result.ok() ? result.ValueOrDie() : false; } @@ -138,7 +141,25 @@ TEST_F(ComputeConstantTest, ScalarRng) { } } -TEST_F(ComputeConstantTest, DirectParam) { +TEST_F(ComputeConstantTest, Param) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); + auto param = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "lhs"); + auto computation = b.Add(param, b.ConstantR0(1.5f)); + + std::vector arguments; + arguments.emplace_back(*Literal::CreateR0(42.5f)); + EXPECT_TRUE(IsConstant(computation, &b, arguments.size())); + + auto value = + ComputeConstantScalar(client, computation, &b, arguments); + ASSERT_TRUE(value.ok()) << value.status(); + EXPECT_EQ(value.ValueOrDie(), 44.0f); + } +} + +TEST_F(ComputeConstantTest, DirectParamMissing) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); ComputationBuilder b(client, TestName()); @@ -152,7 +173,7 @@ TEST_F(ComputeConstantTest, DirectParam) { } } -TEST_F(ComputeConstantTest, IndirectParam) { +TEST_F(ComputeConstantTest, IndirectParamMissing) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); ComputationBuilder b(client, TestName()); diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 19252f50f25eee42e4e492b7f0e2ec3960c62126..ab8047c7480f43ba1fd7ca3ad22448e0dd890089 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -250,9 +250,6 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { // Slice at dimension boundaries. RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {5}, {0, 1, 2, 3, 4, 8, 9, 10}); - // Slice at dimension boundaries, but with sizes that cause indices to wrap. - RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {6}, - {0, 1, 2, 3, 4, 5, 8, 9}); // Zero-sized update. RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {}, {2}, {0, 1, 2, 3, 4, 5, 6, 7}); @@ -269,9 +266,6 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { // Slice at dimension boundaries. RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {2, 1}, {{1, 2, 3}, {4, 5, 6}, {7, 10, 11}}); - // Slice at dimension boundaries, but with sizes that cause indices to wrap. - RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {2, 2}, - {{1, 2, 3}, {4, 5, 6}, {7, 8, 10}}); // Zero-sized update. RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{}}, {2, 1}, {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); @@ -289,10 +283,20 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { RunR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {{{13}, {15}}}, {1, 1, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 13}, {11, 15}}}); + } + + template + void TestWrap() { // Slice at dimension boundaries, but with sizes that cause indices to wrap. + RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {6}, + {10, 1, 2, 3, 4, 5, 8, 9}); + // R2 Shape: [3, 3] + RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {2, 2}, + {{1, 2, 3}, {4, 5, 6}, {11, 8, 10}}); + // R3 Shape: [2, 3, 2] RunR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {{{13}, {15}}}, - {1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 13}}}); + {1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 15}, {9, 10}, {11, 13}}}); } template @@ -425,6 +429,12 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3(); } XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int32Wrap) { TestWrap(); } + +XLA_TEST_F(DynamicUpdateSliceTest, Int64Wrap) { TestWrap(); } + +XLA_TEST_F(DynamicUpdateSliceTest, UInt64Wrap) { TestWrap(); } + XLA_TEST_F(DynamicUpdateSliceTest, Int32R1Pred) { // Slice at dimension start. RunR1({false, false, true, true, false, true, true, false}, @@ -497,19 +507,13 @@ XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElements) { RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/2); } -// TODO(b/34128753) CPU and GPU failed on 2016-01-06. Appears not to handle -// wrapping as expected. -XLA_TEST_F(DynamicUpdateSliceTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(R3ContiguousMultipleWrapping))) { +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleWrapping) { // Multiple element, wrapping. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/3, /*size=*/2); } -// TODO(b/34128753) CPU and GPU failed on 2016-01-06. Appears not to handle -// wrapping as expected. -XLA_TEST_F(DynamicUpdateSliceTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(R3ContiguousTooLarge))) { +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousTooLarge) { // Multiple element, update size larger than operand. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/5, /*size=*/2); diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 71a1b0abee51ba2819daed23208b0da8d5107207..49f673f5f0bf9b844ab4030383784208b4e2c58a 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -357,6 +357,111 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } +// TODO(b/63003356): 11-06-2017: fails on all back-ends with incorrect result. +TEST_F(WhileTest, DISABLED_WhileWithPermutationAndTupleResult) { + std::vector shape_elements = { + ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}), + ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})}; + Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); + + // Create a computation for the condition. + // Repeat for N iterations. + const int N = 2; + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Gt(builder.ConstantR0(N), iteration); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body. + // Add 1 to the iteration variable and permute the weights. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + auto w1 = builder.GetTupleElement(prev, 1); + auto w2 = builder.GetTupleElement(prev, 2); + auto w3 = builder.GetTupleElement(prev, 3); + auto result = builder.Tuple( + {builder.Add(iteration, builder.ConstantR0(1)), w3, w1, w2}); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, "while"); + auto init = builder.Tuple( + {builder.ConstantR0(0), builder.ConstantR1(3, 1.f), + builder.ConstantR1(3, 2.f), builder.ConstantR1(3, 3.f)}); + auto result = builder.While(condition, body, init); + VLOG(2) << "result = " + << ShapeUtil::HumanString( + *builder.GetShape(result).ConsumeValueOrDie()); + + auto expected_counter = Literal::CreateR0(N); + auto expected_w1 = Literal::CreateR1({1.0f, 1.0f, 1.0f}); + auto expected_w2 = Literal::CreateR1({2.0f, 2.0f, 2.0f}); + auto expected_w3 = Literal::CreateR1({3.0f, 3.0f, 3.0f}); + auto expected = Literal::MakeTuple({expected_counter.get(), expected_w2.get(), + expected_w3.get(), expected_w1.get()}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); + ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); +} + +// TODO(b/63003356): 11-06-2017: fails on all back-ends with incorrect result. +TEST_F(WhileTest, DISABLED_WhileWithPermutationAndVectorResult) { + std::vector shape_elements = { + ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}), + ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})}; + Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); + + // Create a computation for the condition. + // Repeat for N iterations. + const int N = 2; + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Gt(builder.ConstantR0(N), iteration); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body. + // Add 1 to the iteration variable permute the weights. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + auto w1 = builder.GetTupleElement(prev, 1); + auto w2 = builder.GetTupleElement(prev, 2); + auto w3 = builder.GetTupleElement(prev, 3); + auto result = builder.Tuple( + {builder.Add(iteration, builder.ConstantR0(1)), w3, w1, w2}); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, "while"); + auto init = builder.Tuple( + {builder.ConstantR0(0), builder.ConstantR1(3, 1.f), + builder.ConstantR1(3, 2.f), builder.ConstantR1(3, 3.f)}); + auto xla_while = builder.While(condition, body, init); + + auto add12 = builder.Add(builder.GetTupleElement(xla_while, 1), + builder.GetTupleElement(xla_while, 2)); + auto result = builder.Add(add12, builder.GetTupleElement(xla_while, 3)); + VLOG(2) << "result = " + << ShapeUtil::HumanString( + *builder.GetShape(result).ConsumeValueOrDie()); + std::vector expected = {6.f, 6.f, 6.f}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + // Tests a while node when the result type T is a Tuple. // // tuple> result(0, vector(10, 0.0f)); @@ -899,6 +1004,51 @@ TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) { ErrorSpec(1e-6)); } +// Tests loop where the init value comes from two sources (constant and +// parameter). +// +// int32 result = (0, 1); +// while (result[0] + result[1] < 30) { +// result[0] = result[0] + 1; +// result[1] = result[1] + 1; +// } +TEST_F(WhileTest, WhileWithMixedTupleElements) { + auto result_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}); + + ComputationBuilder outer(client_, "outer"); + auto p = + outer.Tuple({outer.ConstantR0(0), + outer.Parameter(0, ShapeUtil::MakeShape(S32, {}), "t")}); + + ComputationBuilder cond(client_, "cond"); + auto params = cond.Parameter(0, result_shape, "prev"); + auto cond_t = cond.Add(cond.GetTupleElement(params, 1), + cond.GetTupleElement(params, 0)); + cond.Lt(cond_t, cond.ConstantR0(30)); + + ComputationBuilder body(client_, "body"); + auto body_t = body.Parameter(0, result_shape, "t"); + + auto tuple = body.Tuple( + {body.Add(body.GetTupleElement(params, 0), body.ConstantR0(1)), + body.Add(body.GetTupleElement(params, 1), body.ConstantR0(1))}); + + TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build()); + TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build()); + outer.While(cond_computation, body_computation, p); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr parameter_data, + client_->TransferToServer(*Literal::CreateR0(1))); + + auto add1 = Literal::CreateR0(15); + auto add2 = Literal::CreateR0(16); + auto expected = Literal::MakeTuple({add1.get(), add2.get()}); + ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()}, + ErrorSpec(1e-6)); +} + // Tests nested while loops. // // int32 result = 0; diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 5de73ee866388e6b5b1a330b282792d1da4100c2..6c2e37e3b5cdd73157279fb171d3332aa9854184 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -58,6 +58,7 @@ class HloParser { string* root_name); bool ParseInstruction(HloComputation::Builder* builder, string* root_name); bool ParseSharding(HloInstruction* instruction); + bool ParseControlPredecessors(HloInstruction* instruction); bool ParseLiteral(std::unique_ptr* literal, const Shape& shape); bool ParseTupleLiteral(std::unique_ptr* literal, const Shape& shape); bool ParseNonTupleLiteral(std::unique_ptr* literal, @@ -436,10 +437,35 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); } - // Parse "sharding=". - if (lexer_.GetKind() == TokKind::kComma) { - if (!ParseSharding(instruction)) { - return false; + + bool has_sharding = false; + bool has_control = false; + while (EatIfPresent(TokKind::kComma)) { + string attribute_name; + if (!ParseAttributeName(&attribute_name)) { + return TokenError("expects ', sharding=' or ', control-predecessors='"); + } + + if (attribute_name == "sharding") { + // Parse "sharding=". + if (has_sharding) { + return TokenError("expects at most 1 'sharding='"); + } + has_sharding = true; + if (!ParseSharding(instruction)) { + return false; + } + } else if (attribute_name == "control-predecessors") { + // Parse "control-predecessors" + if (has_control) { + return TokenError("expects at most 1 'control-predecessors='"); + } + has_control = true; + if (!ParseControlPredecessors(instruction)) { + return false; + } + } else { + return TokenError(StrCat("unexpected attribute: ", attribute_name)); } } @@ -449,15 +475,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, // ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? ('devices=' ('[' // dims ']')* device_list)? '}' dims ::= int_list device_list ::= int_list bool HloParser::ParseSharding(HloInstruction* instruction) { - if (!ParseToken(TokKind::kComma, - "expects ',' in front of an extra attribute")) { - return false; - } - string attribute_name; - if (!ParseAttributeName(&attribute_name) || attribute_name != "sharding") { - return TokenError("expects attribute name: sharding"); - } - if (!ParseToken(TokKind::kLbrace, "expected '{' to start sharding attribute")) { return false; @@ -577,6 +594,34 @@ bool HloParser::ParseSharding(HloInstruction* instruction) { return true; } +// '{' name+ '}' +bool HloParser::ParseControlPredecessors(HloInstruction* instruction) { + if (!ParseToken(TokKind::kLbrace, + "expects '{' at the beginning of control predecessors")) { + return false; + } + do { + string name; + if (!ParseName(&name)) { + return TokenError("expects a control predecessor"); + } + HloInstruction* pre = + tensorflow::gtl::FindPtrOrNull(instruction_pool_, name); + if (!pre) { + return TokenError( + StrCat("control predecessor ", name, " is not defined: ")); + } + Status status = pre->AddControlDependencyTo(instruction); + if (!status.ok()) { + return TokenError(StrCat("error adding control dependency for: ", name, + " status: ", status.ToString())); + } + } while (EatIfPresent(TokKind::kComma)); + + return ParseToken(TokKind::kRbrace, + "expects '}' at the end of control predecessors"); +} + bool HloParser::SetValueInLiteral(int64 value, int64 linear_index, Literal* literal) { const Shape& shape = literal->shape(); diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index e065af7da61fb6096c820825cbd81d4a788c93fe..359256f0646367f8af13439b30067624defcd44c 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -214,7 +214,7 @@ R"(HloModule TwoSendRecvBothWayRecvFist_module: ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { %recv = f32[] recv(), channel_id=15, sharding={maximal device=1} ROOT %constant = f32[] constant(2.1), sharding={maximal device=0} - %send = () send(f32[] %constant), channel_id=16, sharding={maximal device=0} + %send = () send(f32[] %constant), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} } )" diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index ce3c3eee68ad7f7ebb42836e3cae14803f8650d7..710bb6ff25bf649693165c5e9fb6bc50e81db4ca 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -361,6 +361,7 @@ message WaitForExecutionResponse { message IsConstantRequest { ComputationHandle computation = 1; ComputationDataHandle operand = 2; + int64 num_parameters = 3; } message IsConstantResponse { @@ -371,6 +372,7 @@ message ComputeConstantRequest { ComputationHandle computation = 1; ComputationDataHandle operand = 2; Layout output_layout = 3; + repeated LiteralProto parameters = 4; } message ComputeConstantResponse { diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD index 8b7df4a84c558f662405a28a42426583d5ab39cd..a111cfecb366fe245150cc71d2c43662d0d69090 100644 --- a/tensorflow/contrib/batching/BUILD +++ b/tensorflow/contrib/batching/BUILD @@ -82,6 +82,7 @@ cc_library( tf_cc_test( name = "adaptive_shared_batch_scheduler_test", srcs = ["adaptive_shared_batch_scheduler_test.cc"], + tags = ["manual"], # b/69013768 deps = [ ":adaptive_shared_batch_scheduler", "//tensorflow/contrib/batching/test_util:fake_clock_env", diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h index a0606427a526ffc67e10d12a084eabc64564e4ab..6ed177e001758ad8c566c7965e1ec10ae5235fc8 100644 --- a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h +++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h @@ -399,7 +399,7 @@ ASBSQueue::~ASBSQueue() { template Status ASBSQueue::Schedule(std::unique_ptr* task) { - bool added_new_batch = false; + ASBSBatch* new_batch = nullptr; size_t size = (*task)->size(); if (size > options_.max_batch_size) { return errors::InvalidArgument("Task size ", size, @@ -418,15 +418,14 @@ Status ASBSQueue::Schedule(std::unique_ptr* task) { current_batch_ = nullptr; } if (!current_batch_) { - added_new_batch = true; num_enqueued_batches_++; - current_batch_ = + current_batch_ = new_batch = new ASBSBatch(this, scheduler_->GetEnv()->NowMicros()); } current_batch_->AddTask(std::move(*task)); num_enqueued_tasks_++; } - if (added_new_batch) scheduler_->AddBatch(current_batch_); + if (new_batch != nullptr) scheduler_->AddBatch(new_batch); return Status::OK(); } diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 4d9fd753233f520b961559e2dc3adbafc5bb1d2f..cebe3474ca9251971c23bde9e82564189c1ee624 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -208,7 +208,7 @@ def extract_features(features, feature_columns): if tensor.dtype == dtypes.float32: if len(tensor.shape) > 1 and tensor.shape[1] > 1: unstacked = array_ops.unstack(tensor, axis=1) - for i in xrange(len(unstacked)): + for i in range(len(unstacked)): dense_float_names.append(_FEATURE_NAME_TEMPLATE % (key, i)) dense_floats.append(array_ops.reshape(unstacked[i], [-1, 1])) else: diff --git a/tensorflow/contrib/cmake/external/farmhash.cmake b/tensorflow/contrib/cmake/external/farmhash.cmake index 96fade8b53273afdc379c7c13017e4917ee534f3..0cd0c1030c73d5218411f281d2b077af217e8275 100644 --- a/tensorflow/contrib/cmake/external/farmhash.cmake +++ b/tensorflow/contrib/cmake/external/farmhash.cmake @@ -15,8 +15,8 @@ include (ExternalProject) set(farmhash_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/farmhash_archive ${CMAKE_CURRENT_BINARY_DIR}/external/farmhash_archive/util) -set(farmhash_URL https://github.com/google/farmhash/archive/34c13ddfab0e35422f4c3979f360635a8c050260.zip) -set(farmhash_HASH SHA256=e3d37a59101f38fd58fb799ed404d630f0eee18bfc2a2433910977cc8fea9c28) +set(farmhash_URL https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz) +set(farmhash_HASH SHA256=6560547c63e4af82b0f202cb710ceabb3f21347a4b996db565a411da5b17aba0) set(farmhash_BUILD ${CMAKE_CURRENT_BINARY_DIR}/farmhash/src/farmhash) set(farmhash_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/farmhash/install) set(farmhash_INCLUDES ${farmhash_BUILD}) diff --git a/tensorflow/contrib/cmake/external/fft2d.cmake b/tensorflow/contrib/cmake/external/fft2d.cmake index a35c24e9e01101f837ba961c06429c981ddc4648..d3af2a46761c0f7f0b5db134af8400fc93f2f095 100644 --- a/tensorflow/contrib/cmake/external/fft2d.cmake +++ b/tensorflow/contrib/cmake/external/fft2d.cmake @@ -15,7 +15,7 @@ include (ExternalProject) -set(fft2d_URL http://www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz) +set(fft2d_URL https://mirror.bazel.build/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz) set(fft2d_HASH SHA256=52bb637c70b971958ec79c9c8752b1df5ff0218a4db4510e60826e0cb79b5296) set(fft2d_BUILD ${CMAKE_CURRENT_BINARY_DIR}/fft2d/) set(fft2d_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/fft2d/src) diff --git a/tensorflow/contrib/cmake/external/gemmlowp.cmake b/tensorflow/contrib/cmake/external/gemmlowp.cmake index 54a9e96ce58c5501217368b0d12089aa14696b71..3b146657bfc9bdd54db14839195af45972e67aff 100644 --- a/tensorflow/contrib/cmake/external/gemmlowp.cmake +++ b/tensorflow/contrib/cmake/external/gemmlowp.cmake @@ -14,8 +14,8 @@ # ============================================================================== include (ExternalProject) -set(gemmlowp_URL http://github.com/google/gemmlowp/archive/010bb3e71a26ca1d0884a167081d092b43563996.tar.gz) -set(gemmlowp_HASH SHA256=861cc6d9d902861f54fd77e1ab79286477dcc559b2a283e75b9c22d37b61f6ae) +set(gemmlowp_URL https://mirror.bazel.build/github.com/google/gemmlowp/archive/010bb3e71a26ca1d0884a167081d092b43563996.zip) +set(gemmlowp_HASH SHA256=dd2557072bde12141419cb8320a9c25e6ec41a8ae53c2ac78c076a347bb46d9d) set(gemmlowp_BUILD ${CMAKE_CURRENT_BINARY_DIR}/gemmlowp/src/gemmlowp) set(gemmlowp_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/gemmlowp/src/gemmlowp) diff --git a/tensorflow/contrib/cmake/external/jemalloc.cmake b/tensorflow/contrib/cmake/external/jemalloc.cmake index e4737a1dd825409133cdfd8a54f20dac819c0d5b..198ba13e64e4b6df57c4325a0104b1a6745d173a 100644 --- a/tensorflow/contrib/cmake/external/jemalloc.cmake +++ b/tensorflow/contrib/cmake/external/jemalloc.cmake @@ -15,7 +15,7 @@ include (ExternalProject) set(jemalloc_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/jemalloc/src/jemalloc/include) -set(jemalloc_URL https://github.com/jemalloc/jemalloc-cmake/archive/jemalloc-cmake.4.3.1.tar.gz) +set(jemalloc_URL https://mirror.bazel.build/github.com/jemalloc/jemalloc-cmake/archive/jemalloc-cmake.4.3.1.tar.gz) set(jemalloc_HASH SHA256=f9be9a05fe906deb5c1c8ca818071a7d2e27d66fd87f5ba9a7bf3750bcedeaf0) set(jemalloc_BUILD ${CMAKE_CURRENT_BINARY_DIR}/jemalloc/src/jemalloc) diff --git a/tensorflow/contrib/cmake/external/sqlite.cmake b/tensorflow/contrib/cmake/external/sqlite.cmake index 6d06193824b32557c1d2195c940ff9c698be1bdf..785039a46983747557607562675349c150e064ad 100644 --- a/tensorflow/contrib/cmake/external/sqlite.cmake +++ b/tensorflow/contrib/cmake/external/sqlite.cmake @@ -15,7 +15,7 @@ include (ExternalProject) set(sqlite_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/sqlite) -set(sqlite_URL http://www.sqlite.org/2017/sqlite-amalgamation-3200000.zip) +set(sqlite_URL https://mirror.bazel.build/www.sqlite.org/2017/sqlite-amalgamation-3200000.zip) set(sqlite_HASH SHA256=208780b3616f9de0aeb50822b7a8f5482f6515193859e91ed61637be6ad74fd4) set(sqlite_BUILD ${CMAKE_CURRENT_BINARY_DIR}/sqlite/src/sqlite) set(sqlite_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/sqlite/install) diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index f978c8ccd5a454ca4a89de0ab5d757b566295c60..5b62598aa58fb2ce37694055fc576a6fc308dc3e 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -70,6 +70,7 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc" + "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/ops/prefetching_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/clustering_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc" diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 4a61ed7a3548b1992ddc71acb8a7761e252296ea..03c168795cc2455327f0b7bbf40fd1fd1eebb34e 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -81,6 +81,7 @@ GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_prediction "${tensorflow_source_dir}/t GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_quantiles "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_stats_accumulator "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(cudnn_rnn "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(data_dataset "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(data_prefetching "${tensorflow_source_dir}/tensorflow/contrib/data/ops/prefetching_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc") diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 68234911a3fda9df6c65f32b088d0968a6f37c00..61900450475e9d5734789c5be97f8a7fc636bebc 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -224,6 +224,7 @@ add_python_module("tensorflow/python/grappler") add_python_module("tensorflow/python/keras") add_python_module("tensorflow/python/keras/activations") add_python_module("tensorflow/python/keras/applications") +add_python_module("tensorflow/python/keras/applications/inception_resnet_v2") add_python_module("tensorflow/python/keras/applications/inception_v3") add_python_module("tensorflow/python/keras/applications/mobilenet") add_python_module("tensorflow/python/keras/applications/resnet50") @@ -776,6 +777,8 @@ GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_stats_accumulator_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_stats_accumulator_ops.py) GENERATE_PYTHON_OP_LIB("contrib_cudnn_rnn_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cudnn_rnn/ops/gen_cudnn_rnn_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_data_dataset_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_dataset_ops.py) GENERATE_PYTHON_OP_LIB("contrib_data_prefetching_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_prefetching_ops.py) GENERATE_PYTHON_OP_LIB("contrib_factorization_clustering_ops" diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index 7166e38b28365a6dbce9cf134f81b08a57c722de..c8adb0369b98947d2d29374ee8ada1185815d3cd 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -360,8 +360,8 @@ class CrfDecodeForwardRnnCell(rnn_cell.RNNCell): scope: Unused variable scope of this cell. Returns: - backpointers: [batch_size, num_tags], containing backpointers. - new_state: [batch_size, num_tags], containing new score values. + backpointers: A [batch_size, num_tags] matrix of backpointers. + new_state: A [batch_size, num_tags] matrix of new score values. """ # For simplicity, in shape comments, denote: # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). @@ -385,7 +385,7 @@ class CrfDecodeBackwardRnnCell(rnn_cell.RNNCell): """Initialize the CrfDecodeBackwardRnnCell. Args: - num_tags + num_tags: An integer. """ self._num_tags = num_tags @@ -401,8 +401,9 @@ class CrfDecodeBackwardRnnCell(rnn_cell.RNNCell): """Build the CrfDecodeBackwardRnnCell. Args: - inputs: [batch_size, num_tags], backpointer of next step (in time order). - state: [batch_size, 1], next position's tag index. + inputs: A [batch_size, num_tags] matrix of + backpointer of next step (in time order). + state: A [batch_size, 1] matrix of tag index of next step. scope: Unused variable scope of this cell. Returns: @@ -426,16 +427,16 @@ def crf_decode(potentials, transition_params, sequence_length): This is a function for tensor. Args: - potentials: A [batch_size, max_seq_len, num_tags] tensor, matrix of + potentials: A [batch_size, max_seq_len, num_tags] tensor of unary potentials. - transition_params: A [num_tags, num_tags] tensor, matrix of + transition_params: A [num_tags, num_tags] matrix of binary potentials. - sequence_length: A [batch_size] tensor, containing sequence lengths. + sequence_length: A [batch_size] vector of true sequence lengths. Returns: - decode_tags: A [batch_size, max_seq_len] tensor, with dtype tf.int32. + decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`. Contains the highest scoring tag indicies. - best_score: A [batch_size] tensor, containing the score of decode_tags. + best_score: A [batch_size] vector, containing the score of `decode_tags`. """ # For simplicity, in shape comments, denote: # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD index eaede0e00ecf1986873d50709d135d3f4b3ac9cd..7bcf5a5f4dcd6293644725a2ccf78a763da3d9eb 100644 --- a/tensorflow/contrib/data/BUILD +++ b/tensorflow/contrib/data/BUILD @@ -35,8 +35,19 @@ tf_custom_op_library( ], ) +# TODO(mrry): Move the kernels out of the core library into this library. +tf_custom_op_library( + name = "_dataset_ops.so", + srcs = [ + "ops/dataset_ops.cc", + ], +) + tf_gen_op_libs( - op_lib_names = ["prefetching_ops"], + op_lib_names = [ + "dataset_ops", + "prefetching_ops", + ], ) filegroup( diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 6c46acf20442c2cc435829afa57e8383b493d6af..0c7e793689204ba18dcab03c87902103e5802e45 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -30,6 +30,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@make_saveable_from_iterator @@read_batch_features @@unbatch +@@parallel_interleave @@rejection_resample @@sloppy_interleave @@ -40,8 +41,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=unused-import +# pylint: disable=unused-import from tensorflow.contrib.data.python.ops.batching import batch_and_drop_remainder from tensorflow.contrib.data.python.ops.batching import dense_to_sparse_batch from tensorflow.contrib.data.python.ops.batching import unbatch @@ -50,6 +51,7 @@ from tensorflow.contrib.data.python.ops.dataset_ops import get_single_element from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset from tensorflow.contrib.data.python.ops.error_ops import ignore_errors from tensorflow.contrib.data.python.ops.grouping import group_by_window +from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator from tensorflow.contrib.data.python.ops.readers import FixedLengthRecordDataset diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..1574384cb2bf5578bc5ccd13d2792e30b6359996 --- /dev/null +++ b/tensorflow/contrib/data/ops/dataset_ops.cc @@ -0,0 +1,232 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +// -------------------------------------------------------------------------- + +// The ops in this section can be composed to define an input +// pipeline. Each op produces a DT_VARIANT tensor that represents +// a DAG of "dataset" objects. An "dataset" object can be converted +// to a stateful "iterator" by passing the "dataset" to the +// "MakeIterator" op. +// +// TODO(b/65524810): DT_VARIANT tensors that represent "dataset" objects are +// not presently serializable. To avoid issues with constant folding, ensure +// that any "source dataset" ops (i.e. ops that output a dataset and do not +// take one as input) are marked "stateful". + +REGISTER_OP("IgnoreErrorsDataset") + .Input("input_dataset: variant") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that contains the elements of `input_dataset` ignoring errors. +)doc"); + +REGISTER_OP("MapAndBatchDataset") + .Input("input_dataset: variant") + .Input("other_arguments: Targuments") + .Input("batch_size: int64") + .Input("num_parallel_batches: int64") + .Output("handle: variant") + .Attr("f: func") + .Attr("Targuments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that applies `f` to the outputs of `input_dataset` and then +batches `batch_size` of them. + +Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up +to `batch_size * num_parallel_batches` copies of `f` in parallel. + +batch_size: A scalar representing the number of elements to accumulate in a + batch. It determines the number of concurrent invocations of `f` that process + elements from `input_dataset` in parallel. +num_parallel_batches: A scalar representing the number of batches to create in + parallel. Processing multiple batches in parallel benefits workloads prone to + stragglers. +)doc"); + +REGISTER_OP("ScanDataset") + .Input("input_dataset: variant") + .Input("initial_state: Tstate") + .Input("other_arguments: Targuments") + .Output("handle: variant") + .Attr("f: func") + .Attr("Tstate: list(type) >= 1") + .Attr("Targuments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset successively reduces `f` over the elements of `input_dataset`. +)doc"); + +REGISTER_OP("ParallelInterleaveDataset") + .Input("input_dataset: variant") + .Input("other_arguments: Targuments") + .Input("cycle_length: int64") + .Input("block_length: int64") + .Input("sloppy: bool") + .Output("handle: variant") + .Attr("f: func") + .Attr("Targuments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that applies `f` to the outputs of `input_dataset`. + +The resulting dataset is similar to the `InterleaveDataset`, with the exception +that if retrieving the next value from a dataset would cause the requester to +block, it will skip that input dataset. This dataset is especially useful +when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it +allows the training step to proceed so long as some data is available. + +!! WARNING !! This dataset is not deterministic! + +f: A function mapping elements of `input_dataset`, concatenated with + `other_arguments`, to a Dataset variant that contains elements matching + `output_types` and `output_shapes`. +)doc"); + +REGISTER_OP("GroupByWindowDataset") + .Input("input_dataset: variant") + .Input("key_func_other_arguments: Tkey_func_other_arguments") + .Input("reduce_func_other_arguments: Treduce_func_other_arguments") + .Input( + "window_size_func_other_arguments: Twindow_size_func_other_arguments") + .Output("handle: variant") + .Attr("key_func: func") + .Attr("reduce_func: func") + .Attr("window_size_func: func") + .Attr("Tkey_func_other_arguments: list(type) >= 0") + .Attr("Treduce_func_other_arguments: list(type) >= 0") + .Attr("Twindow_size_func_other_arguments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that computes a windowed group-by on `input_dataset`. + +// TODO(mrry): Support non-int64 keys. + +key_func: A function mapping an element of `input_dataset`, concatenated + with `key_func_other_arguments` to a scalar value of type DT_INT64. +)doc"); + +REGISTER_OP("DenseToSparseBatchDataset") + .Input("input_dataset: variant") + .Input("batch_size: int64") + .Input("row_shape: int64") + .Output("handle: variant") + // NOTE(mrry): the 0th and 2nd elements will be DT_INT64. + .Attr("output_types: list(type) >= 1") + // NOTE(mrry): the 1st and 2nd elements will be vectors. + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that yields a SparseTensor for each element of the input. + +input_dataset: A handle to an input dataset. Must have a single component. +batch_size: A scalar representing the number of elements to accumulate in a + batch. +row_shape: A vector representing the dense shape of each row in the produced + SparseTensor. The shape may be partially specified, using `-1` to indicate + that a particular dimension should use the maximum size of all batch elements. +)doc"); + +REGISTER_OP("SqlDataset") + .Input("driver_name: string") + .Input("data_source_name: string") + .Input("query: string") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that executes a SQL query and emits rows of the result set. + +driver_name: The database type. Currently, the only supported type is 'sqlite'. +data_source_name: A connection string to connect to the database. +query: A SQL query to execute. +)doc"); + +REGISTER_OP("DatasetToSingleElement") + .Input("dataset: variant") + .Output("components: output_types") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + std::vector output_shapes; + TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); + if (output_shapes.size() != c->num_outputs()) { + return errors::InvalidArgument( + "`output_shapes` must be the same length as `output_types` (", + output_shapes.size(), " vs. ", c->num_outputs()); + } + for (size_t i = 0; i < output_shapes.size(); ++i) { + shape_inference::ShapeHandle output_shape_handle; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( + output_shapes[i], &output_shape_handle)); + c->set_output(static_cast(i), output_shape_handle); + } + return Status::OK(); + }) + .Doc(R"doc( +Outputs the single element from the given dataset. + +dataset: A handle to a dataset that contains a single element. +components: The components of the single element of `input`. +)doc"); + +REGISTER_OP("SerializeIterator") + .Input("resource_handle: resource") + .Output("serialized: variant") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Converts the given `resource_handle` representing an iterator to a variant tensor. + +resource_handle: A handle to an iterator resource. +serialized: A variant tensor storing the state of the iterator contained in the + resource. +)doc"); + +REGISTER_OP("DeserializeIterator") + .Input("resource_handle: resource") + .Input("serialized: variant") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Converts the given variant tensor to an iterator and stores it in the given resource. + +resource_handle: A handle to an iterator resource. +serialized: A variant tensor storing the state of the iterator contained in the + resource. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py index 8713640985b1e23da378603af265eec894023e34..df9147af6c03925ac9f372c561000eaa6e7f328e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py @@ -206,9 +206,16 @@ class DatasetSerializationTestBase(test.TestCase): # Generate `break_point` items from ds_fn1 and save checkpoint. self.gen_outputs(ds_fn1, [], break_point) + actual = [] # Build graph for ds_fn2 but load checkpoint for ds_fn1. - actual = self.gen_outputs( - ds_fn2, [], break_point, ckpt_saved=True, verify_exhausted=True) + with ops.Graph().as_default() as g: + _, get_next_op, saver = self._build_graph(ds_fn2) + with self.test_session(graph=g) as sess: + self._restore(saver, sess) + for _ in range(num_outputs - break_point): + actual.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) self.match(expected, actual) diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py index bda9a2a4a37e9c3d35ff99041d1150ffc43f4c43..271d80a54b5a3e1a09cdf37e4f5e659fb67a78f9 100644 --- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py @@ -21,6 +21,7 @@ import os import numpy as np from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.contrib.data.python.ops import readers from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session @@ -33,7 +34,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops -from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import io_ops from tensorflow.python.ops import math_ops diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py index f59ac760dc83a504e563f055b91f1002cb0c80fc..329dc80ba5a29ade74ae8dfd12d37e5c1e2a9f73 100644 --- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py @@ -21,6 +21,7 @@ import os from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import enumerate_ops +from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op @@ -29,7 +30,6 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import io_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import variables diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index 3ae8f71d77fa6ecf08e42bedac702b8f75eec309..8033f1d38806767ce08043d10c42dd376087765c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -21,6 +21,7 @@ import gzip import os import zlib +from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops from tensorflow.contrib.data.python.ops import readers from tensorflow.core.example import example_pb2 @@ -33,7 +34,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import python_io from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import io_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 1b81cf5be9190ffab646192fb9a72fd3da7deee1..727c5d1c38ba30c32968a3cf33f7c03163f060d4 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -11,20 +11,6 @@ load( ) load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") -py_library( - name = "dataset_ops", - srcs = [ - "dataset_ops.py", - ], - srcs_version = "PY2AND3", - deps = [ - ":transformation_ops", - "//tensorflow/python:util", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - ], -) - py_library( name = "iterator_ops", srcs = [ @@ -73,6 +59,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":gen_dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:dataset_ops_gen", @@ -128,6 +115,31 @@ tf_custom_op_py_library( ], ) +tf_gen_op_wrapper_py( + name = "gen_dataset_ops", + out = "gen_dataset_ops.py", + deps = ["//tensorflow/contrib/data:dataset_ops_op_lib"], +) + +tf_custom_op_py_library( + name = "dataset_ops", + srcs = ["dataset_ops.py"], + dso = ["//tensorflow/contrib/data:_dataset_ops.so"], + kernels = [ + "//tensorflow/contrib/data:dataset_ops_op_lib", + ], + srcs_version = "PY2AND3", + deps = [ + ":gen_dataset_ops", + ":transformation_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:platform", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index abc9212a87550745490b974d25a929a66287f785..e6e5f716b62b8d715eecf0c5a79d1c22d34c06b2 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes @@ -24,7 +25,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import math_ops diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py index 45d6dbe7438957029b4d6b71e181cb1fc3596ecb..c4c4426809aa7b5a1c80a0d6f797b9e140be4dea 100644 --- a/tensorflow/contrib/data/python/ops/dataset_ops.py +++ b/tensorflow/contrib/data/python/ops/dataset_ops.py @@ -20,15 +20,21 @@ from __future__ import print_function from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import enumerate_ops from tensorflow.contrib.data.python.ops import error_ops +from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.contrib.data.python.ops import grouping +from tensorflow.contrib.util import loader from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest -from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gen_io_ops +from tensorflow.python.platform import resource_loader from tensorflow.python.util import deprecation +_dataset_ops = loader.load_op_library( + resource_loader.get_path_to_datafile("../../_dataset_ops.so")) + + class Dataset(dataset_ops.Dataset): """Represents a potentially large set of elements. diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py index 238bb52b0205f9ab66f479f1b92e72ab6e38725b..51a279107235f95eba2030291aab9d294f6d2b2d 100644 --- a/tensorflow/contrib/data/python/ops/error_ops.py +++ b/tensorflow/contrib/data/python/ops/error_ops.py @@ -17,9 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest -from tensorflow.python.ops import gen_dataset_ops def ignore_errors(): diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index 6df7b22fb69bb14c41a26bd630a825442f67ee23..1c7c94b3c84a8c48ba9237c323fc13777d25f43d 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -17,12 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_dataset_ops def group_by_window(key_func, diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index 74a919c1fff62cfa79b0877a3d081077ca6776f0..ce23e95697c9116635e6335dc7b1fdc6de514732 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -17,12 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.util import deprecation diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py index d736029fb035e573b70e8b19570e4e8ceca3c005..32d2f42c9352fa35e3671ed549ad85efce2546d7 100644 --- a/tensorflow/contrib/data/python/ops/iterator_ops.py +++ b/tensorflow/contrib/data/python/ops/iterator_ops.py @@ -17,8 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.training import saver diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 2e1c3153ca78e20e2628e8754b9827b817f8c732..f22298b757c73dac096603335b475119e5971df4 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.data.python.ops import dataset_ops as contrib_dataset_ops +from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers from tensorflow.python.data.util import nest @@ -25,7 +26,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import gfile from tensorflow.python.util import deprecation diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py index 5acaed48a3d73e93706bdd0b5b2d614b0c565ab7..87bbbb7d19b15955b507308ce2ea286f602efd37 100644 --- a/tensorflow/contrib/data/python/ops/scan_ops.py +++ b/tensorflow/contrib/data/python/ops/scan_ops.py @@ -19,11 +19,11 @@ from __future__ import print_function import collections +from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_dataset_ops class _ScanDataset(dataset_ops.Dataset): diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist.py b/tensorflow/contrib/eager/python/examples/mnist/mnist.py index 3dd920415df8b6d367d690132b647318ab8ba25b..bfb7d5a9002787f6544d383de58150661ac2bde3 100644 --- a/tensorflow/contrib/eager/python/examples/mnist/mnist.py +++ b/tensorflow/contrib/eager/python/examples/mnist/mnist.py @@ -191,9 +191,9 @@ def main(_): train_dir = None test_dir = None summary_writer = tf.contrib.summary.create_summary_file_writer( - train_dir, flush_secs=10) + train_dir, flush_millis=10000) test_summary_writer = tf.contrib.summary.create_summary_file_writer( - test_dir, flush_secs=10, name='test') + test_dir, flush_millis=10000, name='test') checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt') with tf.device(device): diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py index 318962c634e0d050b35da5efc405400380c1b759..609cbd28772c3ae8da70648ca5b1b264a8a255e2 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py @@ -248,9 +248,9 @@ def main(_): log_dir = os.path.join(FLAGS.dir, "summaries") tf.gfile.MakeDirs(log_dir) train_summary_writer = tf.contrib.summary.create_summary_file_writer( - os.path.join(log_dir, "train"), flush_secs=10) + os.path.join(log_dir, "train"), flush_millis=10000) test_summary_writer = tf.contrib.summary.create_summary_file_writer( - os.path.join(log_dir, "eval"), flush_secs=10, name="eval") + os.path.join(log_dir, "eval"), flush_millis=10000, name="eval") with tf.device(device): for epoch in range(FLAGS.num_epochs): diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py index 97feaec30ed066503ef8ce75cbd5af04ea2ef6bf..c6e628b074e8638fd15a35f2df87609e0ad46000 100644 --- a/tensorflow/contrib/eager/python/network.py +++ b/tensorflow/contrib/eager/python/network.py @@ -182,6 +182,40 @@ def _make_custom_getter_for_deferred_restorations(): return _custom_getter, deferred_restorations +def _make_prefix_stripping_map_fn(scope_name): + """Closure for stripping the scope name of a Network. + + Implemented as a closure rather than a member function to avoid reference + cycles in deferred restorations (this function should not have a reference to + the Network which created it). + + Args: + scope_name: The Network.scope_name to strip from variables. + Returns: + A scope_name-stripping default `map_fn` for the Network. + """ + + def _strip_variable_prefix(original_variable_name): + """The default map_func for saving or restoring variables. + + Strips the variable prefix for the Network on which save/restore was called, + and leaves other variable names fully qualified in the checkpoint. + + Args: + original_variable_name: The _shared_name of the variable (no :0 + suffix) to map. + Returns: + The checkpoint name of the variable. + """ + scope_name_with_slash = scope_name + "/" + if original_variable_name.startswith(scope_name_with_slash): + return original_variable_name[len(scope_name_with_slash):] + else: + return original_variable_name + + return _strip_variable_prefix + + class Network(base.Layer): """Represents the composition of a set of Layers. @@ -488,24 +522,6 @@ class Network(base.Layer): "at https://github.com/tensorflow/tensorflow/issues/new if this is " "important to you") - def _strip_variable_prefix(self, original_variable_name): - """The default map_func for saving or restoring variables. - - Strips the variable prefix for the Network on which save/restore was called, - and leaves other variable names fully qualified in the checkpoint. - - Args: - original_variable_name: The _shared_name of the variable (no :0 - suffix) to map. - Returns: - The checkpoint name of the variable. - """ - scope_name_with_slash = self.scope_name + "/" - if original_variable_name.startswith(scope_name_with_slash): - return original_variable_name[len(scope_name_with_slash):] - else: - return original_variable_name - def save(self, save_path, global_step=None, map_func=None): """Save variables from the Network to a checkpoint. @@ -543,7 +559,7 @@ class Network(base.Layer): save_path = os.path.join(save_path, self.name) user_map_func = map_func if map_func is None: - map_func = self._strip_variable_prefix + map_func = _make_prefix_stripping_map_fn(self.scope_name) variable_map = {} for variable in self.variables: mapped_name = map_func(variable._shared_name) @@ -737,7 +753,7 @@ class Network(base.Layer): save_path = os.path.join(save_path, self.name) user_map_func = map_func if map_func is None: - map_func = self._strip_variable_prefix + map_func = _make_prefix_stripping_map_fn(self.scope_name) # Step one is to restore any existing variables from the checkpoint. existing_variables_by_checkpoint_name = self._restore_existing_variables( save_path=save_path, diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py index c621f527c28306131bdba56d8427eaa787ba150b..14adbafe5735bd2a3d3961402e8ef3e6a7be333b 100644 --- a/tensorflow/contrib/eager/python/network_test.py +++ b/tensorflow/contrib/eager/python/network_test.py @@ -67,7 +67,7 @@ class NetworkTest(test.TestCase): original_output, self.evaluate(net(input_value))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testTrainableAttribute(self): net = network.Network() self.assertTrue(net.trainable) @@ -75,7 +75,7 @@ class NetworkTest(test.TestCase): net.trainable = False self.assertTrue(net.trainable) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testNetworkCall(self): net = MyNetwork(name="abcd") net(constant_op.constant([[2.0]])) # Force variables to be created. @@ -85,6 +85,7 @@ class NetworkTest(test.TestCase): result = net(constant_op.constant([[2.0]])) self.assertEqual(34.0, self.evaluate(result)) + # TODO(allenl): This test creates garbage in some Python versions @test_util.run_in_graph_and_eager_modes() def testNetworkSaveRestoreAlreadyBuilt(self): net = MyNetwork(name="abcd") @@ -96,6 +97,7 @@ class NetworkTest(test.TestCase): self._save_modify_load_network_built(net, global_step=None) self._save_modify_load_network_built(net, global_step=10) + # TODO(allenl): This test creates garbage in some Python versions @test_util.run_in_graph_and_eager_modes() def testSaveRestoreDefaultGlobalStep(self): net = MyNetwork(name="abcd") @@ -106,6 +108,7 @@ class NetworkTest(test.TestCase): save_path = net.save(self.get_temp_dir()) self.assertIn("abcd-4242", save_path) + # TODO(allenl): This test creates garbage in some Python versions @test_util.run_in_graph_and_eager_modes() def testNetworkSaveAndRestoreIntoUnbuilt(self): save_dir = self.get_temp_dir() @@ -377,25 +380,25 @@ class NetworkTest(test.TestCase): gc.set_debug(previous_gc_debug_flags) gc.enable() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testAnonymousNoNameInitially(self): net = MyNetwork() with self.assertRaisesRegexp(ValueError, "does not yet have a final name"): net.name # pylint: disable=pointless-statement - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testExplicitHasNameInitially(self): net = MyNetwork(name="abcd") self.assertEqual("abcd", net.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testUsingResourceVariables(self): net = MyNetwork() net(constant_op.constant([[0.]])) self.assertIsInstance(net.trainable_weights[0], resource_variable_ops.ResourceVariable) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testDuplicateNameError(self): one = constant_op.constant([[1.]]) net = MyNetwork(name="foo") @@ -405,7 +408,7 @@ class NetworkTest(test.TestCase): net1 = MyNetwork(name="foo") net1(one) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testWrappingInVariableScope(self): with variable_scope.variable_scope("outside_scope"): net = MyNetwork() @@ -440,7 +443,7 @@ class NetworkTest(test.TestCase): actual=net.trainable_weights[0].name) self.assertEqual("explicit_name", net.first.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testWrappingInAnonymousVariableScope(self): # Named outside variable_scopes are not supported at the moment. However, # blank-named top level variable scopes do not change variable names, and so @@ -455,20 +458,20 @@ class NetworkTest(test.TestCase): net(one) self.assertTrue(was_called[0]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testReasonableSlashError(self): with self.assertRaisesRegexp( ValueError, "not allowed in Network names"): MyNetwork(name="slash/slash") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testNoVariableScopeNames(self): with self.assertRaisesRegexp( ValueError, "VariableScopes are not valid Network names"): with variable_scope.variable_scope("some_scope") as vs: MyNetwork(name=vs) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testVariableScopeNameCollision(self): with variable_scope.variable_scope("abcd"): pass @@ -478,7 +481,7 @@ class NetworkTest(test.TestCase): one = constant_op.constant([[1.]]) net(one) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testNetworkVariablesDoNotInterfere(self): core.Dense(1, use_bias=True) # Should not interfere with naming. net1 = MyNetwork() @@ -1007,6 +1010,7 @@ class NetworkTest(test.TestCase): class SequentialTest(test.TestCase): + @test_util.assert_no_garbage_created def testTwoLayers(self): # Create a sequential network with one layer. net = network.Sequential([core.Dense(1, use_bias=False)]) @@ -1028,6 +1032,7 @@ class SequentialTest(test.TestCase): l2.trainable_variables[0].assign([[11.0]]) self.assertEqual(231.0, net(constant_op.constant([[7.0]])).numpy()) + @test_util.assert_no_garbage_created def testFunctions(self): # Create a sequential network with one function. net = network.Sequential([nn_ops.relu]) @@ -1038,6 +1043,7 @@ class SequentialTest(test.TestCase): net.add(math_ops.negative) self.assertEqual(-2.0, net(two).numpy()) + @test_util.assert_no_garbage_created def testTrainingLayer(self): net = network.Sequential([core.Dropout(0.99999)]) two = constant_op.constant(2.0) @@ -1051,6 +1057,7 @@ class SequentialTest(test.TestCase): # Should only fail spuriously 1 in 10^100 runs. self.fail("Didn't see dropout happen after 20 tries.") + @test_util.assert_no_garbage_created def testTrainingFunction(self): # Output depends on value of "training". def add_training(input_value, training=None): diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index a0f83ac10555913b5be177f0f2b00b2b0e30494a..6eb2cfdaca7840c4a5dd8cffc9620aaf3f96a1de 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -7,6 +7,7 @@ package( licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") filegroup( name = "all_files", @@ -30,6 +31,7 @@ py_library( ":head", ":logit_fns", ":multi_head", + ":replicate_model_fn", "//tensorflow/python:util", ], ) @@ -227,9 +229,69 @@ py_test( "//tensorflow/python:string_ops", "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", - "//tensorflow/python/estimator:prediction_keys", + "//tensorflow/python/ops/losses", "//tensorflow/python/saved_model:signature_constants", "//third_party/py/numpy", "@six_archive//:six", ], ) + +py_library( + name = "replicate_model_fn", + srcs = [ + "python/estimator/replicate_model_fn.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:device", + "//tensorflow/python:device_lib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/estimator:export_output", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:util", + "@six_archive//:six", + ], +) + +cuda_py_test( + name = "replicate_model_fn_test", + size = "small", + srcs = ["python/estimator/replicate_model_fn_test.py"], + additional_deps = [ + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:dnn", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:export_output", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/estimator:optimizers", + "//tensorflow/python/estimator:prediction_keys", + "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", + "//tensorflow/python/saved_model:signature_constants", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:platform", + "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ":replicate_model_fn", + ], + tags = ["requires-gpu-sm35"], +) diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index 7c992c99ed3fb05d5f2c306304b7084584c201e4..e344ee3c3eab22d217570a8c8073f72998e77b03 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -172,7 +172,8 @@ def multi_label_head(n_classes, weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing weights. It is used to down weight or boost examples during training. It - will be multiplied by the loss of the example. + will be multiplied by the loss of the example. Per-class weighting is + not supported. thresholds: Iterable of floats in the range `(0, 1)`. Accuracy, precision and recall metrics are evaluated for each threshold value. The threshold is applied to the predicted probabilities, i.e. above the threshold is diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index 972ce6163d5b0f580b08888bd69dff0d40fefa34..fd8c53f6a94bf741c02e814ca96bfcea050589c4 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -226,7 +226,7 @@ class MultiLabelHead(test.TestCase): def test_weight_should_not_impact_prediction(self): n_classes = 4 - head = head_lib.multi_label_head(n_classes, weight_column='label_weights') + head = head_lib.multi_label_head(n_classes, weight_column='example_weights') self.assertEqual(n_classes, head.logits_dimension) logits = np.array( @@ -237,7 +237,7 @@ class MultiLabelHead(test.TestCase): spec = head.create_estimator_spec( features={ 'x': np.array(((42,),), dtype=np.int32), - 'label_weights': weights_2x1, + 'example_weights': weights_2x1, }, mode=model_fn.ModeKeys.PREDICT, logits=logits) @@ -549,7 +549,7 @@ class MultiLabelHead(test.TestCase): def test_eval_with_weights(self): n_classes = 2 - head = head_lib.multi_label_head(n_classes, weight_column='label_weights') + head = head_lib.multi_label_head(n_classes, weight_column='example_weights') logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) labels = np.array([[1, 0], [1, 1]], dtype=np.int64) @@ -563,7 +563,7 @@ class MultiLabelHead(test.TestCase): spec = head.create_estimator_spec( features={ 'x': np.array([[41], [42]], dtype=np.int32), - 'label_weights': np.array([[1.], [2.]], dtype=np.float32), + 'example_weights': np.array([[1.], [2.]], dtype=np.float32), }, mode=model_fn.ModeKeys.EVAL, logits=logits, @@ -605,7 +605,7 @@ class MultiLabelHead(test.TestCase): def test_train_create_loss_large_logits(self): """Tests head.create_loss for train mode and large logits.""" n_classes = 2 - head = head_lib.multi_label_head(n_classes, weight_column='label_weights') + head = head_lib.multi_label_head(n_classes, weight_column='example_weights') logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) labels = np.array([[1, 0], [1, 1]], dtype=np.int64) @@ -623,7 +623,7 @@ class MultiLabelHead(test.TestCase): actual_weighted_sum_loss, actual_example_weight_sum, _ = head.create_loss( features={ 'x': np.array(((42,),), dtype=np.int32), - 'label_weights': weights + 'example_weights': weights }, mode=model_fn.ModeKeys.TRAIN, logits=logits, @@ -742,7 +742,7 @@ class MultiLabelHead(test.TestCase): def test_train_with_weights(self): n_classes = 2 - head = head_lib.multi_label_head(n_classes, weight_column='label_weights') + head = head_lib.multi_label_head(n_classes, weight_column='example_weights') logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) labels = np.array([[1, 0], [1, 1]], dtype=np.int64) @@ -761,7 +761,7 @@ class MultiLabelHead(test.TestCase): spec = head.create_estimator_spec( features={ 'x': np.array([[41], [42]], dtype=np.int32), - 'label_weights': np.array([[1.], [2.]], dtype=np.float32), + 'example_weights': np.array([[1.], [2.]], dtype=np.float32), }, mode=model_fn.ModeKeys.TRAIN, logits=logits, diff --git a/tensorflow/contrib/estimator/python/estimator/logit_fns.py b/tensorflow/contrib/estimator/python/estimator/logit_fns.py index fc5efa4d7b98123ae968f98d4a54900e2d63570d..09c2862ccd3f90de4153a2095afc9c3d3f9476c1 100644 --- a/tensorflow/contrib/estimator/python/estimator/logit_fns.py +++ b/tensorflow/contrib/estimator/python/estimator/logit_fns.py @@ -84,7 +84,7 @@ def call_logit_fn(logit_fn, features, mode, params, config): result_is_valid_dictionary = ( isinstance(logit_fn_results, dict) and - all([(isinstance(k, str) and isinstance(v, ops.Tensor)) + all([(isinstance(k, six.string_types) and isinstance(v, ops.Tensor)) for k, v in six.iteritems(logit_fn_results)])) result_is_tensor = isinstance(logit_fn_results, ops.Tensor) diff --git a/tensorflow/contrib/estimator/python/estimator/logit_fns_test.py b/tensorflow/contrib/estimator/python/estimator/logit_fns_test.py index 3279e920018bae8ca9520a6372f6b71971da7b52..074ece6cca2865b9057ab5ce874a210d3d9ac2e0 100644 --- a/tensorflow/contrib/estimator/python/estimator/logit_fns_test.py +++ b/tensorflow/contrib/estimator/python/estimator/logit_fns_test.py @@ -46,7 +46,7 @@ class LogitFnTest(test.TestCase): def test_simple_call_multi_logit_fn(self): def dummy_logit_fn(features): - return {'head1': features['f1'], 'head2': features['f2']} + return {u'head1': features['f1'], 'head2': features['f2']} features = { 'f1': constant_op.constant([[2., 3.]]), diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head.py b/tensorflow/contrib/estimator/python/estimator/multi_head.py index 64b2a9dee83801b5d6d852a3485fc0cc81417ff0..69dbfcee62af526cc92f8699f7137acbcdc03052 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head.py @@ -161,12 +161,52 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access def create_loss(self, features, mode, logits, labels): """See `Head`.""" - # TODO(roumposg): Implement it. - raise NotImplementedError('create_loss not yet implemented for MultiHead.') + # TODO(roumposg): Add support for logits as single Tensor (with + # _split_logits utility). + if not isinstance(logits, dict): + raise ValueError('logits must be a dict. Single Tensor support coming ' + 'soon.') + weighted_sum_losses = [] + example_weight_sums = [] + labels_by_head = {} + for head in self._heads: + (weighted_sum_loss, + example_weight_sum, processed_labels) = head.create_loss( + features, mode, logits[head.name], labels[head.name]) + weighted_sum_losses.append(weighted_sum_loss) + example_weight_sums.append(example_weight_sum) + labels_by_head[head.name] = processed_labels + + weighted_sum_losses = tuple(weighted_sum_losses) + with ops.name_scope('merge_losses', + values=weighted_sum_losses + (self._head_weights or + tuple())): + if self._head_weights: + head_weighted_losses = [] + head_weighted_example_weight_sums = [] + for loss, example_weight_sum, weight in zip(weighted_sum_losses, + example_weight_sums, + self._head_weights): + head_weighted_losses.append(math_ops.multiply(loss, weight)) + head_weighted_example_weight_sums.append(math_ops.multiply( + example_weight_sum, weight)) + merged_weighted_sum_loss = math_ops.add_n(head_weighted_losses) + merged_example_weight_sum = math_ops.add_n( + head_weighted_example_weight_sums) + else: + merged_weighted_sum_loss = math_ops.add_n(weighted_sum_losses) + merged_example_weight_sum = math_ops.add_n(example_weight_sums) + + return head_lib.LossSpec( + weighted_sum_loss=merged_weighted_sum_loss, + example_weight_sum=merged_example_weight_sum, + processed_labels=labels_by_head) def create_estimator_spec( self, features, mode, logits, labels=None, train_op_fn=None): """See `_Head`.""" + # TODO(roumposg): Add support for logits as single Tensor (with + # _split_logits utility). if not isinstance(logits, dict): raise ValueError('logits must be a dict. Given: {}'.format(logits)) if labels and not isinstance(labels, dict): @@ -183,6 +223,8 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access labels=labels[head_name] if labels else None, train_op_fn=_no_op_train_fn)) + # TODO(roumposg): Add LOSS and LOSS_MEAN summaries for the total head- + # combined loss. if mode == model_fn.ModeKeys.TRAIN: if train_op_fn is None: raise ValueError('train_op_fn can not be None in TRAIN mode.') diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py index 48027035cecffc3ce8aacf8ae917f5eb9e9b2473..16177aebd53cbff5c8fd727477ac5d18c9f8bce5 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py @@ -178,7 +178,7 @@ class MultiHeadTest(test.TestCase): # (1 - labels) * (logits > 0) * logits => # head1: expected_unweighted_loss = [[10., 10.], [15., 0.]] # head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]] - # Average over classes, weighted sum ober batch and heads. + # Average over classes, weighted sum over batch and heads. expected_loss_head1 = 17.5 expected_loss_head2 = 30.0 expected_loss = 1. * expected_loss_head1 + 2. * expected_loss_head2 @@ -231,18 +231,25 @@ class MultiHeadTest(test.TestCase): logits = {'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)} labels = {'head1': np.array([[1, 0], [1, 1]], dtype=np.int64)} - with self.assertRaisesRegexp( - NotImplementedError, - r'create_loss not yet implemented for MultiHead\.'): - multi_head.create_loss( - features={'x': np.array(((42,),), dtype=np.int32)}, - mode=model_fn.ModeKeys.TRAIN, - logits=logits, - labels=labels) + loss = multi_head.create_loss( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels)[0] + tol = 1e-3 + with self.test_session(): + # Unreduced loss of the head is [[(10 + 10) / 2], (15 + 0) / 2] + # (averaged over classes, sum-reduced over examples). + self.assertAllClose(17.5, loss.eval(), rtol=tol, atol=tol) def test_train_create_loss_two_heads_with_weights(self): - head1 = head_lib.multi_label_head(n_classes=2, name='head1') - head2 = head_lib.multi_label_head(n_classes=3, name='head2') + # Use different example weighting for each head weighting. + weights1 = np.array([[1.], [2.]], dtype=np.float32) + weights2 = np.array([[2.], [3.]]) + head1 = head_lib.multi_label_head(n_classes=2, name='head1', + weight_column='weights1') + head2 = head_lib.multi_label_head(n_classes=3, name='head2', + weight_column='weights2') multi_head = multi_head_lib.multi_head( [head1, head2], head_weights=[1., 2.]) @@ -255,14 +262,27 @@ class MultiHeadTest(test.TestCase): 'head1': np.array([[1, 0], [1, 1]], dtype=np.int64), 'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64), } - with self.assertRaisesRegexp( - NotImplementedError, - r'create_loss not yet implemented for MultiHead\.'): - multi_head.create_loss( - features={'x': np.array(((42,),), dtype=np.int32)}, - mode=model_fn.ModeKeys.TRAIN, - logits=logits, - labels=labels) + weighted_sum_loss, example_weight_sum, _ = multi_head.create_loss( + features={ + 'x': np.array(((42,),), dtype=np.int32), + 'weights1': weights1, + 'weights2': weights2 + }, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels) + tol = 1e-3 + with self.test_session(): + # loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]] + # = [10, 7.5] + # weighted_sum_loss = 1 * 10 + 2 * 7.5 = 25 + # loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]] + # = [20, 10] + # weighted_sum_loss = 2 * 20 + 3 * 10 = 70 + # head-weighted merge = 1 * 25 + 2 * 70 = 165 + self.assertAllClose(165, weighted_sum_loss.eval(), rtol=tol, atol=tol) + # example_weight_sum = 1 * (1 + 2) + 2 * (2 + 3) = 13 + self.assertAllClose(13., example_weight_sum.eval(), rtol=tol, atol=tol) def test_train_one_head(self): head1 = head_lib.multi_label_head(n_classes=2, name='head1') @@ -332,7 +352,7 @@ class MultiHeadTest(test.TestCase): # (1 - labels) * (logits > 0) * logits => # head1: expected_unweighted_loss = [[10., 10.], [15., 0.]] # head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]] - # Average over classes, weighted sum ober batch and heads. + # Average over classes, weighted sum over batch and heads. expected_loss_head1 = 17.5 expected_loss_head2 = 30.0 expected_loss = 1. * expected_loss_head1 + 2. * expected_loss_head2 diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..7005a647db599dfa386f34406911febe1d9d5651 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py @@ -0,0 +1,470 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities to replicate model_fn's over local GPUs. + +This file contains util that allow to replicate `Estimator.model_fn` over +GPUs. Replicated version of a `model_fn` is returned that can subsequently +be used with `Estimator`. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy + +import six + +from tensorflow.core.framework import node_def_pb2 +from tensorflow.python.client import device_lib +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator import util +from tensorflow.python.estimator.export import export_output as export_output_lib +from tensorflow.python.framework import device as framework_device +from tensorflow.python.framework import ops as ops_lib +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients as gradients_lib +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as variables_lib +from tensorflow.python.platform import tf_logging +from tensorflow.python.training import training_util + + +def replicate_model_fn(model_fn, optimizer_fn, devices=None): + """Replicate `Estimator.model_fn` over GPUs within a single host. + + The given `model_fn` specifies a single forward pass of a model. To replicate + such a model over GPUs, each GPU gets its own instance of the forward pass + (a.k.a. a tower). The input features and labels get sharded into the chunks + that correspond to the number of GPUs. Each tower computes its own loss based + on its input. For each such loss, gradients are computed. After that, the + available losses are summed to form aggregated loss. The available + gradients are summed too. Then, they update weights using the specified + optimizer. + + If `devices` are `None`, then all available GPUs are going to be used for + replication. If no GPUs are available, then the model is going to be + placed on the CPU. + + Two modes of local replication over available GPUs are supported: + 1) If exactly 1 GPU is detected, then variables and operations are placed + onto GPU. + 2) If more than 1 GPU is detected, then variables are going to be placed on + the CPU. Replicas of operations are placed on each individual GPU. + + Here is an example of how one might use their `model_fn` to run over GPUs: + ```python + def optimizer_fn(): + return tf.train.GradientDescentOptimizer(learning_rate=0.001) + ... + def model_fn(...): # See `model_fn` in `Estimator`. + loss = ... + if mode == tf.estimator.ModeKeys.TRAIN: + # See the section below on `EstimatorSpec.train_op`. + return EstimatorSpec(mode=mode, loss=loss, train_op=tf.noop()) + + # No change for `ModeKeys.EVAL` or `ModeKeys.PREDICT`. + return EstimatorSpec(...) + ... + classifier = tf.estimator.Estimator( + model_fn=replicate_model_fn.replicate_model_fn(model_fn, optimizer_fn)) + ``` + + On `EstimatorSpec.train_op`: + `model_fn` returns `EstimatorSpec.train_op` for + `tf.estimator.GraphKeys.TRAIN`. It is typically derived using an optimizer. + `replicate_model_fn` ignores the returned `EstimatorSpec.train_op`, so there + is no need to use an optimizer inside the user's `model_fn`. The + `EstimatorSpec.loss` subgraph is going to be executed, while + `EstimatorSpec.train_op` isn't going to be executed. One could pass + `train_op=tf.noop()` to `EstimatorSpec`. + + On sharding input features and labels: + Input features and labels are split for consumption by each tower. They are + split across the dimension 0. Features and labels need to be batch major. + + On reduction algorithms: + Certain algorithms were chosen for aggregating results of computations on + multiple towers: + - Losses from all towers are reduced using sum. + - Gradients are reduced using sum for each trainable variable. + - `eval_metrics_ops` are reduced per metric using `reduce_mean`. + - `EstimatorSpec.predictions` and `EstimatorSpec.export_outputs` are + reduced using concatenation. + - For all other fields of `EstimatorSpec` the values of the first tower + are taken. + + On replication of variables: + Variables are not duplicated between towers. Instead, they are placed on a + single device as defined above and shared across towers. + + Other current limitations: + - `predictions` are not supported for `ModeKeys.EVAL`. That is required for + `tf.contrib.estimator.add_metrics`. + + Args: + model_fn: `model_fn` as defined in `Estimator`. See the section above about + the train_op argument of `EstimatorSpec`. + optimizer_fn: a function that returns an optimizer instance. The function + may accept one `params` argument. This is the `params` argument as + defined by `Estimator`. See the `Estimator` documentation for details. + devices: Optional list of devices to replicate the model across. This + argument can be used to replice only on the subset of available GPUs. + If `None`, then all available GPUs are going to be used for replication. + If no GPUs are available, then the model is going to be placed on the CPU. + + Returns: + A replicated version of the supplied `model_fn`. Returned function that + conforms to the requirements of `Estimator`'s `model_fn` and can be used + instead of the supplied `model_fn`. + """ + if not devices: + devices = _get_local_devices('GPU') or _get_local_devices('CPU') + + is_a_single_gpu_case = len(devices) == 1 and 'GPU' in devices[0] + local_ps_device = '/{}:0'.format('GPU' if is_a_single_gpu_case else 'CPU') + + tf_logging.info('Replicating the `model_fn` across {}. Local parameter ' + 'server device is going to be {}.'.format( + devices, local_ps_device)) + + def replicated_model_fn(mode, features, labels, params=None, config=None): + """Replicated version of `model_fn` to be used instead.""" + feature_shards, label_shards = _split_batch( + features, labels, len(devices), device=local_ps_device) + tower_specs = _get_loss_towers( + model_fn=model_fn, + mode=mode, + features=feature_shards, + labels=label_shards, + params=params, + config=config, + devices=devices, + local_ps_device=local_ps_device) + + if mode == model_fn_lib.ModeKeys.TRAIN: + train_op = _minimize_towers(tower_specs, + _call_optimizer_fn(optimizer_fn, params)) + return _train_spec( + tower_specs, train_op, aggregation_device=local_ps_device) + elif mode == model_fn_lib.ModeKeys.EVAL: + return _eval_spec(tower_specs, aggregation_device=local_ps_device) + elif mode == model_fn_lib.ModeKeys.PREDICT: + return _predict_spec(tower_specs, aggregation_device=local_ps_device) + + return replicated_model_fn + + +def _get_local_devices(device_type): + local_device_protos = device_lib.list_local_devices() + return [ + device.name + for device in local_device_protos + if device.device_type == device_type + ] + + +def _split_batch(features, labels, number_of_shards, device): + """Split input features and labes into batches.""" + + def split_dictionary(dictionary): + shards = [{} for _ in range(number_of_shards)] + for name, tensor in six.iteritems(dictionary): + for i, shard in enumerate(array_ops.split(tensor, number_of_shards)): + shards[i][name] = shard + return shards + + with ops_lib.name_scope('split_inputs'): + with ops_lib.device(device): + if isinstance(features, dict): + feature_shards = split_dictionary(features) + else: + feature_shards = array_ops.split(features, number_of_shards) + + if labels is None: + label_shards = None + elif isinstance(labels, dict): + label_shards = split_dictionary(labels) + else: + label_shards = array_ops.split(labels, number_of_shards) + return feature_shards, label_shards + + +_DEFAULT_NAME_SCOPE_PATTERN = 'tower_{}' + + +def _get_loss_towers(model_fn, + mode, + features, + labels, + params, + config, + devices, + local_ps_device, + name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN): + """Replicate the loss computation across devices.""" + tower_specs = [] + + model_fn_args = util.fn_args(model_fn) + optional_params = {} + if 'params' in model_fn_args: + optional_params['params'] = copy.deepcopy(params) + if 'config' in model_fn_args: + optional_params['config'] = copy.deepcopy(config) + + for i, device in enumerate(devices): + is_the_first_tower = (i == 0) + + device_setter = _local_device_setter( + worker_device=device, ps_device=local_ps_device) + + # We would like to preserve the names of the variables and ops that a user + # might be relying on. Names with prefix are going to resolve to variables + # and ops of the first tower. + name_scope = name_scope_pattern + if is_the_first_tower: + name_scope = '' + + with variable_scope.variable_scope('', reuse=not is_the_first_tower): + with ops_lib.name_scope(name_scope.format(i)): + with ops_lib.device(device_setter): + labels_shard = None + if labels: + labels_shard = labels[i] + + tower_specs.append( + model_fn( + mode=mode, + features=features[i], + labels=labels_shard, + **optional_params)) + return tower_specs + + +def _local_device_setter(ps_device, worker_device): + """A device setter that puts distributes Var/Ops to PS/workers.""" + ps_ops = ['Variable', 'VariableV2', 'VarHandleOp'] + + def local_device_chooser(op): + current_device = framework_device.DeviceSpec.from_string(op.device or '') + + node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def + if node_def.op in ps_ops: + ps_device_spec = framework_device.DeviceSpec.from_string( + '{}'.format(ps_device)) + + ps_device_spec.merge_from(current_device) + return ps_device_spec.to_string() + else: + worker_device_spec = framework_device.DeviceSpec.from_string( + worker_device or '') + worker_device_spec.merge_from(current_device) + return worker_device_spec.to_string() + + return local_device_chooser + + +def _minimize_towers(tower_specs, optimizer): + """Aggregate and apply gradients for computed losses.""" + grad_lists = {} + for tower_spec in tower_specs: + with ops_lib.device(tower_spec.loss.device): + variables = variables_lib.trainable_variables() + gradients = gradients_lib.gradients(tower_spec.loss, variables) + + for var, grad in zip(variables, gradients): + if grad is not None: + grad_lists.setdefault(var, []).append(grad) + + aggregated_grads = [] + with ops_lib.name_scope('gradient_aggregating'): + for var, grads in six.iteritems(grad_lists): + grad = _compute_sum_on_device(grads, var.device) + aggregated_grads.append((grad, var)) + + train_op = optimizer.apply_gradients( + aggregated_grads, global_step=training_util.get_global_step()) + + return train_op + + +def _call_optimizer_fn(optimizer_fn, params): + arguments = {} + optimizer_fn_arguments = util.fn_args(optimizer_fn) + if 'params' in optimizer_fn_arguments: + arguments['params'] = params + return optimizer_fn(**arguments) + + +def _compute_sum_on_device(values, device, name=None): + with ops_lib.device(device): + return math_ops.add_n(values, name=name) + + +def _train_spec(tower_specs, + train_op, + aggregation_device, + aggregated_loss_name='loss'): + """Populate replicated EstimatorSpec for `GraphKeys.TRAIN`.""" + estimator_spec = tower_specs[0]._asdict() + estimator_spec['mode'] = model_fn_lib.ModeKeys.TRAIN + estimator_spec['train_op'] = train_op + estimator_spec['loss'] = _compute_sum_on_device( + [spec.loss for spec in tower_specs], aggregation_device, + aggregated_loss_name) + return model_fn_lib.EstimatorSpec(**estimator_spec) + + +def _eval_spec(tower_specs, aggregation_device, aggregated_loss_name='loss'): + """Populate replicated EstimatorSpec for `GraphKeys.EVAL`.""" + estimator_spec = tower_specs[0]._asdict() + estimator_spec['mode'] = model_fn_lib.ModeKeys.EVAL + estimator_spec['loss'] = _compute_sum_on_device( + [spec.loss for spec in tower_specs], aggregation_device, + aggregated_loss_name) + + eval_metric_ops_lists = {} + for tower_spec in tower_specs: + metrics = tower_spec.eval_metric_ops or {} + for name, (_, update_op) in six.iteritems(metrics): + update_ops = eval_metric_ops_lists.setdefault(name, ([])) + update_ops.append(update_op) + + eval_metric_ops = {} + for name, (metric_tensor, _) in six.iteritems(tower_specs[0].eval_metric_ops): + with ops_lib.control_dependencies(eval_metric_ops_lists[name]): + # This operation reduces local variables across all metrics, yet is + # called for every metric. This is redundant and it's done because + # it is hard to know what local variables correspond to what metric. + # Estimator is going to execute all `reduced_update_op`s as part of + # a group inside a single `Session.run()` call, which will avoid duplicate + # computation. + reduced_update_op = _reduce_metric_variables(len(tower_specs)) + eval_metric_ops[name] = (metric_tensor, reduced_update_op) + + estimator_spec['eval_metric_ops'] = eval_metric_ops + return model_fn_lib.EstimatorSpec(**estimator_spec) + + +def _reduce_metric_variables(number_of_towers): + """Aggregate local variables used in metrics into the first tower.""" + if number_of_towers == 1: + return control_flow_ops.no_op() + + metric_variables = ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES) + variables_per_tower = len(metric_variables) // number_of_towers + + if len(metric_variables) % number_of_towers != 0: + raise ValueError( + 'Different `EstimatorSpec.eval_metric_ops` across `model_fn()` calls.' + ' Expected {} local variables, but got {} instead.'.format( + variables_per_tower * number_of_towers, len(metric_variables))) + + # `metric_variables` has the size of `variables_per_tower` x + # number_of_towers. Each tower is produced by calling the same model_fn. + # First `variables_per_tower` correspond to the first tower. Each such + # variable has an replica at the `(variables_per_tower * i)` position, where + # `i` is `[1.. number_of_towers]`. We are going to add values from replicas + # to each variable of the first tower. We then zero out replica values, so + # that `_reduce_metric_variables` operation is idempotent. If a metric + # is then computed based on local variables from the first tower, then the + # resulting metric is an estimate for all `number_of_towers` towers. + ops = [] + for i in range(0, variables_per_tower): + next_replica_id = i + variables_per_tower + replicas = [ + metric_variables[replica_id] + for replica_id in range(next_replica_id, len(metric_variables), + variables_per_tower) + ] # `replicas` doesn't contain the first-tower variable. + + reduce_op = state_ops.assign_add(metric_variables[i], + math_ops.add_n(replicas)) + + with ops_lib.control_dependencies([reduce_op]): + for replica in replicas: + zeros_for_replica = array_ops.zeros( + array_ops.shape(replica), dtype=replica.dtype) + zero_out_replica_op = state_ops.assign(replica, zeros_for_replica) + ops.append(zero_out_replica_op) + + return control_flow_ops.group(*ops) + + +def _predict_spec(tower_specs, aggregation_device): + """Populate replicated EstimatorSpec for `GraphKeys.PREDICT`.""" + estimator_spec = tower_specs[0]._asdict() + estimator_spec['mode'] = model_fn_lib.ModeKeys.PREDICT + + with ops_lib.device(aggregation_device): + estimator_spec['predictions'] = _concat_tensor_dicts( + *[tower_spec.predictions for tower_spec in tower_specs]) + + export_outputs_dict = _dict_concat( + *[tower_spec.export_outputs for tower_spec in tower_specs]) + + export_outputs = {} + for name, export_output_list in six.iteritems(export_outputs_dict): + if isinstance(export_output_list[0], export_output_lib.PredictOutput): + export_outputs[name] = export_output_lib.PredictOutput( + outputs=_concat_tensor_dicts(*[ + export_output.outputs for export_output in export_output_list + ])) + elif isinstance(export_output_list[0], + export_output_lib.RegressionOutput): + export_outputs[name] = export_output_lib.RegressionOutput( + value=array_ops.concat( + [export_output.value for export_output in export_output_list], + axis=0)) + elif isinstance(export_output_list[0], + export_output_lib.ClassificationOutput): + scores = None + if export_output_list[0].scores is not None: + scores = array_ops.concat( + [export_output.scores for export_output in export_output_list], + axis=0) + + classes = None + if export_output_list[0].classes is not None: + classes = array_ops.stack( + [export_output.classes for export_output in export_output_list], + axis=0) + + export_outputs[name] = export_output_lib.ClassificationOutput( + scores=scores, classes=classes) + + estimator_spec['export_outputs'] = export_outputs + return model_fn_lib.EstimatorSpec(**estimator_spec) + + +def _concat_tensor_dicts(*tensor_dicts): + return { + name: array_ops.concat(tensors, axis=0, name=name) + for name, tensors in six.iteritems(_dict_concat(*tensor_dicts)) + } + + +def _dict_concat(*dicts): + list_dict = {} + for d in dicts: + if d is None: + continue + + for k, v in six.iteritems(d): + list_dict.setdefault(k, []).append(v) + return list_dict diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py new file mode 100644 index 0000000000000000000000000000000000000000..10b47fba5af0f2a036df637a4f4f996d388270c6 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py @@ -0,0 +1,901 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for utilities that replicate `Estimator.model_fn` over GPUs.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re +import shutil +import tempfile +import numpy as np +import six + +from tensorflow.contrib.estimator.python.estimator import replicate_model_fn +from tensorflow.python.estimator import estimator as estimator_lib +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.canned import dnn +from tensorflow.python.estimator.canned import optimizers +from tensorflow.python.estimator.canned import prediction_keys +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.export import export_output +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops as ops_lib +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.ops.losses import losses +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import gradient_descent + + +class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def test_complete_flow(self): + n_classes = 3 + input_dimension = 2 + batch_size = 12 + + data = np.linspace( + 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32) + x_data = data.reshape(batch_size, input_dimension) + y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1)) + train_input_fn = numpy_io.numpy_input_fn( + x={'x': x_data}, + y=y_data, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': x_data}, y=y_data, batch_size=batch_size, shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': x_data}, batch_size=batch_size, shuffle=False) + + feature_columns = [ + feature_column.numeric_column('x', shape=(input_dimension,)) + ] + + estimator = dnn.DNNClassifier( + hidden_units=(2, 2), + feature_columns=feature_columns, + n_classes=n_classes, + model_dir=self._model_dir) + + def optimizer_fn(): + return optimizers.get_optimizer_instance('Adagrad', learning_rate=0.05) + + # TODO(isaprykin): Switch Estimator to use allow_soft_placement=True + # during export_savedmodel and then switch this test to replicate over + # GPUs instead of CPUs. + estimator = estimator_lib.Estimator( + model_fn=replicate_model_fn.replicate_model_fn( + estimator.model_fn, + optimizer_fn, + devices=['/cpu:0', '/cpu:0', '/cpu:0']), + model_dir=estimator.model_dir, + config=estimator.config, + params=estimator.params) + + num_steps = 10 + estimator.train(train_input_fn, steps=num_steps) + + scores = estimator.evaluate(eval_input_fn) + self.assertEqual(num_steps, scores[ops_lib.GraphKeys.GLOBAL_STEP]) + self.assertIn('loss', six.iterkeys(scores)) + + predicted_proba = np.array([ + x[prediction_keys.PredictionKeys.PROBABILITIES] + for x in estimator.predict(predict_input_fn) + ]) + self.assertAllEqual((batch_size, n_classes), predicted_proba.shape) + + feature_spec = feature_column.make_parse_example_spec(feature_columns) + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + export_dir = estimator.export_savedmodel(tempfile.mkdtemp(), + serving_input_receiver_fn) + self.assertTrue(gfile.Exists(export_dir)) + + def _as_label(self, data_in_float): + return np.rint(data_in_float).astype(np.int64) + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + +class ReplicateModelTest(test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(10, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.multiply(features, c) + + loss = None + if mode is not model_fn_lib.ModeKeys.PREDICT: + loss = losses.absolute_difference( + labels=labels, + predictions=predictions, + reduction=losses.Reduction.SUM) + loss = math_ops.reduce_sum(loss) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions), + 'auc': metrics_lib.auc(labels, predictions) + } + + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=loss, + eval_metric_ops=metrics, + predictions={'probabilities': predictions}, + train_op=control_flow_ops.no_op()) # This train_op isn't actually used. + + def optimizer_fn(self, params): + return gradient_descent.GradientDescentOptimizer(params['learning_rate']) + + @property + def params(self): + params = {} + params['learning_rate'] = 1.0 + return params + + def test_train(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.TRAIN, + features, labels, self.params) + session.run(variables.global_variables_initializer()) + + # loss = feature * c - label + total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0) + self.assertEqual(total_loss, session.run(estimator_spec.loss)) + + # loss' of c is 3. + # new value of c = 10 - learning rate * 3 = 7.0. + session.run(estimator_spec.train_op) + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(7.0, session.run(c)) + + def test_train_spec_with_optimizer_without_params(self): + + def optimizer_fn_without_params(): + return gradient_descent.GradientDescentOptimizer(learning_rate=1.0) + + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session() as session: # pylint: disable=unused-variable + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, + optimizer_fn_without_params, + devices=['/gpu:0', '/gpu:1']) + # This call is going to fail if `replicated_model_fn` is still passing + # `params` inside `optimizer_fn`, even though the latter doesn't take any: + estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.TRAIN, + features, labels, self.params) + del estimator_spec + + def test_eval(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.EVAL, features, + labels, self.params) + session.run(variables.local_variables_initializer()) + session.run(variables.global_variables_initializer()) + + accuracy, a = estimator_spec.eval_metric_ops['accuracy'] + auc, b = estimator_spec.eval_metric_ops['auc'] + + session.run([a, b]) + accuracy = session.run(accuracy) + auc = session.run(auc) + + # Accuracy is 0.0 (no match) in the first tower. + # Accuracy is 1.0 (match) in the second tower, since the feature + # times weight "c" happened to be equal to the label. + total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02)) + + self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01) + self.assertEqual(0, auc) + self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01) + + def test_predict(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.PREDICT, + features, labels, self.params) + session.run(variables.global_variables_initializer()) + + self.assertAllClose({ + 'probabilities': np.array([[0.1], [0.02]]) + }, session.run(estimator_spec.predictions)) + + def test_train_single_tower(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn) + estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.TRAIN, + features, labels, self.params) + session.run(variables.global_variables_initializer()) + + # loss = feature * c - label + total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0) + self.assertEqual(total_loss, session.run(estimator_spec.loss)) + + # loss' of c is 3. + # new value of c = 10 - learning rate * 3 = 7.0. + session.run(estimator_spec.train_op) + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(7.0, session.run(c)) + + def test_eval_single_tower(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn, devices=['/gpu:0']) + estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.EVAL, features, + labels, self.params) + session.run(variables.local_variables_initializer()) + session.run(variables.global_variables_initializer()) + + accuracy, a = estimator_spec.eval_metric_ops['accuracy'] + auc, b = estimator_spec.eval_metric_ops['auc'] + + session.run([a, b]) + accuracy = session.run(accuracy) + auc = session.run(auc) + + # Accuracy is 0.0 (no match) in the first tower. + # Accuracy is 1.0 (match) in the second tower, since the feature + # times weight "c" happened to be equal to the label. + total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02)) + + self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01) + self.assertEqual(0, auc) + self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01) + + def test_predict_single_tower(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn, devices=['/gpu:0']) + estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.PREDICT, + features, labels, self.params) + session.run(variables.global_variables_initializer()) + + self.assertAllClose({ + 'probabilities': np.array([[0.1], [0.02]]) + }, session.run(estimator_spec.predictions)) + + +class GetLossTowersTest(test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(0.25, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.add(np.array([0.1, 0.2, 0.3, features[0]]), c) + labels = np.array([0.1, 0.2, 0.3, labels[0]]) + + loss = losses.absolute_difference( + labels=labels, predictions=predictions, reduction=losses.Reduction.SUM) + + return model_fn_lib.EstimatorSpec(mode=mode, loss=math_ops.reduce_sum(loss)) + + def test_gradients_are_computed(self): + with self.test_session() as session: + tower_specs = replicate_model_fn._get_loss_towers( + self.model_fn, + mode=None, + features=[[0.6], [1.6]], + labels=[[0.6], [0.6]], + params=None, + config=None, + devices=['/gpu:0', '/gpu:1'], + local_ps_device='/gpu:0', + name_scope_pattern='test_tower_{}') + session.run(variables.global_variables_initializer()) + + self.assertEqual(len(tower_specs), 2) + + self.assertEqual('/device:GPU:0', tower_specs[0].loss.device) + self.assertEqual('Sum:0', tower_specs[0].loss.name) + self.assertEqual(1.0, session.run(tower_specs[0].loss)) + + self.assertEqual('/device:GPU:1', tower_specs[1].loss.device) + self.assertEqual('test_tower_1/Sum:0', tower_specs[1].loss.name) + # The input batch for the second tower had a loss that is 1.0 + # bigger: 0.6 vs 1.6. + self.assertEqual(2.0, session.run(tower_specs[1].loss)) + + self.assertEqual(1, len(variables.global_variables())) + self.assertEqual(1, len(variables.trainable_variables())) + + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(0.25, session.run(c)) + + +class SplitBatchTest(test_util.TensorFlowTestCase): + + def evaluate_shards(self, first_list, second_list): + evaluate_items = lambda x: x.eval() + return list(map(evaluate_items, first_list)), list( + map(evaluate_items, second_list)) + + def test_simple_half_split(self): + with self.test_session() as session: # pylint: disable=unused-variable + features = [0.0, 1.0, 2.0, 3.0] + labels = [10.0, 11.0, 12.0, 13.0] + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 2, device='/gpu:0') + + feature_shards, label_shards = self.evaluate_shards( + feature_shards, label_shards) + + self.assertAllEqual([[0.0, 1.0], [2.0, 3.0]], feature_shards) + self.assertAllEqual([[10.0, 11.0], [12.0, 13.0]], label_shards) + + def test_to_each_their_own(self): + with self.test_session() as session: # pylint: disable=unused-variable + features = [0.0, 1.0, 2.0, 3.0] + labels = [10.0, 11.0, 12.0, 13.0] + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 4, device='/gpu:0') + + feature_shards, label_shards = self.evaluate_shards( + feature_shards, label_shards) + + self.assertAllEqual([[0.0], [1.0], [2.0], [3.0]], feature_shards) + self.assertAllEqual([[10.0], [11.0], [12.0], [13.0]], label_shards) + + def test_one_batch(self): + with self.test_session() as session: # pylint: disable=unused-variable + features = [0.0, 1.0, 2.0, 3.0] + labels = [10.0, 11.0, 12.0, 13.0] + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 1, device='/gpu:0') + + feature_shards, label_shards = self.evaluate_shards( + feature_shards, label_shards) + + self.assertAllEqual([[0.0, 1.0, 2.0, 3.0]], feature_shards) + self.assertAllEqual([[10.0, 11.0, 12.0, 13.0]], label_shards) + + def test_half_split_in_dictionary(self): + with self.test_session() as session: # pylint: disable=unused-variable + features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]} + labels = [10.0, 11.0, 12.0, 13.0] + + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 2, device='/gpu:0') + + self.assertAllEqual([0.0, 1.0], feature_shards[0]['first'].eval()) + self.assertAllEqual([4.0, 5.0], feature_shards[0]['second'].eval()) + self.assertAllEqual([2.0, 3.0], feature_shards[1]['first'].eval()) + self.assertAllEqual([6.0, 7.0], feature_shards[1]['second'].eval()) + self.assertAllEqual([10.0, 11.0], label_shards[0].eval()) + self.assertAllEqual([12.0, 13.0], label_shards[1].eval()) + + def test_one_batch_in_dictionary(self): + with self.test_session() as session: # pylint: disable=unused-variable + features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]} + labels = [10.0, 11.0, 12.0, 13.0] + + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 1, device='/gpu:0') + + self.assertAllEqual([0.0, 1.0, 2.0, 3.0], + feature_shards[0]['first'].eval()) + self.assertAllEqual([4.0, 5.0, 6.0, 7.0], + feature_shards[0]['second'].eval()) + self.assertAllEqual([10.0, 11.0, 12.0, 13.0], label_shards[0].eval()) + + def test_feature_and_label_dictionaries(self): + with self.test_session() as session: # pylint: disable=unused-variable + features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]} + labels = {'first': [10.0, 11.0], 'second': [12.0, 13.0]} + + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 2, device='/gpu:0') + + self.assertAllEqual([0.0, 1.0], feature_shards[0]['first'].eval()) + self.assertAllEqual([4.0, 5.0], feature_shards[0]['second'].eval()) + self.assertAllEqual([2.0, 3.0], feature_shards[1]['first'].eval()) + self.assertAllEqual([6.0, 7.0], feature_shards[1]['second'].eval()) + self.assertAllEqual([10.0], label_shards[0]['first'].eval()) + self.assertAllEqual([12.0], label_shards[0]['second'].eval()) + self.assertAllEqual([11], label_shards[1]['first'].eval()) + self.assertAllEqual([13.0], label_shards[1]['second'].eval()) + + +class TrainSpecTest(test_util.TensorFlowTestCase): + + expected_predictions = {} + + def create_estimator_spec(self, loss): + return model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.TRAIN, + loss=loss, + train_op=loss, # Not used; currently required. + predictions=self.expected_predictions) + + def create_constant_loss(self, loss_value): + return constant_op.constant(loss_value, dtype=dtypes.float64) + + def test_example(self): + with self.test_session() as session: + tower_losses = list(map(self.create_constant_loss, [2, 4, 6])) + tower_specs = list(map(self.create_estimator_spec, tower_losses)) + + expected_train_op = tower_losses[1] + + estimator_spec = replicate_model_fn._train_spec( + tower_specs, expected_train_op, aggregation_device='/gpu:0') + + self.assertEqual(expected_train_op, estimator_spec.train_op) + self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss)) + self.assertEqual(self.expected_predictions, estimator_spec.predictions) + + +class EvalSpecTest(test_util.TensorFlowTestCase): + + def create_estimator_spec(self, loss, metrics): + return model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.EVAL, loss=loss, eval_metric_ops=metrics) + + def create_constant_loss(self, loss_value): + return constant_op.constant(loss_value, dtype=dtypes.float64) + + def create_eval_metrics(self, noise): + predictions = np.array([0.1, 0.2, 0.3, 0.6 + noise]) + labels = np.array([0.1, 0.2, 0.3, 0.6]) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions), + 'auc': metrics_lib.auc(labels, predictions) + } + return metrics + + def test_example(self): + with self.test_session() as session: + tower_losses = map(self.create_constant_loss, [2, 4, 6]) + tower_metrics = map(self.create_eval_metrics, [0, 0.2, 0.3]) + tower_specs = [ + self.create_estimator_spec(l, m) + for l, m in zip(tower_losses, tower_metrics) + ] + session.run(variables.local_variables_initializer()) + + estimator_spec = replicate_model_fn._eval_spec( + tower_specs, aggregation_device='/device:GPU:0') + + accuracy, a = estimator_spec.eval_metric_ops['accuracy'] + auc, b = estimator_spec.eval_metric_ops['auc'] + + self.assertEqual('/device:CPU:0', accuracy.device) + self.assertEqual('/device:CPU:0', auc.device) + + session.run([a, b]) + accuracy = session.run(accuracy) + auc = session.run(auc) + + self.assertNear((12 - 2) / 12, accuracy, 0.01) + self.assertEqual(0, auc) + self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss)) + + def test_handles_single_tower(self): + with self.test_session() as session: + tower_losses = map(self.create_constant_loss, [5]) + tower_metrics = map(self.create_eval_metrics, [0.2]) + tower_specs = [ + self.create_estimator_spec(l, m) + for l, m in zip(tower_losses, tower_metrics) + ] + session.run(variables.local_variables_initializer()) + + estimator_spec = replicate_model_fn._eval_spec( + tower_specs, aggregation_device='/device:GPU:0') + + accuracy, a = estimator_spec.eval_metric_ops['accuracy'] + auc, b = estimator_spec.eval_metric_ops['auc'] + + self.assertEqual('/device:CPU:0', accuracy.device) + self.assertEqual('/device:CPU:0', auc.device) + + session.run([a, b]) + accuracy = session.run(accuracy) + auc = session.run(auc) + + self.assertNear((4 - 1) / 4, accuracy, 0.01) + self.assertEqual(0, auc) + self.assertEqual(5, session.run(estimator_spec.loss)) + + +class PredictSpecTest(test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(0.25, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.add(np.array([features[0], features[0]]), c) + + return model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.PREDICT, + predictions={ + 'probabilities': predictions + }) + + def test_example(self): + with self.test_session() as session: + tower_specs = replicate_model_fn._get_loss_towers( + self.model_fn, + mode=None, + features=[[0.1], [0.2]], + labels=[[], []], + params=None, + config=None, + devices=['/gpu:0', '/gpu:1'], + local_ps_device='/gpu:0', + ) + session.run(variables.global_variables_initializer()) + + estimator_spec = replicate_model_fn._predict_spec( + tower_specs, aggregation_device='/gpu:0') + + self.assertEqual('/device:GPU:0', + estimator_spec.predictions['probabilities'].device) + self.assertAllClose({ + 'probabilities': np.array([0.35, 0.35, 0.45, 0.45]) + }, session.run(estimator_spec.predictions)) + + +class ReduceMetricVariablesTest(test_util.TensorFlowTestCase): + + def create_metric_variable(self, initial_value, name): + return variable_scope.variable( + initial_value, + trainable=False, + collections=[ops_lib.GraphKeys.METRIC_VARIABLES], + validate_shape=True, + name=name) + + def create_tower_metrics(self, tower_id): + with variable_scope.variable_scope('', reuse=(tower_id != 0)): + self.create_metric_variable(1.3 * (tower_id + 1), 'total') + self.create_metric_variable(2.3 * (tower_id + 1), 'count') + self.create_metric_variable( + np.array([3.3, 3.5, 3.7]) * (tower_id + 1), 'total') + + def test_example(self): + with self.test_session() as session: + for tower_id in range(3): + self.create_tower_metrics(tower_id) + + session.run( + variables.variables_initializer( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))) + + session.run( + replicate_model_fn._reduce_metric_variables(number_of_towers=3)) + + # 1st tower = 1.3, 2.3, [3.3, 3.5, 3.7] + # 2nd tower = 2.6, 4.6, [6.6, 7.0, 7.4] + # 3rd tower = 3.9, 6.9, [9.9, 10.5, 11.1] + # Reduced = 7.8, 13.8, [19.8, 21.0, 22.2] + # Towers are accumulated in the first tower. + local_metrics = session.run( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)) + + self.assertNear(7.8, local_metrics[0], 0.01) + self.assertNear(13.8, local_metrics[1], 0.01) + self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01) + self.assertNear(0.0, local_metrics[3], 0.01) + self.assertNear(0.0, local_metrics[4], 0.01) + self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01) + self.assertNear(0.0, local_metrics[6], 0.01) + self.assertNear(0.0, local_metrics[7], 0.01) + self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01) + + def test_reduce_is_idempotent(self): + with self.test_session() as session: + for tower_id in range(3): + self.create_tower_metrics(tower_id) + + session.run( + variables.variables_initializer( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))) + + for _ in range(20): + session.run( + replicate_model_fn._reduce_metric_variables(number_of_towers=3)) + + local_metrics = session.run( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)) + + self.assertNear(7.8, local_metrics[0], 0.01) + self.assertNear(13.8, local_metrics[1], 0.01) + self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01) + self.assertNear(0.0, local_metrics[3], 0.01) + self.assertNear(0.0, local_metrics[4], 0.01) + self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01) + self.assertNear(0.0, local_metrics[6], 0.01) + self.assertNear(0.0, local_metrics[7], 0.01) + self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01) + + def test_handles_single_tower(self): + with self.test_session() as session: + self.create_tower_metrics(0) + session.run( + variables.variables_initializer( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))) + + session.run( + replicate_model_fn._reduce_metric_variables(number_of_towers=1)) + + local_metrics = session.run( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)) + + self.assertNear(1.3, local_metrics[0], 0.01) + self.assertNear(2.3, local_metrics[1], 0.01) + self.assertAllClose([3.3, 3.5, 3.7], local_metrics[2], 0.01) + + def test_doesnt_accept_uneven_number_of_variables(self): + with self.test_session() as session: + for tower_id in range(3): + self.create_tower_metrics(tower_id) + self.create_metric_variable(-1.0, 'oddball') + + session.run( + variables.variables_initializer( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))) + + with self.assertRaisesRegexp(ValueError, ''): + session.run( + replicate_model_fn._reduce_metric_variables(number_of_towers=3)) + + +class MergeExportOutputsTest(test_util.TensorFlowTestCase): + + def optimizer_fn(self): + return gradient_descent.GradientDescentOptimizer(1.0) + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(10, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = {'probabilities': math_ops.multiply(features, c)} + loss = losses.absolute_difference( + labels=labels, + predictions=predictions['probabilities'], + reduction=losses.Reduction.SUM) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions['probabilities']), + 'auc': metrics_lib.auc(labels, predictions['probabilities']) + } + tensor_string_repr = str(features) + classes = constant_op.constant( + re.search('(split_inputs/split:[0-9])', tensor_string_repr).group(1), + dtype=dtypes.string) + + export_outputs = { + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: + export_output.PredictOutput(predictions), + 'classification_output': + export_output.ClassificationOutput(predictions['probabilities'], + classes), + 'classification_scores': + export_output.ClassificationOutput( + scores=predictions['probabilities']), + 'classification_classes': + export_output.ClassificationOutput(classes=classes), + 'regression_output': + export_output.RegressionOutput(predictions['probabilities']), + } + + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=math_ops.reduce_sum(loss), + eval_metric_ops=metrics, + predictions=predictions, + train_op=loss, # This train_op isn't actually used. + export_outputs=export_outputs) + + def replicate_estimator_spec(self, session): + features = np.array([0.01, 0.002]) + labels = np.array([0.01, 0.02]) + + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.PREDICT, + features, labels, {}) + session.run(variables.global_variables_initializer()) + return estimator_spec + + def test_merde_predict_output(self): + with self.test_session() as session: + estimator_spec = self.replicate_estimator_spec(session) + self.assertAllClose( + { + 'probabilities': np.array([0.1, 0.02]) + }, + session.run(estimator_spec.export_outputs[ + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs)) + + def test_merge_classification_output_scores_classes(self): + with self.test_session() as session: + estimator_spec = self.replicate_estimator_spec(session) + self.assertAllClose( + [0.1, 0.02], + session.run( + estimator_spec.export_outputs['classification_output'].scores)) + self.assertAllEqual( + [b'split_inputs/split:0', b'split_inputs/split:1'], + session.run( + estimator_spec.export_outputs['classification_output'].classes)) + + def test_merge_classification_output_scores(self): + with self.test_session() as session: + estimator_spec = self.replicate_estimator_spec(session) + self.assertAllClose( + [0.1, 0.02], + session.run( + estimator_spec.export_outputs['classification_scores'].scores)) + self.assertEqual( + None, estimator_spec.export_outputs['classification_scores'].classes) + + def test_merge_classification_output_classes(self): + with self.test_session() as session: + estimator_spec = self.replicate_estimator_spec(session) + self.assertAllEqual( + [b'split_inputs/split:0', b'split_inputs/split:1'], + session.run( + estimator_spec.export_outputs['classification_classes'].classes)) + self.assertEqual( + None, estimator_spec.export_outputs['classification_classes'].scores) + + def test_merge_regression_output(self): + with self.test_session() as session: + estimator_spec = self.replicate_estimator_spec(session) + self.assertAllClose( + [0.1, 0.02], + session.run(estimator_spec.export_outputs['regression_output'].value)) + + +class GetLocalDevicesTest(test_util.TensorFlowTestCase): + + def test_there_is_at_least_a_cpu(self): + self.assertTrue(replicate_model_fn._get_local_devices('CPU')) + + def test_there_is_no_xpu(self): + self.assertFalse( + replicate_model_fn._get_local_devices('XPU')) # XPU doesn't exist. + + def test_whether_there_is_a_gpu(self): + self.assertEqual( + len(replicate_model_fn._get_local_devices('GPU')), + test.is_gpu_available()) + + +class LocalDeviceSetterTest(test_util.TensorFlowTestCase): + + def test_vars_are_on_ps_but_ops_are_on_workers(self): + local_device_setter = replicate_model_fn._local_device_setter( + ps_device='/device:GPU:3', worker_device='/device:GPU:2') + + with ops_lib.device(local_device_setter): + c = variables.Variable(0.01) + self.assertEqual('/device:GPU:3', c.device) + + cc = variables.Variable(0.02) + self.assertEqual('/device:GPU:3', cc.device) + + ccc = variables.Variable(0.03) + self.assertEqual('/device:GPU:3', ccc.device) + + c_op = array_ops.concat(c, axis=0) + self.assertEqual('/device:GPU:2', c_op.device) + + cc_op = array_ops.concat(cc, axis=0) + self.assertEqual('/device:GPU:2', cc_op.device) + + +class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase): + + def test_example(self): + with self.test_session() as session: + total = replicate_model_fn._compute_sum_on_device( + [1.0, 2.0, 3.0, 4.0], device='/device:GPU:0', name='test_sum') + + self.assertEqual('/device:GPU:0', total.device) + self.assertEqual('test_sum', total.op.name) + self.assertEqual(10.0, session.run(total)) + + +class ConcatTensorDictsTest(test_util.TensorFlowTestCase): + + def test_example(self): + tensor_dicts = [ + { + 'a': np.array([1.0, 2.0]), + 'b': np.array([11.0]), + 'c': np.array([21.0]), + }, + { + 'a': np.array([3.0]), + 'b': np.array([12.0, 13.0]), + }, + { + 'b': np.array([14.0]), + }, + ] + + with self.test_session() as session: + self.assertAllClose({ + 'a': np.array([1.0, 2.0, 3.0]), + 'b': np.array([11.0, 12.0, 13.0, 14.0]), + 'c': np.array([21.0]), + }, session.run(replicate_model_fn._concat_tensor_dicts(*tensor_dicts))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index 891425fd8cae6fbbf60d30cbd9137c049073456c..e8dad886a1409babdf4ea47b9cd05def1f1ce25e 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -24,6 +24,7 @@ tf_custom_op_py_library( "python/framework/__init__.py", "python/framework/checkpoint_utils.py", "python/framework/experimental.py", + "python/framework/graph_util.py", "python/framework/tensor_util.py", "python/ops/__init__.py", "python/ops/accumulate_n_v2.py", @@ -32,6 +33,7 @@ tf_custom_op_py_library( "python/ops/checkpoint_ops.py", "python/ops/ops.py", "python/ops/prettyprint_ops.py", + "python/ops/sort_ops.py", "python/ops/variables.py", ], dso = [ @@ -231,6 +233,17 @@ py_test( ], ) +py_test( + name = "graph_util_test", + srcs = ["python/framework/graph_util_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":framework_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:platform", + ], +) + py_test( name = "tensor_util_test", srcs = ["python/framework/tensor_util_test.py"], @@ -307,6 +320,20 @@ py_test( ], ) +py_test( + name = "sort_ops_test", + size = "medium", + srcs = ["python/ops/sort_ops_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":framework_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:random_ops", + "//third_party/py/numpy", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 8421ba7c0423c6ed274f92ba74930822d0171e05..3f592611830e40a30392239c85486a2fad15a2a2 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -79,6 +79,8 @@ See the @{$python/contrib.framework} guide. @@load_embedding_initializer @@load_linear_multiclass_bias_initializer @@load_variable_slot_initializer + +@@sort """ from __future__ import absolute_import diff --git a/tensorflow/contrib/framework/python/framework/__init__.py b/tensorflow/contrib/framework/python/framework/__init__.py index c8e6a4685498a4d89cef44f6a9a3acbe7557cb67..2d49771ab756359712a3ee0b23649c231678f952 100644 --- a/tensorflow/contrib/framework/python/framework/__init__.py +++ b/tensorflow/contrib/framework/python/framework/__init__.py @@ -21,6 +21,7 @@ from __future__ import print_function # pylint: disable=wildcard-import from tensorflow.contrib.framework.python.framework.checkpoint_utils import * from tensorflow.contrib.framework.python.framework.experimental import experimental +from tensorflow.contrib.framework.python.framework.graph_util import * from tensorflow.contrib.framework.python.framework.tensor_util import * # pylint: enable=wildcard-import from tensorflow.python.util import decorator_utils diff --git a/tensorflow/contrib/framework/python/framework/graph_util.py b/tensorflow/contrib/framework/python/framework/graph_util.py new file mode 100644 index 0000000000000000000000000000000000000000..8ab8711db4650921e0d366a91adfe2f68b5a42f9 --- /dev/null +++ b/tensorflow/contrib/framework/python/framework/graph_util.py @@ -0,0 +1,128 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helpers to manipulate a tensor graph in python. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import copy +import six + +# pylint: disable=unused-import +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import node_def_pb2 +from tensorflow.python.framework.graph_util_impl import _assert_nodes_are_present +from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes +from tensorflow.python.framework.graph_util_impl import _extract_graph_summary +from tensorflow.python.framework.graph_util_impl import _node_name + +__all__ = ["fuse_op"] + + +def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes, + output_quantized, op_name, op_type): + """Fuse subgraph between input_nodes and output_nodes into a single custom op. + + Args: + graph_def: A graph_pb2.GraphDef proto. + input_nodes: input nodes to the subgraph to be fused. + output_nodes: output nodes to the subgraph to be fused. + output_dtypes: A list of output datatypes for the custom op + output_quantized: A boolean flag that indicates if output is quantized + op_name: fused op name. + op_type: fused op type. + Returns: + The GraphDef of the new graph. + + Raises: + TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto. + """ + + if not isinstance(graph_def, graph_pb2.GraphDef): + raise TypeError("graph_def must be a graph_pb2.GraphDef proto.") + + if isinstance(input_nodes, six.string_types): + raise TypeError("input_nodes must be a list.") + + if isinstance(output_nodes, six.string_types): + raise TypeError("output_nodes must be a list.") + + name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( + graph_def) + _assert_nodes_are_present(name_to_node, input_nodes + output_nodes) + + # Nodes upto and including input_nodes + reachable_by_input = _bfs_for_reachable_nodes(input_nodes, name_to_input_name) + # Nodes upto and including output_nodes + reachable_by_output = _bfs_for_reachable_nodes(output_nodes, + name_to_input_name) + + # Set of nodes in the list input_nodes + input_nodes_set = set(input_nodes) + + # Set of nodes in the list output_nodes + output_nodes_set = set(output_nodes) + + nodes_post_output = [] + for node in graph_def.node: + n = _node_name(node.name) + if n in reachable_by_output: + if n not in reachable_by_input and n not in output_nodes_set: + # n is between input and output, i.e., part of the fused op + next_to_visit = [n] + while next_to_visit: + cur_node = next_to_visit[0] + del next_to_visit[0] + if cur_node in reachable_by_input and cur_node not in input_nodes_set: + raise TypeError("Node %s uses input %s not in input_nodes." % + (n, cur_node)) + if cur_node not in input_nodes_set: + next_to_visit += name_to_input_name[cur_node] + else: + nodes_post_output.append(n) + + # Add all nodes upto the input nodes + out = graph_pb2.GraphDef() + reachable_by_input_sorted = sorted( + list(reachable_by_input), key=lambda n: name_to_seq_num[n]) + for node in reachable_by_input_sorted: + out.node.extend([copy.deepcopy(name_to_node[node])]) + + # Add the custom op + new_node = node_def_pb2.NodeDef() + for node in input_nodes: + new_node.input.append(node) + new_node.attr["_output_types"].list.type[:] = output_dtypes + new_node.attr["_output_quantized"].b = output_quantized + new_node.op = op_type + new_node.name = op_name + out.node.extend([new_node]) + + # Add the nodes in the output of the custom op + for index, n in enumerate(output_nodes): + assert len(name_to_node[n].input) == 1 + new_node = copy.deepcopy(name_to_node[n]) + del new_node.input[:] + new_node.input.append(op_name + (":" + str(index) if index != 0 else "")) + out.node.extend([new_node]) + + # Add the nodes post output_nodes + for n in nodes_post_output: + out.node.extend([copy.deepcopy(name_to_node[n])]) + + out.library.CopyFrom(graph_def.library) + out.versions.CopyFrom(graph_def.versions) + return out diff --git a/tensorflow/contrib/framework/python/framework/graph_util_test.py b/tensorflow/contrib/framework/python/framework/graph_util_test.py new file mode 100644 index 0000000000000000000000000000000000000000..87b992e22e1ad3aa20389d0834eeb3a5972c676e --- /dev/null +++ b/tensorflow/contrib/framework/python/framework/graph_util_test.py @@ -0,0 +1,61 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""@graph_util tests.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.framework.python.framework import graph_util +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import node_def_pb2 +from tensorflow.core.framework import types_pb2 +from tensorflow.python.platform import test + + +def GetNewNode(name, op, input_nodes): + new_node = node_def_pb2.NodeDef() + new_node.op = op + new_node.name = name + for node in input_nodes: + new_node.input.append(node) + return new_node + + +class GraphUtilTest(test.TestCase): + + def testGraphUtil(self): + graph_def = graph_pb2.GraphDef() + node_a = GetNewNode('A', 'Placeholder', []) + node_b = GetNewNode('B', 'Op1', ['A']) + node_c = GetNewNode('C', 'Op1', ['B']) + node_d = GetNewNode('D', 'Op1', ['C']) + node_e = GetNewNode('E', 'Op1', ['D']) + graph_def.node.extend([node_a, node_b, node_c, node_d, node_e]) + fused_graph_def = graph_util.fuse_op( + graph_def, ['A'], ['D'], [types_pb2.DT_FLOAT], True, 'FusedOp', 'Op2') + self.assertEqual(len(fused_graph_def.node), 4) + self.assertEqual(fused_graph_def.node[0].name, 'A') + self.assertEqual(fused_graph_def.node[1].name, 'FusedOp') + self.assertEqual(fused_graph_def.node[1].input[0], 'A') + self.assertEqual(fused_graph_def.node[1].op, 'Op2') + self.assertEqual(fused_graph_def.node[1].attr['_output_quantized'].b, True) + self.assertEqual(fused_graph_def.node[1].attr['_output_types'].list.type, + [types_pb2.DT_FLOAT]) + self.assertEqual(fused_graph_def.node[2].name, 'D') + self.assertEqual(fused_graph_def.node[3].name, 'E') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/framework/python/ops/__init__.py b/tensorflow/contrib/framework/python/ops/__init__.py index edef37cf0c0719bf10a4c75c34adb30b9716cdcd..685bb94779762ce46ee342e7e0a182c54be64743 100644 --- a/tensorflow/contrib/framework/python/ops/__init__.py +++ b/tensorflow/contrib/framework/python/ops/__init__.py @@ -24,5 +24,6 @@ from tensorflow.contrib.framework.python.ops.arg_scope import * from tensorflow.contrib.framework.python.ops.checkpoint_ops import * from tensorflow.contrib.framework.python.ops.ops import * from tensorflow.contrib.framework.python.ops.prettyprint_ops import * +from tensorflow.contrib.framework.python.ops.sort_ops import * from tensorflow.contrib.framework.python.ops.variables import * # pylint: enable=wildcard-import diff --git a/tensorflow/contrib/framework/python/ops/sort_ops.py b/tensorflow/contrib/framework/python/ops/sort_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..8f62f0ea7b9b561f235b9496ffda97a9f378d530 --- /dev/null +++ b/tensorflow/contrib/framework/python/ops/sort_ops.py @@ -0,0 +1,113 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Support for sorting tensors. + +@@sort +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops as framework_ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops + + +def sort(values, axis=-1, direction='ASCENDING', name=None): + """Sorts a tensor. + + Args: + values: 1-D or higher numeric `Tensor`. + axis: The axis along which to sort. The default is -1, which sorts the last + axis. + direction: The direction in which to sort the values (`'ASCENDING'` or + `'DESCENDING'`). + name: Optional name for the operation. + + Returns: + A `Tensor` with the same dtype and shape as `values`, with the elements + sorted along the given `axis`. + + Raises: + ValueError: If axis is not a constant scalar, or the direction is invalid. + """ + with framework_ops.name_scope(name, 'sort'): + if direction not in _SORT_IMPL: + raise ValueError('%s should be one of %s' % + (direction, ', '.join(sorted(_SORT_IMPL.keys())))) + # Axis must be an integer, not a Tensor. + axis = framework_ops.convert_to_tensor(axis, name='axis') + axis_static = tensor_util.constant_value(axis) + if axis.shape.ndims != 0 or axis_static is None: + raise ValueError('axis must be a constant scalar') + axis_static = int(axis_static) # Avoids NumPy casting error + + values = framework_ops.convert_to_tensor(values, name='values') + + return _SORT_IMPL[direction](values, axis_static) + + +def _descending_sort(values, axis): + """Sorts values in reverse using `top_k`. + + Args: + values: Tensor of numeric values. + axis: Index of the axis which values should be sorted along. + + Returns: + The sorted values. + """ + k = array_ops.shape(values)[axis] + rank = array_ops.rank(values) + # Fast path: sorting the last axis. + if axis == -1 or axis + 1 == values.get_shape().ndims: + return nn_ops.top_k(values, k)[0] + + # Otherwise, transpose the array. Swap axes `axis` and `rank - 1`. + if axis < 0: + # Make axis a Tensor with the real axis index if needed. + axis += rank + transposition = array_ops.concat( + [ + # Axes up to axis are unchanged. + math_ops.range(axis), + # Swap axis and rank - 1. + [rank - 1], + # Axes in [axis + 1, rank - 1) are unchanged. + math_ops.range(axis + 1, rank - 1), + # Swap axis and rank - 1. + [axis] + ], + axis=0) + top_k_input = array_ops.transpose(values, transposition) + values, unused_indices = nn_ops.top_k(top_k_input, k) + # transposition contains a single cycle of length 2 (swapping 2 elements), + # so it is an involution (it is its own inverse). + return array_ops.transpose(values, transposition) + + +def _ascending_sort(values, axis): + # Negate the values to get the ascending order from descending sort. + values_or_indices = _descending_sort(-values, axis) + return -values_or_indices + + +_SORT_IMPL = { + 'ASCENDING': _ascending_sort, + 'DESCENDING': _descending_sort, +} diff --git a/tensorflow/contrib/framework/python/ops/sort_ops_test.py b/tensorflow/contrib/framework/python/ops/sort_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d08ae502f10d98ee14d8bea2f76b18bedb935cea --- /dev/null +++ b/tensorflow/contrib/framework/python/ops/sort_ops_test.py @@ -0,0 +1,95 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the sort wrapper.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.framework.python.ops import sort_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import test + + +class SortTest(test.TestCase): + + def testRandom_lowDimensionality(self): + self._testRandom_lowDimensionality(negative_axis=False) + + def testRandom_lowDimensionality_negative(self): + self._testRandom_lowDimensionality(negative_axis=True) + + def _testRandom_lowDimensionality(self, negative_axis): + np.random.seed(42) + for _ in range(20): + rank = np.random.randint(1, 3) + shape = [np.random.randint(0, 20) for _ in range(rank)] + arr = np.random.random(shape) + sort_axis = np.random.choice(rank) + if negative_axis: + sort_axis = -1 - sort_axis + with self.test_session(): + self.assertAllEqual( + np.sort(arr, axis=sort_axis), + sort_ops.sort(constant_op.constant(arr), axis=sort_axis).eval()) + + def testRandom_highDimensionality(self): + np.random.seed(100) + for _ in range(20): + rank = np.random.randint(5, 15) + shape = [np.random.randint(1, 4) for _ in range(rank)] + arr = np.random.random(shape) + sort_axis = np.random.choice(rank) + with self.test_session(): + self.assertAllEqual( + np.sort(arr, axis=sort_axis), + sort_ops.sort(constant_op.constant(arr), axis=sort_axis).eval()) + + def testScalar(self): + # Create an empty scalar where the static shape is unknown. + zeros_length_1 = array_ops.zeros( + random_ops.random_uniform([1], minval=0, maxval=1, dtype=dtypes.int32), + dtype=dtypes.int32) + scalar = array_ops.zeros(zeros_length_1) + + sort = sort_ops.sort(scalar) + with self.test_session(): + with self.assertRaises(errors.InvalidArgumentError): + sort.eval() + + def testNegativeOutOfBounds_staticShape(self): + arr = constant_op.constant([3, 4, 5]) + with self.assertRaises(ValueError): + sort_ops.sort(arr, axis=-4) + + def testDescending(self): + arr = np.random.random((10, 5, 5)) + with self.test_session(): + self.assertAllEqual( + np.sort(arr, axis=0)[::-1], + sort_ops.sort( + constant_op.constant(arr), + axis=0, + direction='DESCENDING').eval()) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py index 4c50d40aaa9b3c5d94d0a66d08e8ab6173db427a..db18ebf05d5fb98e28e767be7bcccdf992a56fd8 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py @@ -28,13 +28,14 @@ import six from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging # pylint: disable=g-multiple-import,g-bad-import-order from .pandas_io import HAS_PANDAS, extract_pandas_data, extract_pandas_matrix, extract_pandas_labels from .dask_io import HAS_DASK, extract_dask_data, extract_dask_labels - # pylint: enable=g-multiple-import,g-bad-import-order @@ -365,8 +366,13 @@ class DataFeeder(object): self.random_state = np.random.RandomState( 42) if random_state is None else random_state - num_samples = list(self._x.values())[0].shape[ - 0] if x_is_dict else self._x.shape[0] + if x_is_dict: + num_samples = list(self._x.values())[0].shape[0] + elif tensor_util.is_tensor(self._x): + num_samples = self._x.shape[0].value # shape will be a Dimension, extract an int + else: + num_samples = self._x.shape[0] + if self._shuffle: self.indices = self.random_state.permutation(num_samples) else: diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 33377a70c2506261b497c1b0fe8ab5ba0c680c7e..3dd1f1a627738a7e1f6eead8c8c0eaae237190a3 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -1982,7 +1982,7 @@ def streaming_sparse_precision_at_k(predictions, `predictions`, or if either `metrics_collections` or `updates_collections` are not a list or tuple. """ - return metrics.sparse_precision_at_k( + return metrics.precision_at_k( k=k, class_id=class_id, predictions=predictions, @@ -2323,7 +2323,7 @@ def streaming_sparse_average_precision_at_k(predictions, update: `Operation` that increments variables appropriately, and whose value matches `metric`. """ - return metrics.sparse_average_precision_at_k( + return metrics.average_precision_at_k( k=k, predictions=predictions, labels=labels, diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md index a8427e60144445c008d032ff2cbfd801d294974c..764e126e0d64d5e6c6caf0a9f0d43a87995447eb 100644 --- a/tensorflow/contrib/model_pruning/README.md +++ b/tensorflow/contrib/model_pruning/README.md @@ -20,7 +20,7 @@ conv = tf.nn.conv2d(images, pruning.apply_mask(weights), stride, padding) This creates a convolutional layer with additional variables mask and threshold as shown below: ![Convolutional layer with mask and -threshold](./mask.png "Convolutional layer with mask and threshold") +threshold](https://storage.googleapis.com/download.tensorflow.org/example_images/mask.png "Convolutional layer with mask and threshold") Alternatively, the API also provides variant of tensorflow layers with these auxiliary variables built-in (see @@ -37,82 +37,20 @@ auxiliary variables built-in (see The pruning library allows for specification of the following hyper parameters: -| Hyperparameter | Type | Default | Description | -| ---------------------------- | ------- | ------------- | -------------- | -| name | string | model_pruning | Name of the | -: : : : pruning : -: : : : specification. : -: : : : Used for : -: : : : adding : -: : : : summaries and : -: : : : ops under a : -: : : : common : -: : : : tensorflow : -: : : : name_scope : -| begin_pruning_step | integer | 0 | The global | -: : : : step at which : -: : : : to begin : -: : : : pruning : -| end_pruning_step | integer | -1 | The global | -: : : : step at which : -: : : : to terminate : -: : : : pruning. : -: : : : Defaults to -1 : -: : : : implying that : -: : : : pruning : -: : : : continues till : -: : : : the training : -: : : : stops : -| do_not_prune | list of | [""] | list of layers | -: : strings : : that are not : -: : : : pruned : -| threshold_decay | float | 0.9 | The decay | -: : : : factor to use : -: : : : for : -: : : : exponential : -: : : : decay of the : -: : : : thresholds : -| pruning_frequency | integer | 10 | How often | -: : : : should the : -: : : : masks be : -: : : : updated? (in # : -: : : : of : -: : : : global_steps). : -| nbins | integer | 255 | Number of bins | -: : : : to use for : -: : : : histogram : -: : : : computation : -| initial_sparsity | float | 0.0 | Initial | -: : : : sparsity value : -| target_sparsity | float | 0.5 | Target | -: : : : sparsity value : -| sparsity_function_begin_step | integer | 0 | The global | -: : : : step at this : -: : : : which the : -: : : : gradual : -: : : : sparsity : -: : : : function : -: : : : begins to take : -: : : : effect : -| sparsity_function_end_step | integer | 100 | The global | -: : : : step used as : -: : : : the end point : -: : : : for the : -: : : : gradual : -: : : : sparsity : -: : : : function : -| sparsity_function_exponent | float | 3.0 | exponent = 1 | -: : : : is linearly : -: : : : varying : -: : : : sparsity : -: : : : between : -: : : : initial and : -: : : : final. : -: : : : exponent > 1 : -: : : : varies more : -: : : : slowly towards : -: : : : the end than : -: : : : the beginning : +|Hyperparameter | Type | Default | Description | +|:----------------------------|:-------:|:-------------:|:--------------| +| name | string | model_pruning | Name of the pruning specification. Used for adding summaries and ops under a common tensorflow name_scope | +| begin_pruning_step | integer | 0 | The global step at which to begin pruning | +| end_pruning_step | integer | -1 | The global step at which to terminate pruning. Defaults to -1 implying that pruning continues till the training stops | +| do_not_prune | list of strings | [""] | list of layers strings that are not pruned | +| threshold_decay | float | 0.9 | The decay factor to use for exponential decay of the thresholds | +| pruning_frequency | integer | 10 | How often should the masks be updated? (in # of global_steps) | +| nbins | integer | 255 | Number of bins to use for histogram computation | +| initial_sparsity | float | 0.0 | Initial sparsity value | +| target_sparsity | float | 0.5 | Target sparsity value | +| sparsity_function_begin_step | integer | 0 | The global step at this which the gradual sparsity function begins to take effect | +| sparsity_function_end_step | integer | 100 | The global step used as the end point for the gradual sparsity function | +| sparsity_function_exponent | float | 3.0 | exponent = 1 is linearly varying sparsity between initial and final. exponent > 1 varies more slowly towards the end than the beginning | The sparsity $$s_t$$ at global step $$t$$ is given by: @@ -190,6 +128,3 @@ Eval: ```shell $ bazel-bin/$examples_dir/cifar10/cifar10_eval --run_once ``` - -TODO(suyoggupta): Add figures showing the sparsity function, sparsity for -different layers etc. diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 7e0e41477c9e70fcfcc0163a6348d3170fc43e73..5e85c125df8ca0d632fa9b0db86d942bb354631e 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -1362,24 +1362,25 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell): def output_size(self): return self._num_units - def _norm(self, inp, scope): + def _norm(self, inp, scope, dtype=dtypes.float32): shape = inp.get_shape()[-1:] gamma_init = init_ops.constant_initializer(self._norm_gain) beta_init = init_ops.constant_initializer(self._norm_shift) with vs.variable_scope(scope): # Initialize beta and gamma for use by layer_norm. - vs.get_variable("gamma", shape=shape, initializer=gamma_init) - vs.get_variable("beta", shape=shape, initializer=beta_init) + vs.get_variable("gamma", shape=shape, initializer=gamma_init, dtype=dtype) + vs.get_variable("beta", shape=shape, initializer=beta_init, dtype=dtype) normalized = layers.layer_norm(inp, reuse=True, scope=scope) return normalized def _linear(self, args): out_size = 4 * self._num_units proj_size = args.get_shape()[-1] - weights = vs.get_variable("kernel", [proj_size, out_size]) + dtype = args.dtype + weights = vs.get_variable("kernel", [proj_size, out_size], dtype=dtype) out = math_ops.matmul(args, weights) if not self._layer_norm: - bias = vs.get_variable("bias", [out_size]) + bias = vs.get_variable("bias", [out_size], dtype=dtype) out = nn_ops.bias_add(out, bias) return out @@ -1388,13 +1389,14 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell): c, h = state args = array_ops.concat([inputs, h], 1) concat = self._linear(args) + dtype = args.dtype i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) if self._layer_norm: - i = self._norm(i, "input") - j = self._norm(j, "transform") - f = self._norm(f, "forget") - o = self._norm(o, "output") + i = self._norm(i, "input", dtype=dtype) + j = self._norm(j, "transform", dtype=dtype) + f = self._norm(f, "forget", dtype=dtype) + o = self._norm(o, "output", dtype=dtype) g = self._activation(j) if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1: @@ -1403,7 +1405,7 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell): new_c = (c * math_ops.sigmoid(f + self._forget_bias) + math_ops.sigmoid(i) * g) if self._layer_norm: - new_c = self._norm(new_c, "state") + new_c = self._norm(new_c, "state", dtype=dtype) new_h = self._activation(new_c) * math_ops.sigmoid(o) new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py index ecfa6baeff35ce6b25185a686d528ec73f17c1ee..56e31985936c22d9b5d6c85fff067118152e220d 100644 --- a/tensorflow/contrib/summary/summary_ops.py +++ b/tensorflow/contrib/summary/summary_ops.py @@ -246,8 +246,8 @@ def image(name, tensor, bad_color=None, max_images=3, family=None): """Writes an image summary if possible.""" def function(tag, scope): - if bad_color is None: - bad_color_ = constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8) + bad_color_ = (constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8) + if bad_color is None else bad_color) # Note the identity to move the tensor to the CPU. return gen_summary_ops.write_image_summary( context.context().summary_writer_resource, diff --git a/tensorflow/contrib/tpu/profiler/BUILD b/tensorflow/contrib/tpu/profiler/BUILD index f6309e2e72f75a4ba5b323b4d7348c49555d522e..0e1fca3d3c8b6f3a19b3e989dbee1863475796c5 100644 --- a/tensorflow/contrib/tpu/profiler/BUILD +++ b/tensorflow/contrib/tpu/profiler/BUILD @@ -95,3 +95,10 @@ tf_proto_library_cc( cc_api_version = 2, visibility = ["//visibility:public"], ) + +tf_proto_library_cc( + name = "tf_op_stats_proto", + srcs = ["tf_op_stats.proto"], + cc_api_version = 2, + visibility = ["//visibility:public"], +) diff --git a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto new file mode 100644 index 0000000000000000000000000000000000000000..5b2dbb31243d401fbab31bab5bc86133896693fe --- /dev/null +++ b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto @@ -0,0 +1,127 @@ +// This proto describes the format of tensorflow operation level stats for +// profiling (in tensorboard) purpose. + +syntax = "proto2"; + +package tensorflow.tpu; + +// Result proto for OpMetrics. +message OpMetricsResult { + // True if this OP is executed on the device; False if it is executed on the + // host. + optional bool on_device = 1; + reserved 2; // was uint32 id. + // Name of this OP. + optional string name = 3; + // Rank of this OP. + optional uint64 rank = 4; + // The starting time in cycles of the last instance of this OP executed. + optional double last_starttime_in_cycles = 5; + // The ending time in cycles of the last instance of this OP executed. + optional double last_endtime_in_cycles = 6; + // If this OP (say A), is an immediate child of another OP (say B), this field + // stores the sum of duration in microseconds of A inside B. If A appears more + // than once in B, the duration of all A's appearances will be added together. + // This sum will be reset after the self-time of B is calculated so that it + // can be reused for a new parent OP. + optional double sum_of_duration_in_us_as_children = 7; + // Number of instances that this OP occurred. + optional uint64 occurrences = 8; + // Total time in microseconds spent in this OP (accumulated + // over all of its occurrences). + optional double total_time_in_us = 9; + // Total self time in microseconds spent in this OP + // (accumulated over all of its occurrences). + optional double total_self_time_in_us = 10; + // The total self time as a fraction of sum of all OP's + // total self time on the host. + optional double host_total_self_time_as_fraction_of_all_op_time = 11; + // Cumulative total self time in fraction on the host. + optional double host_cumulative_total_self_time_as_fraction_of_all_op_time = + 12; + // The total self time as a fraction of sum of all OP's + // total self time on the device. + optional double device_total_self_time_as_fraction_of_all_op_time = 13; + // Cumulative total self time in fraction on the device. + optional double device_cumulative_total_self_time_as_fraction_of_all_op_time = + 14; + // Total number of FLOPs incurred by this OP. + optional double total_flops = 15; + // Total time in microseconds that the MXU is occupied by this OP. + optional double total_bytes_accessed = 16; + // Total time in microseconds that the MXU is occupied by this OP. + optional double mxu_occupancy_in_us = 17; + // Total time in microseconds that the XU is occupied by this OP. + optional double xu_occupancy_in_us = 18; + // Total DMA access stall time in microseconds. + optional double total_dma_stall_in_us = 19; +} + +// Result proto for OpMetricsDb. +message OpMetricsDbResult { + // A bunch of OpMetricsResults. + repeated OpMetricsResult metrics_db = 1; +} + +// Result proto for StepInfo. +message StepInfoResult { + // The (micro) step number. + optional uint32 step_num = 1; + // The step duration in picoseconds. + optional uint64 duration_ps = 2; + // The infeed duration in picoseconds. + // Can turn into a map if we want a variable number of ops. + optional uint64 infeed_duration_ps = 3; +} + +// Result proto for a sequence of steps. +message StepSequenceResult { + // A sequence of StepInfoResults. + repeated StepInfoResult step_sequence = 1; +} + +// Result proto for a StepDatabase. +message StepDatabaseResult { + // A map from core_id to StepSequenceResult. + map step_sequence_per_core = 1; +} + +// Result proto for Dashboard data. +message DashboardResult { + // The total iteration time in nanoseconds. + optional double iteration_time_ns = 1; + // The total number of iterations. + optional int32 num_iterations = 2; + // The total computation time in nanoseconds. + optional double computation_time_ns = 3; + // The total number of computations. + optional int32 num_computations = 4; +} + +// Result proto for HloExtraInfo. +message HloExtraInfoResult { + // Category of the HLO op given by the compiler. + optional string category = 1; + // The long name of the HLO that includes the dimensions. + optional string long_name = 2; +} + +// Result proto for HloExtraInfoMap. +message HloExtraInfoMapResult { + // A map from HLO name to HloExtraInfo. + map hlo_extrainfo_map = 1; +} + +// Result proto for TfStatsHelper. +message TfOpStats { + // The result for the TF-metric database. + optional OpMetricsDbResult tf_metrics_db = 1; + // The result for the HLO-metric database. + optional OpMetricsDbResult hlo_metrics_db = 2; + // The result for the step database. + optional StepDatabaseResult step_db = 3; + // The result for the TPU dashboard. + optional DashboardResult dashboard = 4; + // The result for the HloExtraInfoMap. + optional HloExtraInfoMapResult hlo_extrainfo_map = 5; +} diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 5a3b8314291951b5dfce091dccb0dc9e5f7af3b5..060b3f912926fbaa56bc1150e50434a7ad22c847 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -66,7 +66,7 @@ _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum' _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY] # TODO(b/65703635): Flip the value and remove all dead code. -_WRAP_INPUT_FN_INTO_WHILE_LOOP = False +_WRAP_INPUT_FN_INTO_WHILE_LOOP = True def _create_global_step(graph): diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 7b535da0b2f751ef258580b678bec5022d671b82..9530af637ef953c293472d926281de77cf626752 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1414,16 +1414,19 @@ LIB_INTERNAL_PUBLIC_HEADERS = tf_additional_lib_hdrs() + [ "platform/tracing.h", ] +# Replicated for lib_internal and lib_internal_impl. +LIB_INTERNAL_DEFINES = (tf_additional_lib_defines() + [ + "TF_USE_SNAPPY", + ] + tf_additional_verbs_lib_defines() + + tf_additional_mpi_lib_defines() + + tf_additional_gdr_lib_defines()) + cc_library( name = "lib_internal", srcs = LIB_INTERNAL_PRIVATE_HEADERS, hdrs = LIB_INTERNAL_PUBLIC_HEADERS, copts = tf_copts(), - defines = tf_additional_lib_defines() + [ - "TF_USE_SNAPPY", - ] + tf_additional_verbs_lib_defines() + - tf_additional_mpi_lib_defines() + - tf_additional_gdr_lib_defines(), + defines = LIB_INTERNAL_DEFINES, linkopts = select({ "//tensorflow:freebsd": [], "//tensorflow:windows": [], @@ -1477,6 +1480,7 @@ cc_library( ), hdrs = LIB_INTERNAL_PUBLIC_HEADERS, copts = tf_copts(), + defines = LIB_INTERNAL_DEFINES, deps = tf_additional_lib_deps() + [ ":lib_hash_crc32c_accelerate_internal", ":lib_proto_parsing", diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc index ceeb172fa0a9abf2ab7adcfc801b4bcb5fa04381..d95d958d5afaad58bdec82183be3d3a09cf4605d 100644 --- a/tensorflow/core/api_def/api_test.cc +++ b/tensorflow/core/api_def/api_test.cc @@ -46,92 +46,218 @@ constexpr char kDefaultApiDefDir[] = "tensorflow/core/api_def/base_api"; constexpr char kOverridesFilePath[] = "tensorflow/cc/ops/op_gen_overrides.pbtxt"; -constexpr char kApiDefFileFormat[] = "api_def_%c.pbtxt"; -constexpr char kAlphabet[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; +constexpr char kApiDefFileFormat[] = "api_def_%s.pbtxt"; +constexpr char kApiDefFilePattern[] = "api_def_*.pbtxt"; -// Get map from first character to ApiDefs for ops -// that start with that character. -std::unordered_map GenerateApiDef( - const OpList& ops, const OpGenOverrides& overrides) { +void FillBaseApiDef(ApiDef* api_def, const OpDef& op) { + api_def->set_graph_op_name(op.name()); + // Add arg docs + for (auto& input_arg : op.input_arg()) { + if (!input_arg.description().empty()) { + auto* api_def_in_arg = api_def->add_in_arg(); + api_def_in_arg->set_name(input_arg.name()); + api_def_in_arg->set_description(input_arg.description()); + } + } + for (auto& output_arg : op.output_arg()) { + if (!output_arg.description().empty()) { + auto* api_def_out_arg = api_def->add_out_arg(); + api_def_out_arg->set_name(output_arg.name()); + api_def_out_arg->set_description(output_arg.description()); + } + } + // Add attr docs + for (auto& attr : op.attr()) { + if (!attr.description().empty()) { + auto* api_def_attr = api_def->add_attr(); + api_def_attr->set_name(attr.name()); + api_def_attr->set_description(attr.description()); + } + } + // Add docs + api_def->set_summary(op.summary()); + api_def->set_description(op.description()); +} + +// Checks if arg1 should be before arg2 according to ordering in args. +bool CheckArgBefore(const ApiDef::Arg* arg1, const ApiDef::Arg* arg2, + const protobuf::RepeatedPtrField& args) { + for (auto& arg : args) { + if (arg.name() == arg2->name()) { + return false; + } else if (arg.name() == arg1->name()) { + return true; + } + } + return false; +} + +// Checks if attr1 should be before attr2 according to ordering in op_def. +bool CheckAttrBefore(const ApiDef::Attr* attr1, const ApiDef::Attr* attr2, + const OpDef& op_def) { + for (auto& attr : op_def.attr()) { + if (attr.name() == attr2->name()) { + return false; + } else if (attr.name() == attr1->name()) { + return true; + } + } + return false; +} + +// Applies renames to args. +void ApplyArgOverrides( + protobuf::RepeatedPtrField* args, + const protobuf::RepeatedPtrField& renames, + const protobuf::RepeatedPtrField& op_args, + const string& op_name) { + for (auto& rename : renames) { + // First check if rename is valid. + bool valid = false; + for (const auto& op_arg : op_args) { + if (op_arg.name() == rename.from()) { + valid = true; + } + } + QCHECK(valid) << rename.from() << " is not a valid argument for " + << op_name; + bool found_arg = false; + // If Arg is already in ApiDef, just update it. + for (int i = 0; i < args->size(); ++i) { + auto* arg = args->Mutable(i); + if (arg->name() == rename.from()) { + arg->set_rename_to(rename.to()); + found_arg = true; + break; + } + } + if (!found_arg) { // not in ApiDef, add a new arg. + auto* new_arg = args->Add(); + new_arg->set_name(rename.from()); + new_arg->set_rename_to(rename.to()); + } + } + // We don't really need a specific order here right now. + // However, it is clearer if order follows OpDef. + std::sort(args->pointer_begin(), args->pointer_end(), + [&](ApiDef::Arg* arg1, ApiDef::Arg* arg2) { + return CheckArgBefore(arg1, arg2, op_args); + }); +} + +// Returns existing attribute with the given name if such +// attribute exists. Otherwise, adds a new attribute and returns it. +ApiDef::Attr* FindOrAddAttr(ApiDef* api_def, const string attr_name) { + // If Attr is already in ApiDef, just update it. + for (int i = 0; i < api_def->attr_size(); ++i) { + auto* attr = api_def->mutable_attr(i); + if (attr->name() == attr_name) { + return attr; + } + } + // Add a new Attr. + auto* new_attr = api_def->add_attr(); + new_attr->set_name(attr_name); + return new_attr; +} + +// Applies renames and default values to attributes. +void ApplyAttrOverrides(ApiDef* api_def, const OpGenOverride& op_override, + const OpDef& op_def) { + for (auto& attr_rename : op_override.attr_rename()) { + auto* attr = FindOrAddAttr(api_def, attr_rename.from()); + attr->set_rename_to(attr_rename.to()); + } + + for (auto& attr_default : op_override.attr_default()) { + auto* attr = FindOrAddAttr(api_def, attr_default.name()); + *(attr->mutable_default_value()) = attr_default.value(); + } + // We don't really need a specific order here right now. + // However, it is clearer if order follows OpDef. + std::sort(api_def->mutable_attr()->pointer_begin(), + api_def->mutable_attr()->pointer_end(), + [&](ApiDef::Attr* attr1, ApiDef::Attr* attr2) { + return CheckAttrBefore(attr1, attr2, op_def); + }); +} + +void ApplyOverridesToApiDef(ApiDef* api_def, const OpDef& op, + const OpGenOverride& op_override) { + // Fill ApiDef with data based on op and op_override. + // Set visibility + if (op_override.skip()) { + api_def->set_visibility(ApiDef_Visibility_SKIP); + } else if (op_override.hide()) { + api_def->set_visibility(ApiDef_Visibility_HIDDEN); + } + // Add endpoints + if (!op_override.rename_to().empty()) { + api_def->add_endpoint()->set_name(op_override.rename_to()); + } else if (!op_override.alias().empty()) { + api_def->add_endpoint()->set_name(op.name()); + } + + for (auto& alias : op_override.alias()) { + auto* endpoint = api_def->add_endpoint(); + endpoint->set_name(alias); + } + + ApplyArgOverrides(api_def->mutable_in_arg(), op_override.input_rename(), + op.input_arg(), api_def->graph_op_name()); + ApplyArgOverrides(api_def->mutable_out_arg(), op_override.output_rename(), + op.output_arg(), api_def->graph_op_name()); + ApplyAttrOverrides(api_def, op_override, op); +} + +// Get map from ApiDef file path to corresponding ApiDefs proto. +std::unordered_map GenerateApiDef( + const string& api_def_dir, const OpList& ops, + const OpGenOverrides& overrides) { std::unordered_map name_to_override; for (const auto& op_override : overrides.op()) { name_to_override[op_override.name()] = op_override; } - std::unordered_map api_defs_map; + std::unordered_map api_defs_map; for (const auto& op : ops.op()) { CHECK(!op.name().empty()) << "Encountered empty op name: %s" << op.DebugString(); - const char file_id = toupper(op.name()[0]); - CHECK(isalpha(file_id)) << "Unexpected op name: " << op.name(); - ApiDef* api_def = api_defs_map[file_id].add_op(); - api_def->set_graph_op_name(op.name()); + string file_path = io::JoinPath(api_def_dir, kApiDefFileFormat); + file_path = strings::Printf(file_path.c_str(), op.name().c_str()); + ApiDef* api_def = api_defs_map[file_path].add_op(); + FillBaseApiDef(api_def, op); if (name_to_override.find(op.name()) != name_to_override.end()) { - const auto& op_override = name_to_override[op.name()]; - // Set visibility - if (op_override.skip()) { - api_def->set_visibility(ApiDef_Visibility_SKIP); - } else if (op_override.hide()) { - api_def->set_visibility(ApiDef_Visibility_HIDDEN); - } - // Add endpoints - if (!op_override.rename_to().empty()) { - auto* endpoint = api_def->add_endpoint(); - endpoint->set_name(op_override.rename_to()); - } else { - auto* endpoint = api_def->add_endpoint(); - endpoint->set_name(op.name()); - } - for (auto& alias : op_override.alias()) { - auto* endpoint = api_def->add_endpoint(); - endpoint->set_name(alias); - } - // Add attributes - for (auto& attr : op.attr()) { - auto* api_def_attr = api_def->add_attr(); - api_def_attr->set_name(attr.name()); - for (auto& attr_override : op_override.attr_default()) { - if (attr.name() == attr_override.name()) { - *(api_def_attr->mutable_default_value()) = attr_override.value(); - } - } - for (auto& attr_rename : op_override.attr_rename()) { - if (attr.name() == attr_rename.from()) { - api_def_attr->set_rename_to(attr_rename.to()); - } - } - } - } else { - auto* endpoint = api_def->add_endpoint(); - endpoint->set_name(op.name()); + ApplyOverridesToApiDef(api_def, op, name_to_override[op.name()]); } - // Add docs - api_def->set_summary(op.summary()); - api_def->set_description(op.description()); } return api_defs_map; } -// Reads golden api defs file with the given suffix. -string GetGoldenApiDefsStr(Env* env, const string& api_files_dir, char suffix) { - string file_path = strings::Printf( - io::JoinPath(api_files_dir, kApiDefFileFormat).c_str(), suffix); - if (env->FileExists(file_path).ok()) { +// Reads golden ApiDef files and returns a map from file name to ApiDef file +// contents. +std::unordered_map GetGoldenApiDefs( + Env* env, const string& api_files_dir) { + std::vector matching_paths; + TF_CHECK_OK(env->GetMatchingPaths( + io::JoinPath(api_files_dir, kApiDefFilePattern), &matching_paths)); + + std::unordered_map file_path_to_api_def; + for (auto& file_path : matching_paths) { string file_contents; - TF_EXPECT_OK(ReadFileToString(env, file_path, &file_contents)); - return file_contents; + TF_CHECK_OK(ReadFileToString(env, file_path, &file_contents)); + file_path_to_api_def[file_path] = file_contents; } - return ""; + return file_path_to_api_def; } void RunApiTest(bool update_api_def, const string& api_files_dir) { // Read C++ overrides file - string overrides_file_contents; + OpGenOverrides overrides; Env* env = Env::Default(); - TF_EXPECT_OK( - ReadFileToString(env, kOverridesFilePath, &overrides_file_contents)); + TF_EXPECT_OK(ReadTextProto(env, kOverridesFilePath, &overrides)); // Read all ops OpList ops; @@ -139,29 +265,22 @@ void RunApiTest(bool update_api_def, const string& api_files_dir) { const std::vector multi_line_fields = {"description"}; // Get expected ApiDefs - OpGenOverrides overrides; - auto new_api_defs_map = GenerateApiDef(ops, overrides); + const auto new_api_defs_map = GenerateApiDef(api_files_dir, ops, overrides); bool updated_at_least_one_file = false; + const auto golden_api_defs_map = GetGoldenApiDefs(env, api_files_dir); - for (char c : kAlphabet) { - string golden_api_defs_str = GetGoldenApiDefsStr(env, api_files_dir, c); - string new_api_defs_str = new_api_defs_map[c].DebugString(); + for (auto new_api_entry : new_api_defs_map) { + const auto& file_path = new_api_entry.first; + const auto& golden_api_defs_str = golden_api_defs_map.at(file_path); + string new_api_defs_str = new_api_entry.second.DebugString(); new_api_defs_str = PBTxtToMultiline(new_api_defs_str, multi_line_fields); if (golden_api_defs_str == new_api_defs_str) { continue; } if (update_api_def) { - string output_file_path = - io::JoinPath(api_files_dir, strings::Printf(kApiDefFileFormat, c)); - if (new_api_defs_str.empty()) { - std::cout << "Deleting " << output_file_path << "..." << std::endl; - TF_EXPECT_OK(env->DeleteFile(output_file_path)); - } else { - std::cout << "Updating " << output_file_path << "..." << std::endl; - TF_EXPECT_OK( - WriteStringToFile(env, output_file_path, new_api_defs_str)); - } + std::cout << "Updating " << file_path << "..." << std::endl; + TF_EXPECT_OK(WriteStringToFile(env, file_path, new_api_defs_str)); updated_at_least_one_file = true; } else { EXPECT_EQ(golden_api_defs_str, new_api_defs_str) @@ -170,6 +289,21 @@ void RunApiTest(bool update_api_def, const string& api_files_dir) { } } + for (const auto& golden_api_entry : golden_api_defs_map) { + const auto& file_path = golden_api_entry.first; + if (new_api_defs_map.find(file_path) == new_api_defs_map.end()) { + if (update_api_def) { + std::cout << "Deleting " << file_path << "..." << std::endl; + TF_EXPECT_OK(env->DeleteFile(file_path)); + updated_at_least_one_file = true; + } else { + EXPECT_EQ("", golden_api_entry.second) + << "To update golden API files, run " + << "tensorflow/core/api_def/update_api_def.sh."; + } + } + } + if (update_api_def && !updated_at_least_one_file) { std::cout << "Api def files are already up to date." << std::endl; } diff --git a/tensorflow/core/api_def/base_api/api_def_A.pbtxt b/tensorflow/core/api_def/base_api/api_def_A.pbtxt deleted file mode 100644 index 8193d1bc624535c7894430284686e8664fb71a2d..0000000000000000000000000000000000000000 --- a/tensorflow/core/api_def/base_api/api_def_A.pbtxt +++ /dev/null @@ -1,670 +0,0 @@ -op { - graph_op_name: "Abort" - endpoint { - name: "Abort" - } - summary: "Raise a exception to abort the process when called." - description: <= 2." -} -op { - graph_op_name: "AdjustContrastv2" - endpoint { - name: "AdjustContrastv2" - } - summary: "Adjust the contrast of one or more images." - description: < [2.0132, 1.056] -``` - -@compatibility(numpy) -Equivalent to np.angle. -@end_compatibility -END -} -op { - graph_op_name: "Any" - endpoint { - name: "Any" - } - summary: "Computes the \"logical or\" of elements across dimensions of a tensor." - description: < l1 else 0.0 -accum = accum_new -END -} -op { - graph_op_name: "ApplyFtrlV2" - endpoint { - name: "ApplyFtrlV2" - } - summary: "Update \'*var\' according to the Ftrl-proximal scheme." - description: < l1 else 0.0 -accum = accum_new -END -} -op { - graph_op_name: "ApplyGradientDescent" - endpoint { - name: "ApplyGradientDescent" - } - summary: "Update \'*var\' by subtracting \'alpha\' * \'delta\' from it." -} -op { - graph_op_name: "ApplyMomentum" - endpoint { - name: "ApplyMomentum" - } - summary: "Update \'*var\' according to the momentum scheme. Set use_nesterov = True if you" - description: <= 2." +} diff --git a/tensorflow/core/api_def/base_api/api_def_AdjustContrastv2.pbtxt b/tensorflow/core/api_def/base_api/api_def_AdjustContrastv2.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..429a5e4434e011d1ba43847b9abf8877b4d41e7a --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_AdjustContrastv2.pbtxt @@ -0,0 +1,36 @@ +op { + graph_op_name: "AdjustContrastv2" + endpoint { + name: "AdjustContrast" + } + in_arg { + name: "images" + description: < [2.0132, 1.056] +``` + +@compatibility(numpy) +Equivalent to np.angle. +@end_compatibility +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_Any.pbtxt b/tensorflow/core/api_def/base_api/api_def_Any.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..09fd4e0b6036447dfe355ff56da29e276de62f2b --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_Any.pbtxt @@ -0,0 +1,42 @@ +op { + graph_op_name: "Any" + endpoint { + name: "Any" + } + endpoint { + name: "ReduceAny" + } + in_arg { + name: "input" + description: < l1 else 0.0 +accum = accum_new +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_ApplyFtrlV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_ApplyFtrlV2.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..974f3adc196129f9fe83d098c22dc3cd237263d6 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ApplyFtrlV2.pbtxt @@ -0,0 +1,75 @@ +op { + graph_op_name: "ApplyFtrlV2" + in_arg { + name: "var" + description: < l1 else 0.0 +accum = accum_new +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_ApplyGradientDescent.pbtxt b/tensorflow/core/api_def/base_api/api_def_ApplyGradientDescent.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..2f38ebd1b8c89a1a65368d3da38cead73225ada5 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ApplyGradientDescent.pbtxt @@ -0,0 +1,35 @@ +op { + graph_op_name: "ApplyGradientDescent" + in_arg { + name: "var" + description: < -1. +END + } + attr { + name: "scientific" + description: < -1. +END + } + attr { + name: "fill" + description: < -1. If empty, pads with spaces. +Another typical value is '0'. String cannot be longer than 1 character. +END + } + summary: "Converts each entry in the given tensor to strings. Supports many numeric" + description: <