@ -2712,9 +2712,12 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
" FLASH_ATTN " ,
" FLASH_FF " ,
" MAP_UNARY " ,
" MAP_BINARY " ,
} ;
static_assert ( GGML_OP_COUNT = = 36 , " GGML_OP_COUNT != 36 " ) ;
static_assert ( GGML_OP_COUNT = = 3 8, " GGML_OP_COUNT != 38 " ) ;
static const char * GGML_OP_SYMBOL [ GGML_OP_COUNT ] = {
" none " ,
@ -2757,9 +2760,12 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
" flash_attn(x) " ,
" flash_ff(x) " ,
" f(x) " ,
" f(x,y) " ,
} ;
static_assert ( GGML_OP_COUNT = = 36 , " GGML_OP_COUNT != 36 " ) ;
static_assert ( GGML_OP_COUNT = = 3 8, " GGML_OP_COUNT != 38 " ) ;
static_assert ( sizeof ( struct ggml_object ) % GGML_MEM_ALIGN = = 0 , " ggml_object size must be a multiple of GGML_MEM_ALIGN " ) ;
static_assert ( sizeof ( struct ggml_tensor ) % GGML_MEM_ALIGN = = 0 , " ggml_tensor size must be a multiple of GGML_MEM_ALIGN " ) ;
@ -4907,6 +4913,90 @@ struct ggml_tensor * ggml_flash_ff(
return result ;
}
// ggml_map_unary
struct ggml_tensor * ggml_map_unary_impl_f32 (
struct ggml_context * ctx ,
struct ggml_tensor * a ,
const ggml_unary_op_f32_t fun ,
bool inplace ) {
bool is_node = false ;
if ( ! inplace & & a - > grad ) {
is_node = true ;
}
struct ggml_tensor * addr_tensor = ggml_new_tensor_1d ( ctx , GGML_TYPE_I32 , sizeof ( void * ) / sizeof ( int32_t ) ) ;
* ( ( void ( * * ) ( void ) ) addr_tensor - > data ) = ( void ( * ) ( void ) ) fun ;
struct ggml_tensor * result = inplace ? ggml_view_tensor ( ctx , a ) : ggml_dup_tensor ( ctx , a ) ;
result - > op = GGML_OP_MAP_UNARY ;
result - > grad = is_node ? ggml_dup_tensor ( ctx , result ) : NULL ;
result - > src0 = a ;
result - > opt [ 0 ] = addr_tensor ;
return result ;
}
struct ggml_tensor * ggml_map_unary_f32 (
struct ggml_context * ctx ,
struct ggml_tensor * a ,
const ggml_unary_op_f32_t fun ) {
return ggml_map_unary_impl_f32 ( ctx , a , fun , false ) ;
}
struct ggml_tensor * ggml_map_unary_inplace_f32 (
struct ggml_context * ctx ,
struct ggml_tensor * a ,
const ggml_unary_op_f32_t fun ) {
return ggml_map_unary_impl_f32 ( ctx , a , fun , true ) ;
}
// ggml_map_binary
struct ggml_tensor * ggml_map_binary_impl_f32 (
struct ggml_context * ctx ,
struct ggml_tensor * a ,
struct ggml_tensor * b ,
const ggml_binary_op_f32_t fun ,
bool inplace ) {
GGML_ASSERT ( ggml_are_same_shape ( a , b ) ) ;
bool is_node = false ;
if ( ! inplace & & ( a - > grad | | b - > grad ) ) {
is_node = true ;
}
struct ggml_tensor * addr_tensor = ggml_new_tensor_1d ( ctx , GGML_TYPE_I32 , sizeof ( void * ) / sizeof ( int32_t ) ) ;
* ( ( void ( * * ) ( void ) ) addr_tensor - > data ) = ( void ( * ) ( void ) ) fun ;
struct ggml_tensor * result = inplace ? ggml_view_tensor ( ctx , a ) : ggml_dup_tensor ( ctx , a ) ;
result - > op = GGML_OP_MAP_BINARY ;
result - > grad = is_node ? ggml_dup_tensor ( ctx , result ) : NULL ;
result - > src0 = a ;
result - > src1 = b ;
result - > opt [ 0 ] = addr_tensor ;
return result ;
}
struct ggml_tensor * ggml_map_binary_f32 (
struct ggml_context * ctx ,
struct ggml_tensor * a ,
struct ggml_tensor * b ,
const ggml_binary_op_f32_t fun ) {
return ggml_map_binary_impl_f32 ( ctx , a , b , fun , false ) ;
}
struct ggml_tensor * ggml_map_binary_inplace_f32 (
struct ggml_context * ctx ,
struct ggml_tensor * a ,
struct ggml_tensor * b ,
const ggml_binary_op_f32_t fun ) {
return ggml_map_binary_impl_f32 ( ctx , a , b , fun , true ) ;
}
////////////////////////////////////////////////////////////////////////////////
void ggml_set_param (
@ -8875,6 +8965,111 @@ static void ggml_compute_forward_flash_ff(
}
}
// ggml_compute_forward_map_unary
static void ggml_compute_forward_map_unary_f32 (
const struct ggml_compute_params * params ,
const struct ggml_tensor * src0 ,
struct ggml_tensor * dst ,
const ggml_unary_op_f32_t fun ) {
GGML_ASSERT ( ggml_are_same_shape ( src0 , dst ) ) ;
if ( params - > type = = GGML_TASK_INIT | | params - > type = = GGML_TASK_FINALIZE ) {
return ;
}
const int n = ggml_nrows ( src0 ) ;
const int nc = src0 - > ne [ 0 ] ;
assert ( dst - > nb [ 0 ] = = sizeof ( float ) ) ;
assert ( src0 - > nb [ 0 ] = = sizeof ( float ) ) ;
for ( int i = 0 ; i < n ; i + + ) {
fun ( nc ,
( float * ) ( ( char * ) dst - > data + i * ( dst - > nb [ 1 ] ) ) ,
( float * ) ( ( char * ) src0 - > data + i * ( src0 - > nb [ 1 ] ) ) ) ;
}
}
static void ggml_compute_forward_map_unary (
const struct ggml_compute_params * params ,
const struct ggml_tensor * src0 ,
struct ggml_tensor * dst ,
const ggml_unary_op_f32_t fun ) {
switch ( src0 - > type ) {
case GGML_TYPE_F32 :
{
ggml_compute_forward_map_unary_f32 ( params , src0 , dst , fun ) ;
} break ;
case GGML_TYPE_Q4_0 :
case GGML_TYPE_Q4_1 :
case GGML_TYPE_I8 :
case GGML_TYPE_I16 :
case GGML_TYPE_I32 :
case GGML_TYPE_F16 :
case GGML_TYPE_COUNT :
{
GGML_ASSERT ( false ) ;
} break ;
}
}
// ggml_compute_forward_map_binary
static void ggml_compute_forward_map_binary_f32 (
const struct ggml_compute_params * params ,
const struct ggml_tensor * src0 ,
const struct ggml_tensor * src1 ,
struct ggml_tensor * dst ,
const ggml_binary_op_f32_t fun ) {
assert ( params - > ith = = 0 ) ;
assert ( ggml_are_same_shape ( src0 , src1 ) & & ggml_are_same_shape ( src0 , dst ) ) ;
if ( params - > type = = GGML_TASK_INIT | | params - > type = = GGML_TASK_FINALIZE ) {
return ;
}
const int n = ggml_nrows ( src0 ) ;
const int nc = src0 - > ne [ 0 ] ;
assert ( dst - > nb [ 0 ] = = sizeof ( float ) ) ;
assert ( src0 - > nb [ 0 ] = = sizeof ( float ) ) ;
assert ( src1 - > nb [ 0 ] = = sizeof ( float ) ) ;
for ( int i = 0 ; i < n ; i + + ) {
fun ( nc ,
( float * ) ( ( char * ) dst - > data + i * ( dst - > nb [ 1 ] ) ) ,
( float * ) ( ( char * ) src0 - > data + i * ( src0 - > nb [ 1 ] ) ) ,
( float * ) ( ( char * ) src1 - > data + i * ( src1 - > nb [ 1 ] ) ) ) ;
}
}
static void ggml_compute_forward_map_binary (
const struct ggml_compute_params * params ,
const struct ggml_tensor * src0 ,
const struct ggml_tensor * src1 ,
struct ggml_tensor * dst ,
const ggml_binary_op_f32_t fun ) {
switch ( src0 - > type ) {
case GGML_TYPE_F32 :
{
ggml_compute_forward_map_binary_f32 ( params , src0 , src1 , dst , fun ) ;
} break ;
case GGML_TYPE_Q4_0 :
case GGML_TYPE_Q4_1 :
case GGML_TYPE_I8 :
case GGML_TYPE_I16 :
case GGML_TYPE_I32 :
case GGML_TYPE_F16 :
case GGML_TYPE_COUNT :
{
GGML_ASSERT ( false ) ;
} break ;
}
}
/////////////////////////////////
static void ggml_compute_forward ( struct ggml_compute_params * params , struct ggml_tensor * tensor ) {
@ -9024,6 +9219,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_flash_ff ( params , tensor - > src0 , tensor - > src1 , tensor - > opt [ 0 ] , tensor - > opt [ 1 ] , tensor - > opt [ 2 ] , tensor ) ;
} break ;
case GGML_OP_MAP_UNARY :
{
const ggml_unary_op_f32_t fun = * ( ( ggml_unary_op_f32_t * ) tensor - > opt [ 0 ] - > data ) ;
ggml_compute_forward_map_unary ( params , tensor - > src0 , tensor , fun ) ;
}
break ;
case GGML_OP_MAP_BINARY :
{
const ggml_binary_op_f32_t fun = * ( ( ggml_binary_op_f32_t * ) tensor - > opt [ 0 ] - > data ) ;
ggml_compute_forward_map_binary ( params , tensor - > src0 , tensor - > src1 , tensor , fun ) ;
}
break ;
case GGML_OP_NONE :
{
// nop
@ -9283,6 +9490,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
GGML_ASSERT ( false ) ; // not supported
} break ;
case GGML_OP_MAP_UNARY :
case GGML_OP_MAP_BINARY :
{
GGML_ASSERT ( false ) ; // not supported
} break ;
case GGML_OP_NONE :
{
// nop
@ -9775,6 +9987,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
work_size = MAX ( work_size , cur ) ;
} break ;
case GGML_OP_MAP_UNARY :
case GGML_OP_MAP_BINARY :
{
node - > n_tasks = 1 ;
} break ;
case GGML_OP_NONE :
{
node - > n_tasks = 1 ;